In [2]:
import sys
sys.path.insert(0, '..')

In [3]:
import os
import subprocess
import numpy as np
import os
import subprocess
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torch.optim as optim
import matplotlib.pyplot as plt
from sklearn.metrics import roc_auc_score, roc_curve

from src.utils.synthetic_seqdata import *
from src.utils.datasets import DNASequenceDataset
from src.models.deepbind_cnn import BasicCNN
from src.trainer import Trainer

In [6]:
# download and load the data
from src.utils.synthetic_seqdata import download_data, load_data
savedir = "./data"
download_data(savedir)
Xs, Ys = load_data(savedir=savedir)

--2023-07-26 09:33:57--  https://www.dropbox.com/s/drnyowfdv1lbjz6/train_sequences.txt
Resolving www.dropbox.com (www.dropbox.com)... 162.125.4.18, 2620:100:6019:18::a27d:412
Connecting to www.dropbox.com (www.dropbox.com)|162.125.4.18|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: /s/raw/drnyowfdv1lbjz6/train_sequences.txt [following]
--2023-07-26 09:33:57--  https://www.dropbox.com/s/raw/drnyowfdv1lbjz6/train_sequences.txt
Reusing existing connection to www.dropbox.com:443.
HTTP request sent, awaiting response... 302 Found
Location: https://ucb40e74a0319073ea4d53404482.dl.dropboxusercontent.com/cd/0/inline/CAmCSM68D8Zanst7J4xdSiy-BTp95s8unNQ_W0PaIzJzKl5MVMfWDFDo1kqfVQQ05eJhuzzNSirVLtu5QvLOvbjTTjQqX26i9QLV2ZsFZnanRbiF9SrVbo6Mpp1O6ANz5sM/file# [following]
--2023-07-26 09:33:57--  https://ucb40e74a0319073ea4d53404482.dl.dropboxusercontent.com/cd/0/inline/CAmCSM68D8Zanst7J4xdSiy-BTp95s8unNQ_W0PaIzJzKl5MVMfWDFDo1kqfVQQ05eJhuzzNSirVLtu5QvLOvbjTTjQqX26i9QLV2Z

In [8]:
len(Xs['train']), Ys, len(Ys['train']), len(Ys['valid']), len(Ys['test'])

(14000,
 {'train': array([1., 1., 0., ..., 1., 0., 1.], dtype=float32),
  'valid': array([0., 1., 1., ..., 0., 0., 1.], dtype=float32),
  'test': array([0., 1., 0., ..., 1., 1., 1.], dtype=float32)},
 14000,
 2000,
 4000)

In [10]:
config = {
    "batch_size": 32,
    "learning_rate": 0.001,
    "architecture": "deepbind",
    "dataset": "synthetic data",
    "epochs": 35,
    "patience": 3,
    }

In [11]:
# create datasets and data loaders
alphabet = "ACGT"
train_dataset = DNASequenceDataset(Xs["train"], Ys["train"], alphabet=alphabet)
valid_dataset = DNASequenceDataset(Xs["valid"], Ys["valid"], alphabet=alphabet)
test_dataset = DNASequenceDataset(Xs["test"], Ys["test"], alphabet=alphabet)

loaders = {
    'train': DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True),
    'valid': DataLoader(valid_dataset, batch_size=config['batch_size']),
    'test': DataLoader(test_dataset, batch_size=config['batch_size']),
}

In [None]:
len(loaders['train']), len(loaders['valid']), len(loaders['test'])

In [12]:
# import wandb
# # start a new wandb run to track this script
# wandb.init(
#     # set the wandb project where this run will be logged
#     project="deepbind_cnn_synthetic1",
    
#     # track hyperparameters and run metadata
#     config=config
# )


In [None]:
# Prepare data and model
input_size = 4  # One-hot encoding of DNA bases A, C, G, T
output_size = 1  # Single output for regression, modify for multi-class classification
learning_rate = 0.001
num_epochs = 35
patience = 3  # Number of epochs to wait before early stopping
current_patience = 0

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = BasicCNN(input_size, output_size).to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

In [None]:
print(model)

In [None]:
trainer = Trainer(config=config, 
                  device=device, 
                  model=model, 
                  criterion=criterion, 
                  optimizer=optimizer, 
                  loaders=loaders,
                  wandb=wandb)

trainer.train()

In [None]:
# Training function
def train(model, loader, criterion, optimizer):
    model.train()
    running_loss = 0.0
    for inputs, labels in loader:
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs.squeeze(), labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * inputs.size(0)
    return running_loss / len(loader.dataset)

# Validation function
def evaluate(model, loader, criterion):
    model.eval()
    running_loss = 0.0
    with torch.no_grad():
        for inputs, labels in loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs.squeeze(), labels)
            running_loss += loss.item() * inputs.size(0)
    return running_loss / len(loader.dataset)

In [None]:
# Training loop with validation
train_losses = []
valid_losses = []
# Training loop
for epoch in range(num_epochs):
    best_valid_loss = float('inf')
    train_loss = train(model, loaders['train'], criterion, optimizer)
    valid_loss = evaluate(model, loaders['valid'], criterion)
    train_losses.append(train_loss)
    valid_losses.append(valid_loss)
    wandb.log({"train_loss": train_loss, "valid_loss": valid_loss})
    print(f"Epoch [{epoch+1}/{num_epochs}] - Train Loss: {train_loss:.4f}, Valid Loss: {valid_loss:.4f}")

    # Early stopping
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        current_patience = 0
        # Save the best model if you want
        torch.save(model.state_dict(), "best_model.pt")
    else:
        current_patience += 1
        if current_patience >= patience:
            print("Early stopping! Validation loss hasn't improved in the last", patience, "epochs.")
            break

# Test the model on the test set
test_loss = evaluate(model, loaders['test'], criterion)
print(f"Test Loss: {test_loss:.4f}")


In [None]:
# Calculate AUROC for test predictions
model.eval()
test_scores = []
test_labels = []
with torch.no_grad():
    for inputs, labels in loaders['test']:
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = model(inputs)
        test_scores.extend(outputs.cpu().squeeze().tolist())
        test_labels.extend(labels.cpu().tolist())

# Plot training/validation loss
plt.figure(figsize=(8, 6))
plt.plot(range(1, len(train_losses)+1), train_losses, label='Training Loss')
plt.plot(range(1, len(valid_losses)+1), valid_losses, label='Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.show()

In [None]:
test_auroc = roc_auc_score(test_labels, test_scores)
print(f"Test AUROC: {test_auroc:.4f}")

# Plot ROC curve
fpr, tpr, thresholds = roc_curve(test_labels, test_scores)
plt.figure(figsize=(8, 6))
plt.plot(fpr, tpr, color='b')
plt.plot([0, 1], [0, 1], color='gray', linestyle='--')
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC Curve')
plt.show()