In [1]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import os
import random
import torch
import optuna
import lightning as L
from lightning.pytorch.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import CSVLogger
from optuna.integration import PyTorchLightningPruningCallback
from torchmetrics.functional.classification import binary_confusion_matrix
from dataset_module import SeqDatasetModule
from classification_attention_model import SequenceEncoder
from sklearn.model_selection import train_test_split
from dataclasses import dataclass

In [2]:
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# Set seed for reproducibility
seed=17
set_seed(seed)

#DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


Load data

In [3]:
data_dir="./data"
filename= "classification_SeqLen6_MotifLen3_MotifsNum50.csv"
df= pd.read_csv(os.path.join(data_dir, filename))
df

Unnamed: 0,seq,label,motif,positions,mult_motifs
0,AAAAPA,1,A...PA,"(1, 2, 3)",0
1,AAACPA,1,A...PA,"(1, 2, 3)",0
2,AAADPA,1,A...PA,"(1, 2, 3)",0
3,AAAEPA,1,A...PA,"(1, 2, 3)",0
4,AAAFPA,1,A...PA,"(1, 2, 3)",0
...,...,...,...,...,...
396303,YYWSSW,1,..WS.W,"(0, 1, 4)",0
396304,YYWSTW,1,..WS.W,"(0, 1, 4)",0
396305,YYWSVW,1,..WS.W,"(0, 1, 4)",0
396306,YYWSWW,1,..WS.W,"(0, 1, 4)",0


In [4]:

train_val_dataset, test_dataset = train_test_split(df, test_size=0.2, random_state=seed)
train_val_dataset


Unnamed: 0,seq,label,motif,positions,mult_motifs
194320,KRASAM,1,..ASA.,"(0, 1, 5)",0
193740,KQQDKA,0,..QD.A,"(0, 1, 4)",0
378379,WQLPMS,0,..LP.S,"(0, 1, 4)",0
71416,CTSRAN,1,..SRA.,"(0, 1, 5)",0
214519,MASAHQ,1,.AS..Q,"(0, 3, 4)",0
...,...,...,...,...,...
251821,PQSPDS,1,PQ.P..,"(2, 4, 5)",0
125680,GHGRFQ,0,G.GR..,"(1, 4, 5)",0
304441,RTDCWA,1,..DCW.,"(0, 1, 5)",0
297103,RMPPSG,1,R.PP..,"(1, 4, 5)",0


Optimize hyperparameters

In [5]:
@dataclass
class ModelConfig:
    seq_len: int = 6 # max sequence length
    #vocab_size: int = 20 # number of tokens
    n_layer: int = 1 # number of layers
    n_head: int = 1 # number of heads
    #n_embd: int = 20 # embedding dimension
    dim_feedforward: int = 2048
    dropout: float = 0.015
    bias: bool = False
    layer_norm_eps: float = 1e-5
    d_model: int = 120 # seq_len *n_embd
    batch_size: int = 64
    num_workers: int = 10


def get_predictions(trainer, best_model, final_data_module, test_dataset):
    train_predictions =trainer.predict(model=best_model, datamodule=final_data_module)
    train_predictions=torch.concat(train_predictions).cpu().detach().numpy()

    test_data_module = SeqDatasetModule(test_dataset, test_dataset, batch_size=best_model.config.batch_size, num_workers=best_model.config.num_workers)
    test_predictions =trainer.predict(model=best_model, datamodule=test_data_module)
    test_predictions=torch.concat(test_predictions).cpu().detach().numpy()

    
    return train_predictions, test_predictions

def refit_best_model(params, train_val_dataset, test_dataset):

    max_epochs=5
    every_n_train_steps =1000

    model_config = ModelConfig(**params)


    checkpoint_callback = ModelCheckpoint(
                                            dirpath='./best_model_logs/checkpoints/',  # directory to save checkpoints
                                            filename=f'encoder_model-layers={model_config.n_layer}-heads={model_config.n_head}-dropout={model_config.dropout:.5f}-'+ '{epoch:02d}-{val_loss:.4f}', 
                                            save_top_k=1,  # save the best checkpoint
                                            every_n_train_steps= every_n_train_steps
                                            )
    
    best_model_logger = CSVLogger("best_model_logs")

    final_data_module = SeqDatasetModule(train_val_dataset, test_dataset, batch_size=model_config.batch_size, num_workers=model_config.num_workers)

    best_model = SequenceEncoder(model_config)
    trainer = L.Trainer( callbacks=[checkpoint_callback], logger=best_model_logger, max_epochs=max_epochs)
    trainer.fit(model=best_model, datamodule=final_data_module)
    
    latest_checkpoint_path = checkpoint_callback.best_model_path
    print(f'Best checkpoint path: {latest_checkpoint_path}')

    trainer.test(model=best_model, datamodule=final_data_module)

    print("predicting label probabilities")

    train_predictions, test_predictions= get_predictions(trainer, best_model, final_data_module, test_dataset)

    return train_predictions, test_predictions



def optimize_models(train_val_dataset, n_trials, timeout=1000, n_startup_trials=2, max_epochs=5):    

    def objective(trial):

        model_config = ModelConfig()


        #model_config.n_layer = trial.suggest_int("n_layer", 1, 2, step=1)
        #model_config.n_head =  trial.suggest_int("n_head", 2, 4, step=2)
        model_config.dropout=trial.suggest_float('dropout', 0.0023, 0.0025)
     
        
        train_dataset, val_dataset = train_test_split(train_val_dataset, test_size=0.2)
        
        data_module = SeqDatasetModule(train_dataset, val_dataset, 
                                    batch_size=model_config.batch_size, 
                                    num_workers=model_config.num_workers)

        model = SequenceEncoder(model_config)
        # train model
        logger = CSVLogger("logs")
        trainer = L.Trainer(
            logger=logger,
            max_epochs=max_epochs,
            callbacks=[PyTorchLightningPruningCallback(trial, monitor="val_loss")])

        # hyperparameters = dict(num_layers=num_layers, dropout_prob=dropout_prob, num_heads=nhead)
        # trainer.logger.log_hyperparams(hyperparameters)
        
        # Train the model.
        #with mlflow.start_run() as run:
        trainer.fit(model=model, datamodule=data_module )
        val_loss=trainer.callback_metrics["val_loss"].item()

        return val_loss

    pruner = optuna.pruners.MedianPruner(n_startup_trials=n_startup_trials)

    study = optuna.create_study(direction="minimize", 
                                storage="sqlite:///db.sqlite3",
                                pruner=pruner)
    study.optimize(objective, n_trials=n_trials, timeout=timeout)

    return study


In [6]:
study = optimize_models(train_val_dataset, n_trials=3)

[I 2024-06-25 10:56:30,590] A new study created in RDB with name: no-name-995e7bfc-edcb-4321-b47f-48de1f270406
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type                | Params
----------------------------------------------
0 | model | SequenceTransformer | 549 K 
----------------------------------------------
549 K     Trainable params
0         Non-trainable params
549 K     Total params
2.198     Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=5` reached.
[I 2024-06-25 10:58:25,679] Trial 0 finished with value: 0.0012378619285300374 and parameters: {'dropout': 0.0024217735400829743}. Best is trial 0 with value: 0.0012378619285300374.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type                | Params
----------------------------------------------
0 | model | SequenceTransformer | 549 K 
----------------------------------------------
549 K     Trainable params
0         Non-trainable params
549 K     Total params
2.198     Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=5` reached.
[I 2024-06-25 11:00:16,988] Trial 1 finished with value: 0.0024394700303673744 and parameters: {'dropout': 0.0024713488075353896}. Best is trial 0 with value: 0.0012378619285300374.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type                | Params
----------------------------------------------
0 | model | SequenceTransformer | 549 K 
----------------------------------------------
549 K     Trainable params
0         Non-trainable params
549 K     Total params
2.198     Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

[I 2024-06-25 11:01:07,109] Trial 2 pruned. Trial was pruned at epoch 1.


In [7]:
print("Number of finished trials: {}".format(len(study.trials)))

print("Best trial:")
best_trial = study.best_trial

print("  Value: {}".format(best_trial.value))

print("  Params: ")
for key, value in best_trial.params.items():
    print("    {}: {}".format(key, value))


Number of finished trials: 3
Best trial:
  Value: 0.0012378619285300374
  Params: 
    dropout: 0.0024217735400829743


In [8]:
train_predictions, test_predictions = refit_best_model(best_trial.params, train_val_dataset, test_dataset)   


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type                | Params
----------------------------------------------
0 | model | SequenceTransformer | 549 K 
----------------------------------------------
549 K     Trainable params
0         Non-trainable params
549 K     Total params
2.198     Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=5` reached.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Best checkpoint path: /home/rick/hd2/Platt Lab Dropbox/Rick Farouni/sequence_models/best_model_logs/checkpoints/encoder_model-layers=1-heads=1-dropout=0.00242-epoch=04-val_loss=0.0005.ckpt


Testing: |          | 0/? [00:00<?, ?it/s]

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_acc            0.9999621510505676
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
predicting label probabilities


Predicting: |          | 0/? [00:00<?, ?it/s]

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: |          | 0/? [00:00<?, ?it/s]

Assess model fit

In [9]:
train_val_dataset.loc[:, "pred_prob"] = train_predictions 
test_dataset.loc[:, "pred_prob"] = test_predictions

In [16]:
test_dataset

Unnamed: 0,seq,label,motif,positions,mult_motifs,pred_prob,split
238782,PGNGRG,0,.G.GR.,"(0, 2, 5)",0,0.000008,test
69243,CRTKAQ,1,C...AQ,"(1, 2, 3)",0,0.999996,test
277252,QRSPML,0,QRS...,"(3, 4, 5)",0,0.000068,test
311511,SDWCCW,1,SDW...,"(3, 4, 5)",0,0.999995,test
280503,QSDCWW,1,..DCW.,"(0, 1, 5)",0,0.999996,test
...,...,...,...,...,...,...,...
216596,MELPQS,0,..LP.S,"(0, 1, 4)",0,0.000014,test
83895,DMDAAQ,1,D.DA..,"(1, 4, 5)",0,0.999991,test
293577,RILYRR,0,R...RR,"(1, 2, 3)",0,0.000010,test
254469,PRMERL,0,P...RL,"(1, 2, 3)",0,0.000003,test


In [10]:
train_val_dataset['pred_label'] = (train_val_dataset['pred_prob'] > 0.5).astype(int)
filtered_train_val_dataset = train_val_dataset[train_val_dataset['label'] != train_val_dataset['pred_label']]
filtered_train_val_dataset

Unnamed: 0,seq,label,motif,positions,mult_motifs,pred_prob,pred_label
220058,MPRAQG,1,.P.AQ.,"(0, 2, 5)",0,0.186599,0
277849,QRSRAH,0,QRS...,"(3, 4, 5)",1,0.55095,1
53433,CDRPAQ,1,C...AQ,"(1, 2, 3)",0,0.223282,0
377113,WPRAQC,1,.P.AQ.,"(0, 2, 5)",0,0.043108,0
377391,WPWAQA,1,.P.AQ.,"(0, 2, 5)",0,0.339056,0
278214,QRSRWN,0,QRS...,"(3, 4, 5)",0,0.509417,1
277860,QRSRAV,0,QRS...,"(3, 4, 5)",1,0.534969,1
251481,PQRPRR,1,PQ.P..,"(2, 4, 5)",0,0.227421,0
269597,QPRAQC,1,.P.AQ.,"(0, 2, 5)",0,0.31554,0
269601,QPRAQG,1,.P.AQ.,"(0, 2, 5)",0,0.307436,0


In [17]:
test_dataset['pred_label'] = (test_dataset['pred_prob'] > 0.5).astype(int)
filtered_test_dataset = test_dataset[test_dataset['label'] != test_dataset['pred_label']]
filtered_test_dataset

Unnamed: 0,seq,label,motif,positions,mult_motifs,pred_prob,split,pred_label
34164,APRAQG,1,.P.AQ.,"(0, 2, 5)",0,0.382034,test,0
278218,QRSRWS,0,QRS...,"(3, 4, 5)",1,0.956633,test,1
2526,AAWSDW,0,.A.SD.,"(0, 2, 5)",1,0.870069,test,1


In [18]:
binary_confusion_matrix(torch.tensor(train_val_dataset['pred_prob'].values), torch.tensor(train_val_dataset['label'].values))

tensor([[159551,     11],
        [    10, 157474]])

In [19]:
binary_confusion_matrix(torch.tensor(test_dataset['pred_prob'].values), torch.tensor(test_dataset['label'].values))

tensor([[39697,     2],
        [    1, 39562]])

In [12]:
 0.002392348637241081

0.002392348637241081

Save data with predictions

In [20]:
test_dataset["split"] ="test"
train_val_dataset["split"] ="train"
df_predicted=pd.concat([train_val_dataset, test_dataset ])
df_predicted

Unnamed: 0,seq,label,motif,positions,mult_motifs,pred_prob,pred_label,split
194320,KRASAM,1,..ASA.,"(0, 1, 5)",0,0.999998,1,train
193740,KQQDKA,0,..QD.A,"(0, 1, 4)",0,0.000013,0,train
378379,WQLPMS,0,..LP.S,"(0, 1, 4)",0,0.000092,0,train
71416,CTSRAN,1,..SRA.,"(0, 1, 5)",0,0.999993,1,train
214519,MASAHQ,1,.AS..Q,"(0, 3, 4)",0,0.999987,1,train
...,...,...,...,...,...,...,...,...
216596,MELPQS,0,..LP.S,"(0, 1, 4)",0,0.000014,0,test
83895,DMDAAQ,1,D.DA..,"(1, 4, 5)",0,0.999991,1,test
293577,RILYRR,0,R...RR,"(1, 2, 3)",0,0.000010,0,test
254469,PRMERL,0,P...RL,"(1, 2, 3)",0,0.000003,0,test


In [21]:
filename_predicted= "predicted_" + filename
df_predicted.to_csv(os.path.join(data_dir, filename_predicted), index=None)

In [22]:
filename_predicted

'predicted_classification_SeqLen6_MotifLen3_MotifsNum50.csv'

In [None]:
import re
MOTIFS= df['motif'].unique()
def find_motif(seq):
    # negative matching is first (presence of low fitness motif preculdes further enrichment) 
    for motif in MOTIFS:
        if re.match(motif, seq):
            return  motif 
