In [1]:
import torch
from transformers import ElectraConfig, ElectraForSequenceClassification

# Load the configuration from json file
config = ElectraConfig.from_json_file('discriminator.json')
config.num_labels = 1  # Adjust the number of output labels
model = ElectraForSequenceClassification(config)


# Initialize the model
model = ElectraForSequenceClassification(config)

# Load the model weights
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
pretrained_weights = torch.load('discriminator.pth', map_location=device)

# Create a new state dict with the weights for the sequence classification model
sequence_classification_weights = {k: v for k, v in pretrained_weights.items() if k.startswith("electra.")}

# Add the classifier weights manually
sequence_classification_weights["classifier.dense.weight"] = pretrained_weights["discriminator_predictions.dense.weight"]
sequence_classification_weights["classifier.dense.bias"] = pretrained_weights["discriminator_predictions.dense.bias"]
sequence_classification_weights["classifier.out_proj.weight"] = pretrained_weights["discriminator_predictions.dense_prediction.weight"]
sequence_classification_weights["classifier.out_proj.bias"] = pretrained_weights["discriminator_predictions.dense_prediction.bias"]

# Load the sequence classification weights into the model
model.load_state_dict(sequence_classification_weights)

# Move the model to the correct device
model = model.to(device)


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from genomicDataset import GenomicsDataset
from torch.utils.data import DataLoader, random_split
# Create the directory-label dictionary
dir_label_dict = {
    "../preprocessing/output/atacseq_only": 0,
    "../preprocessing/output/chipseq_only": 1,
    "../preprocessing/output/intersecting": 1,
    "../preprocessing/output/non_intersecting": 0,
}

min_val, max_val = 0, 366.0038259577389

# Create the Dataset
dataset = GenomicsDataset(dir_label_dict, min_val, max_val)

# Determine the lengths of splits
total_samples = len(dataset)
train_len = int(0.7 * total_samples)
valid_len = int(0.15 * total_samples)
test_len = total_samples - train_len - valid_len

# Create the random splits
train_data, valid_data, test_data = random_split(dataset, lengths=[train_len, valid_len, test_len])

# Create DataLoaders for each split
train_loader = DataLoader(train_data, batch_size=16, shuffle=True)
valid_loader = DataLoader(valid_data, batch_size=16, shuffle=True)
test_loader = DataLoader(test_data, batch_size=16, shuffle=True)



In [3]:
from transformers import AdamW
from torch.nn import BCEWithLogitsLoss
import torch.nn.functional as F

# Set the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize optimizer
optimizer = AdamW(model.parameters(), lr=1e-5)

# Specify loss function
criterion = BCEWithLogitsLoss()

# Number of training epochs
epochs = 20

# Training loop

for epoch in range(epochs):
    # Set model to training mode
    model.train()

    # Initialize running loss and accuracy
    running_loss = 0.0
    total_correct = 0
    total_samples = 0

    # Iterate over batches in the training data loader
    for i, data in enumerate(train_loader):

        # Move tensors to the device
        input_ids = data['input_ids'].to(device)
        position_ids = data['position_ids'].to(device)
        labels = data['labels'].to(device)
        chromosome = data['chromosome'].to(device)
        reads = data['reads'].to(device)

        # Clear gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = model(input_ids, position_ids=position_ids, reads=reads, chromosome=chromosome)

        # Calculate loss
        logits = outputs.logits.squeeze()
        loss = criterion(logits, labels.float())

        # Get the predictions
        preds = (torch.sigmoid(logits) > 0.5).long()

        # Calculate the number of correct predictions
        correct = (preds == labels).sum().item()
        total_correct += correct
        total_samples += labels.size(0)

        # Backward pass
        loss.backward()

        # Update weights
        optimizer.step()

        # Update running loss
        running_loss += loss.item()

    # Calculate average loss and accuracy over the training epoch
    avg_train_loss = running_loss / len(train_loader)
    train_accuracy = total_correct / total_samples

    print(f"Training Epoch: {epoch+1}/{epochs}, Loss: {avg_train_loss:.6f}, Accuracy: {train_accuracy:.2f}")

    # Validation phase
    model.eval()
    running_valid_loss = 0.0
    total_valid_correct = 0
    total_valid_samples = 0

    with torch.no_grad(): # No gradient calculation for validation, saves memory
        for i, data in enumerate(valid_loader):
            # Move tensors to the device
            input_ids = data['input_ids'].to(device)
            position_ids = data['position_ids'].to(device)
            labels = data['labels'].to(device)
            chromosome = data['chromosome'].to(device)
            reads = data['reads'].to(device)

            # Forward pass
            outputs = model(input_ids, position_ids=position_ids, reads=reads, chromosome=chromosome, nlahtdj="blgajobi")

            # Calculate loss
            logits = outputs.logits.squeeze()
            loss = criterion(logits, labels.float())

            # Get the predictions
            preds = (torch.sigmoid(logits) > 0.5).long()

            # Calculate the number of correct predictions
            correct = (preds == labels).sum().item()
            total_valid_correct += correct
            total_valid_samples += labels.size(0)

            # Update running loss
            running_valid_loss += loss.item()

        # Calculate average loss and accuracy over the validation epoch
        avg_valid_loss = running_valid_loss / len(valid_loader)
        valid_accuracy = total_valid_correct / total_valid_samples

        print(f"Validation Epoch: {epoch+1}/{epochs}, Loss: {avg_valid_loss:.6f}, Accuracy: {valid_accuracy:.2f}")






reads torch.Size([16, 512])
reads torch.Size([16, 512])
reads torch.Size([16, 512])
reads torch.Size([16, 512])
reads torch.Size([16, 512])
reads torch.Size([16, 512])
reads torch.Size([16, 512])
reads torch.Size([16, 512])
reads torch.Size([16, 512])
reads torch.Size([16, 512])
reads torch.Size([16, 512])
reads torch.Size([16, 512])
reads torch.Size([16, 512])


KeyboardInterrupt: 

In [None]:
# Testing phase
model.eval()
running_test_loss = 0.0
total_test_correct = 0
total_test_samples = 0

with torch.no_grad():  # No gradient calculation for testing, saves memory
    for i, data in enumerate(test_loader):
        # Move tensors to the device
        input_ids = data['input_ids'].to(device)
        position_ids = data['position_ids'].to(device)
        labels = data['labels'].to(device)
        chromosome = data['chromosome'].to(device)
        reads = data['reads'].to(device)

        # Forward pass
        outputs = model(input_ids, position_ids=position_ids)

        # Calculate loss
        logits = outputs.logits.squeeze()
        loss = criterion(logits, labels.float())

        # Get the predictions
        preds = (torch.sigmoid(logits) > 0.5).long()

        # Calculate the number of correct predictions
        correct = (preds == labels).sum().item()
        total_test_correct += correct
        total_test_samples += labels.size(0)

        # Update running loss
        running_test_loss += loss.item()

    # Calculate average loss and accuracy over the testing set
    avg_test_loss = running_test_loss / len(test_loader)
    test_accuracy = total_test_correct / total_test_samples

    print(f"Testing Loss: {avg_test_loss:.6f}, Accuracy: {test_accuracy:.2f}")


Testing Loss: 0.816694, Accuracy: 0.63
