In [13]:
import pandas as pd
import numpy as np

import os

import nltk
from nltk.tokenize import TweetTokenizer
from nltk.corpus import stopwords 

from transformers import BertForSequenceClassification, BertTokenizer, BertForMaskedLM

from simpletransformers.language_modeling import LanguageModelingModel

from sklearn.metrics.pairwise import cosine_similarity, paired_euclidean_distances
from sklearn.metrics.pairwise import euclidean_distances


from tqdm import tqdm
import torch

import os
import pickle

import time

import torch
from torch.nn import CosineSimilarity

import itertools

from functools import partial

stop_words = set(stopwords.words('english')) 

In [2]:
dataFolder = '/data1/roshansk/covid_data/'
fileList = os.listdir(dataFolder)

df = pd.read_csv(os.path.join(dataFolder, fileList[0]), nrows = 500000)

In [3]:
model = BertForSequenceClassification.from_pretrained('/data1/roshansk/Exp1/checkpoint-141753-epoch-1', output_hidden_states= True)

In [4]:
tokenizer = BertTokenizer.from_pretrained('/data1/roshansk/Exp1/checkpoint-141753-epoch-1')

In [None]:
import nu

### V1

In [6]:
def getSimilarWords(model, df, symptom, embList, similarityThreshold = 0.3, numThreshold = 10000):
    
     
    output = []
    
    for i in tqdm(range(numThreshold)):
        
        if symptom in df.iloc[i]['message'].lower():
                 
            tokens = tokenizer.encode(df.iloc[i]['message'].lower())
            decoded = tokenizer.decode(tokens).split(" ")
            logits, hidden_states = model(torch.Tensor(tokens).unsqueeze(0).long())

            hidden_states = torch.stack(hidden_states).squeeze(1).permute(1,0,2)
            
            
            hidden_states = hidden_states[:,9:13,:]
            hidden_states = torch.sum(hidden_states,1).detach().cpu().numpy()
            
            similarity = cosine_similarity(hidden_states, embList.reshape(1,-1)).reshape(-1)

                            
            index = np.where([similarity> similarityThreshold])[1]

            selectTokens = np.array(tokens)[index]
            selectSim = similarity[index]
                      


            for j in range(len(index)):
                token = tokenizer.ids_to_tokens[selectTokens[j]]
                sim = selectSim[j]
                output.append((token, sim,i))

            
        if i==numThreshold:
            break
            
    return output



In [9]:
file = 'fatigue_16342_Emb.npy'
symptom = ''

embList = np.load(os.path.join('EmbFolder/',file))
embList = np.mean(embList,0)

startTime = time.time()

out1 = getSimilarWords(model, df, symptom, embList, similarityThreshold = 0.3, numThreshold = 10000)


print(f"Time taken : {time.time() - startTime}")

  1%|          | 10000/1500000 [25:11<62:33:01,  6.62it/s]

Time taken : 1511.28932762146





### V2

In [28]:
def getSimilarWords(model, tokenizer, combinedOutputFolder, symptom, meanEmb, similarityThreshold = 0.3, numThreshold = 150000, numComp = 10000):
        
    output = []

    symptomToken = tokenizer.encode(symptom)[1]

    fileList = os.listdir(combinedOutputFolder)
    
    cos = CosineSimilarity(dim=1, eps=1e-6)

    examineCount = 0

    for i in tqdm(range(len(fileList))):

        if examineCount >= numThreshold:
            break


        filename = os.path.join(combinedOutputFolder, f"{i}.pkl")
        subDict = pickle.load(open(filename,'rb'))

        IDList = subDict['id']
        tokenList = subDict['token']
        embList = subDict['emb']

#         sim = np.round(cosine_similarity(embList, meanEmb.reshape(1,-1)).reshape(-1),4)
        
#         arrA = torch.from_numpy(meanEmb.reshape(1,-1))
#         arrB = torch.from_numpy(embList)
        
        arrA = torch.from_numpy(meanEmb.reshape(1,-1)).cuda()
        arrB = torch.from_numpy(embList).cuda()
        
        sim = cos(arrA,arrB).cpu().numpy().reshape(-1)
        
        sim = np.round(sim,4)

        index= np.where([sim> similarityThreshold])[1]

        tokenList_ = tokenList[index]
        IDList_ = IDList[index]
        simList = sim[index]

        out = [(x,y,z) for x,y,z in zip(tokenList_, simList, IDList_)]

        output += out

        examineCount += numComp
        
    del arrA
    del arrB

    return output

In [30]:
combinedOutputFolder = '/data2/roshansk/ADRModel_DataStore_10000/'
numComp = 10000
numThreshold = 160000


file = 'fatigue_16342_Emb.npy'
embList = np.load(os.path.join('EmbFolder/',file))
meanEmb = np.mean(embList,0)
symptom = ''


startTime = time.time()

output1 = getSimilarWords(model, tokenizer, combinedOutputFolder, symptom, 
                meanEmb, similarityThreshold = 0.3, numThreshold = numThreshold, numComp = numComp)

print(f"Time taken : {time.time() - startTime}")

 40%|████      | 16/40 [03:38<05:27, 13.66s/it]

Time taken : 218.6913194656372





In [49]:
((48.45/160000)*1500000)/60

7.570312500000001

In [16]:
# combinedOutputFolder = '/data2/roshansk/ADRModel_DataStore_10000/'
# numComp = 10000
# numThreshold = 40000


# file = 'fatigue_16342_Emb.npy'
# embList = np.load(os.path.join('EmbFolder/',file))
# meanEmb = np.mean(embList,0)
# symptom = ''


numThreshold = 60000


startTime = time.time()

output1 = getSimilarWords(model, tokenizer, combinedOutputFolder, symptom, 
                meanEmb, similarityThreshold = 0.3, numThreshold = numThreshold, numComp = numComp)

print(f"Time taken : {time.time() - startTime}")

 15%|█▌        | 6/40 [01:30<08:34, 15.13s/it]

Time taken : 90.91519069671631





In [22]:
90/60000

0.0015

In [23]:
(100000*0.0015)/60

2.5

In [19]:
90/6

15.0

In [7]:
def convertToDf(data):
    
    df_ = pd.DataFrame(data)
    df_.columns = ['token','sim','numCount']

    meanVal = df_.groupby('token')['sim'].mean().reset_index()
    countVal = df_.groupby('token')['sim'].count().reset_index()
    countVal.columns = ['token','numCount']

    df_.sort_values('sim',inplace=True,ascending=False)

    outDf = meanVal.merge(countVal, on='token')
    outDf.sort_values('sim',inplace=True,ascending=False)
    
    return outDf

In [17]:
outDf = convertToDf(out1)


In [18]:
outDf.head(10)

Unnamed: 0,token,sim,numCount
442,fatigue,0.854485,2
515,headache,0.536889,3
453,fever,0.534045,13
168,anger,0.531018,1
183,asthma,0.528599,2
364,discouraged,0.522323,1
171,anxiety,0.52123,19
658,misery,0.518782,1
697,obesity,0.510591,1
725,paranoia,0.495413,2


In [40]:
outDf = convertToDf(output)

outDf.head(10)

Unnamed: 0,token,sim,numCount
441,fatigue,0.8545,2
514,headache,0.5369,3
452,fever,0.534046,13
168,anger,0.531,1
183,asthma,0.5286,2
363,discouraged,0.5223,1
171,anxiety,0.521232,19
657,misery,0.5188,1
696,obesity,0.5106,1
724,paranoia,0.4954,2


In [9]:
import numba

In [13]:
%load_ext line_profiler

The line_profiler extension is already loaded. To reload it, use:
  %reload_ext line_profiler


In [45]:
%lprun -f getSimilarWords getSimilarWords(model, tokenizer, combinedOutputFolder, symptom, meanEmb, similarityThreshold = 0.3, numThreshold = numThreshold, numComp = numComp)

In [15]:
fileList = os.listdir(combinedOutputFolder)

filename = os.path.join(combinedOutputFolder, f"{0}.pkl")
subDict = pickle.load(open(filename,'rb'))

IDList = subDict['id']
tokenList = subDict['token']
embList = subDict['emb']



In [None]:

        
sim = cosine_similarity(embList, meanEmb).reshape(-1)

In [27]:
arrA = torch.from_numpy(meanEmb.reshape(1,-1)).cuda()

In [28]:
arrB = torch.from_numpy(embList).cuda()

(353311,)

In [26]:
help(CosineSimilarity)

Help on class CosineSimilarity in module torch.nn.modules.distance:

class CosineSimilarity(torch.nn.modules.module.Module)
 |  Returns cosine similarity between :math:`x_1` and :math:`x_2`, computed along dim.
 |  
 |  .. math ::
 |      \text{similarity} = \dfrac{x_1 \cdot x_2}{\max(\Vert x_1 \Vert _2 \cdot \Vert x_2 \Vert _2, \epsilon)}.
 |  
 |  Args:
 |      dim (int, optional): Dimension where cosine similarity is computed. Default: 1
 |      eps (float, optional): Small value to avoid division by zero.
 |          Default: 1e-8
 |  Shape:
 |      - Input1: :math:`(\ast_1, D, \ast_2)` where D is at position `dim`
 |      - Input2: :math:`(\ast_1, D, \ast_2)`, same shape as the Input1
 |      - Output: :math:`(\ast_1, \ast_2)`
 |  Examples::
 |      >>> input1 = torch.randn(100, 128)
 |      >>> input2 = torch.randn(100, 128)
 |      >>> cos = nn.CosineSimilarity(dim=1, eps=1e-6)
 |      >>> output = cos(input1, input2)
 |  
 |  Method resolution order:
 |      CosineSimilarity
 |

### V3

In [6]:
from numba import jit, njit, prange

In [7]:
@njit(parallel=True)
def getSimilarWords(model, tokenizer, combinedOutputFolder, symptom, meanEmb, similarityThreshold = 0.3, numThreshold = 150000, numComp = 10000):
        
    output = []

    symptomToken = tokenizer.encode(symptom)[1]

    fileList = os.listdir(combinedOutputFolder)
    
    cos = CosineSimilarity(dim=1, eps=1e-6)

    examineCount = 0

    for i in prange(len(fileList)):

        if examineCount >= numThreshold:
            break


        filename = os.path.join(combinedOutputFolder, f"{i}.pkl")
        subDict = pickle.load(open(filename,'rb'))

        IDList = subDict['id']
        tokenList = subDict['token']
        embList = subDict['emb']

#         sim = np.round(cosine_similarity(embList, meanEmb.reshape(1,-1)).reshape(-1),4)

        arrA = torch.from_numpy(meanEmb.reshape(1,-1)).cuda()
        arrB = torch.from_numpy(embList).cuda()
        
        sim = cos(arrA,arrB).cpu().numpy().reshape(-1)
        
        sim = np.round(sim,4)

        index= np.where([sim> similarityThreshold])[1]

        tokenList_ = tokenList[index]
        IDList_ = IDList[index]
        simList = sim[index]

        out = [(x,y,z) for x,y,z in zip(tokenList_, simList, IDList_)]

        output += out

        examineCount += numComp

    return output

In [None]:
numComp = 10000
numThreshold = 20000

startTime= time.time()
out2 = getSimilarWords(model, tokenizer, combinedOutputFolder, symptom, meanEmb, similarityThreshold = 0.3, numThreshold = numThreshold, numComp = numComp)

print(f"Time taken : {time.time() - startTime}")

### V4

In [5]:
from dask.distributed import Client

In [75]:
client = Client(n_workers=6, threads_per_worker=2, processes=False)

In [76]:
client.dashboard_link

'http://128.91.252.35:8787/status'

In [77]:
def getSimilarWords(model, tokenizer, combinedOutputFolder, symptom, meanEmb, similarityThreshold = 0.3, numThreshold = 150000, numComp = 10000):
        
    output = []

    symptomToken = tokenizer.encode(symptom)[1]

    fileList = os.listdir(combinedOutputFolder)
    
    cos = CosineSimilarity(dim=1, eps=1e-6)
    
    computeTaskPart_ = partial(computeTask_, symptom = '',combinedOutputFolder=combinedOutputFolder, 
                               meanEmb=meanEmb,similarityThreshold = similarityThreshold)

    examineCount = 0
    
    totalSteps = int(numThreshold/numComp)
    
    mult = 10
    
    totalIters = totalSteps//mult if totalSteps%mult==0 else totalSteps//mult + 1
    
    
    finalOutput = []
    
    for i in tqdm(range(totalIters)):
        
        start = i*mult
        end = min((i+1)*mult, totalSteps)
    
        tasks = list(range(start,end))
        
        sent = client.map(computeTaskPart_, tasks)

        result = client.gather(sent)

        print("Tasks complete")

        result = list(itertools.chain.from_iterable(result))
        
        finalOutput.append(result)
    
    
    finalOutput = list(itertools.chain.from_iterable(finalOutput))
    
    return finalOutput




def computeTask_(index, symptom, combinedOutputFolder,meanEmb, similarityThreshold):

    symptomToken = tokenizer.encode(symptom)[1]

    cos = CosineSimilarity(dim=1, eps=1e-6)

    filename = os.path.join(combinedOutputFolder, f"{index+6}.pkl")
    subDict = pickle.load(open(filename,'rb'))

    IDList = subDict['id']
    tokenList = subDict['token']
    embList = subDict['emb']

#         sim = np.round(cosine_similarity(embList, meanEmb.reshape(1,-1)).reshape(-1),4)

    arrA = torch.from_numpy(meanEmb.reshape(1,-1))
    arrB = torch.from_numpy(embList)

#         arrA = torch.from_numpy(meanEmb.reshape(1,-1)).cuda()
#         arrB = torch.from_numpy(embList).cuda()

    sim = cos(arrA,arrB).cpu().numpy().reshape(-1)

    sim = np.round(sim,4)

    index= np.where([sim> similarityThreshold])[1]

    tokenList_ = tokenList[index]
    IDList_ = IDList[index]
    simList = sim[index]

    out = [(x,y,z) for x,y,z in zip(tokenList_, simList, IDList_)]

    return out

In [78]:
combinedOutputFolder = '/data2/roshansk/ADRModel_DataStore_10000/'
numComp = 10000
numThreshold = 250000


file = 'fatigue_16342_Emb.npy'
embList = np.load(os.path.join('EmbFolder/',file))
meanEmb = np.mean(embList,0)
symptom = ''


In [79]:
startTime = time.time()

result = getSimilarWords(model, tokenizer, combinedOutputFolder, symptom, 
                meanEmb, similarityThreshold = 0.3, numThreshold = numThreshold, numComp = numComp)

print(f"Time taken : {time.time() - startTime}")



 33%|███▎      | 1/3 [00:29<00:59, 29.57s/it]

Tasks complete


 67%|██████▋   | 2/3 [02:32<00:57, 57.51s/it]

Tasks complete


100%|██████████| 3/3 [04:20<00:00, 86.98s/it]

Tasks complete
Time taken : 260.9436800479889





In [72]:
client.shutdown()




distributed.client - ERROR - Failed to reconnect to scheduler after 10.00 seconds, closing client
_GatheringFuture exception was never retrieved
future: <_GatheringFuture finished exception=CancelledError()>
concurrent.futures._base.CancelledError


In [80]:
260/60

4.333333333333333

In [81]:
4.33*6

25.98

In [83]:
client.shutdown()

distributed.client - ERROR - Failed to reconnect to scheduler after 10.00 seconds, closing client
_GatheringFuture exception was never retrieved
future: <_GatheringFuture finished exception=CancelledError()>
concurrent.futures._base.CancelledError


### New ADR Model  v4

In [82]:

def computeTask_(index, symptom, combinedOutputFolder,meanEmb, similarityThreshold):

    symptomToken = tokenizer.encode(symptom)[1]

    cos = CosineSimilarity(dim=1, eps=1e-6)

    filename = os.path.join(combinedOutputFolder, f"{index+6}.pkl")
    subDict = pickle.load(open(filename,'rb'))

    IDList = subDict['id']
    tokenList = subDict['token']
    embList = subDict['emb']

#         sim = np.round(cosine_similarity(embList, meanEmb.reshape(1,-1)).reshape(-1),4)

    arrA = torch.from_numpy(meanEmb.reshape(1,-1))
    arrB = torch.from_numpy(embList)

#         arrA = torch.from_numpy(meanEmb.reshape(1,-1)).cuda()
#         arrB = torch.from_numpy(embList).cuda()

    sim = cos(arrA,arrB).cpu().numpy().reshape(-1)

    sim = np.round(sim,4)

    index= np.where([sim> similarityThreshold])[1]

    tokenList_ = tokenList[index]
    IDList_ = IDList[index]
    simList = sim[index]

    out = [(x,y,z) for x,y,z in zip(tokenList_, simList, IDList_)]

    return out


class ADRModel(object):
    
    def __init__(self, df, model, tokenizer, graph, outputFolder, combinedOutputFolder, modelOutputFolder = './', queue=None, useMasterEmb = False, 
                 masterContrib = 0.5, embeddingType='last4sum',
                 numThreshold= 10000, saveEveryDepth = False,
                numComp = 10000):
        
        self.df = df
        self.model = model
        self.tokenizer = tokenizer
        self.graph = graph
        self.outputFolder = outputFolder
        self.combinedOutputFolder = combinedOutputFolder
        self.embeddingType = embeddingType
        self.numThreshold = numThreshold
        self.saveEveryDepth = saveEveryDepth
        self.modelOutputFolder = modelOutputFolder
        self.numComp = numComp
        
        if queue is None:
            self.q = deque()
        else:
            self.q = queue
            
            
        self.masterEmb = None
        
        self.useMasterEmb = useMasterEmb
        self.masterContrib = masterContrib
        
        self.masterEmbList = []
        
        self.client = Client(n_workers=6, threads_per_worker=2, processes=False)
        
        self.generateStates()
        
        
    def generateStates(self):
        
        
        for i in tqdm(range(len(self.df))):
            
            if os.path.exists(os.path.join(self.outputFolder, f"{i}.msh")):
                continue


            tokens = self.tokenizer.encode(self.df.iloc[i]['message'].lower())
            decoded = self.tokenizer.decode(tokens).split(" ")
            logits, hidden_states = self.model(torch.Tensor(tokens).unsqueeze(0).long())

            hidden_states = torch.stack(hidden_states).squeeze(1).permute(1,0,2)

            
            if self.embeddingType == 'last4sum':
                embedding = torch.sum(hidden_states[:,9:13,:],1)
            elif self.embeddingType =='last4concat':
                embedding = hidden_states[tokenIndex,9:13,:].reshape(-1)
            elif self.embeddingType == 'secondlast':
                embedding = hidden_states[tokenIndex,-2,:]
            else:
                embedding = hidden_states[tokenIndex,-1,:]
                    
                    
            embedding = embedding.detach().cpu().numpy()
            
            marshal.dump(embedding.tolist(), open(os.path.join(self.outputFolder, f"{i}.msh"), 'wb'))
        
        
        
        
    def getSymptomEmbedding(self, symptom, subset = None):
    
        embeddingList = []
        messageList = []

#         if subset is not None:
#             self.df = self.df.iloc[subset]

#         if type(df) == pd.Series:
#             self.df = pd.DataFrame(self.df).T

#         symptomToken = self.tokenizer.encode(symptom)[1]
        symptomToken = self.tokenizer.convert_tokens_to_ids(symptom)

        for i in range(len(self.df)):

            if symptomToken in self.tokenizer.encode(self.df.iloc[i]['message'].lower()):

                tokens = self.tokenizer.encode(self.df.iloc[i]['message'].lower())
                decoded = self.tokenizer.decode(tokens).split(" ")

                hidden_states = np.array(marshal.load( open(os.path.join(self.outputFolder, f"{i}.msh"), 'rb') ))

                try:
                    tokenIndex = tokens.index(symptomToken)
                except:
                    a= 1
                    continue

 
                embedding = hidden_states[tokenIndex,:]

                embeddingList.append(embedding)
                messageList.append(self.df.iloc[i]['message'].lower())

                if len(embeddingList)==30:
                    break



        return embeddingList, messageList
    
    def getSimilarWords(self, symptom, meanEmb, similarityThreshold = 0.3):
        
        output = []

        symptomToken = self.tokenizer.encode(symptom)[1]

        fileList = os.listdir(self.combinedOutputFolder)

        cos = CosineSimilarity(dim=1, eps=1e-6)

        computeTaskPart_ = partial(computeTask_, symptom = symptom,combinedOutputFolder=self.combinedOutputFolder, 
                                   meanEmb=meanEmb,similarityThreshold = similarityThreshold)

        examineCount = 0

        totalSteps = int(self.numThreshold/self.numComp)

        mult = 10

        totalIters = totalSteps//mult if totalSteps%mult==0 else totalSteps//mult + 1


        finalOutput = []

        for i in tqdm(range(totalIters)):

            start = i*mult
            end = min((i+1)*mult, totalSteps)

            tasks = list(range(start,end))

            sent = self.client.map(computeTaskPart_, tasks)

            result = self.client.gather(sent)

            print("Tasks complete")

            result = list(itertools.chain.from_iterable(result))

            finalOutput.append(result)


        finalOutput = list(itertools.chain.from_iterable(finalOutput))

        return finalOutput

        
    
    
    def getOutput(self, out):
    
        output = out

        outMap = {}

        for i in range(len(output)):
            if output[i][0] in outMap:
                outMap[output[i][0]].append(output[i][1])
            else:
                outMap[output[i][0]] = [output[i][1]]


        outMap_ = {}

        for i in range(len(output)):
            if output[i][0] in outMap_:
                outMap_[output[i][0]].append(output[i][2])
            else:
                outMap_[output[i][0]] = [output[i][2]]


        outputDf = []

        for key in outMap.keys():
            length = len(outMap[key])
            mean = np.mean(outMap[key])

            outputDf.append([key, length, mean])

        outputDf = pd.DataFrame(outputDf)
        outputDf.columns = ['word','counts','mean_sim']
        outputDf = outputDf.sort_values('mean_sim', ascending=False)

        return outputDf, outMap, outMap_
    
    
    
    
    def exploreNode(self, word, depth, maxDepth = 3, topk = 5):

    
        self.graph.addNode(word,0,depth)

        print(f"Depth : {depth} Exploring {word}")

        if depth == maxDepth:
            print("Reached max depth")
            return

        keyWord = word

        token = self.tokenizer.encode(keyWord)[1]

        if self.graph[word].vector is None:

            inEdgeList = self.graph[word].edges_in

            if len(inEdgeList)==0:
                textIDList = None
            else:
                textIDList = []

                for edge in inEdgeList:
                    textIDList.append(self.graph.edgeList[edge].textID)

                textIDList = list(set(list(itertools.chain.from_iterable(textIDList))))

            
            embList,msgList = self.getSymptomEmbedding(keyWord, subset = textIDList)

            meanEmb = np.array(embList)
            meanEmb = np.mean(meanEmb,0)


            self.graph[word].vector = meanEmb
            
            if self.masterEmb is None:
                self.masterEmb = meanEmb
            
            dist = getCosineDist(meanEmb, self.masterEmb)
            
            self.graph[word].masterDist = dist

        else:
            meanEmb = self.graph[word].vector
            
            if self.masterEmb is None:
                self.masterEmb = meanEmb
                
            dist = getCosineDist(meanEmb, self.masterEmb)
            
            self.graph[word].masterDist = dist


        symptom_ =''
        embList_ = meanEmb

        if self.useMasterEmb:
            
            finalEmb = self.masterContrib*self.masterEmb + (1 - self.masterContrib)*meanEmb
            
            out = self.getSimilarWords( symptom_, finalEmb , similarityThreshold = 0.3)
        else:
            out = self.getSimilarWords( symptom_, meanEmb, similarityThreshold = 0.3)

        outputDf, outMap, outMap_ = self.getOutput(out)

        outputDf = outputDf[outputDf.word!=keyWord]
    #     outputDf = outputDf[~outputDf.word.isin(list(graph.wordMap.keys()))]
        outputDf = outputDf.sort_values('mean_sim', ascending=False)
        outputDf = outputDf.head(topk)

        outputDf = outputDf[outputDf.mean_sim>0.4]

        print(outputDf)
        print("-----------------------")

        for i in range(len(outputDf)):

            word = outputDf.iloc[i]['word']
            numCount = outputDf.iloc[i]['counts']
            weight = outputDf.iloc[i]['mean_sim']
            textIDs = outMap_[word]

            wordList = set(self.graph.wordMap.keys())

            self.graph.addNode(word,0,depth+1)
            self.graph[word].textIDList.append(textIDs)
            self.graph.addEdge(keyWord, word, numCount, weight, textIDs)

            if word in wordList:
                continue

#             if "#" in word:
#                 continue


            self.q.append((word, depth+1))
            
            
    def trainModel(self, maxDepth = 3, topk = 5):
        
        currDepth = 0
        
        while len(self.q)>0:
            token, depth = self.q.popleft()
            
            if depth> currDepth:
                
                if self.saveEveryDepth:
                    filepath = os.path.join( self.modelOutputFolder, f"depth_{currDepth}.pkl")
                    self.saveModel(filepath)
                
                self.masterEmbList.append(self.masterEmb.copy())
                self.getMeanEmbedding(depth-1)
                currDepth += 1
            
            self.exploreNode(word = token, depth = depth, maxDepth=maxDepth, topk=topk)
        
        #Saving final model
        filepath = os.path.join(self.modelOutputFolder, "final.pkl")
        self.saveModel(filepath)


            
    def getMeanEmbedding(self, depth, topk = 3):
        
        candidates = self.graph.depthMap[depth]
        
        vals = [self.graph[x].masterDist for x in candidates]
        
        vals = [(x,y) for x,y in zip(candidates,vals)]
        
        vals = sorted(vals, key = lambda x : -x[1])
        
        meanEmb = self.masterEmb
        
        selectedWords = []
        for i in range(min(topk, len(vals)) ):
            meanEmb += self.graph[ vals[i][0] ].vector
            selectedWords.append(vals[i][0])
            
        meanEmb = meanEmb/(topk+1)
        
        self.masterEmb = meanEmb
        
        for i in range(len(selectedWords)):
            print(selectedWords[i])
        print("Master Embedding updated.")
        print("-----------------")
        
        
    
    def plotGraph(self):
        
        edgeList, nodeList, nodeValues, nodeCount, nodeText, nodeSize = getGraphComponents(self.graph)

        G=nx.Graph()

        G.add_nodes_from(nodeList)
        G.add_edges_from(edgeList)

        edge_trace, node_trace1, node_trace = getPlotlyComponents(G, nodeList, nodeSize, nodeValues, nodeText)


        fig = go.Figure(data=[edge_trace, node_trace1, node_trace],
             layout=go.Layout(
                title='<br>Network graph made with Python',
                titlefont_size=16,
                showlegend=False,
                hovermode='closest',
                margin=dict(b=20,l=5,r=5,t=50),

                xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
                yaxis=dict(showgrid=False, zeroline=False, showticklabels=False))
                )
        
        fig.show()
        
        
    def saveModel(self,filename):
        
        classDict = self.__dict__.copy()
        classDict.pop('model')
        classDict.pop('tokenizer')
        classDict.pop('df')
        
        pickle.dump( classDict, open( filename, "wb" ) )
        
        
    def loadModel(self, filename):
        
        classDict = pickle.load(open(filename, 'rb'))
        
        for key in list(classDict.keys()):
            self.__dict__[key] = classDict[key]
        
        
