In [1]:
import pandas as pd
import numpy as np

import os

import nltk
from nltk.tokenize import TweetTokenizer
from nltk.corpus import stopwords 

from transformers import BertForSequenceClassification, BertTokenizer, BertForMaskedLM

from simpletransformers.language_modeling import LanguageModelingModel

from sklearn.metrics.pairwise import cosine_similarity, paired_euclidean_distances
from sklearn.metrics.pairwise import euclidean_distances
from sklearn.preprocessing import normalize, StandardScaler, MinMaxScaler

from tqdm import tqdm
import torch

import networkx as nx

import matplotlib.pyplot as plt
%matplotlib inline

import plotly.graph_objects as go
from functools import partial

import pickle

from collections import deque

stop_words = set(stopwords.words('english')) 


%load_ext autoreload

%autoreload 2

# from utils import *
# from plotting import *

import time

import feather

import _pickle as cPickle

import json
import codecs
import marshal

In [2]:
inputFolder = '/data1/roshansk/ADRModel_DataStore/'

In [6]:
covidData = '/data1/roshansk/covid_data/'
os.listdir(covidData)

df = pd.read_csv(os.path.join(covidData, 'messages_cm_mar1_apr23_noRT.csv'), nrows = 300000)

df = df[['message_id','user_id','message']]

In [4]:
tokenizer = BertTokenizer.from_pretrained('/data1/roshansk/Exp1/checkpoint-141753-epoch-1')

model = BertForSequenceClassification.from_pretrained('/data1/roshansk/Exp1/checkpoint-141753-epoch-1', output_hidden_states= True)

### Generating individual files

In [8]:
outputFolder = '/data1/roshansk/ADRModel_DataStore/'
embeddingType = 'last4sum'

for i in tqdm(range(150000,150100)):
            
    if os.path.exists(os.path.join(outputFolder, f"{i}.msh")):
        continue


    tokens = tokenizer.encode(df.iloc[i]['message'].lower())
    decoded = tokenizer.decode(tokens).split(" ")
    logits, hidden_states = model(torch.Tensor(tokens).unsqueeze(0).long())

    hidden_states = torch.stack(hidden_states).squeeze(1).permute(1,0,2)


    if embeddingType == 'last4sum':
        embedding = torch.sum(hidden_states[:,9:13,:],1)
    elif embeddingType =='last4concat':
        embedding = hidden_states[tokenIndex,9:13,:].reshape(-1)
    elif embeddingType == 'secondlast':
        embedding = hidden_states[tokenIndex,-2,:]
    else:
        embedding = hidden_states[tokenIndex,-1,:]


    embedding = embedding.detach().cpu().numpy()

    marshal.dump(embedding.tolist(), open(os.path.join(outputFolder, f"{i}.msh"), 'wb'))

100%|██████████| 100/100 [00:15<00:00,  6.33it/s]


### Generating Agg files

In [17]:
def aggFiles(index, numComp, df, tokenizer, inputFolder, outputFolder):

    IDList = []
    tokenList = []
    embList = []

    for i in tqdm(range(index*numComp, (index+1)*numComp)):
        text = df.iloc[i]['message']

        tokens = tokenizer.convert_ids_to_tokens(tokenizer.encode(text))


        emb = np.array(marshal.load(open(os.path.join(inputFolder, f"{i}.msh"),'rb' )))

        IDList += [i]*len(tokens)
        tokenList += tokens

        embList.append(emb)

    IDList = np.array(IDList)
    tokenList = np.array(tokenList)
    embList = np.concatenate(embList,axis=0)

    subDict = {'id':IDList, 'token':tokenList,'emb':embList}


    
    filename = os.path.join(outputFolder, f"{index}.pkl")
    pickle.dump(subDict, open(filename,'wb'))

In [None]:
aggFiles(1, numComp, df, tokenizer, inputFolder, outputFolder)

In [14]:
os.listdir('/data1/roshansk/ADRModel_DataStore_10000')

['89999.pkl',
 '119999.pkl',
 '9999.pkl',
 '59999.pkl',
 '149999.pkl',
 '99999.pkl',
 '139999.pkl',
 '39999.pkl',
 '49999.pkl',
 '129999.pkl',
 '29999.pkl',
 '19999.pkl',
 '109999.pkl',
 '79999.pkl',
 '69999.pkl']

In [18]:
numComp = 10000
outputFolder = '/data1/roshansk/ADRModel_DataStore_10000'
inputFolder = '/data1/roshansk/ADRModel_DataStore/'

for i in range(0,15):
    
    aggFiles(i, numComp, df, tokenizer, inputFolder, outputFolder)
    

100%|██████████| 10000/10000 [02:48<00:00, 59.45it/s]
100%|██████████| 10000/10000 [02:53<00:00, 57.60it/s]
100%|██████████| 10000/10000 [03:13<00:00, 51.69it/s]
100%|██████████| 10000/10000 [02:57<00:00, 56.49it/s]
100%|██████████| 10000/10000 [03:16<00:00, 50.90it/s]
100%|██████████| 10000/10000 [03:24<00:00, 48.82it/s]
100%|██████████| 10000/10000 [03:17<00:00, 50.68it/s]
100%|██████████| 10000/10000 [03:14<00:00, 51.38it/s]
100%|██████████| 10000/10000 [03:13<00:00, 51.61it/s]
100%|██████████| 10000/10000 [03:20<00:00, 49.91it/s] 
100%|██████████| 10000/10000 [03:07<00:00, 53.37it/s] 
100%|██████████| 10000/10000 [02:53<00:00, 57.70it/s]
100%|██████████| 10000/10000 [02:43<00:00, 60.99it/s]
100%|██████████| 10000/10000 [02:46<00:00, 60.03it/s]
100%|██████████| 10000/10000 [04:30<00:00, 36.97it/s]
