In [6]:
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import joblib
from src.dataset import ProteinDataset
from src.utils import train_model, test_model
import torch
from src.model import ChemicalShiftsPredictor, ChemicalShiftsPredictorAttention
from src.utils import packed_padded_collate

from tqdm.notebook import tqdm

from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence
import numpy as np

In [2]:
# Load and prepare data
csv_file = 'data/strict.csv'
prott5_file = 'data/embeddings/unfiltered_all_prott5.h5'
prott5_res_file = 'data/embeddings/unfiltered_all_prott5_res.h5'
prostt5_file = 'data/embeddings/prostt5.h5'
esm_file = 'data/embeddings/unfiltered_all_esm2_3b.h5'
esm_res_file = 'data/embeddings/unfiltered_all_esm2_3b_res.h5'
chemical_shifts_df = pd.read_csv(csv_file)

test_ids = []
with open("pdb_matched/final_test_ids.txt", "r") as f:
    for line in f:
        test_ids.append(line.strip())
        
        
chemical_shifts_df = chemical_shifts_df[chemical_shifts_df['ID'].isin(test_ids)]

In [3]:
target_column = 'H'

scaler = joblib.load(f'scaler_h.joblib')
#h_filtered = chemical_shifts_df.dropna(subset=[target_column])
chemical_shifts_df[target_column] = scaler.transform(chemical_shifts_df[target_column].values.reshape(-1, 1))

test_dataset = ProteinDataset([target_column], chemical_shifts_df, prott5_file, prott5_res_file, prostt5_file, esm_res_file, esm_file)


learning_rate = 0.001
weight_decay = 1e-5
patience = 10
batch_size = 128
num_epochs = 5

use_prostt5 = True
use_protein_mean = True
use_attention = True

model = ChemicalShiftsPredictor(use_prostt5=use_prostt5, use_protein_mean=use_protein_mean, use_attention=use_attention)
model.load_state_dict(torch.load('best_model_H.pth'))

model = model.cuda()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

if use_attention:
        test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=6, collate_fn=packed_padded_collate)
else:
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=6)

all_predictions = []

for inputs in tqdm(test_loader):
        amino_acid_prott5_emb, amino_acid_prostt5_emb, amino_acid_esm2_emb, protein_prott5_emb, protein_prostt5_emb_stack, targets = [x.to(device) for x in inputs]
        embeddings = [amino_acid_prott5_emb]
        if use_prostt5:
            embeddings.append(amino_acid_prostt5_emb)
        if use_protein_mean:
            embeddings.append(protein_prott5_emb)
            # if use_prostt5:
            #     embeddings.append(protein_prostt5_emb)
        embeddings.append(amino_acid_esm2_emb)
        concatenated_embeddings = torch.cat(embeddings, dim=1)
        if use_attention:
            sequence_lengths = protein_prostt5_emb_stack.abs().sum(dim=-1).nonzero()[:, 1].max(dim=-1, keepdim=True)[0] + 1

            # Create mask based on actual sequence lengths
            max_seq_len = protein_prostt5_emb_stack.size(2)
            mask = torch.arange(max_seq_len, device=device)[None, :] < sequence_lengths.to(device)
            outputs = model(concatenated_embeddings, protein_prostt5_emb_stack.to(device), mask.to(device))
            all_predictions.extend(outputs.cpu().detach().numpy())



  0%|          | 0/101 [00:00<?, ?it/s]

In [9]:
chemical_shifts_df["H_our"] = scaler.inverse_transform(np.array(all_predictions).reshape(-1, 1))

In [None]:
chemical_shifts_df["H", "H_our"].to_csv("H_our.csv", index=False, header=["H", "H_our"