In [8]:
import os
import lightning as L
from lightning.pytorch import loggers as pl_loggers
from lightning.pytorch.callbacks import TQDMProgressBar, ModelCheckpoint, EarlyStopping

import numpy as np
import pandas as pd

import seaborn as sns
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torchmetrics
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from torchmetrics.classification import (
    AUROC,
)
from torchmetrics import (
    PearsonCorrCoef,
    SpearmanCorrCoef,
    R2Score
)

import ast

plt.rcParams["savefig.bbox"] = 'tight'

# Check data table

In [3]:
data = pd.read_csv("/home/user11/data/data_processed/data.tsv", sep="\t", names=["peptide", "score", "hla"])
embeddings_table = pd.read_csv("/home/user11/data/embeddings_proteins/wide_data.tsv", sep="\t")

i = 1

train = pd.read_csv(f"/home/user11/data/data_processed/train{i}", sep="\t", names=["peptide", "score", "hla"])
train.hla = train.hla.str.replace("_", "")
train_data = pd.merge(train, embeddings_table, on=["peptide", "score", "hla"])

val = pd.read_csv(f"/home/user11/data/data_processed/test{i}", sep="\t", names=["peptide", "score", "hla"])
val.hla = val.hla.str.replace("_", "")
val_data = pd.merge(val, embeddings_table, on=["peptide", "score", "hla"])

In [63]:
train_data.loc[0]

peptide                                                DLDKKETVWHLEE
score                                                            0.0
hla                                            HLA-DPA10103-DPB10201
alpha_id                                                    DPA10103
beta_id                                                     DPB10201
alpha_seq          MRPEDRMFHIRAVILRALSLAFLLSLRGAGAIKADHVSTYAAFVQT...
beta_seq           MMVLQVSAAPRTVALTALLMVLLTSVVQGRATPENYLFQGRQECYA...
alpha_path         /home/user11/data/embeddings_proteins/emb_esmc...
beta_path          /home/user11/data/embeddings_proteins/emb_esmc...
interface                         YAFFMFSGGAILNTLFGQFEYFDIEEVRMHLGMT
peptide_path       /home/user11/data/embeddings_proteins/emb_esmc...
alpha_positions    [39, 41, 52, 54, 61, 82, 83, 88, 89, 91, 95, 9...
beta_positions     [37, 39, 38, 52, 54, 56, 73, 83, 93, 96, 97, 1...
Name: 0, dtype: object

In [80]:
class MHCSequenceDataset(Dataset):
    def __init__(self, df):
        self.df = df
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        _, score, _, _, _, _, _, alpha_path, beta_path, _, peptide_path, alpha_positions, beta_positions = train_data.loc[idx]
        
        alpha_embeddings = np.load(alpha_path)[:, ast.literal_eval(alpha_positions), :].squeeze(0)
        beta_embeddings = np.load(beta_path)[:, ast.literal_eval(beta_positions), :].squeeze(0)
        peptide_embeddings = torch.FloatTensor(np.load(peptide_path))
        #binding = torch.tensor(score > 0.5).to(torch.long)

        return torch.FloatTensor(np.concatenate([alpha_embeddings, beta_embeddings], axis=0)), \
                F.pad(peptide_embeddings, (0, 0, (21 - peptide_embeddings.shape[1]) // 2, 21 - peptide_embeddings.shape[1] - (21 - peptide_embeddings.shape[1]) // 2), 'constant', value=0).squeeze(0), \
                torch.tensor(score, dtype=torch.float)

In [None]:
train_dataset = MHCSequenceDataset(train)
val_dataset = MHCSequenceDataset(val)

train_dataloader = DataLoader(train_dataset, batch_size=512, shuffle=True, num_workers=8)
val_dataloader = DataLoader(val_dataset, batch_size=512, shuffle=False, num_workers=8)

#len(train_dataset), train_dataset[30][1].shape

In [82]:
val_dataset[0][2].dtype

torch.float32

# Models

In [94]:

class PositionalEncoding(nn.Module):
    def __init__(self, dim, max_len=1000):
        super().__init__()
        pe = torch.zeros(max_len, dim)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, dim, 2) * -(torch.log(torch.tensor(10000.0)) / dim))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.pe = pe.unsqueeze(0)  # [1, max_len, dim]

    def forward(self, x):
        return x + self.pe[:, :x.size(1), :].to(x.device)

class CrossAttentionIC50Model(nn.Module):
    def __init__(self, d_model=1152, nhead=8, dim_feedforward=2048):
        super().__init__()
        self.protein_pos = PositionalEncoding(d_model, max_len=34)
        self.peptide_pos = PositionalEncoding(d_model, max_len=21)

        self.cross_attn = nn.MultiheadAttention(embed_dim=d_model, num_heads=nhead, batch_first=True)

        self.mlp = nn.Sequential(
            nn.Linear(d_model, dim_feedforward),
            nn.ReLU(),
            nn.Linear(dim_feedforward, 1),
            nn.Sigmoid(),  # т.к. IC50 нормализован от 0 до 1
        )

    def forward(self, protein, peptide):
        """
        protein: [B, 34, 1152]
        peptide: [B, L, 1152] (L ∈ [9, 21])
        """

        B, L, D = peptide.size()

        # позиционная кодировка
        protein = self.protein_pos(protein)
        peptide_fwd = self.peptide_pos(peptide)
        peptide_rev = self.peptide_pos(torch.flip(peptide, dims=[1]))

        # cross-attention (protein queries, peptide keys/values)
        attn_out_fwd, _ = self.cross_attn(query=protein, key=peptide_fwd, value=peptide_fwd)
        attn_out_rev, _ = self.cross_attn(query=protein, key=peptide_rev, value=peptide_rev)

        # Инвариантность ориентации — усреднение
        attn_out = (attn_out_fwd + attn_out_rev) / 2  # [B, 34, 1152]

        # Пулинг по белку (например, mean pooling)
        pooled = attn_out.mean(dim=1)  # [B, 1152]

        return self.mlp(pooled)  # [B, 1]


# Create lightning module

In [95]:
class LModel(L.LightningModule):
    def __init__(self, model, learning_rate, weight_decay):
        super().__init__()
        self.save_hyperparameters(ignore=['model'])
        self.model = model

        self.learning_rate = learning_rate
        self.weight_decay = weight_decay

        self.train_metrics_regression = self._make_metrics_regression("train_")
        self.validation_metrics_regression = self._make_metrics_regression("validation_")
        self.train_metrics_classification = self._make_metrics_classification("train_")
        self.validation_metrics_classification = self._make_metrics_classification("validation_")

        self.loss_fn = nn.MSELoss()

        self.cutoff = 1.0 - np.log(500) / np.log(50000)

    def _make_metrics_classification(self, prefix):
        metrics = torchmetrics.MetricCollection(
            {
               "auroc": AUROC(num_classes=2, task="binary")
            },
            prefix=prefix)
        return metrics

    def _make_metrics_regression(self, prefix):
        metrics = torchmetrics.MetricCollection(
            {
                "pcc": PearsonCorrCoef(),
                "srcc": SpearmanCorrCoef(),  
                "r2": R2Score(),             
            },
            prefix=prefix)
        return metrics

    def forward(self, mhc_embeddings, peptide_embeddings):
        return self.model(mhc_embeddings, peptide_embeddings)

    def _evaluate(self, batch, stage=None):
        mhc_embeddings, peptide_embeddings, scores = batch
        logits = self.forward(mhc_embeddings, peptide_embeddings).squeeze(1)
        binary_logits = logits >= self.cutoff
        loss = self.loss_fn(logits, scores) 

        

        metrics_dict = {f"{stage}_loss": loss}

        if stage == 'validation':
            val_metrics_regression = self.validation_metrics_regression(logits, scores)
            val_metrics_classification = self.validation_metrics_classification(binary_logits, scores)
            metrics_dict.update(val_metrics_regression)
            metrics_dict.update(val_metrics_classification)
        elif stage == 'train':
            train_metrics_regression = self.train_metrics_regression(logits, scores)
            train_metrics_classification = self.train_metrics_classification(binary_logits, scores)
            metrics_dict.update(train_metrics_regression)
            metrics_dict.update(train_metrics_classification)

            self.log_dict(metrics_dict, 
                          on_step=True, 
                          on_epoch=False, 
                          sync_dist=True, 
                          prog_bar=True)

        return loss

    def training_step(self, batch, batch_idx):
        loss = self._evaluate(batch, stage='train')
        return loss

    def on_train_epoch_end(self):
        self.train_metrics_classification.reset()
        self.train_metrics_regression.reset()

    def validation_step(self, batch, batch_idx):
        _ = self._evaluate(batch, stage='validation')        

    def on_validation_epoch_end(self):
        # Логируем валидационные метрики
        self.log_dict(self.validation_metrics_regression.compute(), 
                      on_step=False, 
                      on_epoch=True, 
                      sync_dist=True, 
                      prog_bar=True)
        self.validation_metrics_regression.reset()

        self.log_dict(self.validation_metrics_classification.compute(), 
                      on_step=False, 
                      on_epoch=True, 
                      sync_dist=True, 
                      prog_bar=True)
        self.validation_metrics_classification.reset()


    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), 
                                      lr=self.learning_rate, 
                                      weight_decay=self.weight_decay)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer,
            mode='max',
            factor=0.5,
            patience=5,
            min_lr=1e-6
        )
        
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "monitor": "validation_auroc",
                "interval": "epoch",
                "frequency": 1,
            },
        }
    

In [29]:
log_path = '/home/user11/results/logs/'
log_csv_path = '/home/user11/results/logs_csv/'
checkpoints_path = '/home/user11/results/models/'
EPOCHS = 100


In [97]:
model_name = 'CrossAttentionIC50Model'
model = LModel(CrossAttentionIC50Model(),
               learning_rate=3e-4,
               weight_decay=1e-3,
               )

logger = pl_loggers.TensorBoardLogger(name=f"{model_name}", save_dir=log_path)
logger_csv = pl_loggers.CSVLogger(name=f"{model_name}", save_dir=log_csv_path)


checkpoint_callback = ModelCheckpoint(
    dirpath=os.path.join(checkpoints_path, f"{model_name}", f"version_{logger_csv.version}", "checkpoints"),
    filename="model-epoch={epoch:02d}",
    save_top_k=-1,
    every_n_epochs=1,
    save_on_train_epoch_end=True,
)

best_iou_callback = ModelCheckpoint(
    dirpath=os.path.join(checkpoints_path, f"{model_name}", f"version_{logger_csv.version}", "checkpoints"),
    filename="best_model_epoch={epoch:02d}-auroc={auroc:.4f}",
    monitor="validation_auroc",
    mode="max",
    save_top_k=1,
    save_on_train_epoch_end=True,
)

early_stop = EarlyStopping(monitor="validation_auroc", patience=10, mode="max")

trainer = L.Trainer(
    max_epochs=EPOCHS,
    devices=[0],
    default_root_dir=f'{checkpoints_path}/{model_name}',
    logger=[logger, logger_csv],
    accelerator="gpu",
    precision="16-mixed",
    callbacks=[best_iou_callback, early_stop, TQDMProgressBar(refresh_rate=1)],
    log_every_n_steps=1
)



Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [98]:
trainer.fit(model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

  | Name                              | Type                    | Params | Mode 
--------------------------------------------------------------------------------------
0 | model                             | CrossAttentionIC50Model | 7.7 M  | train
1 | train_metrics_regression          | MetricCollection        | 0      | train
2 | validation_metrics_regression     | MetricCollection        | 0      | train
3 | train_metrics_classification      | MetricCollection        | 0      | train
4 | validation_metrics_classification | MetricCollection        | 0      | train
5 | loss_fn                           | MSELoss                 | 0      | train
--------------------------------------------------------------------------------------
7.7 M     Trainable params
0         Non-trainable params
7.7 M     Total params
30.706    Total estimated model params size (MB)
23        Modules in train mode
0         Modules in eval mode


Sanity Checking: |                                                       | 0/? [00:00<?, ?it/s]

Training: |                                                              | 0/? [00:00<?, ?it/s]



Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]