In [4]:
import sys
sys.path.insert(0, 'src') 
import os

In [5]:
%reload_ext autoreload
%autoreload 2

In [6]:
import argparse
from tot.methods.bfs import solve
from tot.tasks.bio_name import Bio_Name

In [7]:
# args = argparse.Namespace(backend='gpt-4-1106-preview', temperature=0.7, task='bio_name', naive_run=False, prompt_sample=None, method_generate='sample_bionames', method_evaluate='votes_for_bionames', method_select='greedy', n_generate_sample=3, n_evaluate_sample=2, n_select_sample=2)
args = argparse.Namespace(backend='gpt-3.5-turbo-1106', temperature=0.7, task='bio_name', naive_run=False, prompt_sample=None, method_generate='sample_bionames', method_evaluate='votes_for_bionames', method_select='greedy', n_generate_sample=3, n_evaluate_sample=2, n_select_sample=2)
task = Bio_Name()

In [8]:
# import pandas as pd
# filename = 'src/tot/data/gene_sets/gene_sets.csv'
# df = pd.read_csv(filename, header=None, encoding='latin1')
# df.dropna(inplace=True)
# df.columns = ['_', '_', 'genes', 'count', 'process']

In [9]:
# x = df['genes'].tolist()
# y = df['process'].tolist()
# with open('src/tot/data/gene_sets/x.txt', 'w') as f:
#     for el in x:
#         f.write(el + '\n')
        
# with open('src/tot/data/gene_sets/y.txt', 'w') as f:
#     for el in y:
#         f.write(el + '\n')

In [10]:
from transformers import AutoTokenizer, AutoModel
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
import torch

SapBERT_tokenizer = AutoTokenizer.from_pretrained('cambridgeltl/SapBERT-from-PubMedBERT-fulltext')
SapBERT_model = AutoModel.from_pretrained('cambridgeltl/SapBERT-from-PubMedBERT-fulltext')

In [11]:
def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[0] #First element of model_output contains all token embeddings
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

In [12]:
def getSentenceEmbedding(sentence, tokenizer, model):
    # Tokenize sentences
    encoded_input = tokenizer(sentence, padding=True, truncation=True, return_tensors='pt')

    # Compute token embeddings
    with torch.no_grad():
        model_output = model(**encoded_input)
        
    # Perform pooling. In this case, mean pooling.
    sentence_embedding = mean_pooling(model_output, encoded_input['attention_mask'])
    
    return sentence_embedding

In [13]:
def getSentenceSimilarity(sentence1, sentence2, tokenizer, model, simMetric):
    sentence1_embedding = getSentenceEmbedding(sentence1, tokenizer, model)
    sentence2_embedding = getSentenceEmbedding(sentence2, tokenizer, model)
    
    if simMetric == "cosine_similarity":
        sentenceSim = cosine_similarity(sentence1_embedding, sentence2_embedding)[0][0]
    # ToDo: add other simMetrics
    #elif simMetric == "cosine_similarity_primitive": # use primitive operations
   #     sentenceSim = np.dot(sentence1_embedding, sentence2_embedding)/(norm(sentence1_embedding)*norm(sentence2_embedding))
    
    return sentenceSim, sentence1_embedding, sentence2_embedding

In [14]:
def similarity_score(x, y):
    return getSentenceSimilarity(x, y, SapBERT_tokenizer, SapBERT_model, "cosine_similarity")[0]

In [15]:
def test_example(args, task, idx):
    y = task.get_label(idx)
    print('Final answer:', y)
    final_answer, _, _ = solve(args, task, idx)
    print('Final answer:', final_answer)
    print('True answer:', y)
    print('Similarity score:', similarity_score(final_answer, y))


In [16]:
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"


In [36]:
test_example(args, task, 2)

Final answer: circadian rhythm

-- step 0 --
 -- select ids --
[0, 1]
-- paths --
Gene Expression
Cell Signaling
Metabolic Pathways
-- new_ys --: ['Gene Expression', 'Cell Signaling', 'Metabolic Pathways']
-- sol values --: (1, 1, 0)
-- choices --: ['Gene Expression', 'Cell Signaling']

-- step 1 --
 -- select ids --
[0, 3]
 -- select relations --
[['', 'is a'], ['', 'is a']]
 -- omit relations --
[['', 'part of'], ['', 'part of'], ['', 'regulates'], ['', 'part of']]
-- paths --
Gene Expression -> Transcription
Gene Expression -> Translation
Gene Expression -> RNA Splicing
Cell Signaling -> G Protein-Coupled Receptor Signaling
Cell Signaling -> Circadian Rhythm Signaling
Cell Signaling -> Insulin Signaling Pathway
-- new_ys --: ['Transcription', 'G Protein-Coupled Receptor Signaling', 'Translation', 'RNA Splicing', 'Circadian Rhythm Signaling', 'Insulin Signaling Pathway']
-- sol values --: (1, 1, 0, 0, 0, 0)
-- choices --: ['Transcription', 'G Protein-Coupled Receptor Signaling']

-- 

In [None]:
# final_answer, ys, infos = solve(args, task, 0)
# print(ys[0])

In [None]:
# ys, infos = solve(args, task, 0)
# print(ys[0])