In [1]:
import pandas as pd
import numpy as np

import os

import nltk
from nltk.tokenize import TweetTokenizer
from nltk.corpus import stopwords 

from transformers import BertForSequenceClassification, BertTokenizer, BertForMaskedLM

from simpletransformers.language_modeling import LanguageModelingModel

from sklearn.metrics.pairwise import cosine_similarity, paired_euclidean_distances
from sklearn.metrics.pairwise import euclidean_distances
from sklearn.preprocessing import normalize, StandardScaler, MinMaxScaler

from tqdm import tqdm
import torch

import networkx as nx

import matplotlib.pyplot as plt
%matplotlib inline

import plotly.graph_objects as go
from functools import partial

import pickle

from collections import deque

stop_words = set(stopwords.words('english')) 

import time

%load_ext autoreload

%autoreload 2

from utils import *
from plotting import *

import marshal

In [2]:
from utils import *
import time

In [3]:
model = BertForSequenceClassification.from_pretrained('/data1/roshansk/Exp1/checkpoint-141753-epoch-1', output_hidden_states= True)

tokenizer = BertTokenizer.from_pretrained('/data1/roshansk/Exp1/checkpoint-141753-epoch-1')

In [14]:
class ADRModel(object):
    
    def __init__(self, df, model, tokenizer, graph, outputFolder, combinedOutputFolder, modelOutputFolder = './', queue=None, useMasterEmb = False, 
                 masterContrib = 0.5, embeddingType='last4sum',
                 numThreshold= 10000, saveEveryDepth = False,
                numComp = 10000):
        
        self.df = df
        self.model = model
        self.tokenizer = tokenizer
        self.graph = graph
        self.outputFolder = outputFolder
        self.combinedOutputFolder = combinedOutputFolder
        self.embeddingType = embeddingType
        self.numThreshold = numThreshold
        self.saveEveryDepth = saveEveryDepth
        self.modelOutputFolder = modelOutputFolder
        self.numComp = numComp
        
        if queue is None:
            self.q = deque()
        else:
            self.q = queue
            
            
        self.masterEmb = None
        
        self.useMasterEmb = useMasterEmb
        self.masterContrib = masterContrib
        
        self.generateStates()
        
        
    def generateStates(self):
        
        
        for i in tqdm(range(len(self.df))):
            
            if os.path.exists(os.path.join(self.outputFolder, f"{i}.msh")):
                continue


            tokens = self.tokenizer.encode(self.df.iloc[i]['message'].lower())
            decoded = self.tokenizer.decode(tokens).split(" ")
            logits, hidden_states = self.model(torch.Tensor(tokens).unsqueeze(0).long())

            hidden_states = torch.stack(hidden_states).squeeze(1).permute(1,0,2)

            
            if self.embeddingType == 'last4sum':
                embedding = torch.sum(hidden_states[:,9:13,:],1)
            elif self.embeddingType =='last4concat':
                embedding = hidden_states[tokenIndex,9:13,:].reshape(-1)
            elif self.embeddingType == 'secondlast':
                embedding = hidden_states[tokenIndex,-2,:]
            else:
                embedding = hidden_states[tokenIndex,-1,:]
                    
                    
            embedding = embedding.detach().cpu().numpy()
            
            marshal.dump(embedding.tolist(), open(os.path.join(self.outputFolder, f"{i}.msh"), 'wb'))
        
        
        
        
    def getSymptomEmbedding(self, symptom, subset = None):
    
        embeddingList = []
        messageList = []

#         if subset is not None:
#             self.df = self.df.iloc[subset]

#         if type(df) == pd.Series:
#             self.df = pd.DataFrame(self.df).T

#         symptomToken = self.tokenizer.encode(symptom)[1]
        symptomToken = self.tokenizer.convert_tokens_to_ids(symptom)

        for i in range(len(self.df)):

            if symptomToken in self.tokenizer.encode(self.df.iloc[i]['message'].lower()):

                tokens = self.tokenizer.encode(self.df.iloc[i]['message'].lower())
                decoded = self.tokenizer.decode(tokens).split(" ")

                hidden_states = np.array(marshal.load( open(os.path.join(self.outputFolder, f"{i}.msh"), 'rb') ))

                try:
                    tokenIndex = tokens.index(symptomToken)
                except:
                    a= 1
                    continue

 
                embedding = hidden_states[tokenIndex,:]

                embeddingList.append(embedding)
                messageList.append(self.df.iloc[i]['message'].lower())

                if len(embeddingList)==30:
                    break



        return embeddingList, messageList
    
    
    
    def getSimilarWords(self, symptom, meanEmb, similarityThreshold = 0.3):
        
        output = []
        
        symptomToken = self.tokenizer.encode(symptom)[1]
        
        fileList = os.listdir(self.combinedOutputFolder)
        
        examineCount = 0
        
        for i in range(len(fileList)):
            
            if examineCount >= self.numThreshold:
                break
            
            
            filename = os.path.join(self.combinedOutputFolder, f"{i}.pkl")
            subDict = pickle.load(open(filename,'rb'))
            
            IDList = subDict['id']
            tokenList = subDict['token']
            embList = subDict['emb']

            sim = np.round(cosine_similarity(embList, meanEmb.reshape(1,-1)).reshape(-1),4)

            index= np.where([sim> similarityThreshold])[1]

            tokenList_ = tokenList[index]
            IDList_ = IDList[index]
            simList = sim[index]

            out = [(x,y,z) for x,y,z in zip(tokenList_, simList, IDList_)]
            
            output += out
            
            examineCount += self.numComp
            
        return output
        
    
    
    def getOutput(self, out):
    
        output = out

        outMap = {}

        for i in range(len(output)):
            if output[i][0] in outMap:
                outMap[output[i][0]].append(output[i][1])
            else:
                outMap[output[i][0]] = [output[i][1]]


        outMap_ = {}

        for i in range(len(output)):
            if output[i][0] in outMap_:
                outMap_[output[i][0]].append(output[i][2])
            else:
                outMap_[output[i][0]] = [output[i][2]]


        outputDf = []

        for key in outMap.keys():
            length = len(outMap[key])
            mean = np.mean(outMap[key])

            outputDf.append([key, length, mean])

        outputDf = pd.DataFrame(outputDf)
        outputDf.columns = ['word','counts','mean_sim']
        outputDf = outputDf.sort_values('mean_sim', ascending=False)

        return outputDf, outMap, outMap_
    
    
    
    
    def exploreNode(self, word, depth, maxDepth = 3, topk = 5):

    
        self.graph.addNode(word,0,depth)

        print(f"Depth : {depth} Exploring {word}")

        if depth == maxDepth:
            print("Reached max depth")
            return

        keyWord = word

        token = self.tokenizer.encode(keyWord)[1]

        if self.graph[word].vector is None:

            inEdgeList = self.graph[word].edges_in

            if len(inEdgeList)==0:
                textIDList = None
            else:
                textIDList = []

                for edge in inEdgeList:
                    textIDList.append(self.graph.edgeList[edge].textID)

                textIDList = list(set(list(itertools.chain.from_iterable(textIDList))))

            
            embList,msgList = self.getSymptomEmbedding(keyWord, subset = textIDList)

            meanEmb = np.array(embList)
            meanEmb = np.mean(meanEmb,0)


            self.graph[word].vector = meanEmb
            
            if self.masterEmb is None:
                self.masterEmb = meanEmb
            
            dist = getCosineDist(meanEmb, self.masterEmb)
            
            self.graph[word].masterDist = dist

        else:
            meanEmb = self.graph[word].vector
            
            if self.masterEmb is None:
                self.masterEmb = meanEmb
                
            dist = getCosineDist(meanEmb, self.masterEmb)
            
            self.graph[word].masterDist = dist


        symptom_ =''
        embList_ = meanEmb

        if self.useMasterEmb:
            
            finalEmb = self.masterContrib*self.masterEmb + (1 - self.masterContrib)*meanEmb
            
            out = self.getSimilarWords( symptom_, finalEmb , similarityThreshold = 0.3)
        else:
            out = self.getSimilarWords( symptom_, meanEmb, similarityThreshold = 0.3)

        outputDf, outMap, outMap_ = self.getOutput(out)

        outputDf = outputDf[outputDf.word!=keyWord]
    #     outputDf = outputDf[~outputDf.word.isin(list(graph.wordMap.keys()))]
        outputDf = outputDf.sort_values('mean_sim', ascending=False)
        outputDf = outputDf.head(topk)

        outputDf = outputDf[outputDf.mean_sim>0.4]

        print(outputDf)
        print("-----------------------")

        for i in range(len(outputDf)):

            word = outputDf.iloc[i]['word']
            numCount = outputDf.iloc[i]['counts']
            weight = outputDf.iloc[i]['mean_sim']
            textIDs = outMap_[word]

            wordList = set(self.graph.wordMap.keys())

            self.graph.addNode(word,0,depth+1)
            self.graph[word].textIDList.append(textIDs)
            self.graph.addEdge(keyWord, word, numCount, weight, textIDs)

            if word in wordList:
                continue

#             if "#" in word:
#                 continue


            self.q.append((word, depth+1))
            
            
    def trainModel(self, maxDepth = 3, topk = 5):
        
        currDepth = 0
        
        while len(self.q)>0:
            token, depth = self.q.popleft()
            
            if depth> currDepth:
                
                if self.saveEveryDepth:
                    filepath = os.path.join( self.modelOutputFolder, f"depth_{currDepth}.pkl")
                    self.saveModel(filepath)
                
                self.getMeanEmbedding(depth-1)
                currDepth += 1
            
            self.exploreNode(word = token, depth = depth, maxDepth=maxDepth, topk=topk)
        
        #Saving final model
        filepath = os.path.join(self.modelOutputFolder, "final.pkl")
        self.saveModel(filepath)


            
    def getMeanEmbedding(self, depth, topk = 3):
        
        candidates = self.graph.depthMap[depth]
        
        vals = [self.graph[x].masterDist for x in candidates]
        
        vals = [(x,y) for x,y in zip(candidates,vals)]
        
        vals = sorted(vals, key = lambda x : -x[1])
        
        meanEmb = self.masterEmb
        
        for i in range(min(topk, len(vals)) ):
            meanEmb += self.graph[ vals[i][0] ].vector
            
        meanEmb = meanEmb/(topk+1)
        
        self.masterEmb = meanEmb
        
        print("Master Embedding updated")
        
        
    
    def plotGraph(self):
        
        edgeList, nodeList, nodeValues, nodeCount, nodeText, nodeSize = getGraphComponents(self.graph)

        G=nx.Graph()

        G.add_nodes_from(nodeList)
        G.add_edges_from(edgeList)

        edge_trace, node_trace1, node_trace = getPlotlyComponents(G, nodeList, nodeSize, nodeValues, nodeText)


        fig = go.Figure(data=[edge_trace, node_trace1, node_trace],
             layout=go.Layout(
                title='<br>Network graph made with Python',
                titlefont_size=16,
                showlegend=False,
                hovermode='closest',
                margin=dict(b=20,l=5,r=5,t=50),

                xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
                yaxis=dict(showgrid=False, zeroline=False, showticklabels=False))
                )
        
        fig.show()
        
        
    def saveModel(self,filename):
        
        classDict = self.__dict__.copy()
        classDict.pop('model')
        classDict.pop('tokenizer')
        classDict.pop('df')
        
        pickle.dump( classDict, open( filename, "wb" ) )
        
        
    def loadModel(self, filename):
        
        classDict = pickle.load(open(filename, 'rb'))
        
        for key in list(classDict.keys()):
            self.__dict__[key] = classDict[key]
        
        


### Model Training

### Covid Testing

In [5]:
covidData = '/data1/roshansk/covid_data/'
os.listdir(covidData)

['messages_cm_mar1_apr23_noRT.csv',
 'messages_cm_mar1_apr23.csv',
 'msgs_orig_latest.csv',
 'CovidTrainDf_1.csv',
 'covid19_msgs_2020_02.csv',
 'messages_cm_mar1_apr23_noRT_PET.csv']

In [6]:
df = pd.read_csv(os.path.join(covidData, 'messages_cm_mar1_apr23_noRT.csv'), nrows = 10000)

df = df[['message_id','user_id','message']]

In [11]:
df.head()

Unnamed: 0,message_id,user_id,message
0,1238220897720336390,790013818999074820,loving how i predicted the #coronavirus in my ...
1,1238220898307432448,320608440,The markets is hemorrhaging #trump and his pal...
2,1238220912530522123,305058336,Amar es prevenirn#coronavirus
3,1238220933384593413,137437056,"Chris's glib, histrionic commentary on COVID-1..."
4,1238220933766230017,1012237740,If only COVID-19 started in Madagascar


In [12]:
embList, msgList = getSymptomEmbedding(model,tokenizer, df, 'cough',0)

In [13]:
for i in range(len(msgList)):
    print(f"{i} | {msgList[i]}")
    print("------------")

0 | per the state health department: nnsymptoms of covid-19 can include fever, cough and breathing trouble. most develoâ€¦ https://t.co/u1cnlahiw1
------------
1 | this coronavirus shit is so crazy that i'm waiting for an announcement that if you cough or sneeze in public, you wâ€¦ https://t.co/dt4l43mows
------------
2 | a real story, from the middle of iowa:  person has fever, cough, aches. person calls dr. they facetime. nurse saysâ€¦ https://t.co/sv6pptihks
------------
3 | yâ€™all really made toilet paper a commodity when the symptoms of covid-19 are shortness of breath, fever, and cough according to the cdc ðÿ¥´
------------
4 | someone go grab mitch mcconnell cough in his face and kick his ass into his senate seat.
------------
5 | me when anybody sneezes or coughs around me #covid_19 https://t.co/0jdjvgaprn
------------
6 | this has truly been the worst time to be asian and have a cough ðÿ™ƒnn#coronavirus #covid19
------------
7 | getting on the @cta and scanning for the seat n

In [14]:
indexList = [0,2,3,12,13,25]

In [15]:
meanEmb = np.array(embList)[indexList,:]
meanEmb = np.mean(meanEmb,0)

(768,)

#### Model Training

In [15]:
# from ADRModel import *

graph = Graph()

# graph.addNode('cough',0,0)
# graph['cough'].vector = meanEmb

outputFolder = '/data1/roshansk/ADRModel_DataStore/'
combinedOutputFolder = '/data1/roshansk/ADRModel_DataStore_10000/'
modelFolder = './ModelFolder/Covid_Test/'

q = deque()
q.append(('cough',0))

ADR = ADRModel(df, model, tokenizer, graph, outputFolder, combinedOutputFolder, modelOutputFolder = modelFolder, 
               queue = q,  useMasterEmb=True, masterContrib=0.3, numThreshold=10000)

100%|██████████| 10000/10000 [00:00<00:00, 147483.71it/s]


In [16]:
startTime = time.time()

ADR.trainModel(maxDepth=2, topk=3)

print(time.time() - startTime)

Depth : 0 Exploring cough
          word  counts  mean_sim
1235   coughed       2  0.653850
215      fever      13  0.603154
952   coughing       6  0.599517
-----------------------
Master Embedding updated
Depth : 1 Exploring coughed
          word  counts  mean_sim
339      cough      19  0.646658
1538  coughing       6  0.642000
3268    shouts       1  0.584400
-----------------------
Depth : 1 Exploring fever
             word  counts  mean_sim
114         cough      19  0.649416
1050       asthma       2  0.578950
83    respiratory       7  0.538300
-----------------------
Depth : 1 Exploring coughing
         word  counts  mean_sim
4405  coughed       2  0.715050
931     cough      19  0.662916
6896   shouts       1  0.598000
-----------------------
Master Embedding updated
Depth : 2 Exploring shouts
Reached max depth
Depth : 2 Exploring asthma
Reached max depth
Depth : 2 Exploring respiratory
Reached max depth
107.44656777381897


In [18]:
from ADRModel import *

graph = Graph()

# graph.addNode('cough',0,0)
# graph['cough'].vector = meanEmb

outputFolder = '/data1/roshansk/ADRModel_DataStore/'
combinedOutputFolder = '/data1/roshansk/ADRModel_DataStore_10000/'
modelFolder = './ModelFolder/Covid_Test/'

q = deque()
q.append(('cough',0))

ADR = ADRModel(df, model, tokenizer, graph, outputFolder, modelOutputFolder = modelFolder, 
               queue = q,  useMasterEmb=True, masterContrib=0.3, numThreshold=10000)

100%|██████████| 10000/10000 [00:00<00:00, 157237.26it/s]


In [19]:
startTime = time.time()

ADR.trainModel(maxDepth=2, topk=3)

print(time.time() - startTime)

Depth : 0 Exploring cough
          word  counts  mean_sim
1236   coughed       2  0.653836
215      fever      13  0.603151
953   coughing       6  0.599497
-----------------------
Master Embedding updated
Depth : 1 Exploring coughed
          word  counts  mean_sim
340      cough      19  0.646646
1541  coughing       6  0.642004
3270    shouts       1  0.584375
-----------------------
Depth : 1 Exploring fever
             word  counts  mean_sim
114         cough      19  0.649410
1051       asthma       2  0.578951
83    respiratory       7  0.538298
-----------------------
Depth : 1 Exploring coughing
         word  counts  mean_sim
4406  coughed       2  0.715081
931     cough      19  0.662899
6896   shouts       1  0.598000
-----------------------
Master Embedding updated
Depth : 2 Exploring shouts
Reached max depth
Depth : 2 Exploring asthma
Reached max depth
Depth : 2 Exploring respiratory
Reached max depth
717.622035741806


In [11]:
176/46

3.8260869565217392

In [20]:
717/107

6.700934579439252

### GetSimilar Testing  

In [9]:
outputFolder = '/data1/roshansk/ADRModel_DataStore/'


def getSimilarWords(model, tokenizer, df, symptom, embList, similarityThreshold = 0.3):
    
     
    output = []


    symptomToken = tokenizer.encode(symptom)[1]
#         symptomToken = self.tokenizer.convert_tokens_to_ids(symptom)

    for i in tqdm(range(len(df))):

        tokens = tokenizer.encode(df.iloc[i]['message'].lower())

        if symptomToken in tokens:


            hidden_states = np.array(marshal.load( open(os.path.join(outputFolder, f"{i}.msh"), 'rb') ))

            similarity = cosine_similarity(hidden_states, embList.reshape(1,-1)).reshape(-1)


            index = np.where([similarity> similarityThreshold])[1]

            try:
                selectTokens = np.array(tokens)[index]
            except:
                print(i)
                print(index)
                print(hidden_states.shape)
                print(len(tokens))
                print(len(self.tokenizer.encode(self.df.iloc[i]['message'].lower())))
                print(self.df.iloc[i]['message'])
                break

            similarityValues = similarity[index]
            
            wordValues = np.array(tokenizer.convert_ids_to_tokens(tokens))[index]
            
            rownumValues = [i]*len(index)
            
            output = [(x,y,z) for x,y,z in zip(wordValues,similarityValues,rownumValues)]


#             for j in range(len(index)):
#                 token = tokenizer.ids_to_tokens[selectTokens[j]]
#                 sim = selectSim[j]
#                 output.append((token, sim,i))


        if i==100000:
            break

    return output

In [89]:
import time

startTime = time.time()

out = getSimilarWords(model, tokenizer,df.iloc[0:1000],'',ADR.masterEmb)


print(time.time() - startTime)

100%|██████████| 1000/1000 [00:17<00:00, 55.77it/s]

17.935619354248047





In [103]:
import time

startTime = time.time()

out = getSimilarWords(model, tokenizer,df.iloc[0:1000],'',ADR.masterEmb)


print(time.time() - startTime)

100%|██████████| 1000/1000 [00:18<00:00, 55.17it/s]

18.130207777023315





In [94]:
import time

startTime = time.time()

out = getSimilarWords(model, tokenizer,df.iloc[0:10],'',ADR.masterEmb)


print(time.time() - startTime)

100%|██████████| 10/10 [00:00<00:00, 43.38it/s]

0.23391938209533691





In [100]:
a = np.array([1,2,34])
b = np.array([4,5,6])
c = np.array([5,5,5])

In [105]:
%load_ext line_profiler

In [106]:
def test(x):

    out = 0
    
    for i in range(x):
        a = np.random.random()
        out+=a
        
    return out

In [108]:
%lprun -f getSimilarWords getSimilarWords(model, tokenizer,df.iloc[0:10],'',ADR.masterEmb)

100%|██████████| 10/10 [00:00<00:00, 38.23it/s]


In [10]:
outList = []
for i in range(10):
    hidden_states = np.array(marshal.load( open(os.path.join(outputFolder, f"{i}.msh"), 'rb') ))
    outList.append(hidden_states)
    

In [12]:
outList[0].shape

(29, 768)

In [13]:
outList[1].shape

(47, 768)

### BMIN Dataset

In [26]:
os.listdir('/data1/roshansk')
df = pd.read_csv('/data1/roshansk/Statin_Data.csv')

In [27]:
df = df[['tweet_id','text','Category','Subcategory']]
df.rename({'text':'message'},axis = 1, inplace = True)
df.head()

Unnamed: 0,tweet_id,message,Category,Subcategory
0,9.808937e+17,"Q4: I recently had a ""mild stroke"" and was pre...",u,q
1,8.940068e+17,@WSJ Interesting my muscle specimens got wors...,u,d
2,1.02701e+18,"I have normal LDL, my neurologist tried to pot...",u,f
3,1.034315e+18,"@eitch_kay @SBakerMD No, when I was diagnosed ...",u,n
4,9.719656e+17,@SassyPharmD @DrBabyFace7 I just had a pt go f...,h,d


In [28]:
sub = df[df.Subcategory=='a']

for i in range(30):
    print(sub.iloc[i]['message'])
    print("-------------")

@kendra_bond got free samples of crestor last year from pvt doc i paid for, va sent 4 different meds since, cause muscle pain..crestor=$212
-------------
@Questar1959Ron Crestor was the only statin I could take that didn't cause me to have muscle atrophy. My insurance denied me at first, then my doctor insisted, and the co-pay is 125 a month for 30 pills. It'd be 750 without insurance.
-------------
@medscape the memory loss seen is proportional to the dosage as seen in my patients, less with atorvastatin with ezetemibe @improveit
-------------
@AlexBThomson @eoinmccarthy Stopped one this week, 96yo on 10m atorvastatin, repeat falls from aching legs 🙄🙄
-------------
@evanackermann crestor 40 as good chol lowering of any psk9, cheaper and better cv data.  had patient get rhabdo last week.  good for him
-------------
@drkristieleong @qunol_coq10 my husband lost memory for 14 months as a result of atorvastatin 40mg for 4 weeks
-------------
zOmg, very first time witnessed uncle cramping h

In [29]:
len(sub)

331

In [42]:
df1 = sub.copy()


import random

df2 = df[df.Subcategory!='a']
random.seed(123)
df2 = df2.sample(1500, replace = False,random_state = 123)

In [43]:
finalDf = pd.concat([df1,df2],axis = 0)

#### Extracting Seed

In [44]:
embList, msgList = getSymptomEmbedding(model,tokenizer, finalDf, 'pain',0)

In [45]:
for i in range(len(msgList)):
    print(f"{i} | {msgList[i]}")
    print("------------")

0 | @kendra_bond got free samples of crestor last year from pvt doc i paid for, va sent 4 different meds since, cause muscle pain..crestor=$212
------------
1 | @georgiaedemd @amyv_ntp @tuckergoodrich @ldlskeptic @vernersviews @lanacares @dramerling @draseemmalhotra @proftimnoakes @fructoseno @jedipd @dietdoctor1 @tednaiman @drjamesdinic @aapsonline @robertlustigmd @annchildersmd @loukasmarios @clunesm suspect simvastatin caused heart attack that led to cabg - because my 2yrs+ of stable angina turned into coronary artery spasms within weeks of #statins  atorvastatin 80mg finished me off after cabg, dementia, peripheral neuropathy, type 2 diabetes, chronic thigh muscle pain...
------------
2 | the next week, i meet w/doc to talk side effects. he gives me crestor 10 mg. pain subsides for ~2 days, then comes back. 6/n #landryslife
------------
3 | @theheartorg @drmarthagulati the only molecule i tolerate is crestor and only 5 mg , otherwise  i become like 20 older by muscle' s pain
------

In [38]:
indexList = [0,1,4,9,11,12,13,15]
meanEmb = np.array(embList)[indexList,:]
meanEmb = np.mean(meanEmb,0)

In [46]:
len(finalDf)

1831

#### Model Training

In [39]:
!mkdir /data1/roshansk/BMIN_DataStore

In [67]:
from ADRModel import *

graph = Graph()

graph.addNode('pain',0,0)
graph['pain'].vector = meanEmb

outputFolder = '/data1/roshansk/BMIN_DataStore/'
modelFolder = './ModelFolder/BMIN_Model/'

q = deque()
q.append(('pain',0))

ADR = ADRModel(finalDf, model, tokenizer, graph, outputFolder, modelOutputFolder = modelFolder, 
               queue = q,  useMasterEmb=True, masterContrib=0.3, numThreshold=2000,saveEveryDepth=True)

100%|██████████| 1831/1831 [00:00<00:00, 121269.75it/s]


In [68]:
ADR.trainModel(maxDepth=4, topk=5)

Depth : 0 Exploring pain
           word  counts  mean_sim
133  discomfort       1  0.698737
371     hurting       2  0.643920
158      ##amps      15  0.630992
375        hurt       4  0.621371
239     lesions       1  0.612525
-----------------------
Master Embedding updated
Depth : 1 Exploring discomfort
        word  counts  mean_sim
4       pain      62  0.687144
214  lesions       1  0.651498
513     sore       1  0.634436
142   ##amps      15  0.612099
668   injury       3  0.576561
-----------------------
Depth : 1 Exploring hurting
           word  counts  mean_sim
5          pain      63  0.689491
537        hurt       4  0.676999
171  discomfort       1  0.610711
579       hurts       1  0.601236
326     lesions       1  0.570270
-----------------------
Depth : 1 Exploring ##amps
            word  counts  mean_sim
6           pain      62  0.698183
137   discomfort       1  0.690847
258      lesions       1  0.649865
477      muscles       5  0.612069
1506       ##cts       

In [69]:
ADR.plotGraph()

In [65]:
ADR.graph.describeNode('sore')

Exploring sore
discomfort -> sore       | 1 |  0.623 | [257]
--------------------


In [66]:
finalDf.iloc[257]['message']

'Dear Lipitor, thanks for trying to kill me with one dose. Good lord! Possible side effects lists were kidding. Muscle and joint soreness? Took one dose Thurs night and felt like whole body got instant rheumatoid arthritis! Just now easing off. F this stuff! #lipitor'

In [81]:
ADR.graph.describeNode('clothes')

Exploring clothes
calves     -> clothes    | 1 |   0.58 | [1024]
--------------------


In [82]:
finalDf.iloc[1024]['message']

'@Pseudologichunt @RichLucido @jaketapper That’s not true it’s good. He is on a high cholesterol medication. His clothes drawl levels without that would be dangerously high and so it is blood pressure! He’s on a Statin drug called Rosuvastatin. If he wasn’t on it his Cholesterol levels would be dangerously high! https://t.co/s9NSiKndBK'

In [107]:
ADR.graph.describeNode('swelling')

Exploring swelling
inflammation -> swelling   | 2 |   0.61 | [280, 1589]
--------------------


In [108]:
finalDf.iloc[280]['message']

'@johnraysta thank you. used to get severe allergic reaction (swelling, itchy skin) when using simvastatin before, but so far none w/ crestor'

In [114]:
def evaluateText(text, model, tokenizer, compareEmb):
    tokens = tokenizer.encode(text)
    temp = tokenizer.convert_ids_to_tokens(tokens)

    decoded = tokenizer.decode(tokens).split(" ")
    logits, hidden_states = model(torch.Tensor(tokens).unsqueeze(0).long())
    hidden_states = torch.stack(hidden_states).squeeze(1).permute(1,0,2)

    emb = torch.sum(hidden_states[:,9:13,:],1).detach().cpu().numpy()

    sim = cosine_similarity(emb, compareEmb.reshape(1,-1)).reshape(-1)

    sim = cosine_similarity(emb, compareEmb.reshape(1,-1)).reshape(-1)

    for i in range(len(temp)):
        print(f"{temp[i]:10s} : {str(np.round(sim[i],3))}")

In [115]:
evaluateText(finalDf.iloc[280]['message'], model, tokenizer, ADR.masterEmb)

[CLS]      : 0.073
@          : 0.212
john       : 0.274
##ray      : 0.261
##sta      : 0.241
thank      : 0.242
you        : 0.191
.          : 0.117
used       : 0.304
to         : 0.271
get        : 0.357
severe     : 0.387
allergic   : 0.47
reaction   : 0.469
(          : 0.336
swelling   : 0.631
,          : 0.358
it         : 0.31
##chy      : 0.482
skin       : 0.506
)          : 0.27
when       : 0.33
using      : 0.272
sim        : 0.228
##vas      : 0.229
##tat      : 0.301
##in       : 0.318
before     : 0.264
,          : 0.265
but        : 0.276
so         : 0.156
far        : 0.116
none       : 0.269
w          : 0.289
/          : 0.268
crest      : 0.269
##or       : 0.228
[SEP]      : 0.073
