In [179]:
# ! pip install pytorch-metric-learning

In [186]:
import os
import lightning as L
from lightning.pytorch import loggers as pl_loggers
from lightning.pytorch.callbacks import TQDMProgressBar, ModelCheckpoint, EarlyStopping

import numpy as np
import pandas as pd

import seaborn as sns
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torchmetrics
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from torchmetrics.classification import (
    AUROC,
    Accuracy
)
from torchmetrics import (
    PearsonCorrCoef,
    SpearmanCorrCoef,
    R2Score
)

import ast

import torch
import numpy as np
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from pytorch_metric_learning.samplers import MPerClassSampler
import torch.nn.functional as F
from ast import literal_eval


from pytorch_metric_learning.losses import TripletMarginLoss
from pytorch_metric_learning.miners import TripletMarginMiner


plt.rcParams["savefig.bbox"] = 'tight'

# Загружаем данные

In [172]:
data = pd.read_csv("/home/user11/data/data_processed/data.tsv", sep="\t", names=["peptide", "score", "hla"])
#embeddings_table = pd.read_csv("/home/user11/data/embeddings_proteins/wide_data.tsv", sep="\t")
embeddings_table = pd.read_csv("/home/user11/data/embeddings_proteins/wide_data_32float.tsv", sep="\t")

i = 1

train = pd.read_csv(f"/home/user11/data/data_processed/train{i}", sep="\t", names=["peptide", "score", "hla"])
train.hla = train.hla.str.replace("_", "")
train_data = pd.merge(train, embeddings_table, on=["peptide", "score", "hla"])
train_data = train_data[train_data['peptide'].apply(len) <= 21]
train_data = train_data[train_data['peptide'].apply(len) >= 9]


val = pd.read_csv(f"/home/user11/data/data_processed/test{i}", sep="\t", names=["peptide", "score", "hla"])
val.hla = val.hla.str.replace("_", "")
val_data = pd.merge(val, embeddings_table, on=["peptide", "score", "hla"])
val_data = val_data[val_data['peptide'].apply(len) <= 21]
val_data = val_data[val_data['peptide'].apply(len) >= 9]

test_data = pd.read_csv("/home/user11/data/embeddings_proteins/wide_data_test.tsv", sep="\t")
test_data = pd.read_csv("/home/user11/data/embeddings_proteins/wide_data_test_32float.tsv", sep="\t")
print(len(test_data))
test_data = test_data[test_data['peptide'].apply(len) <= 21]
test_data = test_data[test_data['peptide'].apply(len) >= 9]
test_data = test_data.reset_index(drop=True)
print(len(test_data))


2052
2009


In [190]:
from collections import Counter

filter_ = []
for k, val in Counter(embeddings_table['hla']).items():
    if val <= 3:
        filter_.append(k)
filter_ = set(filter_)

In [193]:
import pandas as pd
from sklearn.model_selection import train_test_split
from collections import Counter

embeddings_table = pd.read_csv("/home/user11/data/embeddings_proteins/greate_data.tsv", sep="\t")


filter_ = []
for k, val in Counter(embeddings_table['hla']).items():
    if val <= 5:
        filter_.append(k)
filter_ = set(filter_)

embeddings_table = embeddings_table[[True if not i in filter_ else False for i in embeddings_table['hla']]]

# Сначала делим на train и temp (временно для val и test)
train_data, temp_df = train_test_split(embeddings_table, 
                                       test_size=0.3, 
                                       random_state=42, 
                                       shuffle=True,
                                       stratify=embeddings_table['hla'])

# Теперь делим temp на val и test
val_data, test_data = train_test_split(temp_df, 
                                       test_size=0.5, 
                                       random_state=42, 
                                       shuffle=True,
                                       stratify=temp_df['hla'])

# Проверка размеров
print(f"Train size: {len(train_data)}")
print(f"Validation size: {len(val_data)}")
print(f"Test size: {len(test_data)}")


Train size: 25216
Validation size: 5403
Test size: 5404


### Применение PCA + SVM

In [5]:
def pracess_table(data):
    featuries = []
    labels = []
    for index, row in data.iterrows():
        alpha_path = row['alpha_path']
        beta_path = row['beta_path']
        peptide_path = row['peptide_path']
        # alpha_positions = row['alpha_positions']
        # beta_positions = row['beta_positions']
        score = row['score']

        alpha_embeddings = np.load(alpha_path)[:,1:-1].squeeze(0).mean(axis=0)
        beta_embeddings = np.load(beta_path)[:,1:-1].squeeze(0).mean(axis=0)
        peptide_embeddings = torch.FloatTensor(np.load(peptide_path))[:,1:-1].squeeze(0).mean(axis=0)
        featuries.append(np.concat([alpha_embeddings, beta_embeddings, peptide_embeddings]))
        labels.append(score > 0.496)
    return featuries, labels



In [8]:
test_features, test_labels = pracess_table(test_data)
train_features, train_labels = pracess_table(train_data)

In [None]:
from sklearn.svm import SVC
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import (
    accuracy_score,
    classification_report,
    confusion_matrix,
    roc_auc_score,
)

# Предобработка
# test_features, test_labels = pracess_table(test_data)
# train_features, train_labels = pracess_table(train_data)

# Масштабирование признаков
scaler = StandardScaler()
train_scaled = scaler.fit_transform(train_features)
test_scaled = scaler.transform(test_features)

# PCA: уменьшаем размерность (например, до 20 компонентов — можно изменить)
pca = PCA(n_components=20)
train_pca = pca.fit_transform(train_scaled)
test_pca = pca.transform(test_scaled)

# Обучение SVM
model = SVC(kernel='rbf', C=1.0, gamma='scale', probability=True)
model.fit(train_pca, train_labels)

# Предсказания
predictions = model.predict(test_pca)
probas = model.predict_proba(test_pca)

# AUC ROC (только для бинарной классификации)
if len(model.classes_) == 2:
    auc = roc_auc_score(test_labels, probas[:, 1])
    print("AUC ROC:", auc)
else:
    print("AUC ROC доступен только для бинарной классификации.")

# Метрики
print("Accuracy:", accuracy_score(test_labels, predictions))
print("Confusion Matrix:\n", confusion_matrix(test_labels, predictions))
print("Classification Report:\n", classification_report(test_labels, predictions))


In [None]:
def is_exists(path):
    return os.path.exists(path)

: 

In [None]:
np.sum(test_data['peptide_path'].apply(is_exists)) == len(test_data)

np.True_

: 

In [None]:
np.load(train_data.iloc[10]['peptide_path'])

array([[[ 0.01533592,  0.04636236, -0.02179497, ...,  0.00812898,
          0.01680294, -0.01285871],
        [-0.00108992,  0.02830932,  0.00663236, ..., -0.00385393,
         -0.01219377,  0.01918724],
        [-0.00031568,  0.03600387, -0.02954051, ...,  0.01349472,
         -0.01393486, -0.00657987],
        ...,
        [-0.00483224,  0.0498512 , -0.0490288 , ..., -0.00414214,
          0.01086233, -0.00033358],
        [-0.01201398,  0.05209634, -0.0242684 , ..., -0.01939956,
         -0.00191659, -0.00660642],
        [-0.0449412 ,  0.0454271 ,  0.00334062, ...,  0.00191115,
          0.00306173, -0.00519145]]], shape=(1, 15, 1152), dtype=float32)

: 

### Пути и глобальные переменные

In [173]:
log_path = '/home/user11/results/logs/'
log_csv_path = '/home/user11/results/logs_csv/'
checkpoints_path = '/home/user11/results/models/'
EPOCHS = 100


## Работа с моделями на основе Attention

### Базовый датасет

In [216]:
def collate_fn(batch):
    proteins, peptides, lengths, scores = zip(*batch)
    
    proteins = torch.stack(proteins)        # [B, 34, 1152]
    peptides = torch.stack(peptides)        # [B, 21, 1152] — уже паддинг
    lengths = torch.tensor(lengths)         # [B]
    scores = torch.tensor(scores)#.unsqueeze(1)  # [B, 1]
    
    return proteins, peptides, lengths, scores


class MHCSequenceDataset(Dataset):
    def __init__(self, df):
        self.df = df
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        row = self.df.loc[idx]
        score = row['score']
        alpha_path = row['alpha_path']
        beta_path = row['beta_path']
        peptide_path = row['peptide_path']
        alpha_positions = row['alpha_positions']
        beta_positions = row['beta_positions']
                
        alpha_embeddings = np.load(alpha_path)[:, ast.literal_eval(alpha_positions), :].squeeze(0)
        beta_embeddings = np.load(beta_path)[:, ast.literal_eval(beta_positions), :].squeeze(0)
        peptide_embeddings = torch.FloatTensor(np.load(peptide_path))[:,1:-1].squeeze(0)

        peptide_len = peptide_embeddings.shape[0]

        # Паддинг по центру до 21
        total_pad = 21 - peptide_len
        left_pad = total_pad // 2
        right_pad = total_pad - left_pad
        peptide_padded = F.pad(peptide_embeddings, (0, 0, left_pad, right_pad), 'constant', value=0)
        protein = torch.FloatTensor(np.concatenate([alpha_embeddings, beta_embeddings], axis=0))

        # peptide_mask = torch.zeros(peptide_padded.size(0), dtype=torch.bool)
        # peptide_mask[peptide_len:] = True

        return protein, peptide_padded, peptide_len, torch.tensor(score, dtype=torch.float)
    
    

###  Датасет с маской

In [224]:
def collate_fn(batch):
    proteins, peptides, masks, scores = zip(*batch)

    proteins = torch.stack(proteins)        # [B, 34, 1152]
    peptides = torch.stack(peptides)        # [B, 21, 1152]
    masks = torch.stack(masks)              # [B, 21] — attention mask
    scores = torch.tensor(scores)           # [B]

    return proteins, peptides, masks, scores


class MHCSequenceDataset(Dataset):
    def __init__(self, df):
        self.df = df
        
    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        score = row['score']
        alpha_path = row['alpha_path']
        beta_path = row['beta_path']
        peptide_path = row['peptide_path']
        alpha_positions = row['alpha_positions']
        beta_positions = row['beta_positions']
                
        alpha_embeddings = np.load(alpha_path)[:, ast.literal_eval(alpha_positions), :].squeeze(0)
        beta_embeddings = np.load(beta_path)[:, ast.literal_eval(beta_positions), :].squeeze(0)
        peptide_embeddings = torch.FloatTensor(np.load(peptide_path))[:,1:-1].squeeze(0)

        peptide_len = peptide_embeddings.shape[0]

        # Паддинг по центру до 21
        total_pad = 21 - peptide_len
        left_pad = total_pad // 2
        right_pad = total_pad - left_pad
        peptide_padded = F.pad(peptide_embeddings, (0, 0, left_pad, right_pad), 'constant', value=0)

        # Маска: 1 — реальные токены, 0 — паддинг
        mask = torch.zeros(21, dtype=torch.bool)
        mask[left_pad:left_pad + peptide_len] = True

        protein = torch.FloatTensor(np.concatenate([alpha_embeddings, beta_embeddings], axis=0))

        return protein, peptide_padded, mask, torch.tensor(score, dtype=torch.float)


### Датасет с использованием маски для данных, где эмбендинг считался по белку эпитопа

In [209]:
def collate_fn(batch):
    proteins, peptides, masks, scores = zip(*batch)

    proteins = torch.stack(proteins)        # [B, 34, 1152]
    peptides = torch.stack(peptides)        # [B, 21, 1152]
    masks = torch.stack(masks)              # [B, 21] — attention mask
    scores = torch.tensor(scores)           # [B]

    return proteins, peptides, masks, scores


class MHCSequenceDataset(Dataset):
    def __init__(self, df):
        self.df = df
        
    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        score = row['score']
        alpha_path = row['alpha_path']
        beta_path = row['beta_path']
        peptide_path = row['epitop_full_sequence_emb_path']
        start_epitop = int(row['start_epitop'] + 1) # +1 к как первый вектор технический
        end_epitop = int(row['end_epitop'] + 1)# +1 к как первый вектор технический
        alpha_positions = row['alpha_positions']
        beta_positions = row['beta_positions']
                

        alpha_embeddings = np.load(alpha_path)[:, ast.literal_eval(alpha_positions), :].squeeze(0)
        beta_embeddings = np.load(beta_path)[:, ast.literal_eval(beta_positions), :].squeeze(0)
        peptide_embeddings = torch.FloatTensor(np.load(peptide_path)[:, start_epitop:end_epitop, :].squeeze(0))

        peptide_len = peptide_embeddings.shape[0]

        # Паддинг по центру до 21
        total_pad = 21 - peptide_len
        left_pad = total_pad // 2
        right_pad = total_pad - left_pad
        peptide_padded = F.pad(peptide_embeddings, (0, 0, left_pad, right_pad), 'constant', value=0)

        # Маска: 1 — реальные токены, 0 — паддинг
        mask = torch.zeros(21, dtype=torch.bool)
        mask[left_pad:left_pad + peptide_len] = True

        protein = torch.FloatTensor(np.concatenate([alpha_embeddings, beta_embeddings], axis=0))

        return protein, peptide_padded, mask, torch.tensor(score, dtype=torch.float)


### Обьявляем наши датасеты

#### Получен из базы данных + была подготовка эмбендингов и др-е.

In [195]:
BATCH_SIZE = 512

train_dataset = MHCSequenceDataset(train_data)
val_dataset = MHCSequenceDataset(val_data)
test_dataset = MHCSequenceDataset(test_data)

train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=8, collate_fn=collate_fn)
val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=8, collate_fn=collate_fn)
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=8, collate_fn=collate_fn)

#len(train_dataset), train_dataset[30][1].shape

#### Был извлечен эмбенддинг для целого белка эпитопа

In [225]:
BATCH_SIZE = 512

train_dataset = MHCSequenceDataset(train_data)
val_dataset = MHCSequenceDataset(val_data)
test_dataset = MHCSequenceDataset(test_data)

train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=8, collate_fn=collate_fn)
val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=8, collate_fn=collate_fn)
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=8, collate_fn=collate_fn)

#len(train_dataset), train_dataset[30][1].shape

In [220]:
train_data.iloc[0]['epitop_full_sequence_emb_path']

'/home/user11/data/embeddings_proteins/emb_esmc_full_epitopes_650m//Q40960.npy'

In [200]:
peptide_embeddings = np.load(train_data.iloc[0]['epitop_full_sequence_emb_path'])
peptide_embeddings.shape

(1, 314, 1152)

In [202]:
peptide_embeddings[:,1:100,:]

array([[[ 0.03336792,  0.03972925, -0.00214513, ..., -0.02084067,
          0.00078157, -0.02744721],
        [ 0.00825546,  0.05325219, -0.00796509, ..., -0.01479157,
          0.01584252, -0.03361255],
        [-0.03011555,  0.05421186,  0.01487346, ..., -0.02082266,
          0.03606271, -0.02635617],
        ...,
        [ 0.00183306,  0.00074739,  0.03007515, ...,  0.01752948,
          0.01187115, -0.02337635],
        [-0.0229768 ,  0.03160378, -0.01471132, ...,  0.01785925,
          0.00329863, -0.01348982],
        [ 0.00669834,  0.02358298,  0.00841137, ...,  0.00549607,
          0.00671421, -0.00211226]]], shape=(1, 99, 1152), dtype=float32)

In [212]:
batch = next(iter(train_dataloader))

In [None]:
batch[1].shape

torch.Size([512, 21, 1152])

: 

### Модели с attention

#### 1) ProteinPeptideInteractionModel

In [222]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class ProteinPeptideInteractionModel(nn.Module):
    def __init__(self, embedding_dim=1152, hidden_dim=512, num_heads=4, dropout=0.3):
        super(ProteinPeptideInteractionModel, self).__init__()
        
        self.peptide_proj = nn.Linear(embedding_dim, hidden_dim)
        self.protein_proj = nn.Linear(embedding_dim, hidden_dim)
        
        self.cross_attn = nn.MultiheadAttention(embed_dim=hidden_dim, 
                                                num_heads=num_heads, 
                                                batch_first=True,
                                                dropout=dropout)
        
        self.fc = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, 1),
        )
        
    def forward(self, protein, peptide):
        # protein: [B, 34, 1152]
        # peptide: [B, 21, 1152]
        
        # Проецируем эмбеддинги
        protein_proj = self.protein_proj(protein)   # [B, 34, hidden]
        peptide_proj = self.peptide_proj(peptide)   # [B, 21, hidden]
        
        # Кросс-аттеншн: пептид (query) взаимодействует с белком (key, value)
        attn_output, _ = self.cross_attn(query=peptide_proj,
                                         key=protein_proj,
                                         value=protein_proj)

        # Агрегируем: берем среднее по всем позициям пептида
        attn_repr = attn_output.mean(dim=1)        # [B, hidden]
        pep_repr = peptide_proj.mean(dim=1)        # [B, hidden]
        
        combined = torch.cat([attn_repr, pep_repr], dim=1)  # [B, hidden*2]
        
        output = self.fc(combined)
        
        return output
    

# 2
class ProteinPeptideInteractionModelMask(nn.Module):
    def __init__(self, embedding_dim=1152, hidden_dim=512, num_heads=4, dropout=0.3):
        super(ProteinPeptideInteractionModelMask, self).__init__()
        
        self.peptide_proj = nn.Linear(embedding_dim, hidden_dim)
        self.protein_proj = nn.Linear(embedding_dim, hidden_dim)
        
        self.cross_attn = nn.MultiheadAttention(embed_dim=hidden_dim, 
                                                num_heads=num_heads, 
                                                batch_first=True,
                                                dropout=dropout)
        
        self.fc = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, 1),
        )
        
    def forward(self, protein, peptide, peptide_mask):
        # protein: [B, 34, 1152]
        # peptide: [B, 21, 1152]
        # peptide_mask: [B, 21] — bool tensor, True = настоящий токен, False = паддинг
        
        # Проецируем
        protein_proj = self.protein_proj(protein)   # [B, 34, hidden]
        peptide_proj = self.peptide_proj(peptide)   # [B, 21, hidden]
        
        # Кросс-аттеншн: пептид (query) ↔ белок (key/value)
        attn_output, _ = self.cross_attn(
            query=peptide_proj,
            key=protein_proj,
            value=protein_proj,
            key_padding_mask=None  # можно добавить, если у белка будет паддинг
        )

        # Маскированное среднее по пептиду
        mask = peptide_mask.unsqueeze(-1).float()  # [B, 21, 1]
        attn_repr = (attn_output * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1.0)
        pep_repr = (peptide_proj * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1.0)
        
        combined = torch.cat([attn_repr, pep_repr], dim=1)  # [B, hidden*2]
        
        output = self.fc(combined)  # [B, 1]
        return output


# 3
class ProteinPeptideInteractionModelWithSmartGate(nn.Module):
    def __init__(self, embedding_dim=1152, hidden_dim=512, num_heads=4, dropout=0.3):
        super(ProteinPeptideInteractionModelWithSmartGate, self).__init__()
        
        self.peptide_proj = nn.Linear(embedding_dim, hidden_dim)
        self.protein_proj = nn.Linear(embedding_dim, hidden_dim)
        
        self.cross_attn = nn.MultiheadAttention(embed_dim=hidden_dim, 
                                                num_heads=num_heads, 
                                                batch_first=True,
                                                dropout=dropout)
        
        self.fc = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, 1),
        )
        
        self.gate  = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, 2)
        )

    def forward(self, protein, peptide):
        # protein: [B, 34, 1152]
        # peptide: [B, 21, 1152]
        
        # Проецируем эмбеддинги
        protein_proj = self.protein_proj(protein)   # [B, 34, hidden]
        peptide_proj = self.peptide_proj(peptide)   # [B, 21, hidden]
        
        # Кросс-аттеншн: пептид (query) взаимодействует с белком (key, value)
        attn_output, attn_weights = self.cross_attn(query=peptide_proj,
                                         key=protein_proj,
                                         value=protein_proj,
                                        average_attn_weights=False)
        # Агрегируем: берем среднее по всем позициям пептида
        attn_repr = attn_output.mean(dim=1)        # [B, hidden]
        pep_repr = peptide_proj.mean(dim=1)        # [B, hidden]
        
        combined = torch.cat([attn_repr, pep_repr], dim=1)  # [B, hidden*2]
        alpha, beta = self.gate(combined).chunk(chunks=2, dim=-1)

        combined_2 = torch.cat([attn_repr * alpha, pep_repr * beta], dim=1)
        output = self.fc(combined_2)

        return output
    

# 4
class ProteinPeptideInteractionModelWithSmartGateMask(nn.Module):
    def __init__(self, embedding_dim=1152, hidden_dim=512, num_heads=4, dropout=0.3):
        super(ProteinPeptideInteractionModelWithSmartGateMask, self).__init__()
        
        self.peptide_proj = nn.Linear(embedding_dim, hidden_dim)
        self.protein_proj = nn.Linear(embedding_dim, hidden_dim)
        
        self.cross_attn = nn.MultiheadAttention(
            embed_dim=hidden_dim, 
            num_heads=num_heads, 
            batch_first=True,
            dropout=dropout
        )
        
        self.fc = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, 1),
        )
        
        self.gate = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, 2)
        )

    def forward(self, protein, peptide, peptide_mask):
        """
        protein:       [B, 34, 1152]
        peptide:       [B, 21, 1152]
        peptide_mask:  [B, 21] — bool tensor (True = настоящий токен, False = паддинг)
        """
        # Проекции
        protein_proj = self.protein_proj(protein)   # [B, 34, hidden]
        peptide_proj = self.peptide_proj(peptide)   # [B, 21, hidden]
        
        # Кросс-аттеншн
        attn_output, _ = self.cross_attn(
            query=peptide_proj,
            key=protein_proj,
            value=protein_proj,
            key_padding_mask=None,  # не нужен, так как белок фиксированной длины
            average_attn_weights=False
        )

        # Маска: [B, 21, 1]
        mask = peptide_mask.unsqueeze(-1).float()         # float: True → 1.0, False → 0.0
        mask_sum = mask.sum(dim=1).clamp(min=1.0)         # избегаем деления на 0

        # Маскированное среднее
        attn_repr = (attn_output * mask).sum(dim=1) / mask_sum  # [B, hidden]
        pep_repr  = (peptide_proj * mask).sum(dim=1) / mask_sum # [B, hidden]

        # Smart Gate
        combined = torch.cat([attn_repr, pep_repr], dim=1)         # [B, hidden*2]
        alpha, beta = self.gate(combined).chunk(2, dim=-1)         # [B, hidden], [B, hidden]
        gated = torch.cat([attn_repr * alpha, pep_repr * beta], dim=1)  # [B, hidden*2]

        return self.fc(gated)  # [B, 1]
    


class ProteinPeptideInteractionModelWithSmartGateAndWeighting(nn.Module):
    def __init__(self, embedding_dim=1152, hidden_dim=512, num_heads=4, dropout=0.3):
        super(ProteinPeptideInteractionModelWithSmartGateAndWeighting, self).__init__()
        
        self.peptide_proj = nn.Linear(embedding_dim, hidden_dim)
        self.protein_proj = nn.Linear(embedding_dim, hidden_dim)
        
        self.cross_attn = nn.MultiheadAttention(embed_dim=hidden_dim, 
                                                num_heads=num_heads, 
                                                batch_first=True,
                                                dropout=dropout)
        
        self.fc = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, 1),
        )
        
        self.gate  = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, 2)
        )

        self.peptide_importance = nn.Sequential(
            nn.Linear(hidden_dim, 32),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(32, 1)
        )

    def forward(self, protein, peptide, peptide_mask):
        # protein: [B, 34, 1152]
        # peptide: [B, 21, 1152]
        
        # Проецируем эмбеддинги
        protein_proj = self.protein_proj(protein)   # [B, 34, hidden]
        peptide_proj = self.peptide_proj(peptide)   # [B, 21, hidden]
        
        # Кросс-аттеншн: пептид (query) взаимодействует с белком (key, value)
        attn_output, attn_weights = self.cross_attn(query=peptide_proj,
                                        key=protein_proj,
                                        value=protein_proj,
                                        average_attn_weights=False)
        # Агрегируем: берем среднее по всем позициям пептида
        attn_repr = attn_output.mean(dim=1)        # [B, hidden]

        pep_weights = self.peptide_importance(peptide_proj) 
        pep_repr = (peptide_proj * pep_weights.softmax(dim=1)).sum(dim=1) # [B, hidden]

        combined = torch.cat([attn_repr, pep_repr], dim=1)  # [B, hidden*2]
        alpha, beta = self.gate(combined).chunk(chunks=2, dim=-1)

        combined_2 = torch.cat([attn_repr * alpha, pep_repr * beta], dim=1)
        output = self.fc(combined_2)

        return output



#### 2) CrossAttentionModel (Модель с масками и позиционным кодированием)

In [26]:

class PositionalEncoding(nn.Module):
    def __init__(self, dim, max_len=1000):
        super().__init__()
        pe = torch.zeros(max_len, dim)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, dim, 2) * -(torch.log(torch.tensor(10000.0)) / dim))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.pe = pe.unsqueeze(0)  # [1, max_len, dim]

    def forward(self, x):
        return x + self.pe[:, :x.size(1), :].to(x.device)

class CrossAttentionModel(nn.Module):
    def __init__(self, d_model=1152, nhead=8, dim_feedforward=2048, dropout=0.2):
        super().__init__()
        self.protein_pos = PositionalEncoding(d_model, max_len=34)
        self.peptide_pos = PositionalEncoding(d_model, max_len=21)

        self.cross_attn = nn.MultiheadAttention(embed_dim=d_model, num_heads=nhead, dropout=dropout, batch_first=True)

        self.mlp = nn.Sequential(
            nn.Linear(d_model, dim_feedforward),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(dim_feedforward, 1),
        )

    def forward(self, protein, peptide, peptide_mask):

        #B, L, D = peptide.size()

        # позиционная кодировка
        protein = self.protein_pos(protein)
        peptide_fwd = self.peptide_pos(peptide)
        peptide_rev = self.peptide_pos(torch.flip(peptide, dims=[1]))

        peptide_mask_rev = torch.flip(peptide_mask, dims=[1])

        # cross-attention (protein queries, peptide keys/values)
        attn_out_fwd, _ = self.cross_attn(
            query=protein,
            key=peptide_fwd,
            value=peptide_fwd,
            key_padding_mask=~peptide_mask  # игнорировать паддинг
        )

        attn_out_rev, _ = self.cross_attn(
            query=protein,
            key=peptide_rev,
            value=peptide_rev,
            key_padding_mask=~peptide_mask_rev
        )

        # Инвариантность ориентации — усреднение
        attn_out = (attn_out_fwd + attn_out_rev) / 2  # [B, 34, 1152]

        # Пулинг по белку
        pooled = attn_out.mean(dim=1)  # [B, 1152]

        return self.mlp(pooled)  # [B, 1]


In [140]:
class ProteinPeptideInteractionTransformer(nn.Module):
    def __init__(self, embedding_dim=1152, hidden_dim=512, embedding_out_dim=128, num_heads=8, dropout=0.3):
        super().__init__()

        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim

        # Проекция
        self.protein_proj = nn.Linear(embedding_dim, hidden_dim)
        self.peptide_proj = nn.Linear(embedding_dim, hidden_dim)

        # Первый слой cross-attention
        self.cross_attn1 = nn.MultiheadAttention(embed_dim=hidden_dim, num_heads=num_heads, dropout=dropout, batch_first=True)
        self.norm1 = nn.LayerNorm(hidden_dim)

        # Второй слой cross-attention
        self.cross_attn2 = nn.MultiheadAttention(embed_dim=hidden_dim, num_heads=num_heads, dropout=dropout, batch_first=True)
        self.norm2 = nn.LayerNorm(hidden_dim)

        # Агрегация и объединение представлений
        self.encoder = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.LayerNorm(hidden_dim),
            nn.Linear(hidden_dim, embedding_out_dim)
        )

        # Предсказание биндинга
        self.score_head = nn.Sequential(
            nn.Linear(embedding_out_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )

    def forward(self, protein, peptide):
        """
        protein: [B, P_len, 1152]
        peptide: [B, pep_len, 1152]
        Returns:
            embedding: [B, embedding_out_dim]
            score:     [B]
        """
        B = protein.shape[0]

        # Проекция эмбеддингов
        protein_proj = self.protein_proj(protein)     # [B, P_len, hidden_dim]
        peptide_proj = self.peptide_proj(peptide)     # [B, pep_len, hidden_dim]

        # Первый cross-attention
        attn_out1, _ = self.cross_attn1(query=peptide_proj, key=protein_proj, value=protein_proj)
        attn_out1 = self.norm1(attn_out1 + peptide_proj)  # skip-connection

        # Второй cross-attention
        attn_out2, _ = self.cross_attn2(query=attn_out1, key=protein_proj, value=protein_proj)
        attn_out2 = self.norm2(attn_out2 + attn_out1)     # skip-connection

        # Усреднение
        peptide_repr = attn_out2.mean(dim=1)              # [B, hidden_dim]
        protein_repr = protein_proj.mean(dim=1)           # [B, hidden_dim]

        # Объединение
        combined = torch.cat([protein_repr, peptide_repr], dim=1)  # [B, hidden_dim * 2]

        # Получение эмбеддинга
        embedding = self.encoder(combined)                # [B, embedding_out_dim]
        embedding = F.normalize(embedding, p=2, dim=1)

        # Предсказание бинарного класса
        score = self.score_head(embedding).squeeze(-1)    # [B]

        return score


### Класс Lightning  для обучения и логирования модели

In [None]:

class LModelA(L.LightningModule):
    def __init__(self, model, learning_rate, weight_decay):
        super().__init__()
        self.save_hyperparameters(ignore=['model'])
        self.model = model

        self.learning_rate = learning_rate
        self.weight_decay = weight_decay

        # Метрики для каждой стадии
        self.train_metrics_regression = self._make_metrics_regression("train_")
        self.validation_metrics_regression = self._make_metrics_regression("validation_")
        self.test_metrics_regression = self._make_metrics_regression("test_")

        self.train_metrics_classification = self._make_metrics_classification("train_")
        self.validation_metrics_classification = self._make_metrics_classification("validation_")
        self.test_metrics_classification = self._make_metrics_classification("test_")

        # self.loss_fn = nn.MSELoss()
        self.loss_fn = nn.HuberLoss()
        #self.loss_fn = nn.BCEWithLogitsLoss()

        self.cutoff = 1.0 - np.log(500) / np.log(50000)

    def _make_metrics_classification(self, prefix):
        metrics = torchmetrics.MetricCollection(
            {
                "auroc": AUROC(task="binary"),
                "accuracy": Accuracy(task="binary")
            },
            prefix=prefix
        )
        return metrics

    def _make_metrics_regression(self, prefix):
        metrics = torchmetrics.MetricCollection(
            {
                "pcc": PearsonCorrCoef(),
                "r2": R2Score()
            },
            prefix=prefix
        )
        return metrics


    def forward(self, mhc_embeddings, peptide_embeddings):
        return self.model(mhc_embeddings, peptide_embeddings)
        

    def _evaluate(self, batch, stage=None):

        mhc_embeddings, peptide_embeddings, _, scores = batch
        binary_scores = (scores >= self.cutoff).float()
        logits = self.forward(mhc_embeddings, peptide_embeddings).squeeze()        
        probs = logits.sigmoid()
        #loss = self.loss_fn(logits, binary_scores) # For BCE
        loss = self.loss_fn(probs, scores) # For regression

        metrics_dict = {f"{stage}_loss": loss}

        if stage == 'train':
            metrics_dict.update(self.train_metrics_regression(probs, scores))
            metrics_dict.update(self.train_metrics_classification(probs, binary_scores))
        elif stage == 'validation':
            metrics_dict.update(self.validation_metrics_regression(probs, scores))
            metrics_dict.update(self.validation_metrics_classification(probs, binary_scores))
        elif stage == 'test':
            metrics_dict.update(self.test_metrics_regression(probs, scores))
            metrics_dict.update(self.test_metrics_classification(probs, binary_scores))

        self.log_dict(metrics_dict, 
                      on_step=(stage == 'train'), 
                      on_epoch=True, 
                      prog_bar=True, 
                      sync_dist=True,
                      batch_size=BATCH_SIZE)

        return loss

    def training_step(self, batch, batch_idx):
        return self._evaluate(batch, stage='train')

    def validation_step(self, batch, batch_idx):
        self._evaluate(batch, stage='validation')

    def test_step(self, batch, batch_idx):
        self._evaluate(batch, stage='test')

    def on_train_epoch_end(self):
        self.train_metrics_classification.reset()
        self.train_metrics_regression.reset()

    def on_validation_epoch_end(self):
        self.log_dict(self.validation_metrics_regression.compute(),
                      on_step=False, on_epoch=True, 
                      prog_bar=True, sync_dist=True, 
                      batch_size=BATCH_SIZE)
        self.validation_metrics_regression.reset()

        self.log_dict(self.validation_metrics_classification.compute(),
                      on_step=False, on_epoch=True, 
                      prog_bar=True, sync_dist=True, 
                      batch_size=BATCH_SIZE)
        self.validation_metrics_classification.reset()

    def on_test_epoch_end(self):
        self.log_dict(self.test_metrics_regression.compute(),
                      on_step=False, on_epoch=True, 
                      prog_bar=True, sync_dist=True, 
                      batch_size=BATCH_SIZE)
        self.test_metrics_regression.reset()

        self.log_dict(self.test_metrics_classification.compute(),
                      on_step=False, on_epoch=True, 
                      prog_bar=True, sync_dist=True, 
                      batch_size=BATCH_SIZE)
        self.test_metrics_classification.reset()

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), 
                                      lr=self.learning_rate, 
                                      weight_decay=self.weight_decay)

        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer,
            T_max=self.trainer.max_epochs,
            eta_min=1e-6
        )

        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "interval": "step",
                "frequency": 20,
            }
        }


### Класс Lightning  для обучения и логирования модели (с маской)

In [177]:

class LModelA(L.LightningModule):
    def __init__(self, model, learning_rate, weight_decay):
        super().__init__()
        self.save_hyperparameters(ignore=['model'])
        self.model = model

        self.learning_rate = learning_rate
        self.weight_decay = weight_decay

        # Метрики для каждой стадии
        self.train_metrics_regression = self._make_metrics_regression("train_")
        self.validation_metrics_regression = self._make_metrics_regression("validation_")
        self.test_metrics_regression = self._make_metrics_regression("test_")

        self.train_metrics_classification = self._make_metrics_classification("train_")
        self.validation_metrics_classification = self._make_metrics_classification("validation_")
        self.test_metrics_classification = self._make_metrics_classification("test_")

        # self.loss_fn = nn.MSELoss()
        self.loss_fn = nn.HuberLoss()
        #self.loss_fn = nn.BCEWithLogitsLoss()

        self.cutoff = 1.0 - np.log(500) / np.log(50000)

    def _make_metrics_classification(self, prefix):
        metrics = torchmetrics.MetricCollection(
            {
                "auroc": AUROC(task="binary"),
                "accuracy": Accuracy(task="binary")
            },
            prefix=prefix
        )
        return metrics

    def _make_metrics_regression(self, prefix):
        metrics = torchmetrics.MetricCollection(
            {
                "pcc": PearsonCorrCoef(),
                "r2": R2Score()
            },
            prefix=prefix
        )
        return metrics


    def forward(self, mhc_embeddings, peptide_embeddings, mask):
        return self.model(mhc_embeddings, peptide_embeddings, mask)
        

    def _evaluate(self, batch, stage=None):

        mhc_embeddings, peptide_embeddings, mask, scores = batch
        binary_scores = (scores >= self.cutoff).float()
        logits = self.forward(mhc_embeddings, peptide_embeddings, mask).squeeze()        
        probs = logits.sigmoid()
        #loss = self.loss_fn(logits, binary_scores) # For BCE
        loss = self.loss_fn(probs, scores) # For regression

        metrics_dict = {f"{stage}_loss": loss}

        if stage == 'train':
            metrics_dict.update(self.train_metrics_regression(probs, scores))
            metrics_dict.update(self.train_metrics_classification(probs, binary_scores))
        elif stage == 'validation':
            metrics_dict.update(self.validation_metrics_regression(probs, scores))
            metrics_dict.update(self.validation_metrics_classification(probs, binary_scores))
        elif stage == 'test':
            metrics_dict.update(self.test_metrics_regression(probs, scores))
            metrics_dict.update(self.test_metrics_classification(probs, binary_scores))

        self.log_dict(metrics_dict, 
                      on_step=(stage == 'train'), 
                      on_epoch=True, 
                      prog_bar=True, 
                      sync_dist=True,
                      batch_size=BATCH_SIZE)

        return loss

    def training_step(self, batch, batch_idx):
        return self._evaluate(batch, stage='train')

    def validation_step(self, batch, batch_idx):
        self._evaluate(batch, stage='validation')

    def test_step(self, batch, batch_idx):
        self._evaluate(batch, stage='test')

    def on_train_epoch_end(self):
        self.train_metrics_classification.reset()
        self.train_metrics_regression.reset()

    def on_validation_epoch_end(self):
        self.log_dict(self.validation_metrics_regression.compute(),
                      on_step=False, on_epoch=True, 
                      prog_bar=True, sync_dist=True, 
                      batch_size=BATCH_SIZE)
        self.validation_metrics_regression.reset()

        self.log_dict(self.validation_metrics_classification.compute(),
                      on_step=False, on_epoch=True, 
                      prog_bar=True, sync_dist=True, 
                      batch_size=BATCH_SIZE)
        self.validation_metrics_classification.reset()

    def on_test_epoch_end(self):
        self.log_dict(self.test_metrics_regression.compute(),
                      on_step=False, on_epoch=True, 
                      prog_bar=True, sync_dist=True, 
                      batch_size=BATCH_SIZE)
        self.test_metrics_regression.reset()

        self.log_dict(self.test_metrics_classification.compute(),
                      on_step=False, on_epoch=True, 
                      prog_bar=True, sync_dist=True, 
                      batch_size=BATCH_SIZE)
        self.test_metrics_classification.reset()

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), 
                                      lr=self.learning_rate, 
                                      weight_decay=self.weight_decay)

        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer,
            T_max=self.trainer.max_epochs,
            eta_min=1e-6
        )

        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "interval": "step",
                "frequency": 20,
            }
        }


### Обучаем модели с Attention

In [226]:
# best
model_name = 'ProteinPeptideInteractionModelMask'
obj_model = ProteinPeptideInteractionModelMask(embedding_dim=1152, hidden_dim=512, num_heads=4, dropout=0.3)


model = LModelA(obj_model, 
                learning_rate=3e-4,
                weight_decay=1e-4)

logger = pl_loggers.TensorBoardLogger(name=f"{model_name}", save_dir=log_path)
logger_csv = pl_loggers.CSVLogger(name=f"{model_name}", save_dir=log_csv_path)


checkpoint_callback = ModelCheckpoint(
    dirpath=os.path.join(checkpoints_path, f"{model_name}", f"version_{logger_csv.version}", "checkpoints"),
    filename="model-epoch={epoch:02d}",
    save_top_k=-1,
    every_n_epochs=1,
    save_on_train_epoch_end=True,
)

best_iou_callback = ModelCheckpoint(
    dirpath=os.path.join(checkpoints_path, f"{model_name}", f"version_{logger_csv.version}", "checkpoints"),
    filename="best_model_epoch={epoch:02d}-auroc={validation_auroc:.4f}",
    monitor="validation_auroc",
    mode="max",
    save_top_k=1,
    save_on_train_epoch_end=True,
)

early_stop = EarlyStopping(monitor="validation_auroc", patience=10, mode="max")

trainer = L.Trainer(
    max_epochs=EPOCHS,
    devices=[0],
    default_root_dir=f'{checkpoints_path}/{model_name}',
    logger=[logger, logger_csv],
    accelerator="gpu",
    precision="32",
    callbacks=[best_iou_callback, early_stop, TQDMProgressBar(refresh_rate=1)],
    log_every_n_steps=1
)


trainer.fit(model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader)
#trainer.test(model, dataloaders=test_dataloader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

  | Name                              | Type                               | Params | Mode 
-------------------------------------------------------------------------------------------------
0 | model                             | ProteinPeptideInteractionModelMask | 2.8 M  | train
1 | train_metrics_regression          | MetricCollection                   | 0      | train
2 | validation_metrics_regression     | MetricCollection                   | 0      | train
3 | test_metrics_regression           | MetricCollection                   | 0      | train
4 | train_metrics_classification      | MetricCollection                   | 0      | train
5 | validation_metrics_classification | MetricCollection                   | 0      | train
6 | test_metrics_classification       | MetricCollection                   | 0      | train
7

Sanity Checking: |                                                       | 0/? [00:00<?, ?it/s]

Training: |                                                              | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=100` reached.


In [215]:
best_model_path = best_iou_callback.best_model_path
model = LModelA.load_from_checkpoint(best_model_path, model=obj_model)
trainer.test(model, dataloaders=test_dataloader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]


Testing: |                                                               | 0/? [00:00<?, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
      test_accuracy         0.6502590775489807
       test_auroc           0.7415729761123657
        test_loss          0.016708191484212875
        test_pcc            0.5298307538032532
         test_r2            0.27862781286239624
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test_loss': 0.016708191484212875,
  'test_pcc': 0.5298307538032532,
  'test_r2': 0.27862781286239624,
  'test_accuracy': 0.6502590775489807,
  'test_auroc': 0.7415729761123657}]

In [184]:
model_name = 'ProteinPeptideInteractionModelWithSmartGateAndWeighting'
obj_model = ProteinPeptideInteractionModelWithSmartGateAndWeighting(embedding_dim=1152, hidden_dim=512, num_heads=8, dropout=0.3)


model = LModelA(obj_model, 
                learning_rate=1e-4,
                weight_decay=1e-4)

logger = pl_loggers.TensorBoardLogger(name=f"{model_name}", save_dir=log_path)
logger_csv = pl_loggers.CSVLogger(name=f"{model_name}", save_dir=log_csv_path)


checkpoint_callback = ModelCheckpoint(
    dirpath=os.path.join(checkpoints_path, f"{model_name}", f"version_{logger_csv.version}", "checkpoints"),
    filename="model-epoch={epoch:02d}",
    save_top_k=-1,
    every_n_epochs=1,
    save_on_train_epoch_end=True,
)

best_iou_callback = ModelCheckpoint(
    dirpath=os.path.join(checkpoints_path, f"{model_name}", f"version_{logger_csv.version}", "checkpoints"),
    filename="best_model_epoch={epoch:02d}-auroc={validation_auroc:.4f}",
    monitor="validation_auroc",
    mode="max",
    save_top_k=1,
    save_on_train_epoch_end=True,
)

early_stop = EarlyStopping(monitor="validation_auroc", patience=10, mode="max")

trainer = L.Trainer(
    max_epochs=EPOCHS,
    devices=[0],
    default_root_dir=f'{checkpoints_path}/{model_name}',
    logger=[logger, logger_csv],
    accelerator="gpu",
    precision="32",
    callbacks=[best_iou_callback, early_stop, TQDMProgressBar(refresh_rate=1)],
    log_every_n_steps=1
)


trainer.fit(model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader)
#trainer.test(model, dataloaders=test_dataloader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

  | Name                              | Type                                                    | Params | Mode 
----------------------------------------------------------------------------------------------------------------------
0 | model                             | ProteinPeptideInteractionModelWithSmartGateAndWeighting | 3.3 M  | train
1 | train_metrics_regression          | MetricCollection                                        | 0      | train
2 | validation_metrics_regression     | MetricCollection                                        | 0      | train
3 | test_metrics_regression           | MetricCollection                                        | 0      | train
4 | train_metrics_classification      | MetricCollection                                        | 0      | train
5 | validation_metrics_classification 

Sanity Checking: |                                                       | 0/? [00:00<?, ?it/s]

Training: |                                                              | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

In [185]:
best_model_path = best_iou_callback.best_model_path
model = LModelA.load_from_checkpoint(best_model_path, model=obj_model)
trainer.test(model, dataloaders=test_dataloader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]


Testing: |                                                               | 0/? [00:00<?, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
      test_accuracy         0.7292185425758362
       test_auroc           0.5902154445648193
        test_loss          0.038490358740091324
        test_pcc            0.2269923835992813
         test_r2           -0.40101897716522217
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test_loss': 0.038490358740091324,
  'test_pcc': 0.2269923835992813,
  'test_r2': -0.40101897716522217,
  'test_accuracy': 0.7292185425758362,
  'test_auroc': 0.5902154445648193}]

In [29]:
# best
model_name = 'CrossAttentionModel'
obj_model = CrossAttentionModel(d_model=1152, nhead=8, dim_feedforward=2048, dropout=0.2)


model = LModelA(obj_model, 
                learning_rate=3e-4,
                weight_decay=1e-4)

logger = pl_loggers.TensorBoardLogger(name=f"{model_name}", save_dir=log_path)
logger_csv = pl_loggers.CSVLogger(name=f"{model_name}", save_dir=log_csv_path)


checkpoint_callback = ModelCheckpoint(
    dirpath=os.path.join(checkpoints_path, f"{model_name}", f"version_{logger_csv.version}", "checkpoints"),
    filename="model-epoch={epoch:02d}",
    save_top_k=-1,
    every_n_epochs=1,
    save_on_train_epoch_end=True,
)

best_iou_callback = ModelCheckpoint(
    dirpath=os.path.join(checkpoints_path, f"{model_name}", f"version_{logger_csv.version}", "checkpoints"),
    filename="best_model_epoch={epoch:02d}-auroc={validation_auroc:.4f}",
    monitor="validation_auroc",
    mode="max",
    save_top_k=1,
    save_on_train_epoch_end=True,
)

early_stop = EarlyStopping(monitor="validation_auroc", patience=10, mode="max")

trainer = L.Trainer(
    max_epochs=EPOCHS,
    devices=[0],
    default_root_dir=f'{checkpoints_path}/{model_name}',
    logger=[logger, logger_csv],
    accelerator="gpu",
    precision="32",
    callbacks=[best_iou_callback, early_stop, TQDMProgressBar(refresh_rate=1)],
    log_every_n_steps=1
)


trainer.fit(model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader)
#trainer.test(model, dataloaders=test_dataloader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

  | Name                              | Type                | Params | Mode 
----------------------------------------------------------------------------------
0 | model                             | CrossAttentionModel | 7.7 M  | train
1 | train_metrics_regression          | MetricCollection    | 0      | train
2 | validation_metrics_regression     | MetricCollection    | 0      | train
3 | test_metrics_regression           | MetricCollection    | 0      | train
4 | train_metrics_classification      | MetricCollection    | 0      | train
5 | validation_metrics_classification | MetricCollection    | 0      | train
6 | test_metrics_classification       | MetricCollection    | 0      | train
7 | loss_fn                           | HuberLoss           | 0      | train
-----------------------------------------------------------

Sanity Checking: |                                                       | 0/? [00:00<?, ?it/s]

Training: |                                                              | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

In [31]:
best_model_path = best_iou_callback.best_model_path
obj_model = CrossAttentionModel(d_model=1152, nhead=8, dim_feedforward=2048, dropout=0.2)
model = LModelA.load_from_checkpoint(best_model_path, model=obj_model)
trainer.test(model, dataloaders=test_dataloader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]


Testing: |                                                               | 0/? [00:00<?, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
      test_accuracy         0.7376804351806641
       test_auroc           0.5802174210548401
        test_loss           0.03555288910865784
        test_pcc            0.24202650785446167
         test_r2           -0.29383599758148193
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test_loss': 0.03555288910865784,
  'test_pcc': 0.24202650785446167,
  'test_r2': -0.29383599758148193,
  'test_accuracy': 0.7376804351806641,
  'test_auroc': 0.5802174210548401}]

In [165]:
model_name = 'TensorFusionInteractionModel'
obj_model = TensorFusionInteractionModel(embedding_dim=1152, hidden_dim=512, dropout=0.2)


model = LModelA(obj_model, 
                learning_rate=3e-4,
                weight_decay=1e-3)

logger = pl_loggers.TensorBoardLogger(name=f"{model_name}", save_dir=log_path)
logger_csv = pl_loggers.CSVLogger(name=f"{model_name}", save_dir=log_csv_path)


checkpoint_callback = ModelCheckpoint(
    dirpath=os.path.join(checkpoints_path, f"{model_name}", f"version_{logger_csv.version}", "checkpoints"),
    filename="model-epoch={epoch:02d}",
    save_top_k=-1,
    every_n_epochs=1,
    save_on_train_epoch_end=True,
)

best_iou_callback = ModelCheckpoint(
    dirpath=os.path.join(checkpoints_path, f"{model_name}", f"version_{logger_csv.version}", "checkpoints"),
    filename="best_model_epoch={epoch:02d}-auroc={validation_auroc:.4f}",
    monitor="validation_auroc",
    mode="max",
    save_top_k=1,
    save_on_train_epoch_end=True,
)

early_stop = EarlyStopping(monitor="validation_auroc", patience=10, mode="max")

trainer = L.Trainer(
    max_epochs=EPOCHS,
    devices=[0],
    default_root_dir=f'{checkpoints_path}/{model_name}',
    logger=[logger, logger_csv],
    accelerator="gpu",
    precision="32",
    callbacks=[best_iou_callback, early_stop, TQDMProgressBar(refresh_rate=1)],
    log_every_n_steps=1
)


trainer.fit(model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader)
best_model_path = best_iou_callback.best_model_path
model = LModelA.load_from_checkpoint(best_model_path, model=obj_model)
#trainer.test(model, dataloaders=test_dataloader)


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

  | Name                              | Type                         | Params | Mode 
-------------------------------------------------------------------------------------------
0 | model                             | TensorFusionInteractionModel | 1.4 M  | train
1 | train_metrics_regression          | MetricCollection             | 0      | train
2 | validation_metrics_regression     | MetricCollection             | 0      | train
3 | test_metrics_regression           | MetricCollection             | 0      | train
4 | train_metrics_classification      | MetricCollection             | 0      | train
5 | validation_metrics_classification | MetricCollection             | 0      | train
6 | test_metrics_classification       | MetricCollection             | 0      | train
7 | loss_fn                           | HuberLoss      

Sanity Checking: |                                                       | 0/? [00:00<?, ?it/s]

Training: |                                                              | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=20` reached.


In [166]:
best_model_path = best_iou_callback.best_model_path
model = LModelA.load_from_checkpoint(best_model_path, model=obj_model)
trainer.test(model, dataloaders=test_dataloader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]


Testing: |                                                               | 0/? [00:00<?, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
      test_accuracy         0.6993529200553894
       test_auroc           0.5953564643859863
        test_loss          0.039996568113565445
        test_pcc            0.2469816356897354
         test_r2           -0.45708346366882324
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test_loss': 0.039996568113565445,
  'test_pcc': 0.2469816356897354,
  'test_r2': -0.45708346366882324,
  'test_accuracy': 0.6993529200553894,
  'test_auroc': 0.5953564643859863}]

# Использование моделей на основе графов

## Датасет для графов

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import ast
import pandas as pd

from torch.utils.data import Dataset
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GINConv, GATConv, global_mean_pool, LayerNorm


class ProteinPeptideGNNDataset(Dataset):
    def __init__(self, df):
        self.df = df

    def __len__(self):
        return len(self.df)

    def create_chain_edges(self, length, offset=0):
        edges = []
        for i in range(length - 1):
            edges.append((i + offset, i + 1 + offset))
            edges.append((i + 1 + offset, i + offset))
        return edges

    def create_full_bipartite_edges(self, protein_len, peptide_len):
        edges = []
        for i in range(protein_len):
            for j in range(peptide_len):
                edges.append((i, j + protein_len))
                edges.append((j + protein_len, i))
        return edges

    def __getitem__(self, idx):
        row = self.df.loc[idx]
        score = row['score']
        alpha_path = row['alpha_path']
        beta_path = row['beta_path']
        peptide_path = row['peptide_path']
        alpha_positions = row['alpha_positions']
        beta_positions = row['beta_positions']

        alpha_embeddings = np.load(alpha_path)[:, ast.literal_eval(alpha_positions), :].squeeze(0)
        beta_embeddings = np.load(beta_path)[:, ast.literal_eval(beta_positions), :].squeeze(0)
        protein_embeddings = np.concatenate([alpha_embeddings, beta_embeddings], axis=0)
        peptide_embeddings = np.load(peptide_path)[:,1:-1].squeeze(0)

        protein_embeddings = torch.FloatTensor(protein_embeddings)
        peptide_embeddings = torch.FloatTensor(peptide_embeddings)

        x = torch.cat([protein_embeddings, peptide_embeddings], dim=0)
        protein_len = protein_embeddings.size(0)
        peptide_len = peptide_embeddings.size(0)

        seq_edges = self.create_chain_edges(protein_len) + self.create_chain_edges(peptide_len, offset=protein_len)
        int_edges = self.create_full_bipartite_edges(protein_len, peptide_len)
        edge_index = torch.tensor(seq_edges + int_edges, dtype=torch.long).t().contiguous()

        peptide_mask = torch.zeros(x.size(0), dtype=torch.bool)
        peptide_mask[protein_len:] = True
        y = torch.tensor([score], dtype=torch.float)

        data = Data(x=x, edge_index=edge_index, y=y)
        data.peptide_mask = peptide_mask
        data.protein_len = protein_len
        data.peptide_len = peptide_len
        return data

: 

### Объявляем датасеты

In [None]:
BATCH_SIZE = 512

train_dataset = ProteinPeptideGNNDataset(train_data)
val_dataset = ProteinPeptideGNNDataset(val_data)
test_dataset = ProteinPeptideGNNDataset(test_data)


train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=8)
val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=8)
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=8)


: 

### Модели на графах

#### 1)

In [None]:
# ---------- GNN Model ----------
class GNNModel(nn.Module):
    def __init__(self, input_dim=1152, hidden_dim=128, dropout=0.2):
        super().__init__()

        self.conv1 = GATConv(input_dim, hidden_dim, heads=2, concat=False)
        self.norm1 = LayerNorm(hidden_dim)
        self.conv2 = GATConv(hidden_dim, hidden_dim, heads=2, concat=False)
        self.norm2 = LayerNorm(hidden_dim)

        self.dropout = nn.Dropout(dropout)

        self.out = nn.Sequential(
            nn.Linear(hidden_dim * 2, 128),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(128, 1)
        )

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        batch = data.batch
        peptide_mask = data.peptide_mask

        x = F.elu(self.norm1(self.conv1(x, edge_index)))
        x = self.dropout(x)
        x = F.elu(self.norm2(self.conv2(x, edge_index)))
        x = self.dropout(x)

        prot_repr = global_mean_pool(x[~peptide_mask], batch[~peptide_mask])
        pep_repr = global_mean_pool(x[peptide_mask], batch[peptide_mask])

        combined = torch.cat([prot_repr, pep_repr], dim=1)
        score = self.out(combined)

        return score




: 

#### 2)

In [None]:
class GINModel(nn.Module):
    def __init__(self, input_dim=1152, hidden_dim=128, dropout=0.3):
        super().__init__()

        # MLP для GINConv (обычно состоит из 2-х линейных слоёв)
        def gin_mlp(in_dim, out_dim):
            return nn.Sequential(
                nn.Linear(in_dim, out_dim),
                nn.ReLU(),
                nn.Linear(out_dim, out_dim)
            )

        self.gin1 = GINConv(gin_mlp(input_dim, hidden_dim))
        self.norm1 = LayerNorm(hidden_dim)

        self.gin2 = GINConv(gin_mlp(hidden_dim, hidden_dim))
        self.norm2 = LayerNorm(hidden_dim)

        self.dropout = nn.Dropout(dropout)

        self.mlp = nn.Sequential(
            nn.Linear(hidden_dim * 2, 128),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(128, 1)
        )

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        batch = data.batch
        peptide_mask = data.peptide_mask

        x = F.relu(self.norm1(self.gin1(x, edge_index)))
        x = self.dropout(x)

        x = F.relu(self.norm2(self.gin2(x, edge_index)))
        x = self.dropout(x)

        prot_repr = global_mean_pool(x[~peptide_mask], batch[~peptide_mask])
        pep_repr  = global_mean_pool(x[peptide_mask], batch[peptide_mask])

        combined = torch.cat([prot_repr, pep_repr], dim=1)
        score = self.mlp(combined)
        return score

: 

### Lightning для графов

In [None]:

class LModelGM(L.LightningModule):
    def __init__(self, model, learning_rate, weight_decay):
        super().__init__()
        self.save_hyperparameters(ignore=['model'])
        self.model = model

        self.learning_rate = learning_rate
        self.weight_decay = weight_decay

        # Метрики для каждой стадии
        self.train_metrics_regression = self._make_metrics_regression("train_")
        self.validation_metrics_regression = self._make_metrics_regression("validation_")
        self.test_metrics_regression = self._make_metrics_regression("test_")

        self.train_metrics_classification = self._make_metrics_classification("train_")
        self.validation_metrics_classification = self._make_metrics_classification("validation_")
        self.test_metrics_classification = self._make_metrics_classification("test_")

        # self.loss_fn = nn.MSELoss()
        self.loss_fn = nn.HuberLoss()
        #self.loss_fn = nn.BCEWithLogitsLoss()

        self.cutoff = 1.0 - np.log(500) / np.log(50000)

    def _make_metrics_classification(self, prefix):
        metrics = torchmetrics.MetricCollection(
            {
                "auroc": AUROC(task="binary"),
                "accuracy": Accuracy(task="binary")
            },
            prefix=prefix
        )
        return metrics

    def _make_metrics_regression(self, prefix):
        metrics = torchmetrics.MetricCollection(
            {
                "pcc": PearsonCorrCoef(),
                "r2": R2Score()
            },
            prefix=prefix
        )
        return metrics


    def forward(self, batch):
        return self.model(batch)

    def _evaluate(self, batch, stage=None):
        scores = batch.y
        binary_scores = (scores >= self.cutoff).float()
        logits = self.forward(batch)
        probs = logits.sigmoid()        
        loss = self.loss_fn(probs, scores)

        metrics_dict = {f"{stage}_loss": loss}
        if stage == 'train':
            metrics_dict.update(self.train_metrics_regression(probs, scores))
            metrics_dict.update(self.train_metrics_classification(probs, binary_scores))
        elif stage == 'validation':
            metrics_dict.update(self.validation_metrics_regression(probs, scores))
            metrics_dict.update(self.validation_metrics_classification(probs, binary_scores))
        elif stage == 'test':
            metrics_dict.update(self.test_metrics_regression(probs, scores))
            metrics_dict.update(self.test_metrics_classification(probs, binary_scores))

        self.log_dict(metrics_dict, 
                      on_step=(stage == 'train'), 
                      on_epoch=True, 
                      prog_bar=True, 
                      sync_dist=True,
                      batch_size=BATCH_SIZE)

        return loss

    def training_step(self, batch, batch_idx):
        return self._evaluate(batch, stage='train')

    def validation_step(self, batch, batch_idx):
        self._evaluate(batch, stage='validation')

    def test_step(self, batch, batch_idx):
        self._evaluate(batch, stage='test')

    def on_train_epoch_end(self):
        self.train_metrics_classification.reset()
        self.train_metrics_regression.reset()

    def on_validation_epoch_end(self):
        self.log_dict(self.validation_metrics_regression.compute(),
                      on_step=False, on_epoch=True, 
                      prog_bar=True, sync_dist=True, 
                      batch_size=BATCH_SIZE)
        self.validation_metrics_regression.reset()

        self.log_dict(self.validation_metrics_classification.compute(),
                      on_step=False, on_epoch=True, 
                      prog_bar=True, sync_dist=True, 
                      batch_size=BATCH_SIZE)
        self.validation_metrics_classification.reset()

    def on_test_epoch_end(self):
        self.log_dict(self.test_metrics_regression.compute(),
                      on_step=False, on_epoch=True, 
                      prog_bar=True, sync_dist=True, 
                      batch_size=BATCH_SIZE)
        self.test_metrics_regression.reset()

        self.log_dict(self.test_metrics_classification.compute(),
                      on_step=False, on_epoch=True, 
                      prog_bar=True, sync_dist=True, 
                      batch_size=BATCH_SIZE)
        self.test_metrics_classification.reset()

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), 
                                      lr=self.learning_rate, 
                                      weight_decay=self.weight_decay)

        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer,
            T_max=self.trainer.max_epochs,
            eta_min=1e-6
        )

        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "interval": "step",
                "frequency": 20,
            }
        }


: 

In [None]:
model_name = 'GNNModel'
obj_model = GNNModel(input_dim=1152, hidden_dim=128, dropout=0.2)

model = LModelGM(obj_model, 
                 learning_rate=3e-4, 
                 weight_decay=1e-3)

logger = pl_loggers.TensorBoardLogger(name=f"{model_name}", save_dir=log_path)
logger_csv = pl_loggers.CSVLogger(name=f"{model_name}", save_dir=log_csv_path)

checkpoint_callback = ModelCheckpoint(
    dirpath=os.path.join(checkpoints_path, f"{model_name}", f"version_{logger_csv.version}", "checkpoints"),
    filename="model-epoch={epoch:02d}",
    save_top_k=-1,
    every_n_epochs=1,
    save_on_train_epoch_end=True,
)

best_iou_callback = ModelCheckpoint(
    dirpath=os.path.join(checkpoints_path, f"{model_name}", f"version_{logger_csv.version}", "checkpoints"),
    filename="best_model_epoch={epoch:02d}-auroc={validation_auroc:.4f}",
    monitor="validation_auroc",
    mode="max",
    save_top_k=1,
    save_on_train_epoch_end=True,
)

early_stop = EarlyStopping(monitor="validation_auroc", patience=10, mode="max")

trainer = L.Trainer(
    max_epochs=EPOCHS,
    devices=[0],
    default_root_dir=f'{checkpoints_path}/{model_name}',
    logger=[logger, logger_csv],
    accelerator="gpu",
    precision="16-mixed",
    callbacks=[best_iou_callback, early_stop, TQDMProgressBar(refresh_rate=1)],
    log_every_n_steps=1
)


trainer.fit(model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader)
trainer.test(model, dataloaders=test_dataloader)

Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

  | Name                              | Type             | Params | Mode 
-------------------------------------------------------------------------------
0 | model                             | GNNModel         | 362 K  | train
1 | train_metrics_regression          | MetricCollection | 0      | train
2 | validation_metrics_regression     | MetricCollection | 0      | train
3 | test_metrics_regression           | MetricCollection | 0      | train
4 | train_metrics_classification      | MetricCollection | 0      | train
5 | validation_metrics_classification | MetricCollection | 0      | train
6 | test_metrics_classification       | MetricCollection | 0      | train
7 | loss_fn                           | HuberLoss        | 0      | train
---------------------------------------------

Sanity Checking: |                                                       | 0/? [00:00<?, ?it/s]

Training: |                                                              | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]


Testing: |                                                               | 0/? [00:00<?, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
      test_accuracy         0.7147834897041321
       test_auroc           0.5235203504562378
        test_loss          0.042379897087812424
        test_pcc            0.10439755022525787
         test_r2            -0.536857008934021
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test_loss': 0.042379897087812424,
  'test_pcc': 0.10439755022525787,
  'test_r2': -0.536857008934021,
  'test_accuracy': 0.7147834897041321,
  'test_auroc': 0.5235203504562378}]

: 

In [None]:
model_name = 'GINModel'
obj_model = GINModel(input_dim=1152, hidden_dim=128, dropout=0.2)

model = LModelGM(obj_model, 
                 learning_rate=3e-4, 
                 weight_decay=1e-3)

logger = pl_loggers.TensorBoardLogger(name=f"{model_name}", save_dir=log_path)
logger_csv = pl_loggers.CSVLogger(name=f"{model_name}", save_dir=log_csv_path)

checkpoint_callback = ModelCheckpoint(
    dirpath=os.path.join(checkpoints_path, f"{model_name}", f"version_{logger_csv.version}", "checkpoints"),
    filename="model-epoch={epoch:02d}",
    save_top_k=-1,
    every_n_epochs=1,
    save_on_train_epoch_end=True,
)

best_iou_callback = ModelCheckpoint(
    dirpath=os.path.join(checkpoints_path, f"{model_name}", f"version_{logger_csv.version}", "checkpoints"),
    filename="best_model_epoch={epoch:02d}-auroc={validation_auroc:.4f}",
    monitor="validation_auroc",
    mode="max",
    save_top_k=1,
    save_on_train_epoch_end=True,
)

early_stop = EarlyStopping(monitor="validation_auroc", patience=10, mode="max")

trainer = L.Trainer(
    max_epochs=EPOCHS,
    devices=[0],
    default_root_dir=f'{checkpoints_path}/{model_name}',
    logger=[logger, logger_csv],
    accelerator="gpu",
    precision="16-mixed",
    callbacks=[best_iou_callback, early_stop, TQDMProgressBar(refresh_rate=1)],
    log_every_n_steps=1
)


trainer.fit(model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader)
trainer.test(model, dataloaders=test_dataloader)

Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

  | Name                              | Type             | Params | Mode 
-------------------------------------------------------------------------------
0 | model                             | GINModel         | 230 K  | train
1 | train_metrics_regression          | MetricCollection | 0      | train
2 | validation_metrics_regression     | MetricCollection | 0      | train
3 | test_metrics_regression           | MetricCollection | 0      | train
4 | train_metrics_classification      | MetricCollection | 0      | train
5 | validation_metrics_classification | MetricCollection | 0      | train
6 | test_metrics_classification       | MetricCollection | 0      | train
7 | loss_fn                           | HuberLoss        | 0      | train
---------------------------------------------

Sanity Checking: |                                                       | 0/? [00:00<?, ?it/s]

Training: |                                                              | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]


Testing: |                                                               | 0/? [00:00<?, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
      test_accuracy         0.7028372287750244
       test_auroc           0.5785893797874451
        test_loss          0.042902134358882904
        test_pcc            0.18665660917758942
         test_r2            -0.5628758668899536
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test_loss': 0.042902134358882904,
  'test_pcc': 0.18665660917758942,
  'test_r2': -0.5628758668899536,
  'test_accuracy': 0.7028372287750244,
  'test_auroc': 0.5785893797874451}]

: 

# Embeddings

### Модели для работы с эмбендингами

In [92]:
def collate_fn(batch):
    proteins, peptides, labels = zip(*batch)
    
    proteins = torch.stack(proteins)        # [B, 34, 1152]
    peptides = torch.stack(peptides)        # [B, 21, 1152] — уже паддинг
    labels = torch.tensor(labels).unsqueeze(0)
    
    return proteins, peptides, labels


class MHCSequenceDatasetEmb(Dataset):
    def __init__(self, df, labels):
        self.df = df
        self.labels = labels
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        label = self.labels[idx]
        alpha_path = row['alpha_path']
        beta_path = row['beta_path']
        peptide_path = row['peptide_path']
        alpha_positions = row['alpha_positions']
        beta_positions = row['beta_positions']
        
                
        alpha_embeddings = np.load(alpha_path)[:, ast.literal_eval(alpha_positions), :].squeeze(0)
        beta_embeddings = np.load(beta_path)[:, ast.literal_eval(beta_positions), :].squeeze(0)
        peptide_embeddings = torch.FloatTensor(np.load(peptide_path))[:,1:-1].squeeze(0)

        peptide_len = peptide_embeddings.shape[0]

        # Паддинг по центру до 21
        total_pad = 21 - peptide_len
        left_pad = total_pad // 2
        right_pad = total_pad - left_pad
        peptide_padded = F.pad(peptide_embeddings, (0, 0, left_pad, right_pad), 'constant', value=0)
        protein = torch.FloatTensor(np.concatenate([alpha_embeddings, beta_embeddings], axis=0))

        return protein, peptide_padded, int(label)
    
    

In [110]:
from pytorch_metric_learning.losses import TripletMarginLoss
from pytorch_metric_learning.samplers import MPerClassSampler
from pytorch_metric_learning.miners import TripletMarginMiner


cutoff = 1.0 - np.log(500) / np.log(50000)
train_labels = np.array((train_data['score'] >= cutoff), dtype=np.int64)

train_dataset = MHCSequenceDatasetEmb(train_data, train_labels)
sampler = MPerClassSampler(train_labels, m=256, length_before_new_iter=len(train_labels))
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, sampler=sampler, num_workers=8)

cutoff = 1.0 - np.log(500) / np.log(50000)
test_labels = np.array((test_data['score'] >= cutoff), dtype=np.int32)

test_dataset = MHCSequenceDatasetEmb(test_data, test_labels)
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, num_workers=8)

cutoff = 1.0 - np.log(500) / np.log(50000)
val_labels = np.array((val_data['score'] >= cutoff), dtype=np.int32)

val_dataset = MHCSequenceDatasetEmb(val_data, val_labels)
val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, num_workers=8)



In [108]:
batch = next(iter(train_dataloader))

In [109]:

(batch[2] == 1).sum()

tensor(256)

In [117]:
class ProteinPeptideInteractionEmbeddingModel(nn.Module):
    def __init__(self, embedding_dim=1152, hidden_dim=512, embedding_out_dim=128, num_heads=8, dropout=0.3):
        super().__init__()

        # Проекция эмбеддингов
        self.protein_proj = nn.Linear(embedding_dim, hidden_dim)
        self.peptide_proj = nn.Linear(embedding_dim, hidden_dim)

        # Кросс-аттеншн: пептид (query) ↔ белок (key/value)
        self.cross_attn = nn.MultiheadAttention(
            embed_dim=hidden_dim,
            num_heads=num_heads,
            dropout=dropout,
            batch_first=True
        )

        # Голова для эмбеддингов (triplet/metric learning)
        self.embedding_head = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, embedding_out_dim)
        )

        # Голова для предсказания score (IC50, аффинность и т.д.)
        self.score_head = nn.Sequential(
            nn.Linear(embedding_out_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )

    def forward(self, protein, peptide):
        """
        protein: [B, P_len, 1152]
        peptide: [B, pep_len, 1152]
        Returns:
            embedding: [B, embedding_out_dim]
            score:     [B] — регрессия
        """

        # 1. Проекция
        protein_proj = self.protein_proj(protein)   # [B, P_len, hidden_dim]
        peptide_proj = self.peptide_proj(peptide)   # [B, pep_len, hidden_dim]

        # 2. Cross-attention: пептид спрашивает белок
        attn_out, _ = self.cross_attn(query=peptide_proj,
                                      key=protein_proj,
                                      value=protein_proj)  # [B, pep_len, hidden_dim]

        # 3. Усреднение представлений
        attn_repr = attn_out.mean(dim=1)            # [B, hidden_dim]
        pep_repr = peptide_proj.mean(dim=1)         # [B, hidden_dim]

        # 4. Объединение
        combined = torch.cat([attn_repr, pep_repr], dim=1)  # [B, hidden_dim * 2]

        # 5. Эмбеддинг (L2-нормализованный)
        embedding = self.embedding_head(combined)           # [B, embedding_out_dim]
        embedding = F.normalize(embedding, p=2, dim=1)      # L2 нормировка

        # 6. Предсказание score
        score = self.score_head(embedding).squeeze(-1)      # [B]

        return embedding, score


In [130]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class ProteinPeptideInteractionTransformer(nn.Module):
    def __init__(self, embedding_dim=1152, hidden_dim=512, embedding_out_dim=128, num_heads=8, dropout=0.3):
        super().__init__()

        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim

        # Проекция
        self.protein_proj = nn.Linear(embedding_dim, hidden_dim)
        self.peptide_proj = nn.Linear(embedding_dim, hidden_dim)

        # Первый слой cross-attention
        self.cross_attn1 = nn.MultiheadAttention(embed_dim=hidden_dim, num_heads=num_heads, dropout=dropout, batch_first=True)
        self.norm1 = nn.LayerNorm(hidden_dim)

        # Второй слой cross-attention
        self.cross_attn2 = nn.MultiheadAttention(embed_dim=hidden_dim, num_heads=num_heads, dropout=dropout, batch_first=True)
        self.norm2 = nn.LayerNorm(hidden_dim)

        # Агрегация и объединение представлений
        self.encoder = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.LayerNorm(hidden_dim),
            nn.Linear(hidden_dim, embedding_out_dim)
        )

        # Предсказание биндинга
        self.score_head = nn.Sequential(
            nn.Linear(embedding_out_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )

    def forward(self, protein, peptide):
        """
        protein: [B, P_len, 1152]
        peptide: [B, pep_len, 1152]
        Returns:
            embedding: [B, embedding_out_dim]
            score:     [B]
        """
        B = protein.shape[0]

        # Проекция эмбеддингов
        protein_proj = self.protein_proj(protein)     # [B, P_len, hidden_dim]
        peptide_proj = self.peptide_proj(peptide)     # [B, pep_len, hidden_dim]

        # Первый cross-attention
        attn_out1, _ = self.cross_attn1(query=peptide_proj, key=protein_proj, value=protein_proj)
        attn_out1 = self.norm1(attn_out1 + peptide_proj)  # skip-connection

        # Второй cross-attention
        attn_out2, _ = self.cross_attn2(query=attn_out1, key=protein_proj, value=protein_proj)
        attn_out2 = self.norm2(attn_out2 + attn_out1)     # skip-connection

        # Усреднение
        peptide_repr = attn_out2.mean(dim=1)              # [B, hidden_dim]
        protein_repr = protein_proj.mean(dim=1)           # [B, hidden_dim]

        # Объединение
        combined = torch.cat([protein_repr, peptide_repr], dim=1)  # [B, hidden_dim * 2]

        # Получение эмбеддинга
        embedding = self.encoder(combined)                # [B, embedding_out_dim]
        embedding = F.normalize(embedding, p=2, dim=1)

        # Предсказание бинарного класса
        score = self.score_head(embedding).squeeze(-1)    # [B]

        return embedding, score


### Lightning для Emb

In [134]:

class LModelEmb(L.LightningModule):
    def __init__(self, model, learning_rate, weight_decay):
        super().__init__()
        self.save_hyperparameters(ignore=['model'])
        self.model = model

        self.learning_rate = learning_rate
        self.weight_decay = weight_decay


        self.loss_bce = nn.BCEWithLogitsLoss()
        self.loss_fn = TripletMarginLoss(margin=2.)
        self.miner = TripletMarginMiner(margin=2., type_of_triplets="semi-hard")


    def forward(self, mhc_embeddings, peptide_embeddings):
        return self.model(mhc_embeddings, peptide_embeddings)
        

    def _evaluate(self, batch, stage=None):
        mhc_embeddings, peptide_embeddings, labels = batch
        embeddings, scores = self.forward(mhc_embeddings, peptide_embeddings)
        hard_triplets = self.miner(embeddings, labels)    
        loss_tml = self.loss_fn(embeddings, labels, hard_triplets)
        loss_bce = self.loss_bce(scores, labels.float())

        metrics_dict = {f"{stage}_loss_tml": loss_tml,
                        f"{stage}_loss_bce": loss_bce,}
        # with torch.no_grad():
        #     dist_matrix = torch.cdist(embeddings, embeddings, p=2)
        #     print(f"{stage}: mean dist = {dist_matrix.mean():.4f}, max = {dist_matrix.max():.4f}, min = {dist_matrix.min():.4f}")
        self.log_dict(metrics_dict, 
                      on_step=(stage == 'train'), 
                      on_epoch=True, 
                      prog_bar=True, 
                      sync_dist=True,
                      batch_size=BATCH_SIZE)

        return loss_tml + loss_bce

    def training_step(self, batch, batch_idx):
        return self._evaluate(batch, stage='train')

    def validation_step(self, batch, batch_idx):
        self._evaluate(batch, stage='validation')

    def test_step(self, batch, batch_idx):
        self._evaluate(batch, stage='test')

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), 
                                      lr=self.learning_rate, 
                                      weight_decay=self.weight_decay)

        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer,
            T_max=self.trainer.max_epochs,
            eta_min=1e-6
        )

        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "interval": "step",
                "frequency": 20,
            }
        }


In [None]:
model_name = 'ProteinPeptideInteractionTransformer'
obj_model = ProteinPeptideInteractionTransformer(embedding_dim=1152, hidden_dim=512, 
                                                    embedding_out_dim=128, num_heads=8, dropout=0.3)

model = LModelEmb(obj_model, 
                 learning_rate=1e-, 
                 weight_decay=1e-3)

logger = pl_loggers.TensorBoardLogger(name=f"{model_name}", save_dir=log_path)
logger_csv = pl_loggers.CSVLogger(name=f"{model_name}", save_dir=log_csv_path)

checkpoint_callback = ModelCheckpoint(
    dirpath=os.path.join(checkpoints_path, f"{model_name}", f"version_{logger_csv.version}", "checkpoints"),
    filename="model-epoch={epoch:02d}",
    save_top_k=-1,
    every_n_epochs=1,
    save_on_train_epoch_end=True,
)

best_iou_callback = ModelCheckpoint(
    dirpath=os.path.join(checkpoints_path, f"{model_name}", f"version_{logger_csv.version}", "checkpoints"),
    filename="best_model_epoch={epoch:02d}-loss={validation_loss_tml:.4f}",
    monitor="validation_loss_tml",
    mode="min",
    save_top_k=1,
    save_on_train_epoch_end=True,
)

early_stop = EarlyStopping(monitor="validation_loss_tml", patience=10, mode="min")

trainer = L.Trainer(
    max_epochs=EPOCHS,
    devices=[1],
    default_root_dir=f'{checkpoints_path}/{model_name}',
    logger=[logger, logger_csv],
    accelerator="gpu",
    precision="32",
    callbacks=[best_iou_callback, early_stop, TQDMProgressBar(refresh_rate=1)],
    log_every_n_steps=1
)


trainer.fit(model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader)
trainer.test(model, dataloaders=test_dataloader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

  | Name     | Type                                 | Params | Mode 
--------------------------------------------------------------------------
0 | model    | ProteinPeptideInteractionTransformer | 3.9 M  | train
1 | loss_bce | BCEWithLogitsLoss                    | 0      | train
2 | loss_fn  | TripletMarginLoss                    | 0      | train
3 | miner    | TripletMarginMiner                   | 0      | train
--------------------------------------------------------------------------
3.9 M     Trainable params
0         Non-trainable params
3.9 M     Total params
15.535    Total estimated model params size (MB)
25        Modules in train mode
0         Modules in eval mode


Sanity Checking: |                                                       | 0/? [00:00<?, ?it/s]

Training: |                                                              | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                            | 0/? [00:00<?, ?it/s]


Detected KeyboardInterrupt, attempting graceful shutdown ...


NameError: name 'exit' is not defined