In [2]:
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import joblib
from src.dataset import ProteinDataset
from src.utils import train_model, test_model
import torch
from src.model import ChemicalShiftsPredictor, ChemicalShiftsPredictorAttention
from src.utils import packed_padded_collate

from tqdm.notebook import tqdm

from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence
import numpy as np

Pyarrow will become a required dependency of pandas in the next major release of pandas (pandas 3.0),
(to allow more performant data types, such as the Arrow string type, and better interoperability with other libraries)
but was not found to be installed on your system.
If this would cause problems for you,
please provide us feedback at https://github.com/pandas-dev/pandas/issues/54466
        
  import pandas as pd


In [3]:
# Load and prepare data
csv_file = 'data/strict.csv'
prott5_file = 'data/embeddings/unfiltered_all_prott5.h5'
prott5_res_file = 'data/embeddings/unfiltered_all_prott5_res.h5'
prostt5_file = 'data/embeddings/prostt5.h5'
esm_file = 'data/embeddings/unfiltered_all_esm2_3b.h5'
esm_res_file = 'data/embeddings/unfiltered_all_esm2_3b_res.h5'
chemical_shifts_df = pd.read_csv(csv_file)

test_ids = []
with open("pdb_matched/final_test_ids.txt", "r") as f:
    for line in f:
        test_ids.append(line.strip())
        
        
chemical_shifts_df = chemical_shifts_df[chemical_shifts_df['ID'].isin(test_ids)]

In [25]:
target_column = 'H'

scaler = joblib.load(f'scaler_h.joblib')
#h_filtered = chemical_shifts_df.dropna(subset=[target_column])
chemical_shifts_df[target_column] = scaler.transform(chemical_shifts_df[target_column].values.reshape(-1, 1))

test_dataset = ProteinDataset([target_column], chemical_shifts_df, prott5_file, prott5_res_file, prostt5_file, esm_res_file, esm_file)


learning_rate = 0.001
weight_decay = 1e-5
patience = 10
batch_size = 128
num_epochs = 5

use_prostt5 = True
use_protein_mean = True
use_attention = True

model = ChemicalShiftsPredictor(use_prostt5=use_prostt5, use_protein_mean=use_protein_mean, use_attention=use_attention)
model.load_state_dict(torch.load('Full_1e-4_H.pth'))

model = model.cuda()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

if use_attention:
        test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=6, collate_fn=packed_padded_collate)
else:
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=6)

all_predictions = []

for inputs in tqdm(test_loader):
        amino_acid_prott5_emb, amino_acid_prostt5_emb, amino_acid_esm2_emb, protein_prott5_emb, protein_prostt5_emb_stack, targets = [x.to(device) for x in inputs]
        embeddings = [amino_acid_prott5_emb]
        if use_prostt5:
            embeddings.append(amino_acid_prostt5_emb)
        if use_protein_mean:
            embeddings.append(protein_prott5_emb)
            # if use_prostt5:
            #     embeddings.append(protein_prostt5_emb)
        embeddings.append(amino_acid_esm2_emb)
        concatenated_embeddings = torch.cat(embeddings, dim=1)
        if use_attention:
            sequence_lengths = protein_prostt5_emb_stack.abs().sum(dim=-1).nonzero()[:, 1].max(dim=-1, keepdim=True)[0] + 1

            # Create mask based on actual sequence lengths
            max_seq_len = protein_prostt5_emb_stack.size(2)
            mask = torch.arange(max_seq_len, device=device)[None, :] < sequence_lengths.to(device)
            outputs = model(concatenated_embeddings, protein_prostt5_emb_stack.to(device), mask.to(device))
            all_predictions.extend(outputs.cpu().detach().numpy())



  0%|          | 0/101 [00:00<?, ?it/s]

In [26]:
scaler.inverse_transform(np.array(all_predictions).reshape(-1, 1))

array([[8.160605],
       [8.013003],
       [8.667084],
       ...,
       [9.202308],
       [9.353414],
       [8.575436]], dtype=float32)

In [27]:
chemical_shifts_df["H_our"]

447       5.350244
448       5.297827
449       5.783499
450       5.552297
451       5.647673
            ...   
215543    7.764286
215544    7.869600
215545    7.871858
215546    7.893819
215547    7.905555
Name: H_our, Length: 12892, dtype: float32

In [22]:
chemical_shifts_df["H_our"] = scaler.inverse_transform(np.array(all_predictions).reshape(-1, 1))

In [5]:
target_column = 'N'

scaler = joblib.load(f'scaler_n.joblib')
#h_filtered = chemical_shifts_df.dropna(subset=[target_column])
chemical_shifts_df[target_column] = scaler.transform(chemical_shifts_df[target_column].values.reshape(-1, 1))

test_dataset = ProteinDataset([target_column], chemical_shifts_df, prott5_file, prott5_res_file, prostt5_file, esm_res_file, esm_file)


learning_rate = 0.001
weight_decay = 1e-5
patience = 10
batch_size = 128
num_epochs = 5

use_prostt5 = True
use_protein_mean = True
use_attention = True

model = ChemicalShiftsPredictor(use_prostt5=use_prostt5, use_protein_mean=use_protein_mean, use_attention=use_attention)
model.load_state_dict(torch.load('Full_1e-4_N.pth'))

model = model.cuda()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

if use_attention:
        test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=6, collate_fn=packed_padded_collate)
else:
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=6)

all_predictions = []

for inputs in tqdm(test_loader):
        amino_acid_prott5_emb, amino_acid_prostt5_emb, amino_acid_esm2_emb, protein_prott5_emb, protein_prostt5_emb_stack, targets = [x.to(device) for x in inputs]
        embeddings = [amino_acid_prott5_emb]
        if use_prostt5:
            embeddings.append(amino_acid_prostt5_emb)
        if use_protein_mean:
            embeddings.append(protein_prott5_emb)
            # if use_prostt5:
            #     embeddings.append(protein_prostt5_emb)
        embeddings.append(amino_acid_esm2_emb)
        concatenated_embeddings = torch.cat(embeddings, dim=1)
        if use_attention:
            sequence_lengths = protein_prostt5_emb_stack.abs().sum(dim=-1).nonzero()[:, 1].max(dim=-1, keepdim=True)[0] + 1

            # Create mask based on actual sequence lengths
            max_seq_len = protein_prostt5_emb_stack.size(2)
            mask = torch.arange(max_seq_len, device=device)[None, :] < sequence_lengths.to(device)
            outputs = model(concatenated_embeddings, protein_prostt5_emb_stack.to(device), mask.to(device))
            all_predictions.extend(outputs.cpu().detach().numpy())
            
            
chemical_shifts_df["N_our"] = scaler.inverse_transform(np.array(all_predictions).reshape(-1, 1))



  0%|          | 0/101 [00:00<?, ?it/s]

In [23]:
chemical_shifts_df

Unnamed: 0.1,Unnamed: 0,ID,entryID,stID,entity_assemID,entityID,seq_index,seq,k,zscores,pscores,C,CA,CB,HA,H,N,HB,N_our,H_our
447,6,30161_1_1_1,30161,1,1,1,1,M,7,,,175.986,55.600,33.138,4.401,,121.816,1.9350,119.632263,5.350244
448,6,30161_1_1_1,30161,1,1,1,2,I,14,11.1609,0.0947,174.316,61.588,38.193,3.830,-0.243453,123.431,1.7160,127.461349,5.297827
449,6,30161_1_1_1,30161,1,1,1,3,R,21,14.1291,0.0736,176.156,57.343,32.755,4.716,2.210679,127.627,1.8845,127.118889,5.783499
450,6,30161_1_1_1,30161,1,1,1,4,T,21,15.1705,0.0332,173.925,58.964,71.096,5.291,0.513443,110.835,4.3560,117.856674,5.552297
451,6,30161_1_1_1,30161,1,1,1,5,I,21,14.7280,0.0497,173.317,59.542,43.057,4.875,0.133551,120.009,1.2950,124.396027,5.647673
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
215543,1893,36334_1_1_1,36334,1,1,1,106,H,1,3.8761,0.0000,,,,,,,,115.665993,7.764286
215544,1893,36334_1_1_1,36334,1,1,1,107,H,2,2.3427,0.4918,,,,,,,,117.746872,7.869600
215545,1893,36334_1_1_1,36334,1,1,1,108,H,2,2.3427,0.4918,,,30.053,,-0.091785,,,118.197594,7.871858
215546,1893,36334_1_1_1,36334,1,1,1,109,H,2,,,,,,,,,,118.341019,7.893819


In [24]:
# select N_our, H_our, ID, entryID, seq_index, seq and save to csv
df = chemical_shifts_df[['N_our', 'H_our', 'ID', 'entryID', 'seq_index', 'seq']]
df.to_csv('data/our_predictions.csv', index=False)