# Train CNN Classifier on human_ocr_ensembl dataset

The dataset comes from the [Genomic Benchmarks](https://github.com/ML-Bioinfo-CEITEC/genomic_benchmarks). Best reaults achieved are reported in these [tables](https://github.com/ML-Bioinfo-CEITEC/genomic_benchmarks/tree/main/experiments)

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset
from genomic_benchmarks.data_check import info
import optuna

In [2]:
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

## Get dataset

In [3]:
info("human_enhancers_cohn", 0)

Dataset `human_enhancers_cohn` has 2 classes: negative, positive.

All lengths of genomic intervals equals 500.

Totally 27791 sequences have been found, 20843 for training and 6948 for testing.


Unnamed: 0,train,test
negative,10422,3474
positive,10421,3474


In [4]:
dataset = load_dataset("katarinagresova/Genomic_Benchmarks_human_enhancers_cohn")

Downloading readme:   0%|          | 0.00/477 [00:00<?, ?B/s]

Using custom data configuration katarinagresova--Genomic_Benchmarks_human_enhancers_cohn-678f4cb48bca8240


Downloading and preparing dataset None/None to /home/.cache/huggingface/datasets/katarinagresova___parquet/katarinagresova--Genomic_Benchmarks_human_enhancers_cohn-678f4cb48bca8240/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec...


Downloading data files:   0%|          | 0/2 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/4.99M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/1.66M [00:00<?, ?B/s]

  

Extracting data files #1:   0%|          | 0/1 [00:00<?, ?obj/s]

Extracting data files #0:   0%|          | 0/1 [00:00<?, ?obj/s]

Generating train split:   0%|          | 0/20843 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/6948 [00:00<?, ? examples/s]

Dataset parquet downloaded and prepared to /home/.cache/huggingface/datasets/katarinagresova___parquet/katarinagresova--Genomic_Benchmarks_human_enhancers_cohn-678f4cb48bca8240/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec. Subsequent calls will reuse this data.


  0%|          | 0/2 [00:00<?, ?it/s]

In [5]:
dataset['train'][0]

{'seq': 'TGGTGGTACTTGTCAGGACTTGGAGCAGCAGGTGCAAGATTTAGTGGGTTGGTTTTAGAATATCTGCTTGGAAAGTGGAAAAACTCAATGGATCATCTAGACTTTGGAATTTATCTCCTTCCCCACTTCTCCACTCCCCCAACAACAACAACAACAACAATGACAACAAAAACACCTGGAATAAACAGGTCATACAACGAGGTAGTTGATAGAATAATGTACTTTCCTTTCAGGCACCCCTTGGAGGAGGCAGATTCTGCCCTTTAAGCTGAATCTGCCTTTCCTGCATTTCCTGAAACTCCTGCATTTCCTGAAATCTTCCTGTATTTTCCTGAAATTTCCTGCCATTCCTGAAACTTTAAGGTAACTGTGTCATTAAAGGAAGGAGAGAAGGGAAGTATTAGGACTGCAGATTTGGGGTGCATGATCAGCCTGGCTCTGAGCTTGCAGACTCCCAGAGTCAGGGAAGGGAGGAGCCACCAGCAACCTTGTGGCTTACT',
 'label': 0}

## Encode and split dataset

In [6]:
def one_hot_encode(sequence, max_length=500):
    one_hot = torch.zeros((4, max_length), dtype=torch.float32)
    
    mapping = {'A': 0, 'C': 1, 'G': 2, 'T': 3}
    
    for i, nucleotide in enumerate(sequence[:max_length]):
        if nucleotide in mapping:
            one_hot[mapping[nucleotide], i] = 1.0

    return one_hot
    
class DNADataset(Dataset):
    def __init__(self, data):
        self.dataset = data
    
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        seq = self.dataset[idx]['seq']
        label = self.dataset[idx]['label']
        encoded_seq = one_hot_encode(seq)
        return encoded_seq, label

In [7]:
ds = dataset["train"].with_format("torch")
ds = DNADataset(ds)

train_size = int(0.8 * len(ds))
val_size = len(ds) - train_size

train_ds, val_ds = torch.utils.data.random_split(ds, [train_size, val_size])

train_loader = DataLoader(train_ds, batch_size=32, shuffle=True)

val_loader = DataLoader(val_ds, batch_size=32, shuffle=False)

## Define model

In [8]:
# Define a simple CNN for binary classification of DNA sequences
class DNAClassifierCNN(nn.Module):
    def __init__(self, kernel_size=5):
        super(DNAClassifierCNN, self).__init__()
        self.conv1 = nn.Conv1d(in_channels=4, out_channels=16, kernel_size=kernel_size, stride=1)
        self.pool = nn.MaxPool1d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv1d(in_channels=16, out_channels=32, kernel_size=kernel_size, stride=1)

        self.relu = nn.LeakyReLU()        
        self.fc1 = nn.Linear(self.count_flatten_size(), 64)
        self.fc2 = nn.Linear(64, 1)
        self.sigmoid = nn.Sigmoid()

    def count_flatten_size(self):
        dummy_input = torch.zeros(1, 4, 500)
        dummy_output = self.pool(self.conv2(self.pool(self.conv1((dummy_input)))))
        return dummy_output.view(-1).size(0)
        
    def forward(self, x):
        x = self.pool(torch.relu(self.conv1(x)))
        x = self.pool(torch.relu(self.conv2(x)))
        x = x.reshape(x.size(0), -1)  # Flatten for fully connected layer
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        x = self.sigmoid(x)
        return x


In [9]:
# Training loop
def train_model(model, train_loader, optimizer, criterion):
    model.train()
    for batch in train_loader:
        inputs, labels = batch
        labels = labels.float().to(DEVICE)
        optimizer.zero_grad()
        
        outputs = model(inputs.to(DEVICE))
        loss = criterion(outputs.squeeze(), labels)
        loss.backward()
        optimizer.step()

In [10]:
def evaluate_model(model, test_loader, criterion):
    model.eval()
    total_loss = 0
    correct = 0
    with torch.no_grad():
        for batch in test_loader:
            inputs, labels = batch
            labels = labels.float().to(DEVICE)
            
            outputs = model(inputs.to(DEVICE))
            loss = criterion(outputs.squeeze(), labels)
            total_loss += loss.item()
            preds = (outputs.squeeze() > 0.5).float()
            correct += (preds == labels).sum().item()
    
    avg_loss = total_loss / len(test_loader)
    accuracy = correct / len(test_loader.dataset)
    return avg_loss, accuracy

In [11]:
# Run model training and evaluation after each epoch
def evaluation_loop(model, epochs, lr):
    
    adam = optim.AdamW(model.parameters(), lr=lr)
    criterion = nn.BCELoss()
    
    for epoch in range(epochs):
        train_model(model, train_loader, adam, criterion)
        avg_loss, accuracy = evaluate_model(model, val_loader, criterion)
        print(f'Epoch {epoch + 1}/{epochs}, Validation Loss: {avg_loss}, Accuracy: {accuracy}')
    
    avg_loss, accuracy = evaluate_model(model, val_loader, criterion)

    print(f'Loss: {avg_loss}, Accuracy: {accuracy}\n')
    
    return accuracy

## Perform training

In [12]:
model = DNAClassifierCNN().to(DEVICE)
evaluation_loop(model, epochs=5, lr=0.001)

Epoch 1/5, Validation Loss: 0.5962875371670905, Accuracy: 0.6764212041256896
Epoch 2/5, Validation Loss: 0.5881965183119737, Accuracy: 0.6807387862796834
Epoch 3/5, Validation Loss: 0.5886761841883186, Accuracy: 0.6848165027584553
Epoch 4/5, Validation Loss: 0.6245531885678531, Accuracy: 0.6449988006716239
Epoch 5/5, Validation Loss: 0.6462481829501291, Accuracy: 0.6555528903813864
Loss: 0.6462481829501291, Accuracy: 0.6555528903813864



0.6555528903813864

## Hyperparam optimization

Let's try to optimize the learning rate, number of training epochs and size of the convolution kernel

In [15]:
def objective(trial):
    lr = trial.suggest_float('learning_rate', 0.00001, 0.01)
    epochs = trial.suggest_int('epochs', 5, 10)
    kernel_size = trial.suggest_int('kernel_size', 3, 5)

    print(f"LR: {lr}, Epochs: {epochs}, Kernel size: {kernel_size}")

    model = DNAClassifierCNN(kernel_size=kernel_size).to(DEVICE)

    acc = evaluation_loop(model, epochs, lr)
    return acc

In [16]:
study = optuna.create_study(direction='maximize')
study.optimize(objective, n_trials=5)

[I 2024-11-14 19:19:19,019] A new study created in memory with name: no-name-902363e9-c481-422a-a596-0fd84c9221f4


LR: 0.008775394190993218, Epochs: 8, Kernel size: 5
Epoch 1/8, Validation Loss: 0.5948588561465722, Accuracy: 0.6754617414248021
Epoch 2/8, Validation Loss: 0.5907631500531699, Accuracy: 0.6948908611177741
Epoch 3/8, Validation Loss: 0.5856786151878707, Accuracy: 0.6869752938354522
Epoch 4/8, Validation Loss: 0.5866358623704837, Accuracy: 0.6816982489805709
Epoch 5/8, Validation Loss: 0.5891822826771336, Accuracy: 0.6862556968097865
Epoch 6/8, Validation Loss: 0.6052291557079054, Accuracy: 0.6816982489805709
Epoch 7/8, Validation Loss: 0.6305244030388257, Accuracy: 0.6785799952026865
Epoch 8/8, Validation Loss: 0.6614234452029221, Accuracy: 0.6629887263132646


[I 2024-11-14 19:25:52,086] Trial 0 finished with value: 0.6629887263132646 and parameters: {'learning_rate': 0.008775394190993218, 'epochs': 8, 'kernel_size': 5}. Best is trial 0 with value: 0.6629887263132646.


Loss: 0.6614234452029221, Accuracy: 0.6629887263132646

LR: 0.0019691162809866544, Epochs: 10, Kernel size: 5
Epoch 1/10, Validation Loss: 0.63472282909255, Accuracy: 0.6543535620052771
Epoch 2/10, Validation Loss: 0.5907825411276053, Accuracy: 0.6852962341088991
Epoch 3/10, Validation Loss: 0.5995748266008974, Accuracy: 0.6884144878867834
Epoch 4/10, Validation Loss: 0.6072489173357724, Accuracy: 0.6713840249460302
Epoch 5/10, Validation Loss: 0.616197821520667, Accuracy: 0.6668265771168146
Epoch 6/10, Validation Loss: 0.6421644132555896, Accuracy: 0.6639481890141521
Epoch 7/10, Validation Loss: 0.7454319737339747, Accuracy: 0.6529143679539458
Epoch 8/10, Validation Loss: 0.8840712372583287, Accuracy: 0.6406812185176302
Epoch 9/10, Validation Loss: 1.1574187790619508, Accuracy: 0.629407531782202
Epoch 10/10, Validation Loss: 1.8198769188109245, Accuracy: 0.6378028304149677


[I 2024-11-14 19:34:01,652] Trial 1 finished with value: 0.6378028304149677 and parameters: {'learning_rate': 0.0019691162809866544, 'epochs': 10, 'kernel_size': 5}. Best is trial 0 with value: 0.6629887263132646.


Loss: 1.8198769188109245, Accuracy: 0.6378028304149677

LR: 0.002985461977556319, Epochs: 10, Kernel size: 5
Epoch 1/10, Validation Loss: 0.594495553779238, Accuracy: 0.6742624130486927
Epoch 2/10, Validation Loss: 0.5822400102633556, Accuracy: 0.6905732789637803
Epoch 3/10, Validation Loss: 0.5824706706836933, Accuracy: 0.6920124730151115
Epoch 4/10, Validation Loss: 0.5930814285769718, Accuracy: 0.6934516670664428
Epoch 5/10, Validation Loss: 0.5961155054223446, Accuracy: 0.6812185176301271
Epoch 6/10, Validation Loss: 0.6533085828974047, Accuracy: 0.6656272487407052
Epoch 7/10, Validation Loss: 0.7593945276191216, Accuracy: 0.6627488606380427
Epoch 8/10, Validation Loss: 1.126364355323879, Accuracy: 0.6502758455265052
Epoch 9/10, Validation Loss: 1.2706463275519946, Accuracy: 0.6339649796114176
Epoch 10/10, Validation Loss: 1.5852241702662169, Accuracy: 0.6406812185176302


[I 2024-11-14 19:42:08,979] Trial 2 finished with value: 0.6406812185176302 and parameters: {'learning_rate': 0.002985461977556319, 'epochs': 10, 'kernel_size': 5}. Best is trial 0 with value: 0.6629887263132646.


Loss: 1.5852241702662169, Accuracy: 0.6406812185176302

LR: 0.0026759141652755746, Epochs: 10, Kernel size: 3
Epoch 1/10, Validation Loss: 0.6024388545797071, Accuracy: 0.6637083233389302
Epoch 2/10, Validation Loss: 0.5963589330665938, Accuracy: 0.666106980091149
Epoch 3/10, Validation Loss: 0.6459926032383023, Accuracy: 0.6548332933557208
Epoch 4/10, Validation Loss: 0.6000486842093576, Accuracy: 0.673782681698249
Epoch 5/10, Validation Loss: 0.6373377804082768, Accuracy: 0.662269129287599
Epoch 6/10, Validation Loss: 0.6635607164779692, Accuracy: 0.6577116814583833
Epoch 7/10, Validation Loss: 0.7803482347317324, Accuracy: 0.6471575917486208
Epoch 8/10, Validation Loss: 0.8862832798302629, Accuracy: 0.6452386663468458
Epoch 9/10, Validation Loss: 1.1663809441428148, Accuracy: 0.633245382585752
Epoch 10/10, Validation Loss: 1.4030394840786475, Accuracy: 0.6315663228591989


[I 2024-11-14 19:50:16,491] Trial 3 finished with value: 0.6315663228591989 and parameters: {'learning_rate': 0.0026759141652755746, 'epochs': 10, 'kernel_size': 3}. Best is trial 0 with value: 0.6629887263132646.


Loss: 1.4030394840786475, Accuracy: 0.6315663228591989

LR: 0.008592914787533152, Epochs: 8, Kernel size: 4
Epoch 1/8, Validation Loss: 0.6934835692398421, Accuracy: 0.49796114176061407
Epoch 2/8, Validation Loss: 0.6931513379548342, Accuracy: 0.49796114176061407
Epoch 3/8, Validation Loss: 0.6931708172987435, Accuracy: 0.502038858239386
Epoch 4/8, Validation Loss: 0.6933394183639352, Accuracy: 0.502038858239386
Epoch 5/8, Validation Loss: 0.6935275451827595, Accuracy: 0.49796114176061407
Epoch 6/8, Validation Loss: 0.6934432792299576, Accuracy: 0.502038858239386
Epoch 7/8, Validation Loss: 0.69342215689084, Accuracy: 0.502038858239386
Epoch 8/8, Validation Loss: 0.6931612737306202, Accuracy: 0.49796114176061407


[I 2024-11-14 19:56:50,487] Trial 4 finished with value: 0.49796114176061407 and parameters: {'learning_rate': 0.008592914787533152, 'epochs': 8, 'kernel_size': 4}. Best is trial 0 with value: 0.6629887263132646.


Loss: 0.6931612737306202, Accuracy: 0.49796114176061407



In [17]:
print(f"Best hyperparameters: {study.best_params}")
print(f"Best value (validation AU PRC): {study.best_value}")

Best hyperparameters: {'learning_rate': 0.008775394190993218, 'epochs': 8, 'kernel_size': 5}
Best value (validation AU PRC): 0.6629887263132646
