# **Define Model**

### Instructions

**1.** Put model ```.py``` file in this directory


**2.** If you want to use code as is and simply change ```input_data_path```,


Put data as csv in the following format (as seen in ```/mnt/data/GeneLLM/data/solubility.csv```):


    Gene name, StrLabel, Label, Summary

    LIME1, Membrane, 0, "This genomic region..." 
    
    TMEM219, Membrane, 0, "Ceramidases (EC 3.5.1.23) ..." 


Otherwise, somehow load a list of ```sentences``` strings, a list of ```labels``` integers, and get the number of labels ```n_labels```


**3.** Check model hyperparameters


**4.** Run the next block

In [2]:
from shapleyAnalysis import *
import sys
import os

## IMPORT MODEL AND TOKENIZER ##
from BERT import * 
from transformers import BertTokenizerFast
state_dict_path = 'best_model_state_dict.pth'

## DATA ##
input_data_path = "clean_genes.csv"
task_type = "classification"
gene_loaded_data = pd.read_csv(input_data_path)
n_labels = 2
sentences = gene_loaded_data["Summary"].tolist()
geneNames = gene_loaded_data["Gene name"].tolist()
gene_to_idx = {gene:idx for idx, gene in enumerate(geneNames)}

## HYPERPARAMETERS ## 
pool = "mean"
drop_rate =0.1
model_name = 'microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext'
gene2vec_flag = False
gene2vec_hidden = 200
device = "cuda"
tokenizer_max_length = 512

## INITIALIZE MODEL AND TOKENIZER ##
model = FineTunedBERT(pool= pool, 
                      task_type = task_type, 
                      n_labels = n_labels,
                      drop_rate = drop_rate, 
                      model_name = model_name,
                      gene2vec_flag= gene2vec_flag,
                      gene2vec_hidden = gene2vec_hidden).to(device)

model.load_state_dict(torch.load(state_dict_path))

tokenizer = BertTokenizerFast.from_pretrained(model.model_name)

[nltk_data] Downloading package stopwords to
[nltk_data]     /home/dandreas/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


# **Shapley Analysis**

### Instructions

The analysis will save two files to the current directory.

Change the ```dataset_name``` to match your dataset. This will be the prefix of the files saved by the analysis.

In [3]:
#loading
data_dir = 'enrichment_analysis/'
data_file_root_name = 'GeneLLM_all_cluster'

#saving
save_dir = '/home/dandreas/GeneLLM2/data/BIAS_enrichment_analysis_GeneLLM_clusters/'

all_idxs=[]
for i in range(26):
    print(i)
    file_path = data_dir+data_file_root_name+str(i)+'.txt'
    with open(file_path,'r') as clusterfile:
        cluster = [line.strip() for line in clusterfile.readlines()]
    idxs = [gene_to_idx[g] for g in cluster]
    all_idxs+=idxs
    clusterSentences = [sentences[idx] for idx in idxs]
    dataset_name='cluster'+str(i)
    # token_analysis, word_analysis = getShapleyAnalysis_classification(clusterSentences, model, tokenizer, tokenizer_max_length, device, dataset_name, save_dir)
print(len(all_idxs))
print(len())

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
12955


In [None]:
### AGGREGATE SHAP VALUES ###
def concat(x,y):
    z=shap.Explanation(values=[])
    z.values = np.concatenate((x.values,y.values),axis=0)
    z.base_values = np.concatenate((x.base_values,y.base_values),axis=0)
    z.data = x.data+y.data
    return z

shap_values_all= None
for i in range(26):
    file_path = save_dir + 'cluster' + str(i) + '_shap_values.pkl'
    shap_values = pickleLoad(file_path)
    if shap_values_all is None: shap_values_all = shap.Explanation(values = shap_values.values, base_values=shap_values.base_values, data=shap_values.data)
    else:  
        shap_values_all = concat(shap_values_all, shap_values)

pickleSave(shap_values_all,save_dir,'all_shap_values.pkl')

# dataset_name = 'all'
# token_analysis, token_analysis_indexes = getShapValuesDictsAndIndexes(shap_values_all)
# pickleSave(token_analysis,save_dir,dataset_name+'_token_analysis.pkl')
# pickleSave(token_analysis_indexes,save_dir,dataset_name+'_token_analysis_indexes.pkl')

# shap_values_grouped_by_word_all = getShapValuesGroupedByWord(shap_values_all)
# word_analysis, word_analysis_indexes = getShapValuesDictsAndIndexes_groupedByWordSum(shap_values_grouped_by_word_all)
# pickleSave(word_analysis,save_dir,dataset_name+'_word_analysis.pkl')
# pickleSave(word_analysis_indexes,save_dir,dataset_name+'_word_analysis_indexes.pkl')

# **Plot**

### Instructions

Pass the analysis you want to plot (i.e., either ```token_anlysis``` or ```word_analysis``` above)

In [None]:
# default values, uncomment to change and pass as keyword argument to generatePlots()
# nToPlot = 100 
# percentile=90 
# minOccurances = 10 
# collapsePlural = True 
# minStringLength = 1 
# stops = nltkStopwords 
# saveName=None # change savename to a string in order to save plots

generatePlots(word_analysis)