In [None]:
import os
import pandas as pd
import torch
from tqdm import tqdm
from torch.utils.data import Subset
from torch.utils.data import DataLoader
from pytorch_metric_learning.losses import SelfSupervisedLoss, NTXentLoss

from proteinbind_new import EmbeddingDataset, DualEmbeddingDataset
from proteinbind_new import create_proteinbind
from transformers import pipeline
import torch.nn.functional as F
from sklearn.linear_model import LogisticRegression
device = torch.device("cpu")
if torch.cuda.is_available():
    device = torch.device("cuda")

In [None]:
batch_size = 1024
test_num =10

# GO
go_dataset = DualEmbeddingDataset('GO_files/go_AA_embeddings.pt', 'GO_files/GO_embeddings.pt', "go")

indices = list(range(len(go_dataset)))
train_indices = indices[:-test_num]
test_indices = indices[-test_num:]
# test_indices = indices[:test_num]

go_train_dataloader = DataLoader(Subset(go_dataset, train_indices), batch_size=batch_size)
go_test_dataloader = DataLoader(Subset(go_dataset, test_indices), batch_size=batch_size)

# DNA

dna_dataset = DualEmbeddingDataset('DNA_files/dna_AA_embeddings.pt', 'DNA_files/uniprot_nuc_embeddings.pt', "dna")

indices = list(range(len(dna_dataset)))
train_indices = indices[:-test_num]
test_indices = indices[-test_num:]
# test_indices = indices[:test_num]

dna_train_dataloader = DataLoader(Subset(dna_dataset, train_indices), batch_size=batch_size)
dna_test_dataloader = DataLoader(Subset(dna_dataset, test_indices), batch_size=batch_size)


# Text
text_dataset = DualEmbeddingDataset('Text_files/text_AA_embeddings.pt', 'Text_files/protein_whole_text_embedding.pt', "text")

indices = list(range(len(text_dataset)))
train_indices = indices[:-test_num]
test_indices = indices[-test_num:]
# test_indices = indices[:test_num]

text_train_dataloader = DataLoader(Subset(text_dataset, train_indices), batch_size=batch_size)
text_test_dataloader = DataLoader(Subset(text_dataset, test_indices), batch_size=batch_size)


# MSA

msa_dataset = DualEmbeddingDataset('MSA_files/msa_AA-embeddings.pt', 'MSA_files/msa_embeddings.pt', "msa")

indices = list(range(len(msa_dataset)))
train_indices = indices[:-test_num]
test_indices = indices[-test_num:]

msa_train_dataloader = DataLoader(Subset(msa_dataset, train_indices), batch_size=batch_size)
msa_test_dataloader = DataLoader(Subset(msa_dataset, test_indices), batch_size=batch_size)

# PDB

pdb_dataset = DualEmbeddingDataset('Struct_files/pdb_AAs-embedding.pt', 'Struct_files/pdb_embeddings.pt', "pdb")

indices = list(range(len(pdb_dataset)))
train_indices = indices[:-test_num]
test_indices = indices[-test_num:]

pdb_train_dataloader = DataLoader(Subset(pdb_dataset, train_indices), batch_size=batch_size)
pdb_test_dataloader = DataLoader(Subset(pdb_dataset, test_indices), batch_size=batch_size)



In [None]:
model = create_proteinbind()
loss_func = SelfSupervisedLoss(NTXentLoss())

In [None]:
num_epochs = 1000
model.to(device)
model.train()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
losses = []

# Get the total number of mini-batches per epoch based on the smallest data loader
num_batches = min(len(go_train_dataloader), len(text_train_dataloader), len(dna_train_dataloader), len(msa_train_dataloader), len(pdb_train_dataloader))

for epoch in tqdm(range(num_epochs)):
    
    go_loader_iter = iter(go_train_dataloader)
    text_loader_iter = iter(text_train_dataloader)
    dna_loader_iter = iter(dna_train_dataloader)
    msa_loader_iter = iter(msa_train_dataloader)
    pdb_loader_iter = iter(pdb_train_dataloader)
    
    for batch_idx in range(num_batches):

        ####################
        ### GO
        ####################
        go_batch = next(go_loader_iter)

        
        input_data = {
            'go': go_batch['go'].type(torch.float32).to(device),
            'aa': go_batch['aa'].type(torch.float32).to(device)
        }

        output = model(input_data)
        loss = loss_func(output["aa"], output["go"]) + loss_func(output["go"], output["aa"]) 
        losses.append(loss.item())

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        
        ####################
        ### DNA
        ####################
        dna_batch = next(dna_loader_iter)
        
        input_data = {
            'dna': dna_batch['dna'].type(torch.float32).to(device),
            'aa': dna_batch['aa'].type(torch.float32).to(device)
        }

        output = model(input_data)
        loss = loss_func(output["aa"], output["dna"]) + loss_func(output["dna"], output["aa"]) 
        losses.append(loss.item())

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        ####################
        ### TEXT
        ####################
        
        text_batch = next(text_loader_iter)
        
        input_data = {
            'text': text_batch['text'].type(torch.float32).to(device),
            'aa': text_batch['aa'].type(torch.float32).to(device)
        }

        output = model(input_data)
        loss = loss_func(output["aa"], output["text"]) + loss_func(output["text"], output["aa"]) 
        losses.append(loss.item())

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        
        ####################
        ### MSA
        ####################
        
        msa_batch = next(msa_loader_iter)
        
        input_data = {
            'msa': msa_batch['msa'].type(torch.float32).to(device),
            'aa': msa_batch['aa'].type(torch.float32).to(device)
        }

        output = model(input_data)
        loss = loss_func(output["aa"], output["msa"]) + loss_func(output["msa"], output["aa"]) 
        losses.append(loss.item())

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        ####################
        ### PDB
        ####################
        
        pdb_batch = next(pdb_loader_iter)
        
        input_data = {
            'pdb': pdb_batch['pdb'].type(torch.float32).to(device),
            'aa': pdb_batch['aa'].type(torch.float32).to(device)
        }

        output = model(input_data)
        loss = loss_func(output["aa"], output["pdb"]) + loss_func(output["pdb"], output["aa"]) 
        losses.append(loss.item())

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        

    if epoch != 0 and (epoch + 1) % 1 == 0:
        print(f"Epoch {epoch+1}: joint train loss {sum(losses) / num_batches}")
        losses = []

        model.eval()
        
        go_test_loader_iter = iter(go_test_dataloader)
        text_test_loader_iter = iter(text_test_dataloader)
        dna_test_loader_iter = iter(dna_test_dataloader)
        msa_test_loader_iter = iter(msa_test_dataloader)
        pdb_test_loader_iter = iter(pdb_test_dataloader)
        


        ##############
        ##############
        ##############
        go_batch = next(go_test_loader_iter)
        
        input_data = {
            'go': go_batch['go'].type(torch.float32).to(device),
            'aa': go_batch['aa'].type(torch.float32).to(device)
        }

        output = model(input_data)
        loss = loss_func(output["aa"], output["go"]) + loss_func(output["go"], output["aa"]) 
        losses.append(loss.item())

        #assigned_val = torch.matmul(output["aa"], output["go"].T)
        #print(F.softmax(assigned_val, dim = 1)
        ##############
        ##############
        ##############
        dna_batch = next(dna_test_loader_iter)
        
        input_data = {
            'dna': dna_batch['dna'].type(torch.float32).to(device),
            'aa': go_batch['aa'].type(torch.float32).to(device)
        }


        
        output = model(input_data)
        #assigned_val = torch.matmul(output["aa"], output["dna"].T)
        #print(assigned_val)
        loss = loss_func(output["aa"], output["dna"]) + loss_func(output["dna"], output["aa"]) 
        losses.append(loss.item())
        
        ##############
        ##############
        ##############
        
        text_batch = next(text_test_loader_iter)
        
        input_data = {
            'text': text_batch['text'].type(torch.float32).to(device),
            'aa': text_batch['aa'].type(torch.float32).to(device)
        }

        output = model(input_data)
        #assigned_val = torch.matmul(output["aa"], output["text"].T)
        #print(assigned_val)
        loss = loss_func(output["aa"], output["text"]) + loss_func(output["text"], output["aa"]) 
        losses.append(loss.item())


        
        ##############
        ##############
        ##############
        msa_batch = next(msa_test_loader_iter)

        input_data = {
            'msa': msa_batch['msa'].type(torch.float32).to(device),
            'aa':msa_batch['aa'].type(torch.float32).to(device)
        }

        output = model(input_data)
        loss = loss_func(output["aa"], output["msa"]) + loss_func(output["msa"], output["aa"]) 
        losses.append(loss.item())

        ##############
        ##############
        ##############
        pdb_batch = next(pdb_test_loader_iter)
        
        input_data = {
            'pdb': pdb_batch['pdb'].type(torch.float32).to(device),
            'aa': pdb_batch['aa'].type(torch.float32).to(device)
        }
    
        output = model(input_data)
        loss = loss_func(output["aa"], output["pdb"]) + loss_func(output["pdb"], output["aa"]) 
        losses.append(loss.item())
        
        print(f"Epoch {epoch+1}: joint test loss {sum(losses)}")
        losses = []

        #print(f"Epoch {epoch+1}: Linear Cls {eval_linear_task(model)} ")
        #print(f"Epoch {epoch+1}: DMS Correlation Cls {dms_eval_task(model)} ")
        
        
        model.train()




In [None]:
DOWNSTREAM_FILE_PATHS = 'downstream_tasks/downstream_files/'

def eval_linear_task(model):

    model.eval()
    ##### Pathogenicity Task
    
    ################     #####################
    ################ DATA#####################
    ################     #####################
    
    # opening the file in read mode
    my_file = open(f"{DOWNSTREAM_FILE_PATHS}/train_mt/seq-cleaned-train_mt_labels.txt", "r")
    
    # reading the file
    data = my_file.read()

    # replacing end splitting the text 
    # when newline ('\n') is seen.
    data_into_list = data.split("\n")
    print(len(data_into_list))
    train_labels = []
    for entry in data_into_list:
        if entry == 'PLP':
            train_labels.append(1)
        elif entry == 'BLB':
            train_labels.append(0)
    
    my_file.close()
    
    
    # opening the file in read mode
    my_file = open(f"{DOWNSTREAM_FILE_PATHS}/test_mt/seq-cleaned-test_mt-labels.txt", "r")
      
    # reading the file
    data = my_file.read()
      
    # replacing end splitting the text 
    # when newline ('\n') is seen.
    data_into_list = data.split("\n")
    print(len(data_into_list))
    test_labels = []
    for entry in data_into_list:
        if entry == 'PLP':
            test_labels.append(1)
        elif entry == 'BLB':
            test_labels.append(0)
    
    my_file.close()
    
    ################  #####################  #####################
    ################ Pretrained Embedded AAs #####################
    ################ #####################   #####################
    train_wt_aa =  torch.load(f'{DOWNSTREAM_FILE_PATHS}/train_wt/seq-cleaned-train_wt-embeddings.pt')
    train_mt_aa =  torch.load(f'{DOWNSTREAM_FILE_PATHS}/train_mt/seq-cleaned-train_mt_embeddings.pt')
    
    test_wt_aa =  torch.load(f'{DOWNSTREAM_FILE_PATHS}/test_wt/seq-cleaned-test_wt-embeddings.pt')
    test_mt_aa =  torch.load(f'{DOWNSTREAM_FILE_PATHS}/test_mt/seq-cleaned-test_mt-embeddings.pt')
    
    print(train_wt_aa.size(), train_mt_aa.size(), test_wt_aa.size(),test_mt_aa.size() )
    output_train_wt = model({
                'aa': train_wt_aa.type(torch.float32).to(device)
            })
    
    output_train_mt = model({
                'aa': train_mt_aa.type(torch.float32).to(device)
            })
    
    output_test_wt = model({
                'aa': test_wt_aa.type(torch.float32).to(device)
            })
    
    output_test_mt = model({
                'aa': test_mt_aa.type(torch.float32).to(device)
            })
    
    train_concat = torch.cat((output_train_wt['aa'], output_train_mt['aa']), 1)
    test_concat = torch.cat((output_test_wt['aa'], output_test_mt['aa']), 1)
    
    
    ################  #####################  #####################
    ################ Enrich Embedded AAs #####################
    ################ #####################   #####################
    
    
    ###### Linear Classifier ######
    print(train_concat.size(), test_concat.size())
    clf = LogisticRegression(random_state=0).fit(train_concat.detach().numpy(), train_labels)
    
    predictions = clf.predict(test_concat.detach().numpy())
    
    score = clf.score(x_test, test_labels)

    
    from sklearn import metrics
    cm = metrics.confusion_matrix(test_labels, predictions)
    print(cm)
    return score

In [None]:
def dms_eval_task(model):
    model.eval()
    ##### DMS Task
    
    dms_wt_aa =  torch.load(f'{DOWNSTREAM_FILE_PATHS}/DMS/data_wt/data-cleaned-wt-embeddings.pt')
    output_wt = model({
                'aa': dms_wt_aa.type(torch.float32).to(device)
            })
    
    dms_mt_aa =  torch.load(f'{DOWNSTREAM_FILE_PATHS}/DMS/data_mt/data-cleaned-mt-embeddings.pt')
    output_mt = model({
                'aa': dms_mt_aa.type(torch.float32).to(device)
            })
    output_wt['aa'].size(), output_mt['aa'].size()
    cos = torch.nn.CosineSimilarity(dim=1, eps=1e-6)
    output = cos(output_wt['aa'], output_mt['aa'])
    return 