In [None]:
import os
import lightning as L
from lightning.pytorch import loggers as pl_loggers
from lightning.pytorch.callbacks import TQDMProgressBar, ModelCheckpoint, EarlyStopping

import numpy as np
import pandas as pd

import seaborn as sns
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torchmetrics
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from torchmetrics.classification import (
    AUROC,
)
from torchmetrics import (
    PearsonCorrCoef,
    SpearmanCorrCoef,
    R2Score
)

import ast

plt.rcParams["savefig.bbox"] = 'tight'

# Check data table

In [126]:
data = pd.read_csv("/home/user11/data/data_processed/data.tsv", sep="\t", names=["peptide", "score", "hla"])
embeddings_table = pd.read_csv("/home/user11/data/embeddings_proteins/wide_data.tsv", sep="\t")

i = 1

train = pd.read_csv(f"/home/user11/data/data_processed/train{i}", sep="\t", names=["peptide", "score", "hla"])
train.hla = train.hla.str.replace("_", "")
train_data = pd.merge(train, embeddings_table, on=["peptide", "score", "hla"])

val = pd.read_csv(f"/home/user11/data/data_processed/test{i}", sep="\t", names=["peptide", "score", "hla"])
val.hla = val.hla.str.replace("_", "")
val_data = pd.merge(val, embeddings_table, on=["peptide", "score", "hla"])

In [127]:
train_data.loc[0]

peptide                                                DLDKKETVWHLEE
score                                                            0.0
hla                                            HLA-DPA10103-DPB10201
alpha_id                                                    DPA10103
beta_id                                                     DPB10201
alpha_seq          MRPEDRMFHIRAVILRALSLAFLLSLRGAGAIKADHVSTYAAFVQT...
beta_seq           MMVLQVSAAPRTVALTALLMVLLTSVVQGRATPENYLFQGRQECYA...
alpha_path         /home/user11/data/embeddings_proteins/emb_esmc...
beta_path          /home/user11/data/embeddings_proteins/emb_esmc...
interface                         YAFFMFSGGAILNTLFGQFEYFDIEEVRMHLGMT
peptide_path       /home/user11/data/embeddings_proteins/emb_esmc...
alpha_positions    [39, 41, 52, 54, 61, 82, 83, 88, 89, 91, 95, 9...
beta_positions     [37, 39, 38, 52, 54, 56, 73, 83, 93, 96, 97, 1...
Name: 0, dtype: object

# Базовый датасет

In [230]:
def collate_fn(batch):
    proteins, peptides, lengths, scores = zip(*batch)
    
    proteins = torch.stack(proteins)        # [B, 34, 1152]
    peptides = torch.stack(peptides)        # [B, 21, 1152] — уже паддинг
    lengths = torch.tensor(lengths)         # [B]
    scores = torch.tensor(scores).unsqueeze(1)  # [B, 1]
    
    return proteins, peptides, lengths, scores


class MHCSequenceDataset(Dataset):
    def __init__(self, df):
        self.df = df
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        _, score, _, _, _, _, _, alpha_path, beta_path, _, peptide_path, alpha_positions, beta_positions = self.df.loc[idx]
        
        alpha_embeddings = np.load(alpha_path)[:, ast.literal_eval(alpha_positions), :].squeeze(0)
        beta_embeddings = np.load(beta_path)[:, ast.literal_eval(beta_positions), :].squeeze(0)
        peptide_embeddings = torch.FloatTensor(np.load(peptide_path)).squeeze(0)

        peptide_len = peptide_embeddings.shape[0]

        # Паддинг по центру до 21
        total_pad = 21 - peptide_len
        left_pad = total_pad // 2
        right_pad = total_pad - left_pad
        peptide_padded = F.pad(peptide_embeddings, (0, 0, left_pad, right_pad), 'constant', value=0)
        protein = torch.FloatTensor(np.concatenate([alpha_embeddings, beta_embeddings], axis=0))

        return protein, peptide_padded, peptide_len, torch.tensor(score, dtype=torch.float)
    
    

# Модификация датасет

In [207]:
def collate_fn(batch):
    proteins, peptides, lengths, scores = zip(*batch)
    
    # Convert to tensors
    proteins = torch.stack(proteins)        # [B, 34, 1152]
    peptides = torch.stack(peptides)        # [B, 21, 1152]
    lengths = torch.tensor(lengths)         # [B]
    scores = torch.tensor(scores, dtype=torch.float32).unsqueeze(1)  # [B, 1]
    
    # Create mask for peptides (1 for real, 0 for padding)
    max_len = peptides.size(1)
    mask = torch.arange(max_len).expand(len(lengths), max_len) < lengths.unsqueeze(1)
    mask = mask.float().unsqueeze(-1)  # [B, 21, 1]
    
    return proteins, peptides, lengths, scores, mask

class MHCSequenceDataset(Dataset):
    def __init__(self, df, max_peptide_len=21):
        self.df = df
        self.max_len = max_peptide_len
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        
        # Load embeddings
        alpha_emb = np.load(row['alpha_path'])[:, ast.literal_eval(row['alpha_positions']), :].squeeze(0)
        beta_emb = np.load(row['beta_path'])[:, ast.literal_eval(row['beta_positions']), :].squeeze(0)
        peptide_emb = torch.FloatTensor(np.load(row['peptide_path'])).squeeze(0)
        
        # Original peptide length (before padding)
        peptide_len = peptide_emb.shape[0]
        
        # Center padding
        total_pad = self.max_len - peptide_len
        left_pad = total_pad // 2
        right_pad = total_pad - left_pad
        peptide_padded = F.pad(peptide_emb, (0, 0, left_pad, right_pad), 'constant', 0)
        
        # Combine protein chains
        protein = torch.FloatTensor(np.concatenate([alpha_emb, beta_emb], axis=0))
        
        return protein, peptide_padded, peptide_len, torch.tensor(row['score'], dtype=torch.float32)

# Задаем датасет и лоадер

In [231]:
train_dataset = MHCSequenceDataset(train_data)
val_dataset = MHCSequenceDataset(val_data)

train_dataloader = DataLoader(train_dataset, batch_size=512, shuffle=True, num_workers=8, collate_fn=collate_fn)
val_dataloader = DataLoader(val_dataset, batch_size=512, shuffle=False, num_workers=8, collate_fn=collate_fn)

#len(train_dataset), train_dataset[30][1].shape

In [218]:
batch = next(iter(val_dataloader))

In [220]:
len(batch)

5

# GNN

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import ast
import pandas as pd

from torch.utils.data import Dataset
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GATConv, global_mean_pool, LayerNorm


# ---------- Dataset Class ----------
class ProteinPeptideGNNDataset(Dataset):
    def __init__(self, df):
        self.df = df

    def __len__(self):
        return len(self.df)

    def create_chain_edges(self, length, offset=0):
        edges = []
        for i in range(length - 1):
            edges.append((i + offset, i + 1 + offset))
            edges.append((i + 1 + offset, i + offset))
        return edges

    def create_full_bipartite_edges(self, protein_len, peptide_len):
        edges = []
        for i in range(protein_len):
            for j in range(peptide_len):
                edges.append((i, j + protein_len))
                edges.append((j + protein_len, i))
        return edges

    def __getitem__(self, idx):
        row = self.df.loc[idx]
        _, score, _, _, _, _, _, alpha_path, beta_path, _, peptide_path, alpha_positions, beta_positions = row

        alpha_embeddings = np.load(alpha_path)[:, ast.literal_eval(alpha_positions), :].squeeze(0)
        beta_embeddings = np.load(beta_path)[:, ast.literal_eval(beta_positions), :].squeeze(0)
        protein_embeddings = np.concatenate([alpha_embeddings, beta_embeddings], axis=0)
        peptide_embeddings = np.load(peptide_path).squeeze(0)

        protein_embeddings = torch.FloatTensor(protein_embeddings)
        peptide_embeddings = torch.FloatTensor(peptide_embeddings)

        x = torch.cat([protein_embeddings, peptide_embeddings], dim=0)
        protein_len = protein_embeddings.size(0)
        peptide_len = peptide_embeddings.size(0)

        seq_edges = self.create_chain_edges(protein_len) + self.create_chain_edges(peptide_len, offset=protein_len)
        int_edges = self.create_full_bipartite_edges(protein_len, peptide_len)
        edge_index = torch.tensor(seq_edges + int_edges, dtype=torch.long).t().contiguous()

        peptide_mask = torch.zeros(x.size(0), dtype=torch.bool)
        peptide_mask[protein_len:] = True
        y = torch.tensor([score], dtype=torch.float)

        data = Data(x=x, edge_index=edge_index, y=y)
        data.peptide_mask = peptide_mask
        data.protein_len = protein_len
        data.peptide_len = peptide_len
        return data


# ---------- GNN Model ----------
class GNNModel(nn.Module):
    def __init__(self, input_dim=1152, hidden_dim=128, dropout=0.2):
        super().__init__()

        self.conv1 = GATConv(input_dim, hidden_dim, heads=2, concat=False)
        self.norm1 = LayerNorm(hidden_dim)
        self.conv2 = GATConv(hidden_dim, hidden_dim, heads=2, concat=False)
        self.norm2 = LayerNorm(hidden_dim)

        self.dropout = nn.Dropout(dropout)

        self.out = nn.Sequential(
            nn.Linear(hidden_dim * 2, 128),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(128, 1)
        )

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        batch = data.batch
        peptide_mask = data.peptide_mask

        x = F.elu(self.norm1(self.conv1(x, edge_index)))
        x = self.dropout(x)
        x = F.elu(self.norm2(self.conv2(x, edge_index)))
        x = self.dropout(x)

        prot_repr = global_mean_pool(x[~peptide_mask], batch[~peptide_mask])
        pep_repr = global_mean_pool(x[peptide_mask], batch[peptide_mask])

        combined = torch.cat([prot_repr, pep_repr], dim=1)
        return self.out(combined)


In [303]:
train_dataset = ProteinPeptideGNNDataset(train_data)
val_dataset = ProteinPeptideGNNDataset(val_data)

train_dataloader = DataLoader(train_dataset, batch_size=512, shuffle=True, num_workers=8)
val_dataloader = DataLoader(val_dataset, batch_size=512, shuffle=False, num_workers=8)

#len(train_dataset), train_dataset[30][1].shape

# Models

### 1 (Best)

In [291]:

class PositionalEncoding(nn.Module):
    def __init__(self, dim, max_len=1000):
        super().__init__()
        pe = torch.zeros(max_len, dim)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, dim, 2) * -(torch.log(torch.tensor(10000.0)) / dim))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.pe = pe.unsqueeze(0)  # [1, max_len, dim]

    def forward(self, x):
        return x + self.pe[:, :x.size(1), :].to(x.device)

class CrossAttentionIC50Model(nn.Module):
    def __init__(self, d_model=1152, nhead=8, dim_feedforward=2048, dropout=0.2):
        super().__init__()
        self.protein_pos = PositionalEncoding(d_model, max_len=34)
        self.peptide_pos = PositionalEncoding(d_model, max_len=21)

        self.cross_attn = nn.MultiheadAttention(embed_dim=d_model, num_heads=nhead, dropout=dropout, batch_first=True)

        self.mlp = nn.Sequential(
            nn.Linear(d_model, dim_feedforward),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(dim_feedforward, 1),
            #nn.Sigmoid(),  # т.к. IC50 нормализован от 0 до 1
        )

    def forward(self, protein, peptide):
        """
        protein: [B, 34, 1152]
        peptide: [B, L, 1152] (L ∈ [9, 21])
        """

        B, L, D = peptide.size()

        # позиционная кодировка
        protein = self.protein_pos(protein)
        peptide_fwd = self.peptide_pos(peptide)
        peptide_rev = self.peptide_pos(torch.flip(peptide, dims=[1]))

        # cross-attention (protein queries, peptide keys/values)
        attn_out_fwd, _ = self.cross_attn(query=protein, key=peptide_fwd, value=peptide_fwd)
        attn_out_rev, _ = self.cross_attn(query=protein, key=peptide_rev, value=peptide_rev)

        # Инвариантность ориентации — усреднение
        attn_out = (attn_out_fwd + attn_out_rev) / 2  # [B, 34, 1152]

        # Пулинг по белку (например, mean pooling)
        pooled = attn_out.mean(dim=1)  # [B, 1152]

        return self.mlp(pooled)  # [B, 1]


#### 1.2

In [None]:
# import torch
# import torch.nn as nn
# import torch.nn.functional as F

# class PositionalEncoding(nn.Module):
#     def __init__(self, dim, max_len=1000):
#         super().__init__()
#         position = torch.arange(0, max_len).unsqueeze(1)
#         div_term = torch.exp(torch.arange(0, dim, 2) * -(torch.log(torch.tensor(10000.0)) / dim))
#         pe = torch.zeros(max_len, dim)
#         pe[:, 0::2] = torch.sin(position * div_term)
#         pe[:, 1::2] = torch.cos(position * div_term)
#         self.register_buffer('pe', pe.unsqueeze(0))  # [1, max_len, dim]

#     def forward(self, x):
#         return x + self.pe[:, :x.size(1), :]

# class CrossAttentionIC50Model(nn.Module):
#     def __init__(self, d_model=1152, nhead=8, dim_feedforward=2048, dropout=0.1):
#         super().__init__()
#         self.protein_pos = PositionalEncoding(d_model, max_len=34)
#         self.peptide_pos = PositionalEncoding(d_model, max_len=21)

#         self.cross_attn = nn.MultiheadAttention(embed_dim=d_model, num_heads=nhead, dropout=dropout, batch_first=True)

#         self.norm1 = nn.LayerNorm(d_model)
#         self.norm2 = nn.LayerNorm(d_model)

#         self.feedforward = nn.Sequential(
#             nn.Linear(d_model, dim_feedforward),
#             nn.ReLU(),
#             nn.Dropout(dropout),
#             nn.Linear(dim_feedforward, d_model),
#         )

#         self.out = nn.Sequential(
#             nn.Linear(d_model, dim_feedforward // 2),
#             nn.ReLU(),
#             nn.Dropout(dropout),
#             nn.Linear(dim_feedforward // 2, 1),
#             nn.Sigmoid(),  # Output is normalized IC50 ∈ [0, 1]
#         )

#     def forward(self, protein, peptide):
#         """
#         protein: [B, 34, 1152]
#         peptide: [B, L, 1152] (L ∈ [9, 21])
#         """
#         B, L, D = peptide.size()

#         # Positional encoding
#         protein = self.protein_pos(protein)
#         peptide_fwd = self.peptide_pos(peptide)
#         peptide_rev = self.peptide_pos(torch.flip(peptide, dims=[1]))

#         # Cross-attention (protein queries, peptide keys/values)
#         attn_out_fwd, _ = self.cross_attn(query=protein, key=peptide_fwd, value=peptide_fwd)
#         attn_out_rev, _ = self.cross_attn(query=protein, key=peptide_rev, value=peptide_rev)

#         # Orientation invariance: average
#         attn_out = (attn_out_fwd + attn_out_rev) / 2  # [B, 34, D]

#         # Add & Norm
#         x = self.norm1(protein + attn_out)

#         # Feedforward block (like Transformer block)
#         ff_out = self.feedforward(x)
#         x = self.norm2(x + ff_out)

#         # Mean pooling over protein sequence
#         pooled = x.mean(dim=1)  # [B, D]

#         return self.out(pooled)  # [B, 1]


#### 1.3

In [267]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class ProteinPeptideInteractionModel(nn.Module):
    def __init__(self, embedding_dim=1152, hidden_dim=512, num_heads=4):
        super(ProteinPeptideInteractionModel, self).__init__()
        
        self.peptide_proj = nn.Linear(embedding_dim, hidden_dim)
        self.protein_proj = nn.Linear(embedding_dim, hidden_dim)
        
        self.cross_attn = nn.MultiheadAttention(embed_dim=hidden_dim, num_heads=num_heads, batch_first=True)
        
        self.fc = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1),
            #nn.Sigmoid()  # так как выход от 0 до 1
        )
        
    def forward(self, protein, peptide):
        # protein: [B, 34, 1152]
        # peptide: [B, 21, 1152]
        
        # Проецируем эмбеддинги
        protein_proj = self.protein_proj(protein)   # [B, 34, hidden]
        peptide_proj = self.peptide_proj(peptide)   # [B, 21, hidden]
        
        # Кросс-аттеншн: пептид (query) взаимодействует с белком (key, value)
        attn_output, _ = self.cross_attn(query=peptide_proj,
                                         key=protein_proj,
                                         value=protein_proj)
        # Агрегируем: берем среднее по всем позициям пептида
        attn_repr = attn_output.mean(dim=1)        # [B, hidden]
        pep_repr = peptide_proj.mean(dim=1)        # [B, hidden]
        
        combined = torch.cat([attn_repr, pep_repr], dim=1)  # [B, hidden*2]
        
        output = self.fc(combined)  # [B, 1], от 0 до 1
        
        return output


### 2

In [190]:
class ImprovedCrossAttention(nn.Module):
    def __init__(self, d_model=1152, nhead=8, dim_feedforward=2048):
        super().__init__()
        self.protein_pos = PositionalEncoding(d_model, max_len=34)
        self.peptide_pos = PositionalEncoding(d_model, max_len=21)
        
        # Multi-head attention с layer norm
        self.attn = nn.MultiheadAttention(d_model, nhead, batch_first=True)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        
        # FFN с residual connection
        self.ffn = nn.Sequential(
            nn.Linear(d_model, dim_feedforward),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(dim_feedforward, d_model)
        )
        
        # Final MLP с улучшениями
        self.mlp = nn.Sequential(
            nn.Linear(d_model, dim_feedforward),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(dim_feedforward, 1),
            nn.Sigmoid()
        )

    def forward(self, protein, peptide):
        # Позиционные кодировки
        protein = self.protein_pos(protein)
        peptide = self.peptide_pos(peptide)
        
        # Cross-attention с residual
        attn_out, _ = self.attn(
            query=protein,
            key=peptide,
            value=peptide,
            need_weights=False
        )
        attn_out = self.norm1(protein + attn_out)
        
        # FFN с residual
        ffn_out = self.ffn(attn_out)
        ffn_out = self.norm2(attn_out + ffn_out)
        
        # Усреднение с инвариантностью к ориентации
        rev_out, _ = self.attn(
            query=protein,
            key=torch.flip(peptide, [1]),
            value=torch.flip(peptide, [1]),
            need_weights=False
        )
        combined = (ffn_out + rev_out) / 2
        
        # Улучшенный пулинг
        pooled = combined.max(dim=1)[0]  # Max pooling вместо mean
        return self.mlp(pooled)

### 3

In [175]:


class PairwiseEnergyModel(nn.Module):
    def __init__(self, embed_dim=1152, hidden_dim=512, dropout=0.3):
        super(PairwiseEnergyModel, self).__init__()
        # MLP для вычисления энергии E(i,j) для пар (protein[i], peptide[j])
        self.energy_mlp = nn.Sequential(
            nn.Linear(embed_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim // 2, 1)
        )
        # Финальный слой для предсказания IC50
        self.final_mlp = nn.Sequential(
            nn.Linear(34 * 21, hidden_dim),  # 34 (protein) x 21 (peptide)
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim // 2, 1)
        )

    def forward(self, proteins, peptides, lengths):
        batch_size = proteins.size(0)
        protein_len, peptide_max_len = 34, 21  # Фиксированные размеры из датасета

        # Подготовка для pairwise-взаимодействий
        proteins = proteins.unsqueeze(2).expand(-1, -1, peptide_max_len, -1)  # [B, 34, 21, 1152]
        peptides = peptides.unsqueeze(1).expand(-1, protein_len, -1, -1)     # [B, 34, 21, 1152]
        
        # Конкатенация парного представления
        pairwise_input = torch.cat([proteins, peptides], dim=-1)  # [B, 34, 21, 1152*2]
        
        # Вычисление энергий E(i,j)
        energies = self.energy_mlp(pairwise_input).squeeze(-1)  # [B, 34, 21]
        
        # Softmax для нормировки энергий
        energies = F.softmax(energies.view(batch_size, -1), dim=-1)  # [B, 34*21]
        
        # Финальное предсказание
        score = self.final_mlp(energies)  # [B, 1]
        return score

# # Пример использования
# model = PairwiseEnergyModel(embed_dim=1152, hidden_dim=512, dropout=0.3)
# proteins, peptides, lengths, scores = next(iter(DataLoader(dataset, batch_size=32, collate_fn=collate_fn)))
# output = model(proteins, peptides, lengths)  # [B, 1]

### 4

In [173]:

class DualEncoderModel(nn.Module):
    def __init__(self, embed_dim=1152, num_heads=8, num_layers=2, hidden_dim=512, dropout=0.3):
        super(DualEncoderModel, self).__init__()
        # Трансформер-энкодер для белка
        self.protein_encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads, dim_feedforward=hidden_dim, dropout=dropout),
            num_layers=num_layers
        )
        # Трансформер-энкодер для пептида
        self.peptide_encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads, dim_feedforward=hidden_dim, dropout=dropout),
            num_layers=num_layers
        )
        # Interaction Head
        self.interaction_head = nn.Sequential(
            nn.Linear(embed_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim // 2, 1)
        )

    def forward(self, proteins, peptides, lengths):
        batch_size = proteins.size(0)
        
        # Трансформер для белка: [B, 34, 1152] -> [B, 34, 1152]
        protein_out = self.protein_encoder(proteins)  # [B, 34, 1152]
        protein_out = protein_out.mean(dim=1)  # Pooling: [B, 1152]
        
        # Трансформер для пептида: [B, 21, 1152] -> [B, 21, 1152]
        peptide_out = self.peptide_encoder(peptides)  # [B, 21, 1152]
        peptide_out = peptide_out.mean(dim=1)  # Pooling: [B, 1152]
        
        # Объединение представлений
        interaction = torch.cat([protein_out, peptide_out], dim=-1)  # [B, 1152*2]
        
        # Предсказание IC50
        score = self.interaction_head(interaction)  # [B, 1]
        return score

# # Пример использования
# model = DualEncoderModel(embed_dim=1152, num_heads=8, num_layers=2, hidden_dim=512, dropout=0.3)
# proteins, peptides, lengths, scores = next(iter(DataLoader(dataset, batch_size=32, collate_fn=collate_fn)))
# output = model(proteins, peptides, lengths)  # [B, 1]

### 5

In [186]:
class CNNInteractionModel(nn.Module):
    def __init__(self, embed_dim=1152, num_filters=128, kernel_size=3, hidden_dim=512, dropout=0.3):
        super(CNNInteractionModel, self).__init__()
        
        # CNN-энкодер для белка
        self.protein_cnn = nn.Sequential(
            nn.Conv1d(in_channels=embed_dim, out_channels=num_filters, kernel_size=kernel_size, padding=kernel_size//2),
            nn.ReLU(),
            nn.BatchNorm1d(num_filters),
            nn.Conv1d(in_channels=num_filters, out_channels=num_filters, kernel_size=kernel_size, padding=kernel_size//2),
            nn.ReLU(),
            nn.BatchNorm1d(num_filters),
            nn.AdaptiveMaxPool1d(1)  # Global Max Pooling
        )
        
        # CNN-энкодер для пептида
        self.peptide_cnn = nn.Sequential(
            nn.Conv1d(in_channels=embed_dim, out_channels=num_filters, kernel_size=kernel_size, padding=kernel_size//2),
            nn.ReLU(),
            nn.BatchNorm1d(num_filters),
            nn.Conv1d(in_channels=num_filters, out_channels=num_filters, kernel_size=kernel_size, padding=kernel_size//2),
            nn.ReLU(),
            nn.BatchNorm1d(num_filters),
            nn.AdaptiveMaxPool1d(1)  # Global Max Pooling
        )
        
        # Interaction Head
        self.interaction_head = nn.Sequential(
            nn.Linear(num_filters * 2, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim // 2, 1)
        )

    def forward(self, proteins, peptides, lengths):
        batch_size = proteins.size(0)
        
        # Подготовка данных: [B, 34, 1152] -> [B, 1152, 34] для CNN
        proteins = proteins.permute(0, 2, 1)  # [B, 1152, 34]
        peptides = peptides.permute(0, 2, 1)  # [B, 1152, 21]
        
        # CNN для белка
        protein_out = self.protein_cnn(proteins).squeeze(-1)  # [B, num_filters]
        
        # CNN для пептида
        peptide_out = self.peptide_cnn(peptides).squeeze(-1)  # [B, num_filters]
        
        # Объединение представлений
        interaction = torch.cat([protein_out, peptide_out], dim=-1)  # [B, num_filters*2]
        
        # Предсказание IC50
        score = self.interaction_head(interaction)  # [B, 1]
        return score

# Пример использования
# model = CNNInteractionModel(embed_dim=1152, num_filters=128, kernel_size=3, hidden_dim=512, dropout=0.3)
# proteins, peptides, lengths, scores = next(iter(DataLoader(dataset, batch_size=32, collate_fn=collate_fn)))
# output = model(proteins, peptides, lengths)  # [B, 1]

### 6

# Create lightning module

In [319]:
class LModel(L.LightningModule):
    def __init__(self, model, learning_rate, weight_decay):
        super().__init__()
        self.save_hyperparameters(ignore=['model'])
        self.model = model

        self.learning_rate = learning_rate
        self.weight_decay = weight_decay

        self.train_metrics_regression = self._make_metrics_regression("train_")
        self.validation_metrics_regression = self._make_metrics_regression("validation_")
        self.train_metrics_classification = self._make_metrics_classification("train_")
        self.validation_metrics_classification = self._make_metrics_classification("validation_")

        #self.loss_fn = nn.MSELoss()
        #self.loss_fn = nn.HuberLoss()
        self.loss_fn = nn.BCEWithLogitsLoss()

        self.cutoff = 1.0 - np.log(500) / np.log(50000)

    def _make_metrics_classification(self, prefix):
        metrics = torchmetrics.MetricCollection(
            {
               "auroc": AUROC(num_classes=2, task="binary")
            },
            prefix=prefix)
        return metrics

    def _make_metrics_regression(self, prefix):
        metrics = torchmetrics.MetricCollection(
            {
                "pcc": PearsonCorrCoef(),
                #"srcc": SpearmanCorrCoef(),  
                "r2": R2Score(),             
            },
            prefix=prefix)
        return metrics


    # def forward(self, mhc_embeddings, peptide_embeddings):
    #     return self.model(mhc_embeddings, peptide_embeddings)

    def forward(self, data):
        return self.model(data)


    def _evaluate(self, batch, stage=None):

        # mhc_embeddings, peptide_embeddings, length, scores = batch
        # binary_scores = (scores >= self.cutoff).float()
        # logits = self.forward(mhc_embeddings, peptide_embeddings)#.squeeze()        
        # probs = logits.sigmoid()
        # #loss = self.loss_fn(logits, binary_scores) # For BCE
        # loss = self.loss_fn(probs, scores) # For regression

        scores = batch.y
        binary_scores = (scores >= self.cutoff).float()
        logits = self.forward(batch).squeeze()
        probs = logits.sigmoid()
        #loss = self.loss_fn(probs, scores)
        loss = self.loss_fn(logits, binary_scores) # For BCE


        metrics_dict = {f"{stage}_loss": loss}

        if stage == 'validation':
            val_metrics_regression = self.validation_metrics_regression(probs, scores)
            val_metrics_classification = self.validation_metrics_classification(probs, binary_scores)
            metrics_dict.update(val_metrics_regression)
            metrics_dict.update(val_metrics_classification)
        elif stage == 'train':
            train_metrics_regression = self.train_metrics_regression(probs, scores)
            train_metrics_classification = self.train_metrics_classification(probs, binary_scores)
            metrics_dict.update(train_metrics_regression)
            metrics_dict.update(train_metrics_classification)

            self.log_dict(metrics_dict, 
                          on_step=True, 
                          on_epoch=False, 
                          sync_dist=True, 
                          prog_bar=True)

        return loss

    def training_step(self, batch, batch_idx):
        loss = self._evaluate(batch, stage='train')
        return loss

    def on_train_epoch_end(self):
        self.train_metrics_classification.reset()
        self.train_metrics_regression.reset()

    def validation_step(self, batch, batch_idx):
        _ = self._evaluate(batch, stage='validation')        

    def on_validation_epoch_end(self):
        # Логируем валидационные метрики
        self.log_dict(self.validation_metrics_regression.compute(), 
                      on_step=False, 
                      on_epoch=True, 
                      sync_dist=True, 
                      prog_bar=True)
        self.validation_metrics_regression.reset()

        self.log_dict(self.validation_metrics_classification.compute(), 
                      on_step=False, 
                      on_epoch=True, 
                      sync_dist=True, 
                      prog_bar=True)
        self.validation_metrics_classification.reset()


    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), 
                                      lr=self.learning_rate, 
                                      weight_decay=self.weight_decay)
        

        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer,
            T_max=self.trainer.max_epochs,  # или конкретное число, например 50
            eta_min=1e-6
        )

        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "interval": "step",
                "frequency": 20,
            }
        }
        # scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        #     optimizer,
        #     mode='max',
        #     factor=0.5,
        #     patience=5,
        #     min_lr=1e-6
        # )
        
        # return {
        #     "optimizer": optimizer,
        #     "lr_scheduler": {
        #         "scheduler": scheduler,
        #         "monitor": "validation_auroc",
        #         "interval": "epoch",
        #         "frequency": 1,
        #     },
        # }
    

In [165]:
log_path = '/home/user11/results/logs/'
log_csv_path = '/home/user11/results/logs_csv/'
checkpoints_path = '/home/user11/results/models/'
EPOCHS = 100


In [320]:

#model_name = 'CrossAttentionPairwiseModel'


# model_name = 'PairwiseEnergyModel'
# obj_model = PairwiseEnergyModel(embed_dim=1152, hidden_dim=512, dropout=0.3)

# model_name = 'DualEncoderModel'
# obj_model = DualEncoderModel(embed_dim=1152, num_heads=4, num_layers=2, hidden_dim=256, dropout=0.3)

# model_name = 'CNNInteractionModel'
# obj_model = CNNInteractionModel(embed_dim=1152, num_filters=128, kernel_size=3, hidden_dim=512, dropout=0.3)

# model_name = 'ImprovedCrossAttention'
# obj_model = ImprovedCrossAttention(d_model=1152, nhead=8, dim_feedforward=2048)
    
# model_name = 'CrossAttentionIC50Model'
# obj_model = CrossAttentionIC50Model()

# model_name = 'CrossAttentionIC50Model'
# obj_model = CrossAttentionIC50Model(d_model=1152, nhead=4, dim_feedforward=2048, dropout=0.2)


# best
# model_name = 'ProteinPeptideInteractionModel'
# obj_model = ProteinPeptideInteractionModel(embedding_dim=1152, hidden_dim=512, num_heads=4)

model_name = 'GNNModel'
obj_model = GNNModel(input_dim=1152, hidden_dim=256, dropout=0.3)


model = LModel(obj_model,
               learning_rate=1e-4,
               weight_decay=1e-3,
               )

logger = pl_loggers.TensorBoardLogger(name=f"{model_name}", save_dir=log_path)
logger_csv = pl_loggers.CSVLogger(name=f"{model_name}", save_dir=log_csv_path)


checkpoint_callback = ModelCheckpoint(
    dirpath=os.path.join(checkpoints_path, f"{model_name}", f"version_{logger_csv.version}", "checkpoints"),
    filename="model-epoch={epoch:02d}",
    save_top_k=-1,
    every_n_epochs=1,
    save_on_train_epoch_end=True,
)

best_iou_callback = ModelCheckpoint(
    dirpath=os.path.join(checkpoints_path, f"{model_name}", f"version_{logger_csv.version}", "checkpoints"),
    filename="best_model_epoch={epoch:02d}-auroc={validation_auroc:.4f}",
    monitor="validation_auroc",
    mode="max",
    save_top_k=1,
    save_on_train_epoch_end=True,
)

early_stop = EarlyStopping(monitor="validation_auroc", patience=10, mode="max")

trainer = L.Trainer(
    max_epochs=EPOCHS,
    devices=[0],
    default_root_dir=f'{checkpoints_path}/{model_name}',
    logger=[logger, logger_csv],
    accelerator="gpu",
    precision="16-mixed",
    callbacks=[best_iou_callback, early_stop, TQDMProgressBar(refresh_rate=1)],
    log_every_n_steps=1
)



Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [None]:
trainer.fit(model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

  | Name                              | Type              | Params | Mode 
--------------------------------------------------------------------------------
0 | model                             | GNNModel          | 790 K  | train
1 | train_metrics_regression          | MetricCollection  | 0      | train
2 | validation_metrics_regression     | MetricCollection  | 0      | train
3 | train_metrics_classification      | MetricCollection  | 0      | train
4 | validation_metrics_classification | MetricCollection  | 0      | train
5 | loss_fn                           | BCEWithLogitsLoss | 0      | train
--------------------------------------------------------------------------------
790 K     Trainable params
0         Non-trainable params
790 K     Total params
3.161     Total estimated model params size (MB)
26        Modules in train mode
0         Modules in eval mode


Sanity Checking: |                                                       | 0/? [00:00<?, ?it/s]

Training: |                                                              | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

In [318]:
trainer.callback_metrics

{'train_loss': tensor(0.0150),
 'train_pcc': tensor(0.7141),
 'train_r2': tensor(0.4999),
 'train_auroc': tensor(0.8438),
 'validation_pcc': tensor(0.6301),
 'validation_r2': tensor(0.3795),
 'validation_auroc': tensor(0.8139)}