In [2]:
import transformers
from transformers import XLNetTokenizer, XLNetModel
from transformers import AutoModel, BertTokenizerFast
from transformers import AdamW

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from torch.nn import MSELoss
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import torch.nn.functional as F

from sklearn.manifold import TSNE
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score, recall_score, precision_score, classification_report
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
from sklearn.metrics.pairwise import cosine_similarity, pairwise_distances
from sklearn.linear_model import LogisticRegression
from sklearn.utils.class_weight import compute_class_weight
from sklearn.preprocessing import LabelEncoder, MultiLabelBinarizer, StandardScaler

from torch.nn import TripletMarginLoss

import matplotlib.pyplot as plt
import nltk
from nltk.tokenize import word_tokenize, sent_tokenize
from tqdm import tqdm
from scipy.stats import spearmanr
# import mplcursors
import time
import random
import pandas as pd
import numpy as np
import warnings
import re
import json
import networkx as nx
import obonet
from collections import Counter
import collections
import pickle


seed = 42

torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(seed)

random.seed(seed)
np.random.seed(seed)


In [4]:
class LateFusion(nn.Module):
    def __init__(self, in1=768, in2=200, hidden_size=500, out = 768):
        super(Gene2VecFusionModel, self).__init__()
        
        self.embedding1 = nn.Linear(in1, hidden_size) #BERT
        
        
        self.embedding2 = nn.Linear(in2, in2) 
        self.unet1 = nn.Linear(in2, hidden_size)
        self.unet2 = nn.Linear(hidden_size, hidden_size)
        self.unet3 = nn.Linear(hidden_size,hidden_size)
        
        
        self.fc1 = nn.Linear(hidden_size * 2, hidden_size)
        self.fc2 = nn.Linear(hidden_size, out)
        self.relu = nn.ReLU()


    def forward(self, x1, x2):

        #For BERT
        x1 = self.embedding1(x1)
        x1 = self.relu(x1)
                
        
        #For other modalities
        x2 = self.embedding2(x2)
        x2 = self.relu(x2)
        
        x2 = self.unet1(x2)
        x2 = self.relu(x2)
        
        x2 = self.unet2(x2)
        x2 = self.relu(x2)
        
        x2 = self.unet3(x2)
        x2 = self.relu(x2)

        
        z = torch.cat((x1, x2), dim=1)
        z = self.fc1(z)
        z = self.relu(z)
        z = self.fc2(z)
        z = self.relu(z)
        
        return z

class FusionModel(nn.Module):
    def __init__(self, input_size=768, hidden_size=500, n_labels=1):
        super(FusionModel, self).__init__()
        
        self.embedding1 = nn.Linear(input_size, hidden_size)
        self.embedding2 = nn.Linear(input_size, hidden_size)
        
        self.fc1 = nn.Linear(hidden_size * 2, hidden_size)
        self.fc2 = nn.Linear(hidden_size, n_labels)
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()

    def forward(self, x1, x2):

        x1 = self.embedding1(x1)
        x2 = self.embedding2(x2)
        
        z = torch.cat((x1, x2), dim=1)
        z = self.fc1(z)
        z = self.relu(z)
        z = self.fc2(z)
        z = self.sigmoid(z)

        return z

class Gene2VecFusionModel(nn.Module):
    def __init__(self, in1=768, in2=200, hidden_size=500, out = 768):
        super(Gene2VecFusionModel, self).__init__()
        
        self.embedding1 = nn.Linear(in1, hidden_size)
        self.embedding2 = nn.Linear(in2, hidden_size)
        
        self.fc1 = nn.Linear(hidden_size * 2, hidden_size)
        self.fc2 = nn.Linear(hidden_size, out)
        self.relu = nn.ReLU()


    def forward(self, x1, x2):

        x1 = self.embedding1(x1)
        x2 = self.embedding2(x2)
        
        z = torch.cat((x1, x2), dim=1)
        z = self.fc1(z)
        z = self.relu(z)
        z = self.fc2(z)
        z = self.relu(z)
        
        return z

class MultiLabelFocalLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2, reduction='mean'):
        super(MultiLabelFocalLoss, self).__init__()
        
        self.alpha = nn.Parameter(torch.tensor(0.25, requires_grad=True, device="cuda"))  
        self.gamma = nn.Parameter(torch.tensor(2.0, requires_grad=True, device="cuda"))  

    def forward(self, inputs, targets):
        BCE_loss = F.binary_cross_entropy(inputs, targets, reduction='none')
        pt = torch.exp(-BCE_loss) 
        F_loss = self.alpha * (1-pt)**self.gamma * BCE_loss

        return F_loss.mean()

class FineTunedBERT(nn.Module):

    def __init__(self, pool="mean", model_name= "bert-base-cased", bert_state_dict=None,
                 task_type = None, n_labels = None, drop_rate = None,
                 gene2vec_flag=None, gene2vec_hidden = 200, device ="cuda"):
        
        """
            task_type : regression or classification.
        
        """
      
        super(FineTunedBERT, self).__init__()

#         assert (task_type == 'unsupervised' and n_labels == 1) or (task_type == 'regression' and n_labels == 1) or (task_type == 'classification' and n_labels>1) or (task_type == 'multilabel' and n_labels>1), \
#             f"Invalid combination of task_type and n_labels: {task_type} and {n_labels}"  

#         assert gene2vec_flag is not None, f"gene2vec_flag cannot be None: {gene2vec_flag}"

        
        self.model_name = model_name
        self.pool = pool
        
        
        
        
        if "xlnet" in model_name:
            self.bert = XLNetModel.from_pretrained(model_name).to(device)
        
        else:
            if bert_state_dict: 
                self.bert = AutoModel.from_pretrained(model_name)
                self.bert.load_state_dict(bert_state_dict) #.to(device)
            else:
                self.bert = AutoModel.from_pretrained(model_name)#.to(device)
                
        
#         layers_to_train = ["encoder.layer.11", "encoder.layer.10", "encoder.layer.9",
#                            "encoder.layer.8","pooler"]
        for name, param in self.bert.named_parameters():

            if name.startswith("encoder.layer.11") or name.startswith("encoder.layer.10") or\
            name.startswith("encoder.layer.9") or name.startswith("pooler") or\
            name.startswith("encoder.layer.8") : 
                param.requires_grad = True
            else:
                param.requires_grad = False
    
        
        
        self.bert_hidden = self.bert.config.hidden_size
        self.gene2vecFusion = Gene2VecFusionModel(in1=self.bert_hidden,
                                                  in2=gene2vec_hidden,
                                                  hidden_size=500,
                                                  out = self.bert_hidden)
        
        if task_type.lower() == "classification":
            
            assert  n_labels > 1, f"Invalid combination of task_type and n_labels: {task_type} and {n_labels}"
            self.pipeline = nn.Sequential(nn.Linear(self.bert_hidden, n_labels))
#             if gene2vec_flag:
#                 self.pipeline = nn.Sequential(
#                     nn.Linear(self.bert_hidden+gene2vec_hidden, n_labels)
#                 )
                
#             else:
#                 self.pipeline = nn.Sequential(nn.Linear(self.bert_hidden, n_labels))
    
        elif task_type.lower() == "multilabel":
            assert  n_labels > 1, f"Invalid combination of task_type and n_labels: {task_type} and {n_labels}"
            self.pipeline = nn.Sequential(
                    nn.Linear(self.bert_hidden, n_labels),
                    nn.Sigmoid()
                )

            
#             assert  n_labels > 1, f"Invalid combination of task_type and n_labels: {task_type} and {n_labels}"
            
#             if gene2vec_flag:
#                 self.pipeline = nn.Sequential(
#                     nn.Linear(self.bert_hidden+gene2vec_hidden, n_labels),
#                     nn.Sigmoid()
#                 )
                
#             else:
#                 self.pipeline = nn.Sequential(
#                     nn.Linear(self.bert_hidden, n_labels),
#                     nn.Sigmoid()
#                 )

        elif task_type.lower() == "regression":
            assert  n_labels == 1, f"Invalid combination of task_type and n_labels: {task_type} and {n_labels}"
            self.pipeline = nn.Sequential(nn.Linear(self.bert_hidden, 1))
            

#             if gene2vec_flag:
#                 self.pipeline = nn.Sequential(
#                 nn.Linear(self.bert_hidden+gene2vec_hidden, 1))
                
#             else:            
#                 self.pipeline = nn.Sequential(nn.Linear(self.bert_hidden, 1))
        
        elif task_type.lower() =="interaction":
            if gene2vec_flag:
                raise ValueError(f"gene2vec_flag must be False when task_type is set to {task_type}: {gene2vec_flag=}")
                
            else:      
                self.pipeline = nn.Sequential(nn.Linear(self.bert_hidden, 1))
            
        
        elif task_type.lower() == "unsupervised":
            self.pipeline = nn.Sequential(nn.Linear(self.bert_hidden, 1))
            #No need for an assert, labels will not be used during unsupervised.

#             if gene2vec_flag:
#                 raise ValueError(f"gene2vec_flag must be False when task_type is set to {task_type}: {gene2vec_flag=}")
                
#             else:      
#                 self.pipeline = nn.Sequential(nn.Linear(self.bert_hidden, 1))

        else:
            raise ValueError(f"Key Error task_type : {task_type} ")
          
    def forward(self, input_ids_, attention_mask_, gene2vec=None):
        
        
        # retrieving the hidden state embeddings
        if "xlnet" in self.model_name:
            output = self.bert(input_ids = input_ids_,
                               attention_mask=attention_mask_)

            hiddenState, ClsPooled = output.last_hidden_state, output.last_hidden_state[:,0, :]

        else:
            hiddenState, ClsPooled = self.bert(input_ids = input_ids_,
                                               attention_mask=attention_mask_).values()

        
        # perform pooling on the hidden state embeddings
        if self.pool.lower() == "max":
            embeddings = self.max_pooling(hiddenState, attention_mask_)
            
        elif self.pool.lower() == "cls":
            embeddings = ClsPooled
                
        elif self.pool.lower() == "mean":
            embeddings = self.mean_pooling(hiddenState, attention_mask_)

        else:
            raise ValueError('Pooling value error.')
        
        
        if gene2vec is not None:
            embeddings = self.gene2vecFusion(embeddings, gene2vec)
            #embeddings = torch.cat((embeddings, gene2vec), dim=1)
      

        return embeddings, hiddenState, self.pipeline(embeddings)

    def max_pooling(self, hidden_state, attention_mask):
        
        #CLS: First element of model_output contains all token embeddings
        token_embeddings = hidden_state
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        token_embeddings[input_mask_expanded == 0] = -1e9  # Set padding tokens to large negative value
        
        pooled_embeddings = torch.max(token_embeddings, 1)[0]
        return pooled_embeddings
    
    def mean_pooling (self, hidden_state, attention_mask):
        
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(hidden_state.size()).float()
        pooled_embeddings = torch.sum(hidden_state * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) 
        
        return pooled_embeddings
    
def getEmbeddings(text,
                  model = None,
                  model_name = "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext",
                  max_length=512,
                  batch_size=1000,
                  gene2vec_flag=False,
                  gene2vec_hidden=200,
                  pool ="mean"):
    
    
    
    
    

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    

    if isinstance(model, FineTunedBERT):
        print("Loading a pretrained model ...")
        
        model_name = model.model_name            
        tokenizer = BertTokenizerFast.from_pretrained(model_name)

    
    elif isinstance(model, collections.OrderedDict):
        print("Loading a pretrained model from a state dictionary ...")
        
        state_dict = model.copy() 
        
        model = FineTunedBERT(pool= pool,model_name = model_name,
                              gene2vec_flag= gene2vec_flag,
                              gene2vec_hidden = gene2vec_hidden,
                              task_type="unsupervised",
                              n_labels = 1,
                              device = device).to(device)

        
        for s in model.state_dict().keys():
            if s not in state_dict.keys():
                print(s)
        
        
        model.load_state_dict(state_dict)
        tokenizer = BertTokenizerFast.from_pretrained(model_name)
            
    else:
        print("Creating a new pretrained model ...")
        
        #BERT-base embeddings
        model = FineTunedBERT(pool= pool,
                              model_name=model_name,
                              task_type="unsupervised",
                              gene2vec_flag = False,
                              n_labels = 1).to(device)
  
        tokenizer = BertTokenizerFast.from_pretrained(model_name)


    model = nn.DataParallel(model)
    print("Tokenization ...")
    tokens = tokenizer.batch_encode_plus(text, max_length = max_length,
                                         padding="max_length",truncation=True,
                                         return_tensors="pt")
    
    
    dataset = TensorDataset(tokens["input_ids"] , tokens["attention_mask"])
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

    print("Tokenization Done.")
    
    print("Get Embeddings ...")
    
    embeddings=[]
    model.eval()
    for batch_input_ids, batch_attention_mask in tqdm(dataloader):
        with torch.no_grad():
            pooled_embeddings, _, _ = model(batch_input_ids.to(device) ,
                                            batch_attention_mask.to(device))
            embeddings.append(pooled_embeddings)
    
    
    concat_embeddings = torch.cat(embeddings, dim=0)
    
    print(concat_embeddings.size())
    
    return concat_embeddings

def getEmbeddingsWithGene2Vec(dataloader, model):
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
    embeddings=[]
    model.eval()
    for batch in tqdm(dataloader):
        with torch.no_grad():
            batch_inputs_a, batch_masks_a, gene2vec_a  =  batch[0].to(device) , batch[1].to(device), batch[2].to(device)
            pooled_embeddings, _ , _ = model(batch_inputs_a, batch_masks_a, gene2vec =gene2vec_a)
            embeddings.append(pooled_embeddings)
    
    
    concat_embeddings = torch.cat(embeddings, dim=0)
    
    print(concat_embeddings.size())
    
    return concat_embeddings


In [5]:
def getEmbeddings(text,
                  model = None,
                  model_name = "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext",
                  max_length=512,
                  batch_size=1000,
                  gene2vec_flag=False,
                  gene2vec_hidden=200,
                  pool ="mean"):
    
    
    
    
    

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    

    if isinstance(model, FineTunedBERT):
        print("Loading a pretrained model ...")
        
        model_name = model.model_name            
        tokenizer = BertTokenizerFast.from_pretrained(model_name)

    
    elif isinstance(model, collections.OrderedDict):
        print("Loading a pretrained model from a state dictionary ...")
        
        state_dict = model.copy() 
        
        model = FineTunedBERT(pool= pool,model_name = model_name,
                              gene2vec_flag= gene2vec_flag,
                              gene2vec_hidden = gene2vec_hidden,
                              task_type="unsupervised",
                              n_labels = 1,
                              device = device).to(device)

        
        for s in model.state_dict().keys():
            if s not in state_dict.keys():
                print(s)
        
        
        model.load_state_dict(state_dict)
        tokenizer = BertTokenizerFast.from_pretrained(model_name)
            
    else:
        print("Creating a new pretrained model ...")
        
        #BERT-base embeddings
        model = FineTunedBERT(pool= pool,
                              model_name=model_name,
                              task_type="unsupervised",
                              gene2vec_flag = False,
                              n_labels = 1).to(device)
  
        tokenizer = BertTokenizerFast.from_pretrained(model_name)


    model = nn.DataParallel(model)
    print("Tokenization ...")
    tokens = tokenizer.batch_encode_plus(text, max_length = max_length,
                                         padding="max_length",truncation=True,
                                         return_tensors="pt")
    
    
    dataset = TensorDataset(tokens["input_ids"] , tokens["attention_mask"])
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

    print("Tokenization Done.")
    
    print("Get Embeddings ...")
    
    embeddings=[]
    model.eval()
    for batch_input_ids, batch_attention_mask in tqdm(dataloader):
        with torch.no_grad():
            pooled_embeddings, _, _ = model(batch_input_ids.to(device) ,
                                            batch_attention_mask.to(device))
            embeddings.append(pooled_embeddings)
    
    
    concat_embeddings = torch.cat(embeddings, dim=0)
    
    print(concat_embeddings.size())
    
    return concat_embeddings

In [14]:
import pandas as pd
df_obesity = pd.read_csv('/data_link/macaulay/GeneLLM2/diseaseAssociation/cosineSimilarity/cosine_similarities_not_in_obesity.csv')    
df_obesity = df_obesity.iloc[:100]

df_asthma = pd.read_csv('/data_link/macaulay/GeneLLM2/diseaseAssociation/cosineSimilarity/cosine_similarities_not_in_asthma.csv')    
df_asthma = df_asthma.iloc[:100]

df_schizophrenia = pd.read_csv('/data_link/macaulay/GeneLLM2/diseaseAssociation/cosineSimilarity/cosine_similarities_not_in_schizophrenia.csv')    
df_schizophrenia = df_schizophrenia.iloc[:100]

df_hypertension = pd.read_csv('/data_link/macaulay/GeneLLM2/diseaseAssociation/cosineSimilarity/cosine_similarities_not_in_hypertension.csv')    
df_hypertension = df_hypertension.iloc[:100]

df_obesity['Disease'] = 'obesity'
df_asthma['Disease'] = 'asthma'
df_schizophrenia['Disease'] = 'schizophrenia'
df_hypertension['Disease'] = 'hypertension'

df = pd.concat([df_obesity, df_asthma, df_schizophrenia, df_hypertension], axis=0)
df = df.reset_index(drop=True)
df = df.dropna()
df

Unnamed: 0,Gene,Cosine Similarity,Disease
0,LEP,0.758912,obesity
1,THRSP,0.741904,obesity
2,JAZF1,0.720466,obesity
3,C1QTNF12,0.711311,obesity
4,LEPROT,0.708792,obesity
...,...,...,...
395,TRPC3,0.560457,hypertension
396,GPR182,0.560197,hypertension
397,KCNK18,0.558525,hypertension
398,CD36,0.558066,hypertension


In [17]:
import pandas as pd
import torch

def predict_pathway(df_relation):
    # Assuming the file path and column names are correct
    genes = pd.read_csv("/data/macaulay/GeneLLM2/data/clean_genes.csv")
    all_genes = genes["Gene name"].to_list()
    print("1")
    
    # Merge based on 'Gene' column from df_relation and 'Gene name' from genes
    df_merge = pd.merge(df_relation, genes, left_on='Gene', right_on='Gene name', how='left')
    display(df_merge)

    # Filter out rows without a summary
    df_merge = df_merge.dropna(subset=['Summary'])
    gene_to_summary = df_merge.set_index('Gene name')['Summary'].to_dict()

    gene_names = list(gene_to_summary.keys())
    gene_text = list(gene_to_summary.values())
    
    # Load your model and get embeddings
    loaded_model = torch.load("state_dict_0.pth")
    embed = getEmbeddings(gene_text, loaded_model, batch_size=2000).detach().cpu().numpy()
    
    # Create a DataFrame to map gene names to embeddings
    embeddings_df = pd.DataFrame(embed, index=gene_names)
    embeddings_df.reset_index(inplace=True)
    embeddings_df.rename(columns={'index': 'Gene name'}, inplace=True)
    
    # Now, embeddings_df contains each gene name and its corresponding embedding
    display(embeddings_df)
    return embeddings_df

    # Optional: Save the embeddings DataFrame to a CSV file
    

embeddings_df = predict_pathway(df)


1


Unnamed: 0,Gene,Cosine Similarity,Disease,Gene name,Summary
0,LEP,0.758912,obesity,LEP,This gene encodes a protein that is secreted b...
1,THRSP,0.741904,obesity,THRSP,The protein encoded by this gene is similar to...
2,JAZF1,0.720466,obesity,JAZF1,Acts as a transcriptional corepressor of orph...
3,C1QTNF12,0.711311,obesity,C1QTNF12,Predicted to enable hormone activity. Predicte...
4,LEPROT,0.708792,obesity,LEPROT,LEPROT is associated with the Golgi complex an...
...,...,...,...,...,...
395,TRPC3,0.560457,hypertension,TRPC3,The protein encoded by this gene is a membrane...
396,GPR182,0.560197,hypertension,GPR182,Adrenomedullin is a potent vasodilator peptide...
397,KCNK18,0.558525,hypertension,KCNK18,Potassium channels play a role in many cellula...
398,CD36,0.558066,hypertension,CD36,The protein encoded by this gene is the fourth...


Loading a pretrained model from a state dictionary ...
Tokenization ...
Tokenization Done.
Get Embeddings ...


100%|██████████| 1/1 [00:00<00:00, 72.17it/s]

torch.Size([395, 768])





Unnamed: 0,Gene name,0,1,2,3,4,5,6,7,8,...,758,759,760,761,762,763,764,765,766,767
0,LEP,0.401820,0.373210,0.138294,0.064681,-0.121166,0.039754,0.193633,0.331752,0.811940,...,-0.732070,0.192608,-0.063601,0.002144,-0.088668,0.084305,0.073466,-0.386494,-0.024886,0.503482
1,THRSP,0.233719,0.633899,-0.172188,0.674709,-0.197604,0.549218,0.074492,-0.007133,0.798039,...,-0.511492,0.181888,-0.223601,-0.159075,-0.110663,0.207832,-0.225697,-0.693335,0.079949,0.099870
2,JAZF1,0.092980,0.966914,0.071201,0.498121,-0.086072,0.286590,-0.025821,0.413296,1.255126,...,-0.805195,0.045931,-0.176672,-0.204219,-0.303561,0.253543,-0.412748,-0.677840,0.167411,0.162782
3,C1QTNF12,0.271985,0.343538,0.124735,0.357906,0.040200,0.301222,0.076676,0.309229,1.288095,...,-0.966583,0.094286,-0.166362,-0.227300,-0.174461,0.367383,-0.178256,-0.759278,-0.012383,0.497670
4,LEPROT,0.453366,0.293181,0.210744,0.342967,0.594061,-0.077142,0.312388,-0.207635,0.787018,...,-0.800582,-0.149457,0.062632,0.167698,0.137014,0.341324,0.098344,-0.614071,-0.015764,0.314205
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
390,TRPC3,-0.279923,-0.540060,-0.319850,0.274256,0.073394,-0.680006,-0.231466,0.120466,0.662462,...,-0.209290,-0.043997,0.445040,0.152601,0.232559,-0.050196,0.578297,0.077571,0.096527,0.375377
391,GPR182,-0.062727,-0.056829,0.341810,-0.242273,0.103200,0.044457,0.217844,-0.085437,0.738598,...,-0.210477,0.065162,-0.331619,-0.164482,-0.029032,-0.335184,-0.015981,-0.456751,0.284081,0.057050
392,KCNK18,0.631262,-0.484671,-0.535359,0.339540,0.364236,-0.858806,0.043966,0.367130,0.219604,...,0.088608,-0.376143,0.835560,0.205990,0.275130,-0.097203,0.798008,-0.437516,0.387648,0.728651
393,CD36,0.517931,0.427913,-0.437248,-0.078829,0.301689,0.215548,0.164973,0.246903,0.603171,...,-0.172504,0.283440,-0.299242,0.018467,0.021558,-0.230992,-0.114485,-0.222148,0.064239,0.415916


In [22]:
merged = pd.merge(df, embeddings_df, left_on='Gene', right_on= 'Gene name', how='left')
merged = merged.dropna(subset=[0])
merged = merged.drop(columns=['Gene name', 'Cosine Similarity'])
merged

Unnamed: 0,Gene,Disease,0,1,2,3,4,5,6,7,...,758,759,760,761,762,763,764,765,766,767
0,LEP,obesity,0.401820,0.373210,0.138294,0.064681,-0.121166,0.039754,0.193633,0.331752,...,-0.732070,0.192608,-0.063601,0.002144,-0.088668,0.084305,0.073466,-0.386494,-0.024886,0.503482
1,THRSP,obesity,0.233719,0.633899,-0.172188,0.674709,-0.197604,0.549218,0.074492,-0.007133,...,-0.511492,0.181888,-0.223601,-0.159075,-0.110663,0.207832,-0.225697,-0.693335,0.079949,0.099870
2,JAZF1,obesity,0.092980,0.966914,0.071201,0.498121,-0.086072,0.286590,-0.025821,0.413296,...,-0.805195,0.045931,-0.176672,-0.204219,-0.303561,0.253543,-0.412748,-0.677840,0.167411,0.162782
3,C1QTNF12,obesity,0.271985,0.343538,0.124735,0.357906,0.040200,0.301222,0.076676,0.309229,...,-0.966583,0.094286,-0.166362,-0.227300,-0.174461,0.367383,-0.178256,-0.759278,-0.012383,0.497670
4,LEPROT,obesity,0.453366,0.293181,0.210744,0.342967,0.594061,-0.077142,0.312388,-0.207635,...,-0.800582,-0.149457,0.062632,0.167698,0.137014,0.341324,0.098344,-0.614071,-0.015764,0.314205
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
395,TRPC3,hypertension,-0.279923,-0.540060,-0.319850,0.274256,0.073394,-0.680006,-0.231466,0.120466,...,-0.209290,-0.043997,0.445040,0.152601,0.232559,-0.050196,0.578297,0.077571,0.096527,0.375377
396,GPR182,hypertension,-0.062727,-0.056829,0.341810,-0.242273,0.103200,0.044457,0.217844,-0.085437,...,-0.210477,0.065162,-0.331619,-0.164482,-0.029032,-0.335184,-0.015981,-0.456751,0.284081,0.057050
397,KCNK18,hypertension,0.631262,-0.484671,-0.535359,0.339540,0.364236,-0.858806,0.043966,0.367130,...,0.088608,-0.376143,0.835560,0.205990,0.275130,-0.097203,0.798008,-0.437516,0.387648,0.728651
398,CD36,hypertension,0.517931,0.427913,-0.437248,-0.078829,0.301689,0.215548,0.164973,0.246903,...,-0.172504,0.283440,-0.299242,0.018467,0.021558,-0.230992,-0.114485,-0.222148,0.064239,0.415916


In [24]:
df = merged
df.drop_duplicates(subset=[0], inplace=True)
gene_solubility = df.iloc[:, :2]  # This holds the gene name and solubility
embeddings = df.iloc[:, 2:]
  # This holds the embeddings

# Step 4: Remove Duplicate Embeddings
# Dropping duplicate embeddings and keeping the first occurrence
unique_embeddings = embeddings.drop_duplicates()

# Step 5: Rejoin Data
# Since we want to join it back with the corresponding gene names and solubility,
# we need to find the indices of the rows in the original dataframe that match the unique embeddings.
# Then use these indices to extract the corresponding gene names and solubility.
unique_indices = unique_embeddings.index
unique_gene_solubility = gene_solubility.loc[unique_indices]

# Now, combine the gene names/solubility with the unique embeddings
final_df = pd.concat([unique_gene_solubility.reset_index(drop=True), unique_embeddings.reset_index(drop=True)], axis=1)


final_df.to_csv("/data/macaulay/GeneLLM2/data/obesity_disease_gene_embeddings.csv", index=False)
final_df

Unnamed: 0,Gene,Disease,0,1,2,3,4,5,6,7,...,758,759,760,761,762,763,764,765,766,767
0,LEP,obesity,0.401820,0.373210,0.138294,0.064681,-0.121166,0.039754,0.193633,0.331752,...,-0.732070,0.192608,-0.063601,0.002144,-0.088668,0.084305,0.073466,-0.386494,-0.024886,0.503482
1,THRSP,obesity,0.233719,0.633899,-0.172188,0.674709,-0.197604,0.549218,0.074492,-0.007133,...,-0.511492,0.181888,-0.223601,-0.159075,-0.110663,0.207832,-0.225697,-0.693335,0.079949,0.099870
2,JAZF1,obesity,0.092980,0.966914,0.071201,0.498121,-0.086072,0.286590,-0.025821,0.413296,...,-0.805195,0.045931,-0.176672,-0.204219,-0.303561,0.253543,-0.412748,-0.677840,0.167411,0.162782
3,C1QTNF12,obesity,0.271985,0.343538,0.124735,0.357906,0.040200,0.301222,0.076676,0.309229,...,-0.966583,0.094286,-0.166362,-0.227300,-0.174461,0.367383,-0.178256,-0.759278,-0.012383,0.497670
4,LEPROT,obesity,0.453366,0.293181,0.210744,0.342967,0.594061,-0.077142,0.312388,-0.207635,...,-0.800582,-0.149457,0.062632,0.167698,0.137014,0.341324,0.098344,-0.614071,-0.015764,0.314205
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
389,TRPC3,hypertension,-0.279923,-0.540060,-0.319850,0.274256,0.073394,-0.680006,-0.231466,0.120466,...,-0.209290,-0.043997,0.445040,0.152601,0.232559,-0.050196,0.578297,0.077571,0.096527,0.375377
390,GPR182,hypertension,-0.062727,-0.056829,0.341810,-0.242273,0.103200,0.044457,0.217844,-0.085437,...,-0.210477,0.065162,-0.331619,-0.164482,-0.029032,-0.335184,-0.015981,-0.456751,0.284081,0.057050
391,KCNK18,hypertension,0.631262,-0.484671,-0.535359,0.339540,0.364236,-0.858806,0.043966,0.367130,...,0.088608,-0.376143,0.835560,0.205990,0.275130,-0.097203,0.798008,-0.437516,0.387648,0.728651
392,CD36,hypertension,0.517931,0.427913,-0.437248,-0.078829,0.301689,0.215548,0.164973,0.246903,...,-0.172504,0.283440,-0.299242,0.018467,0.021558,-0.230992,-0.114485,-0.222148,0.064239,0.415916


In [None]:
MESH:D001249
MESH:D006973
MESH:D009765
MESH:D012559

df_obesity['Disease'] = 'obesity'
df_asthma['Disease'] = 'asthma'
df_schizophrenia['Disease'] = 'schizophrenia'
df_hypertension['Disease'] = 'hypertension'

In [22]:
import pandas as pd

df_mgi = pd.read_csv('disease_summary_MGI2.txt', sep='\t', header=None)
df_mgi.columns = ['DO Disease Name', 'id', 'did','summary']
df_mgi = df_mgi.dropna().reset_index(drop=True)
# df_mgi
# asthma_rows = df_mgi[df_mgi['DO Disease Name'].str.contains("schizophrenia", case=False)]

# # Display the filtered rows
# asthma_rows

mesh_ids = ['MESH:D001249', 'MESH:D006973', 'MESH:D009765', 'MESH:D012559']

# Filter df_mgi to keep rows where the 'id' column matches any of the MESH IDs in the list
df = df_mgi[df_mgi['id'].isin(mesh_ids)]
df

Unnamed: 0,DO Disease Name,id,did,summary
1041,hypertension,MESH:D006973,DOID:10763,An artery disease characterized by chronic ele...
1479,asthma,MESH:D001249,DOID:2841,A bronchial disease that is characterized by c...
1723,schizophrenia,MESH:D012559,DOID:5419,A psychotic disorder that is characterized by ...
1939,obesity,MESH:D009765,DOID:9970,An overnutrition that is characterized by exce...


In [23]:
diseas_list = df['summary'].to_list()
diseas_name = df['DO Disease Name'].to_list()
diseas_name

['hypertension', 'asthma', 'schizophrenia', 'obesity']

In [26]:
import pandas as pd
import torch
import time

def predict_pathway(gene_names,  gene_text):
   
    # Load your model and get embeddings
    loaded_model = torch.load("state_dict_0.pth")
    embed = getEmbeddings(gene_text, loaded_model, batch_size=2000).detach().cpu().numpy()
    
    # Create a DataFrame to map gene names to embeddings
    embeddings_df = pd.DataFrame(embed, index=gene_names)
    embeddings_df.reset_index(inplace=True)
    embeddings_df.rename(columns={'index': 'Disease name'}, inplace=True)
    
    # Now, embeddings_df contains each gene name and its corresponding embedding
    display(embeddings_df)
    return embeddings_df

    # Optional: Save the embeddings DataFrame to a CSV file
    

embeddings_df = predict_pathway(diseas_name, diseas_list)
embeddings_df.to_csv("disease_embeddings.csv", index=False)


Loading a pretrained model from a state dictionary ...
Tokenization ...
Tokenization Done.
Get Embeddings ...


100%|██████████| 1/1 [00:00<00:00, 89.19it/s]

torch.Size([4, 768])





Unnamed: 0,Disease name,0,1,2,3,4,5,6,7,8,...,758,759,760,761,762,763,764,765,766,767
0,hypertension,0.708679,0.435822,-0.507633,0.396858,-0.458109,-0.739938,0.288066,0.29628,0.622939,...,0.007841,0.167867,0.02446,-0.409624,-0.27554,0.31254,-0.274273,-0.483903,0.965353,0.040116
1,asthma,0.910803,0.233318,-0.27474,0.028964,0.055083,0.483768,0.692197,0.567663,0.574138,...,0.198966,0.69991,0.999969,-0.298612,0.424925,0.174062,-0.094743,0.419234,0.248716,1.413107
2,schizophrenia,0.044523,0.361688,0.748994,-0.423147,0.599872,-0.297578,0.061553,-0.588028,0.093129,...,-0.449921,0.010691,0.524088,-0.710619,-0.053389,0.068494,0.714823,0.128198,0.710103,0.513897
3,obesity,0.421315,0.004165,0.185262,-0.039173,0.35373,0.337592,0.10168,-0.259099,0.380311,...,-1.013368,0.303655,-0.180275,0.215209,-0.780538,0.356815,-0.196825,-0.991935,0.548198,0.831418


In [1]:
import pandas as pd
df = pd.read_csv('go_terms_summaary.csv')
df

Unnamed: 0,GO Term ID,Def Summary
0,GO:0000001,"The distribution of mitochondria, including th..."
1,GO:0000002,The maintenance of the structure and integrity...
2,GO:0000003,The production of new individuals that contain...
3,GO:0000005,OBSOLETE. Assists in the correct assembly of r...
4,GO:0000006,Enables the transfer of zinc ions (Zn2+) from ...
...,...,...
47590,GO:2001313,The chemical reactions and pathways involving ...
47591,GO:2001314,The chemical reactions and pathways resulting ...
47592,GO:2001315,The chemical reactions and pathways resulting ...
47593,GO:2001316,The chemical reactions and pathways involving ...


In [2]:
go_term_list = df['Def Summary'].to_list()
go_term_name = df['GO Term ID'].to_list()
go_term_name

['GO:0000001',
 'GO:0000002',
 'GO:0000003',
 'GO:0000005',
 'GO:0000006',
 'GO:0000007',
 'GO:0000008',
 'GO:0000009',
 'GO:0000010',
 'GO:0000011',
 'GO:0000012',
 'GO:0000014',
 'GO:0000015',
 'GO:0000016',
 'GO:0000017',
 'GO:0000018',
 'GO:0000019',
 'GO:0000020',
 'GO:0000022',
 'GO:0000023',
 'GO:0000024',
 'GO:0000025',
 'GO:0000026',
 'GO:0000027',
 'GO:0000028',
 'GO:0000030',
 'GO:0000031',
 'GO:0000032',
 'GO:0000033',
 'GO:0000034',
 'GO:0000035',
 'GO:0000036',
 'GO:0000038',
 'GO:0000039',
 'GO:0000041',
 'GO:0000044',
 'GO:0000045',
 'GO:0000047',
 'GO:0000048',
 'GO:0000049',
 'GO:0000050',
 'GO:0000051',
 'GO:0000052',
 'GO:0000053',
 'GO:0000054',
 'GO:0000055',
 'GO:0000056',
 'GO:0000059',
 'GO:0000060',
 'GO:0000061',
 'GO:0000062',
 'GO:0000064',
 'GO:0000067',
 'GO:0000070',
 'GO:0000072',
 'GO:0000073',
 'GO:0000075',
 'GO:0000076',
 'GO:0000077',
 'GO:0000078',
 'GO:0000079',
 'GO:0000080',
 'GO:0000082',
 'GO:0000083',
 'GO:0000084',
 'GO:0000085',
 'GO:00000

In [6]:
import pandas as pd
import torch
import time

def predict_pathway(gene_names,  gene_text):
   
    # Load your model and get embeddings
    loaded_model = torch.load("state_dict_0.pth")
    embed = getEmbeddings(gene_text, loaded_model, batch_size=2000).detach().cpu().numpy()
    
    # Create a DataFrame to map gene names to embeddings
    embeddings_df = pd.DataFrame(embed, index=gene_names)
    embeddings_df.reset_index(inplace=True)
    embeddings_df.rename(columns={'index': 'Disease name'}, inplace=True)
    
    # Now, embeddings_df contains each gene name and its corresponding embedding
    display(embeddings_df)
    return embeddings_df

    # Optional: Save the embeddings DataFrame to a CSV file
    

embeddings_df = predict_pathway(go_term_name, go_term_list)
embeddings_df.to_csv("go_terms_embeddings.csv", index=False)
display(embeddings_df)


Loading a pretrained model from a state dictionary ...
Tokenization ...
Tokenization Done.
Get Embeddings ...


100%|██████████| 24/24 [05:09<00:00, 12.89s/it]


torch.Size([47595, 768])


Unnamed: 0,Disease name,0,1,2,3,4,5,6,7,8,...,758,759,760,761,762,763,764,765,766,767
0,GO:0000001,-1.168093,-0.355214,0.265877,-0.710051,0.515028,-0.525165,-0.186588,-0.161192,0.186984,...,-1.350874,-0.991801,-0.648123,-0.361629,-0.914965,-0.506993,0.389760,0.207266,0.070705,0.938594
1,GO:0000002,-1.185879,-0.098765,0.388240,-0.295556,0.327296,-0.119842,0.399882,-0.035890,0.853417,...,-1.086927,-0.842870,-0.385764,0.175797,-1.223772,-0.999628,0.101473,-0.051212,0.048775,0.780469
2,GO:0000003,0.063323,-0.199995,0.151511,-0.942141,0.109313,0.015316,0.633298,0.507875,0.665548,...,0.174185,0.351648,0.138497,0.119273,-0.295167,-0.331179,0.102570,-0.524301,-0.139264,0.761573
3,GO:0000005,0.163135,0.301527,0.219680,0.094342,-0.129769,0.225696,0.357577,0.819992,0.852388,...,-0.084025,-0.291103,-0.003621,0.245929,-0.443244,0.229245,-0.685159,-0.725621,0.285964,0.313210
4,GO:0000006,-0.641113,-0.541363,0.413941,0.699345,0.461507,-0.497388,-0.044589,-0.655766,-0.596647,...,-0.561434,0.246475,-0.029871,-0.212828,-0.985273,0.677472,0.582681,0.299317,-0.131577,0.739702
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
47590,GO:2001313,0.174428,0.194728,-0.284376,0.282102,-0.713190,-0.272055,0.121190,0.129901,-0.983496,...,0.500545,0.429651,-0.292929,-0.464941,-0.740187,0.179149,-0.960807,-0.746958,1.069112,-0.848182
47591,GO:2001314,0.025886,0.306214,-0.254303,0.253673,-0.533680,-0.269355,0.150939,-0.229323,-1.078991,...,0.042979,0.134560,-0.356661,-0.381828,-0.638337,0.077176,-0.788312,-0.683442,1.087031,-0.593092
47592,GO:2001315,0.027134,0.241391,-0.227353,0.317366,-0.726656,-0.197968,0.045653,0.038912,-0.954113,...,0.349853,0.370059,-0.144607,-0.493184,-0.655063,0.217335,-0.841272,-0.821077,1.036363,-0.836614
47593,GO:2001316,0.139543,0.028883,0.899480,0.152932,0.576852,0.330342,0.916943,0.012306,-0.020316,...,-0.354748,-0.083168,0.043640,-0.663565,0.543016,-0.652230,-1.427881,-0.985257,1.673561,0.109659


Unnamed: 0,Disease name,0,1,2,3,4,5,6,7,8,...,758,759,760,761,762,763,764,765,766,767
0,GO:0000001,-1.168093,-0.355214,0.265877,-0.710051,0.515028,-0.525165,-0.186588,-0.161192,0.186984,...,-1.350874,-0.991801,-0.648123,-0.361629,-0.914965,-0.506993,0.389760,0.207266,0.070705,0.938594
1,GO:0000002,-1.185879,-0.098765,0.388240,-0.295556,0.327296,-0.119842,0.399882,-0.035890,0.853417,...,-1.086927,-0.842870,-0.385764,0.175797,-1.223772,-0.999628,0.101473,-0.051212,0.048775,0.780469
2,GO:0000003,0.063323,-0.199995,0.151511,-0.942141,0.109313,0.015316,0.633298,0.507875,0.665548,...,0.174185,0.351648,0.138497,0.119273,-0.295167,-0.331179,0.102570,-0.524301,-0.139264,0.761573
3,GO:0000005,0.163135,0.301527,0.219680,0.094342,-0.129769,0.225696,0.357577,0.819992,0.852388,...,-0.084025,-0.291103,-0.003621,0.245929,-0.443244,0.229245,-0.685159,-0.725621,0.285964,0.313210
4,GO:0000006,-0.641113,-0.541363,0.413941,0.699345,0.461507,-0.497388,-0.044589,-0.655766,-0.596647,...,-0.561434,0.246475,-0.029871,-0.212828,-0.985273,0.677472,0.582681,0.299317,-0.131577,0.739702
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
47590,GO:2001313,0.174428,0.194728,-0.284376,0.282102,-0.713190,-0.272055,0.121190,0.129901,-0.983496,...,0.500545,0.429651,-0.292929,-0.464941,-0.740187,0.179149,-0.960807,-0.746958,1.069112,-0.848182
47591,GO:2001314,0.025886,0.306214,-0.254303,0.253673,-0.533680,-0.269355,0.150939,-0.229323,-1.078991,...,0.042979,0.134560,-0.356661,-0.381828,-0.638337,0.077176,-0.788312,-0.683442,1.087031,-0.593092
47592,GO:2001315,0.027134,0.241391,-0.227353,0.317366,-0.726656,-0.197968,0.045653,0.038912,-0.954113,...,0.349853,0.370059,-0.144607,-0.493184,-0.655063,0.217335,-0.841272,-0.821077,1.036363,-0.836614
47593,GO:2001316,0.139543,0.028883,0.899480,0.152932,0.576852,0.330342,0.916943,0.012306,-0.020316,...,-0.354748,-0.083168,0.043640,-0.663565,0.543016,-0.652230,-1.427881,-0.985257,1.673561,0.109659
