In [22]:
import pandas as pd 
import torch   
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm 
from covNeut_esm2_struct import CovNeut_ftESM 
from covbind_esm2_struct import CovBind_ftESM 
#from covBoth_esm2_struct import CovBoth_ftESM 
from Pretrained.ESM2_MLM_Struct.tokenizer import ESM2_Tokenizer  
from sklearn.manifold import TSNE  
#from silhouette import silhouette_score  
from sklearn.metrics import classification_report, accuracy_score, f1_score  

In [5]:
class SequenceDataset(Dataset):
    def __init__(self, vh_seqs, vl_seqs, targets, labels):
        self.vh_seqs = vh_seqs 
        self.vl_seqs = vl_seqs 
        self.targets = targets
        self.labels = labels 
        self.tg_embs = torch.load('data/target_embeddings.pt') 
        self.tok = ESM2_Tokenizer()

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        vh_seq = self.tok.encode(self.vh_seqs[idx], max_length=228) 
        vl_seq = self.tok.encode(self.vl_seqs[idx], max_length=217)  
        target = self.tg_embs[self.targets[idx]][0,:] 
        label = self.labels[idx]
        return vh_seq, vl_seq, target, torch.tensor(label, dtype=torch.float32) 
    
def get_combined_embs(model, data_loader, device):
    model.to(device)
    model.eval()
    embs = [] 
    cls_labels = [] 
    with torch.no_grad():
        for vh_ids,  vl_ids, ag_embs, labels in tqdm(data_loader, desc="Testing"):
            vh_ids = vh_ids.to(device) 
            vl_ids = vl_ids.to(device)
            ag_embs = ag_embs.to(device) 
            combined_embs = model(vh_ids, vl_ids, ag_embs, return_combined_emb=True) 
            print(combined_embs.shape, flush=True) 
            embs.append(combined_embs)
            cls_labels.append(labels) 
    result = torch.cat(embs, dim=0) 
    return result , torch.cat(cls_labels, dim=0) 

def test_run(model, data_loader, device):
    model.to(device)
    model.eval()
    all_predictions = [] 
    actual_class = []
    with torch.no_grad():
        for vh_ids,  vl_ids, ag_embs, labels in tqdm(data_loader, desc="Test"):
            vh_ids = vh_ids.to(device) 
            vl_ids = vl_ids.to(device)
            ag_embs = ag_embs.to(device) 
            labels = labels.to(device)
            logits = model(vh_ids, vl_ids, ag_embs)
            predictions = torch.sigmoid(logits).round() 
            all_predictions.extend(predictions.view(-1).tolist())
            actual_class.extend(labels.detach().tolist())
    report = classification_report(actual_class, all_predictions, target_names=['Neg', 'Pos'], digits=4)
    print(report)
    return all_predictions 

### Binding

In [6]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 
df_test = pd.read_csv('data/test_data.csv') 

    # create dataloader 
dataset_test = SequenceDataset(df_test['Antibody VH'].tolist(), 
                        df_test['Antibody VL'].tolist(),
                        df_test['Target'].tolist(),
                        df_test['Binding'].tolist())
data_loader_test = DataLoader(dataset_test, batch_size=12, shuffle=False) 


In [7]:
# load model 
model_weight_path = 'saved_models/covbind_ESM2_Struct_ft.pth'
model = CovBind_ftESM(load_weights=False)  
model.load_state_dict(torch.load(model_weight_path, map_location=device))

<All keys matched successfully>

In [8]:
pred_binds = test_run(model, data_loader_test, device)

Test: 100%|██████████| 84/84 [00:16<00:00,  5.12it/s]

              precision    recall  f1-score   support

         Neg     0.8564    0.8357    0.8460       414
         Pos     0.8859    0.9010    0.8934       586

    accuracy                         0.8740      1000
   macro avg     0.8712    0.8684    0.8697      1000
weighted avg     0.8737    0.8740    0.8738      1000






In [10]:
df_test.head()

Unnamed: 0,Antibody,Antibody VH,Antibody VL,Target,Binding,Neutralizing,Target Sequence,Label
0,BD57-0226,QEQLVESGGGVVQPGRSLRLSCAASGFTFSHYGMHWVRQAPGKGLE...,QSVLTQPPSASGTPGQRVTISCSGSSSNIGSNFVHWYQQLPGTAPK...,sars-cov2-omicron-ba5,1,0,MFVFLVLLPLV----SSQCVNLITRTQ---SYTNSFTRGVYYPDKV...,1
1,BD56-697,QVQLQESGPGLVKPSQTLSLTCTVSGDSISSGGYYWSWIRQRPGKG...,SYELTQPPSVSVSPGQTARITCSGDALPKQHAYWYQQKSGQAPVLV...,sars-cov2-wt,0,0,MFVFLVLLPLV----SSQCVNLTTRTQLPPAYTNSFTRGVYYPDKV...,2
2,V016,QVQLVESGGGVVQPGRSLRLSCAASGFTFSNYGMHWVRQAPGKGLE...,DIQMTQSPSTLSASVGDRVTITCRASQSISSWLAWYQQKPGKAPKL...,sars-cov2-wt,1,0,MFVFLVLLPLV----SSQCVNLTTRTQLPPAYTNSFTRGVYYPDKV...,1
3,BD56-210,VQLVQSGAEVKKPGASVKISCKASGYTFSNSYLHWVRQAPGQGLEW...,EIVLTQSPATLSLSPGERATLSCRASQSVSSYVAWYQQKPGQAPRL...,sars-cov2-omicron-xbb1,0,0,LYLLGMLVASV----LAQCVNLITRTQ---SYTNSFTRGVYYPDKV...,2
4,BD55-4348,QVQLVESGGGVVQPGRSLRLSCAASGFTFRSYSMQWVRQAPGEGLE...,SYVLTQPPSVSVAPGKTARITCGGDNIGSYSVHWYQQKPGQAPVLV...,sars-cov2-omicron-ba2.12.1,0,0,MFVFLVLLPLV----SSQCVNLITRTQ---SYTNSFTRGVYYPDKV...,2


In [11]:
df_test['pred_binding'] = pred_binds

In [16]:
df_test[['Antibody', 'Target', 'Binding', 'pred_binding']]

Unnamed: 0,Antibody,Target,Binding,pred_binding
0,BD57-0226,sars-cov2-omicron-ba5,1,0.0
1,BD56-697,sars-cov2-wt,0,0.0
2,V016,sars-cov2-wt,1,1.0
3,BD56-210,sars-cov2-omicron-xbb1,0,0.0
4,BD55-4348,sars-cov2-omicron-ba2.12.1,0,0.0
...,...,...,...,...
995,BD-695,sars-cov2-omicron-ba2,0,0.0
996,BD55-6705,sars-cov2-omicron-ba5,0,1.0
997,BD56-447,sars-cov2-omicron-ba2,1,1.0
998,BD56-700,sars-cov2-omicron-ba4,1,1.0


In [18]:
unq_targets = df_test['Target'].unique()

In [24]:
binding_f1 = {} 
for ag in unq_targets:
    tmp_df = df_test[df_test['Target']==ag]
    f1 = f1_score(tmp_df['Binding'], tmp_df['pred_binding'])
    binding_f1[ag] = f1

In [25]:
binding_f1

{'sars-cov2-omicron-ba5': 0.831858407079646,
 'sars-cov2-wt': 0.9237668161434978,
 'sars-cov2-omicron-xbb1': 0.5714285714285714,
 'sars-cov2-omicron-ba2.12.1': 0.9375,
 'sars-cov2-delta': 0.9545454545454546,
 'sars-cov2-omicron-ba2': 0.954954954954955,
 'sars-cov2-omicron-ba2.75': 0.9577464788732394,
 'sars-cov-1': 0.5833333333333334,
 'sars-cov2-omicron-ba1.1': 0.8888888888888888,
 'sars-cov2-omicron-ba1': 0.9230769230769231,
 'sars-cov2-omicron-ba4': 0.7555555555555555,
 'sars-cov2-omicron-ba2.13': 1.0,
 'sars-cov2-omicron-ba3': 1.0,
 'sars-cov2-beta': 0.8235294117647058}

### Neutralizing

In [29]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 
df_test = pd.read_csv('data/test_data.csv') 

    # create dataloader 
dataset_test = SequenceDataset(df_test['Antibody VH'].tolist(), 
                        df_test['Antibody VL'].tolist(),
                        df_test['Target'].tolist(),
                        df_test['Neutralizing'].tolist())
data_loader_test2 = DataLoader(dataset_test, batch_size=12, shuffle=False) 

In [28]:
# load model 
model_weight_path = 'saved_models/covneut_ESM2_Struct_ft.pth'
model2 = CovNeut_ftESM(load_weights=False)  
model2.load_state_dict(torch.load(model_weight_path, map_location=device))

<All keys matched successfully>

In [31]:
pred_neut = test_run(model2, data_loader_test2, device)

Test:   0%|          | 0/84 [00:00<?, ?it/s]

Test: 100%|██████████| 84/84 [00:15<00:00,  5.35it/s]

              precision    recall  f1-score   support

         Neg     0.9112    0.9203    0.9157       602
         Pos     0.8776    0.8643    0.8709       398

    accuracy                         0.8980      1000
   macro avg     0.8944    0.8923    0.8933      1000
weighted avg     0.8978    0.8980    0.8979      1000






In [32]:
df_test['pred_neut'] = pred_neut

In [33]:
unq_targets = df_test['Target'].unique()

In [34]:
neut_f1 = {} 
for ag in unq_targets:
    tmp_df = df_test[df_test['Target']==ag]
    f1 = f1_score(tmp_df['Neutralizing'], tmp_df['pred_neut'])
    neut_f1[ag] = f1

In [35]:
neut_f1

{'sars-cov2-omicron-ba5': 0.8846153846153846,
 'sars-cov2-wt': 0.8493150684931506,
 'sars-cov2-omicron-xbb1': 0.6666666666666666,
 'sars-cov2-omicron-ba2.12.1': 1.0,
 'sars-cov2-delta': 1.0,
 'sars-cov2-omicron-ba2': 0.8837209302325582,
 'sars-cov2-omicron-ba2.75': 0.9428571428571428,
 'sars-cov-1': 0.47058823529411764,
 'sars-cov2-omicron-ba1.1': 1.0,
 'sars-cov2-omicron-ba1': 0.8292682926829268,
 'sars-cov2-omicron-ba4': 0.7567567567567568,
 'sars-cov2-omicron-ba2.13': 0.9473684210526315,
 'sars-cov2-omicron-ba3': 0.9302325581395349,
 'sars-cov2-beta': 0.75}

In [36]:
df_tg = pd.DataFrame()

In [37]:
df_tg.index = binding_f1.keys()

In [38]:
df_tg['Binding'] = binding_f1.values()

In [39]:
df_tg['Neutralizing'] = neut_f1.values()

In [40]:
df_tg

Unnamed: 0,Binding,Neutralizing
sars-cov2-omicron-ba5,0.831858,0.884615
sars-cov2-wt,0.923767,0.849315
sars-cov2-omicron-xbb1,0.571429,0.666667
sars-cov2-omicron-ba2.12.1,0.9375,1.0
sars-cov2-delta,0.954545,1.0
sars-cov2-omicron-ba2,0.954955,0.883721
sars-cov2-omicron-ba2.75,0.957746,0.942857
sars-cov-1,0.583333,0.470588
sars-cov2-omicron-ba1.1,0.888889,1.0
sars-cov2-omicron-ba1,0.923077,0.829268
