In [1]:
#!pip install torch pandas numpy h5py tqdm scikit-learn tensorboard

In [1]:
%load_ext autoreload
%autoreload 2

import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import joblib
from src.dataset import ProteinDataset
from src.utils import train_model, test_model
import torch
from src.model import ChemicalShiftsPredictor, ChemicalShiftsPredictorAttention

Pyarrow will become a required dependency of pandas in the next major release of pandas (pandas 3.0),
(to allow more performant data types, such as the Arrow string type, and better interoperability with other libraries)
but was not found to be installed on your system.
If this would cause problems for you,
please provide us feedback at https://github.com/pandas-dev/pandas/issues/54466
        
  import pandas as pd


In [2]:
# Load and prepare data
csv_file = 'data/strict.csv'
prott5_file = 'data/embeddings/unfiltered_all_prott5.h5'
prott5_res_file = 'data/embeddings/unfiltered_all_prott5_res.h5'
prostt5_file = 'data/embeddings/prostt5.h5'
esm_file = 'data/embeddings/unfiltered_all_esm2_3b.h5'
esm_res_file = 'data/embeddings/unfiltered_all_esm2_3b_res.h5'
chemical_shifts_df = pd.read_csv(csv_file)
#chemical_shifts_df.describe()

In [3]:
test_ids = []
with open("pdb_matched/final_test_ids.txt", "r") as f:
    for line in f:
        test_ids.append(line.strip())

In [4]:
chemical_shifts_df = chemical_shifts_df[~chemical_shifts_df['ID'].isin(test_ids)]

In [5]:
scaler_applied = False

In [6]:
#target_columns = ['C', 'CA', 'CB', 'HA', 'H', 'N', 'HB']
target_columns = ['H']
chemical_shifts_df.dropna(inplace=True, subset=target_columns)


train_df, val_df = train_test_split(chemical_shifts_df, test_size=0.2, random_state=42)

scaler = StandardScaler()

# training with whole dataset here
train_targets = chemical_shifts_df[target_columns]
scaler.fit(train_targets)

joblib.dump(scaler, 'scaler_h.joblib')

# Apply normalization to the training targets
#train_df[target_columns] = scaler.transform(train_targets)

# Apply the same normalization to validation and test sets
# val_df[target_columns] = scaler.transform(val_df[target_columns])
# test_df[target_columns] = scaler.transform(test_df[target_columns])

# Create datasets
if not scaler_applied:
    chemical_shifts_df[target_columns] = scaler.transform(chemical_shifts_df[target_columns])
    train_df[target_columns] = scaler.transform(train_df[target_columns])
    val_df[target_columns] = scaler.transform(val_df[target_columns])
    scaler_applied = True
    
    
train_dataset = ProteinDataset(target_columns, chemical_shifts_df, prott5_file, prott5_res_file, prostt5_file, esm_res_file, esm_file)
val_dataset = ProteinDataset(target_columns, val_df, prott5_file, prott5_res_file, prostt5_file, esm_res_file, esm_file)

In [7]:
learning_rate = 0.001
weight_decay = 1e-5
patience = 10
batch_size = 128
num_epochs = 5

use_prostt5 = True
use_protein_mean = True
use_attention = True

In [8]:
#model = ChemicalShiftsPredictor(use_prostt5=use_prostt5, use_protein_mean=use_protein_mean, use_attention=use_attention)
#model.load_state_dict(torch.load('Full_1e-4.pth'))

#model = model.cuda()

In [10]:
trained_model = train_model(train_dataset, val_dataset, learning_rate=learning_rate, num_epochs=num_epochs, weight_decay=weight_decay,
                            patience=patience, batch_size=batch_size, use_prostt5=use_prostt5 , use_protein_mean=use_protein_mean, scaler=scaler, use_attention=use_attention)

ChemicalShiftsPredictor(
  (light_attention): LightAttention(
    (feature_convolution): Conv1d(1024, 1024, kernel_size=(9,), stride=(1,), padding=(4,))
    (attention_convolution): Conv1d(1024, 1024, kernel_size=(9,), stride=(1,), padding=(4,))
    (softmax): Softmax(dim=-1)
    (dropout): Dropout(p=0.25, inplace=False)
  )
  (fc_layers): Sequential(
    (0): Linear(in_features=7680, out_features=7680, bias=True)
    (1): ReLU()
    (2): Linear(in_features=7680, out_features=1, bias=True)
  )
)


Epoch 1/5: 100%|██████████| 1413/1413 [19:05<00:00,  1.23batch/s]


Epoch 1, Train Loss: 0.7511, Validation Loss: 0.4135
Epoch 1, Train RMSE: 0.5075, Validation RMSE: 0.4428


Epoch 2/5: 100%|██████████| 1413/1413 [20:52<00:00,  1.13batch/s]


Epoch 2, Train Loss: 0.5650, Validation Loss: 0.3802
Epoch 2, Train RMSE: 0.4520, Validation RMSE: 0.4248


Epoch 3/5: 100%|██████████| 1413/1413 [21:06<00:00,  1.12batch/s]


Epoch 3, Train Loss: 0.5335, Validation Loss: 0.3393
Epoch 3, Train RMSE: 0.4348, Validation RMSE: 0.4008


Epoch 4/5: 100%|██████████| 1413/1413 [20:42<00:00,  1.14batch/s]


Epoch 4, Train Loss: 0.4841, Validation Loss: 0.4322
Epoch 4, Train RMSE: 0.4075, Validation RMSE: 0.4528


Epoch 5/5: 100%|██████████| 1413/1413 [20:31<00:00,  1.15batch/s]


Epoch 5, Train Loss: 1.1505, Validation Loss: 0.4899
Epoch 5, Train RMSE: 0.5582, Validation RMSE: 0.4819


In [11]:
# save model
torch.save(trained_model.state_dict(), 'Full_1e-4_H.pth')

In [27]:
three_to_one = {
    'ALA': 'A', 'ARG': 'R', 'ASN': 'N', 'ASP': 'D', 'CYS': 'C', 
    'GLU': 'E', 'GLN': 'Q', 'GLY': 'G', 'HIS': 'H', 'ILE': 'I', 
    'LEU': 'L', 'LYS': 'K', 'MET': 'M', 'PHE': 'F', 'PRO': 'P', 
    'SER': 'S', 'THR': 'T', 'TRP': 'W', 'TYR': 'Y', 'VAL': 'V',
}

def process_ucb_output(dataframe, id):
    new_columns = {}
    new_columns["ID"] = id
    new_columns['seq_index'] = dataframe["RESNUM"] - min(dataframe["RESNUM"]) + 1
    new_columns['seq'] = [three_to_one[res] for res in dataframe["RESNAME"]]
    new_columns['H'] = dataframe["H_UCBShift"]
    new_columns['N'] = dataframe["N_UCBShift"]
    new_df = pd.DataFrame(new_columns)
    return new_df

In [32]:
process_ucb_output(ucb_predictions['34695_1_1_1'], '34695_1_1_1')

# process all ucb predictions, make single dataframe
all_ucb_predictions = []
for id, dataframe in ucb_predictions.items():
    all_ucb_predictions.append(process_ucb_output(dataframe, id))
    
all_ucb_predictions = pd.concat(all_ucb_predictions)

In [34]:
all_ucb_predictions.to_csv('all_ucb_predictions.csv', index=False)

In [None]:
torch.cuda.empty_cache()
trained_model = train_model(train_dataset, val_dataset, learning_rate=learning_rate, num_epochs=50, weight_decay=weight_decay, patience=patience, batch_size=2048, use_prostt5=True, use_protein_mean=True, use_esm2=False)
test_model(trained_model, test_dataset, batch_size=batch_size, use_prostt5=True, use_protein_mean=True, use_esm2=False)

In [13]:
torch.cuda.empty_cache()