In [None]:
import os
import pandas as pd
import numpy as np
import random
import tqdm
import matplotlib.pyplot as plt
from matplotlib.ticker import MultipleLocator
import matplotlib.dates as mdates
import seaborn as sns
import math
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch
from collections import defaultdict
from sklearn.metrics import mean_absolute_error, mean_squared_error
from abc import ABC, abstractmethod
import re
from transformers import AutoModel, AutoTokenizer
import pickle
import time
import itertools
from collections import Counter

In [None]:
import warnings
from pandas.core.common import SettingWithCopyWarning

warnings.filterwarnings("ignore", category=RuntimeWarning) 
warnings.simplefilter(action="ignore", category=SettingWithCopyWarning)

In [None]:
os.environ["TOKENIZERS_PARALLELISM"] = "false"

# Feature Engineering and Train Test Split:

* This notebook contains code that was utilized for preprocessing user timelines to extract features for our forecasting models
* It also has code related to how the train,val,test split was conducted

## Extractors

In [None]:
class FeatureExtractor(ABC):
    
    @abstractmethod
    def extract(self,bin_ts,**kwargs):
        """
        Extracts Features for each user sequence
        """
        pass

class EngagementCountExtractor(FeatureExtractor):
    
    def __init__(self):
        """
        """
        self.ps_index_map = {k:i for i,k in enumerate(range(-3,4,1))}
        
    def extract(self,bin_ts):
        """
        Extract Count Vector of engagements for a given time step bin
        
        * Default : [1,1,1,1,1,1,1] (represents no engagement)
        * Shifted by 1 for evaulation metrics to work properly (smape and nmse)
        """
        df = bin_ts["eng_dataframe"]
        ps_vec = [1.0]*7
        if len(df)>0:
            mps = df["matched_partisans"].tolist()
            mps = flatten(mps)
            for p in mps:
                ps_vec[self.ps_index_map[p]]+=1
        
        return ps_vec

class EngagementTypeCountExtractor(FeatureExtractor):
    
    def __init__(self):
        """
        """
    
    def get_engagement_type(self,urow):
        """
        """
        if len(urow.matched_urls)>0 and len(urow.matched_mentions)<=0:
            return 0
        elif len(urow.matched_mentions)>0 and len(urow.matched_urls)<=0:
            return 1
        elif len(urow.matched_urls)>0 and len(urow.matched_mentions)>0:
            return 2
    
    def extract(self,bin_ts):
        """
        """
        df = bin_ts["eng_dataframe"]
        engtype_vec = [0.0]*3
        if len(df)>0:
            df["eng_type"] = df.apply(lambda x: self.get_engagement_type(x),axis=1)
            eng_types = df["eng_type"].tolist()
            for e in eng_types:
                engtype_vec[e]+=1
        
        return engtype_vec
        
        
    
class EngagementPublicMetricsExtractor(FeatureExtractor):
    
    def __init__(self):
        """
        """
        pass
    
    def extract(self,bin_ts):
        """
        public metrics types:
        * retweet count
        * reply count
        * like count
        * quote count
        """
        df = bin_ts["eng_dataframe"]
        
        rt_count = []
        rp_count = []
        lc_count = []
        qt_count = []
        
        if len(df)>0:
            pms = df["tweet_public_metrics"].tolist()
            for pm in pms:
                rt_count.append(pm["retweet_count"])
                rp_count.append(pm["reply_count"])
                lc_count.append(pm["like_count"])
                qt_count.append(pm["quote_count"])
        
        
            rt_count = np.sum(rt_count)
            rp_count = np.sum(rp_count)
            lc_count = np.sum(lc_count)
            qt_count = np.sum(qt_count)

            pm_vec = [rt_count,rp_count,lc_count,qt_count]
            return pm_vec
        
        else:
            pm_vec = [0,0,0,0]
            return pm_vec
        

class NonEngagementPublicMetricsExtractor(FeatureExtractor):
    
    def __init__(self):
        """
        """
        pass
    
    def extract(self,bin_ts):
        """
        public metrics types:
        * retweet count
        * reply count
        * like count
        * quote count
        """
        df = bin_ts["noeng_dataframe"]
        
        rt_count = []
        rp_count = []
        lc_count = []
        qt_count = []
        
        if len(df)>0:
            pms = df["tweet_public_metrics"].tolist()
            for pm in pms:
                rt_count.append(pm["retweet_count"])
                rp_count.append(pm["reply_count"])
                lc_count.append(pm["like_count"])
                qt_count.append(pm["quote_count"])
        
        
            rt_count = np.sum(rt_count)
            rp_count = np.sum(rp_count)
            lc_count = np.sum(lc_count)
            qt_count = np.sum(qt_count)

            pm_vec = [rt_count,rp_count,lc_count,qt_count]
            return pm_vec
        
        else:
            pm_vec = [0,0,0,0]
            return pm_vec

class DRIdentifier(object):
    
    def __init__(self,news_df):
        """
        """
        ex_news_df = news_df.explode("Twitter Handle").reset_index(drop=True).explode("URL").reset_index(drop=True)
        ex_news_twh = ex_news_df["Twitter Handle"].tolist()
        self.news_twh = [e for e in ex_news_twh if type(e)==str]
        
        self.match_pattern = re.compile(r"(RT\s{0,}@[a-zA-Z0-9]*)")
        
    def check_drt(self,tweet_text):
        """
        """
        matches = self.match_pattern.search(tweet_text)
        
        if matches != None:
            matches = matches.group(0)
            acc = matches.replace("RT","").replace(":","").replace("@","").strip()

            if acc in self.news_twh:
                return True
            else:
                return False
        else:
            return False
    
    def identify(self,df):
        """
        """
    
        df["direct_retweet"] = df["text"].apply(lambda x: self.check_drt(x))
        
        return df

class EngagementDRCountsExtractor(FeatureExtractor):
    
    def __init__(self,news_df):
        """
        """
        self.drt_identifier = DRIdentifier(news_df)
    
    def extract(self,bin_ts):
        """
        """
        df = bin_ts["eng_dataframe"]
        drt_count_vec = [0.0]*7
        
        if len(df)>0:
            
            df = self.drt_identifier.identify(df)
            
            masks = [df["matched_partisans"].apply(lambda x: i in x) for i in range(-3,4,1)]
            for idx,ps in enumerate(range(-3,4,1)):
                
                df_ps = df[masks[idx]]
                if len(df_ps)>0:
                    drt_counts = df_ps["direct_retweet"].value_counts()
                    if True in drt_counts:
                        drt_count_vec[idx]+= drt_counts[True]
        
        return drt_count_vec
                
                

class TotalTweetcountExtractor(FeatureExtractor):
    
    def __init__(self):
        """
        """
        pass
    
    def extract(self,bin_ts):
        """
        """
        return bin_ts["eng_dataframe"].shape[0] + bin_ts["noeng_dataframe"].shape[0]

class EngagementNSDistExtractor(FeatureExtractor):
    
    def __init__(self,news_df):
        """
        """
        self.news_df = news_df
        self.news_sources = self.news_df["Source"].tolist()
        self.ns_map = {k:v for v,k in enumerate(self.news_sources)}
        
    def extract(self,bin_ts):
        """
        """
        df = bin_ts["eng_dataframe"]
        ns_vec = [0.0]*len(self.news_sources)
        
        if len(df)>0:
            
            ms = df["matched_sources"].tolist()
            ms = flatten(ms)
            
            for m in ms:
                ns_vec[self.ns_map[m]]=1
        
        return ns_vec


def get_hashtags(urow):
    """
    """
    hashtag_regex = r'(?<![#\w])#(\w+)'
    hast1 = re.findall(hashtag_regex,urow.text)
    if len(urow.referenced_text)>0:
        hast1+= re.findall(hashtag_regex,urow.referenced_text[0])
    
    return list(set(hast1))


def identity_tok(x):
    """
    """
    x_clean = []
    for i in x:
        if len(re.sub(r'[^\x00-\x7f]',r'', i))>0:
            x_clean.append(i)
    return x_clean



class MentionsGeneralExtractor(FeatureExtractor):
    
    def __init__(self,vectorizer_path):
        """
        """
        self.vectorizer = pickle.load(open(vectorizer_path,"rb"))
        self.feature_names = self.vectorizer.get_feature_names()
        self.feature_map = {x:i for i,x in enumerate(self.feature_names)}
        
    
    def extract(self,bin_ts):
        """
        * intialize zero vec with vocab size
        * get list of mentions from both engagement and non-engagement tweets
        * and fill in binary counts
        """
        df_eng = bin_ts["eng_dataframe"]
        df_noeng = bin_ts["noeng_dataframe"]
        start,end = bin_ts["date_range"]
        
        men_vec = [0]*len(self.feature_names)
        
        mentions = df_eng["extracted_mentions"].tolist() + df_noeng["extracted_mentions"].tolist()
        
        mentions = flatten(mentions)
        
        if len(mentions)>0:
            for m in mentions:
                if m in self.feature_map:
                    men_vec[self.feature_map[m]]=1

        return men_vec



class HashtagGeneralExtractor(FeatureExtractor):
    
    def __init__(self,vectorizer_path):
        """
        """
        
        self.vectorizer = pickle.load(open(vectorizer_path,"rb"))
        self.feature_names = self.vectorizer.get_feature_names()
        self.feature_map = {x:i for i,x in enumerate(self.feature_names)}
    
    def extract(self,bin_ts):
        """
        * intialize zero vec with vocab size
        * get list of hashtags from both engagement and non-engagement tweets
        * and fill in counts
        """
        df_eng = bin_ts["eng_dataframe"]
        df_noeng = bin_ts["noeng_dataframe"]
        start,end = bin_ts["date_range"]
        
        hash_vec = [0]*len(self.feature_names)
        
        hashtags = df_eng["extracted_hashtags"].tolist() + df_noeng["extracted_hashtags"].tolist()
        
        hashtags = flatten(hashtags)
        
        if len(hashtags)>0:
            for h in hashtags:
                if h in self.feature_map:
                    hash_vec[self.feature_map[h]]=1
        
        return hash_vec

class TopHashtagExtractor(FeatureExtractor):
    
    def __init__(self,tokenizer,topk=100):
        """
        """
        self.tokenizer = tokenizer
        self.topk=topk
    
    def extract(self,bin_ts):
        """
        """
        df_eng = bin_ts["eng_dataframe"]
        df_noeng = bin_ts["noeng_dataframe"]
        
        hashtags = df_eng["extracted_hashtags"].tolist() + df_noeng["extracted_hashtags"].tolist()
        
        hashtags = flatten(hashtags)
        
        hashtags_counter = Counter(hashtags)
        
        top_hash_tups = hashtags_counter.most_common(self.topk)
        
        top_hashes = []
        for tp in top_hash_tups:
            top_hashes.append(tp[0])
        
        if len(top_hashes)==0:
            top_hashes=[""]
        
        tokenized_out = self.tokenizer.batch_encode_plus(top_hashes,
                                       padding="max_length",
                                       max_length=8,
                                       truncation=True,
                                       return_tensors='pt',
                                       add_special_tokens=False)
        
        return tokenized_out

class TopMentionsExtractor(FeatureExtractor):
    
    def __init__(self,tokenizer,topk=100):
        """
        """
        self.tokenizer=tokenizer
        self.topk=topk
    
    def extract(self,bin_ts):
        """
        * intialize zero vec with vocab size
        * get list of hashtags from both engagement and non-engagement tweets
        * and fill in counts
        """
        df_eng = bin_ts["eng_dataframe"]
        df_noeng = bin_ts["noeng_dataframe"]
        
        mentions = df_eng["extracted_mentions"].tolist() + df_noeng["extracted_mentions"].tolist()
        
        mentions = flatten(mentions)
        
        mentions_counter = Counter(mentions)
        
        top_mens_tups = mentions_counter.most_common(self.topk)
        
        top_mentions = []
        for tp in top_mens_tups:
            top_mentions.append(tp[0])
        
        if len(top_mentions)==0:
            top_mentions=[""]
        
        tokenized_out = self.tokenizer.batch_encode_plus(top_mentions,
                                       padding="max_length",
                                       max_length=8,
                                       truncation=True,
                                       return_tensors='pt',
                                       add_special_tokens=False)
        
        return tokenized_out



class EngagementTextExtractorT1(FeatureExtractor):
    """
    # * cls token for 25 most recent tweets -> 25,768
    # * aggregate on top of this -> avg+min+max
    """
    
    def __init__(self,tokenizer,num_tweets=25):
        """
        """
        self.tokenizer = tokenizer
        self.num_tweets = num_tweets
            
    
    def get_tweets(self,eng_bin_df):
        """
        case 1: size of eng_bin_df < self.num_tweets
        """
        
        if eng_bin_df.shape[0] == 0:
            most_recent_tweets = [""]*25
            bin_masks = [0]*25
            tokenized_out = self.tokenizer.batch_encode_plus(most_recent_tweets,
                                           padding="max_length",
                                           max_length=20,
                                           truncation=True,
                                           return_tensors='pt',
                                           add_special_tokens=False)
            
            tokenized_out["bin_masks"] = bin_masks
            return tokenized_out
            
        
        else:
            eng_bin_df = eng_bin_df.sort_index()
            most_recent_tweets = eng_bin_df["text"].tail(self.num_tweets).tolist()
            bin_masks = [1]*len(most_recent_tweets)
            if len(most_recent_tweets)<self.num_tweets:
                most_recent_tweets+= [""]* (self.num_tweets - len(most_recent_tweets))
                bin_masks += [0]*(self.num_tweets - len(bin_masks))
            
            tokenized_out = self.tokenizer.batch_encode_plus(most_recent_tweets,
                                           padding="max_length",
                                           max_length=20,
                                           truncation=True,
                                           return_tensors="pt",
                                           add_special_tokens=True)
            
            tokenized_out["bin_masks"] = bin_masks
            
            return tokenized_out
            
    
    def extract(self,bin_ts):
        """
        """
        df_eng = bin_ts["eng_dataframe"]
        df_neng = bin_ts["noeng_dataframe"]
        
        tokenized_tweets_eng = self.get_tweets(df_eng)
        
        tokenized_tweets_neng = self.get_tweets(df_neng)
        
        return tokenized_tweets_eng, tokenized_tweets_neng
    


class EngagementTextExtractorT2(FeatureExtractor):
    """
    # * combine all tweets into 1 string :
    # * Pick cls token
    """
    
    def __init__(self,tokenizer,num_tweets=25,max_length=512):
        """
        """
        self.tokenizer = tokenizer
        self.num_tweets = num_tweets
        self.max_length = max_length
            
    
    def get_tweets(self,eng_bin_df):
        """
        case 1: size of eng_bin_df < self.num_tweets
        """
        
        if eng_bin_df.shape[0] == 0:
            tokenized_out = self.tokenizer([""],
                                           padding="max_length",
                                           max_length=512,
                                           truncation=True,
                                           return_tensors='pt',
                                           add_special_tokens=False)
            
            
            return tokenized_out
            
        
        else:
            eng_bin_df = eng_bin_df.sort_index()
            most_recent_tweets = eng_bin_df["text"].tail(self.num_tweets).tolist()
            all_str = ". ".join(most_recent_tweets)
            tokenized_out = self.tokenizer(all_str,
                                           padding="max_length",
                                           max_length=512,
                                           truncation=True,
                                           return_tensors="pt",
                                           add_special_tokens=True)
            
            return tokenized_out
            
    
    def extract(self,bin_ts):
        """
        """
        df_eng = bin_ts["eng_dataframe"]
        df_neng = bin_ts["noeng_dataframe"]
        
        tokenized_tweets_eng = self.get_tweets(df_eng)
        
        tokenized_tweets_neng = self.get_tweets(df_neng)
        
        return tokenized_tweets_eng, tokenized_tweets_neng

In [None]:
def flatten(list_):
    """
    """
    return [a  for x in list_ for a in x] 


class FEChainExtractor():
    
    def __init__(self,news_df,ments_vecpath,hash_vecpath,num_eng=10,num_neng=3):
        """
        """
        start = time.time()
        # self.tokenizer = AutoTokenizer.from_pretrained("/home/kshvaram/roberta_tweet/twitter-roberta-base/")
        self.tokenizer = AutoTokenizer.from_pretrained("/home/kshvaram/bert_tweet/twhin-bert-base/")
        stop = time.time()
        print(f"Tokenizer loaded in : {round((stop-start)/60,2)} mins")
        
        self.__extractor1 = EngagementCountExtractor()
        self.__extractor2 = EngagementTypeCountExtractor()
        self.__extractor3 = EngagementPublicMetricsExtractor()
        self.__extractor4 = NonEngagementPublicMetricsExtractor()
        self.__extractor5 = EngagementDRCountsExtractor(news_df)
        self.__extractor6 = TotalTweetcountExtractor()
        self.__extractor7 = EngagementNSDistExtractor(news_df)
        self.__extractor8 = MentionsGeneralExtractor(ments_vecpath)
        self.__extractor9 = HashtagGeneralExtractor(hash_vecpath)
        self.__extractor10 = EngagementTextExtractorT1(num_tweets=num_eng,tokenizer=self.tokenizer)
        self.__extractor11 = EngagementTextExtractorT2(num_tweets=num_eng,max_length=512,tokenizer=self.tokenizer)
        self.__extractor12 = TopHashtagExtractor(topk=50,tokenizer=self.tokenizer)
        self.__extractor13 = TopMentionsExtractor(topk=50,tokenizer=self.tokenizer)
        
        self.fe_factory = {}
        self.__load_factory__()
    
    def __load_factory__(self):
        """
        """
        self.fe_factory = {"engagement_counts":self.__extractor1,
                           "engagement_types":self.__extractor2,
                           "engagement_publicmetrics":self.__extractor3,
                           "nonengagement_publicmetrics":self.__extractor4,
                           "engagement_drt":self.__extractor5,
                           "total_tweet_counts":self.__extractor6,
                           "engagement_nsdist": self.__extractor7,
                           "mentions_usertime":self.__extractor8,
                           "hashtags_usertime":self.__extractor9,
                           "text_t1":self.__extractor10,
                           "top_hashes":self.__extractor12,
                           "top_mentions":self.__extractor13}
    
    def extract_feats(self,bin_ts):
        """
        """
        chain = self.fe_factory
        
        feats = {fe:chain[fe].extract(bin_ts) for fe in chain.keys()}
        
        return feats


class BinUsers(object):
    """
    """
    def __init__(self,user_paths,bin_freq="3MS",start_date='01/01/2014',end_date='9/01/2021'):
        """
        * user_paths : file paths of user timelines
        * bin_freq : size of each bin in terms of time
        * start_date : which date to start the binning from
        * end_date : which date to stop binning
        """
        self.user_paths = user_paths
        self.bin_freq = bin_freq
        self.start_date = start_date
        self.end_date = end_date
    
    def date_range_init(self):
        """
        """
        index = pd.date_range(start=self.start_date, 
                              end=self.end_date, 
                              freq=self.bin_freq,
                              inclusive="both",
                              normalize=True)

        startd=[]
        endd=[]

        for i in range(1,len(index),1):
            startd.append(index[i-1])
            endd.append(index[i])

        self.df_dates = pd.DataFrame()
        self.df_dates["start"]=pd.to_datetime(startd,utc=True)
        self.df_dates["end"]=pd.to_datetime(endd,utc=True)
        year_difference = int(pd.to_datetime(self.end_date).year) - int(pd.to_datetime(self.start_date).year)
        self.df_dates["quarter"] = [1,2,3,4]*year_difference + [1,2,3]
        
    
    def plot_empty_bin_dist(self,avg_bin_dist):
        """
        """
        fig,ax = plt.subplots(1,1,figsize=(3,4))
        sns.boxplot(y=avg_bin_dist,
                    ax=ax,
                    orient="v",
                    width=0.3,
                    showmeans=True,
                    color="tab:red",
                    meanprops={"marker":"o",
                               "markerfacecolor":"white", 
                               "markeredgecolor":"black",
                              "markersize":"10"})
        ax.set_ylabel("Avg Empty Bins Per User")
        ax.set_xlabel("Users")
        plt.show()
        
    
    def bin_users(self,feats_extractor):
        """
        """
        all_user_binned_seqs = []
        avg_empty_bins=[]
        
        self.date_range_init()

        for up in tqdm.tqdm(self.user_paths):
           
            df_u = pd.read_pickle(up)
            df_u["extracted_hashtags"] = df_u.apply(lambda x: get_hashtags(x),axis=1) 
            df_u["created_at"] = pd.to_datetime(df_u["created_at"])
            df_u = df_u.loc[df_u["created_at"]>=self.start_date]
            df_u = df_u.set_index("created_at")

            feat_bins = []
            empty_c=0
            for index,row in self.df_dates.iterrows():
                df_bin_dict = {}
                df_bin = df_u.loc[(df_u.index>=row["start"]) & (df_u.index< row["end"])]
                df_bin_engagement = df_bin.loc[df_bin["matched_partisans"].str.len()>=1]
                df_bin_no_engagement = df_bin.loc[df_bin["matched_partisans"].str.len()<=0]
                
                df_bin_dict["eng_dataframe"] = df_bin_engagement
                df_bin_dict["noeng_dataframe"] = df_bin_no_engagement
                df_bin_dict["quarter"] = row["quarter"]
                df_bin_dict["date_range"] = (row["start"],row["end"])
                
                feats_dict = feats_extractor.extract_feats(df_bin_dict)
                
                feats_dict["eng_text_t1"] = feats_dict["text_t1"][0]
                # feats_dict["eng_text_t2"] = feats_dict["text_t2"][0]
                feats_dict["neng_text_t1"] = feats_dict["text_t1"][1]
                # feats_dict["neng_text_t2"] = feats_dict["text_t2"][1]
                
                # feats_dict["eng_text_text_ids"] = feats_dict["text"][0]["input_ids"][0]
                # feats_dict["eng_text_atten_mask"] = feats_dict["text"][0]["attention_mask"][0]
                # feats_dict["neng_text_text_ids"] = feats_dict["text"][1]["input_ids"][0]
                # feats_dict["neng_text_atten_mask"] = feats_dict["text"][1]["attention_mask"][0]
                
                del feats_dict["text_t1"]
                # del feats_dict["text_t2"]
        
                feats_dict["quarter"] = df_bin_dict["quarter"]
                feats_dict["date_range"] = df_bin_dict["date_range"]
                feats_dict["username"] = up.split("/")[-1]
            
                
                feat_bins.append(feats_dict)
                
                
                
            all_user_binned_seqs.append(feat_bins)
            
            
            
        return all_user_binned_seqs

In [None]:

news_df = pd.read_pickle("../../../Data/all_news_sources/all_news_sources.pkl")

feats_extractor = FEChainExtractor(news_df,
                                   ments_vecpath="vocabs/mentions_general_vectorizer.pk",
                                   hash_vecpath="vocabs/hashtags_general_vectorizer.pk",
                                   num_eng=25,num_neng=5)

In [None]:
users_binned  =  BinUsers(filtered_user_paths,
                          bin_freq="3MS",
                          start_date='01/01/2015',
                          end_date='10/01/2021').bin_users(feats_extractor)

## Train, Val, Test Split

In [None]:

def trunc_datetime(someDate):
    # remove day, time, timezone info for comparison
    return someDate.replace(day=1, hour=0, minute=0, second=0, microsecond=0).tz_localize(None)

def split_all_user_sequences(all_sequences):
    """
    * all_sequences: list of list of dicts , where each list item represents a user
    * and the user list contains dicts wher each of them represents a bin/timestep
    * returns splits of user sequences binned according to 3 year time period
    """
    dates = {1:(trunc_datetime(pd.to_datetime("2015-01-01",format='%Y-%M-%d')),
                trunc_datetime(pd.to_datetime("2019-01-01",format='%Y-%M-%d'))),
             2:(trunc_datetime(pd.to_datetime("2016-01-01",format='%Y-%M-%d')),
                trunc_datetime(pd.to_datetime("2020-01-01",format='%Y-%M-%d'))),
             3:(trunc_datetime(pd.to_datetime("2017-01-01",format='%Y-%M-%d')),
                trunc_datetime(pd.to_datetime("2021-01-01",format='%Y-%M-%d'))),
             4:(trunc_datetime(pd.to_datetime("2018-01-01",format='%Y-%M-%d')),
                trunc_datetime(pd.to_datetime("2022-01-01",format='%Y-%M-%d')))}
    
    run1_obs = []
    run2_obs = []
    run3_obs = []
    run4_obs = []
    
    
    for user_sequence in tqdm.tqdm(all_sequences):
        
        binned_sequence_r1 = []
        binned_sequence_r2 = []
        binned_sequence_r3 = []
        binned_sequence_r4 = []
        
        for user_bin_dict in user_sequence:
            
            start_date = trunc_datetime(user_bin_dict["date_range"][0])
            end_date = trunc_datetime(user_bin_dict["date_range"][1])
            
            if start_date >= dates[1][0] and end_date <= dates[1][1]:
                binned_sequence_r1.append(user_bin_dict)
                
            if start_date >= dates[2][0] and end_date <= dates[2][1]:
                binned_sequence_r2.append(user_bin_dict)
            
            if start_date >= dates[3][0] and end_date <= dates[3][1]:
                binned_sequence_r3.append(user_bin_dict)
            
            if start_date >= dates[4][0] and end_date <= dates[4][1]:
                binned_sequence_r4.append(user_bin_dict)
            
        
        assert len(binned_sequence_r1) == len(binned_sequence_r2) == len(binned_sequence_r3)
        assert len(binned_sequence_r1)-1 ==len(binned_sequence_r4)
        
        run1_obs.append(binned_sequence_r1)
        run2_obs.append(binned_sequence_r2)
        run3_obs.append(binned_sequence_r3)
        run4_obs.append(binned_sequence_r4)
        
    
    return run1_obs, run2_obs, run3_obs, run4_obs

def use_sliding_window(seq,window_size=9):
    """
    Seq : Indices of the timestep bins
    """
    sub_sequences = []
    
    for w in range(len(seq)-window_size+1):
        sub_sequences.append(seq[w:w+window_size])
        
    return sub_sequences


def get_train_val_test_split(user_sequences,train_val_split=[0.8,0.2]):
    """
    """
    test = [useq[4:] for useq in user_sequences]
    
    num_users = len(user_sequences)
    
    train_size = int(train_val_split[0]*num_users)
    val_size = int(train_val_split[1]*num_users)
    
    random.seed(42)
    train_user_indices = random.sample([i for i in range(num_users)],train_size)
    val_user_indices = [j for j in range(num_users) if j not in  train_user_indices] 
    
    train = [user_sequences[i][:4] for i in train_user_indices]
    
    val = [user_sequences[i][:4] for i in val_user_indices]
    
    print(f"No of users in Train : {len(train)}")
    print(f"No of users in Val : {len(val)}")
    print(f"No of users in Test : {len(test)}\n")
    
    return train, val, test

In [None]:
run1_split, run2_split, run3_split, run4_split = split_all_user_sequences(users_binned)

In [None]:
run1_split_sw = [use_sliding_window(useq,window_size=9) for useq in run1_split]
run2_split_sw = [use_sliding_window(useq,window_size=9) for useq in run2_split]
run3_split_sw = [use_sliding_window(useq,window_size=9) for useq in run3_split]
run4_split_sw = [use_sliding_window(useq,window_size=9) for useq in run4_split]

In [None]:
r1_train, r1_val, r1_test = get_train_val_test_split(run1_split_sw,train_val_split=[0.8,0.2])
r2_train, r2_val, r2_test = get_train_val_test_split(run2_split_sw,train_val_split=[0.8,0.2])
r3_train, r3_val, r3_test = get_train_val_test_split(run3_split_sw,train_val_split=[0.8,0.2])
r4_train, r4_val, r4_test = get_train_val_test_split(run4_split_sw,train_val_split=[0.8,0.2])

In [None]:
r1_train = flatten(r1_train) 
r1_val = flatten(r1_val)
r1_test = flatten(r1_test)

r2_train = flatten(r2_train) 
r2_val = flatten(r2_val)
r2_test = flatten(r2_test)

r3_train = flatten(r3_train) 
r3_val = flatten(r3_val)
r3_test = flatten(r3_test)

r4_train = flatten(r4_train) 
r4_val = flatten(r4_val)
r4_test = flatten(r4_test)

## Postprocessing Train, Val, Test - Precomputing Features and saving to Disk

In [None]:
def zscore_obs(sequence_counts):
    """
    """
    binned_mat = np.array(sequence_counts)[:-1]
    means = np.mean(binned_mat,axis=0)
    std_devs = np.std(binned_mat,axis=0)

    normalized_seq = (np.array(sequence_counts) - means)/std_devs
    normalized_seq[np.isnan(normalized_seq)]=0
    normalized_seq[np.isinf(normalized_seq)]=0

    return normalized_seq , means, std_devs

def one_hot_encode_quarters(quarter_ids):
    """
    """
    quarters = []
    
    for q in quarter_ids:
        quarter_encoded = [0,0,0,0]
        quarter_encoded[q-1]=1
        quarters.append(quarter_encoded)
        
    return quarters

def combine_across_time(sequence_mat):
    """
    (timesteps,feature dim) -> (1, feature dim)
    
    * Perform sum and convert values >0 to 1 and <=0 to 0
    * fill nans and infs
    """
    agg_mat = np.sum(sequence_mat[:-1,:],axis=0)
    return (agg_mat > 1).astype(int)

def get_bin_masks(text_sequence):
    """
    """
    bin_masks = []
    for i in text_sequence.shape[0]:
        
        if torch.sum(text_sequence[i,:])==text_sequence.shape[-1]:
            bin_masks.append(True)
        else:
            bin_masks.append(False)
    
    return torch.Tensor(bin_masks)

def remove_empty_bins(text_sequence,attnmask):
    """
    """
    filtered_text = []
    filtered_attn = []
    for i in text_sequence.shape[0]:
        
        if not torch.sum(text_sequence[i,:])==text_sequence.shape[-1]:
            
            filtered_text.append(text_sequence[i,:].unsqueeze())
            filtered_attn.append(attnmask[i,:].unsequeeze())
    
    if len(filtered_text)>0:
        filtered_text = torch.concat(filtered_text,dim=0)
        filtered_attn = torch.concat(filtered_attn,dim=0)
    
    if len(filtered_text)<0:
        pass
    
    return filtered_text, filtered_attn
    
    
def get_reps(eng_text, eng_attn, model,dim=1):
    """
    """
    out = None
    
    
    with torch.set_grad_enabled(False):
        
        out = model(eng_text,eng_attn)
        
    return out.last_hidden_state
    
    
def get_text_reps_cls_text(tokens,attn,bin_masks,model):
    """
    """
    pass
    
def get_text_reps_cls(tokens,attn,model,ttype=1):
    """
    type 1: 
    * Select cls tokens from each tweet rep
    * agg these cls tokens
    * eg: [8,x,20]->[8,x,20,768] -> [8,x,1,768] -> ideally [8,1,768]
    
    * the above can also be done as follows:
    * for every tweet x -> concatenate cls token + avg token embeddings
    * now aggregate this resulting tensor -> sum,min,max,mean etc ...
    
    type 2:
    * select cls token rep 
    * [8,1,512] -> [8,512,768] -> [8,768]
    * 
    """
    res_dict = {}
    hidden_rep = None
    
    
    if ttype==1:
        
        tokens = tokens[:-1,:,:]
        attn = attn[:-1,:,:]
        
        num_timesteps = tokens.shape[0]
        num_tweets = tokens.shape[1]
        max_tweetlength = tokens.shape[2]
        tokens = torch.reshape(tokens,(num_timesteps*num_tweets,max_tweetlength))
        attn = torch.reshape(attn,(num_timesteps*num_tweets,max_tweetlength))
        
        acc = [] # for loop since 9 * 25 is too large a batch for the current gpu
        
#         for i in range(9):
        with torch.set_grad_enabled(False):
            out = model(tokens,attn)
            hidden_rep = out.last_hidden_state
        
        
            
        hidden_rep = torch.reshape(hidden_rep,(num_timesteps,num_tweets,max_tweetlength,768))
        cls_reps = hidden_rep[:,:,0,:] # select cls token -> (timesteps,num_tweets,768)
        # padding_ind = bin_masks.index(0)
        other_token_reps = hidden_rep[:,:,1:,:] # select other tokens -> (timesteps,num_tweets,24,768)
        avg_token_reps = torch.mean(other_token_reps,dim=2) # avg other tokens -> (timesteps,num_tweets,768)
        fused_reps = torch.concat([cls_reps,avg_token_reps],dim=-1) # concat -> (timesteps,num_tweets,1536)
        
        
        masked_hidden_rep = fused_reps
        hr_mean = torch.mean(masked_hidden_rep,dim=1)
        res_dict["mean"] = hr_mean.detach().cpu()
        
    
    else:
        tokens = tokens[:-1,:]
        attn = attn[:-1,:]
        
        with torch.set_grad_enabled(False):
            out = model(tokens,attn)
            hidden_rep = out.last_hidden_state
        
        # type 2 where all tweets are concatenated into one string of length 512 (timesteps,512,768)
        cls_rep = hidden_rep[:,0,:] # (9,768)
        other_token_reps = hidden_rep[:,1:,:] #(9,511,768)
        
        res_dict["cls"] = cls_rep.detach().cpu()
        res_dict["sum"] = torch.sum(other_token_reps,dim=1).detach().cpu()
        res_dict["min"] = torch.min(other_token_reps,dim=1)[0].detach().cpu()
        res_dict["max"] = torch.max(other_token_reps,dim=1)[0].detach().cpu()
        res_dict["mean"] = torch.mean(other_token_reps,dim=1).detach().cpu()
        
    
    return res_dict

def get_hash_mens_reps(tokens,attn,model):
    """
    tokens shape : 50,8
    attn shape : 50,8
    
    # 1st level aggregation
    
    out[:,0,:] -> cls (50,768)
    torch.mean(out[:,1:,:]) -> (50,768)
    
    torch.concat() -> 50,1536
    torch.mean -> 1, 1536
    
    
    """
    # aggregation -> 
    with torch.set_grad_enabled(False):
        
        
        out = model(tokens,attn)
        hidden_rep = out.last_hidden_state

        cls_rep = hidden_rep[:,0,:] 
        other_token_reps = hidden_rep[:,1:,:]

        otr_mean = torch.mean(other_token_reps,dim=1)

        cls_otr = torch.concat([cls_rep,otr_mean],dim=-1)

        cls_otr_mean = torch.mean(cls_otr,dim=0)

        return cls_otr_mean

        
        

def extract_sequence_feats(dataset):
    """
    """
    
    seq_feats = []
    
    # model = AutoModel.from_pretrained("/home/kshvaram/roberta_tweet/twitter-roberta-base/")
    model = AutoModel.from_pretrained("/home/kshvaram/bert_tweet/twhin-bert-base/")
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    model= nn.DataParallel(model)
    model.to(device)
    
    model.eval()
    
    for seqd in tqdm.tqdm(dataset):
        
        eng_counts = []
        eng_type_counts = []
        eng_pm_counts = []
        neng_pm_counts = []
        eng_drt_counts = []
        tt_counts = []
        eng_ns = []
        mentions = []
        hashtags = []

        quarter_ids = []

        eng1_text = []
        eng1_attn = []
        eng1_bin_masks = []
        
        eng2_text = []
        eng2_attn = []
        # eng2_bin_masks = []
        
        
        neng1_text = []
        neng1_attn = []
        neng1_bin_masks = []
        
        neng2_text = []
        neng2_attn = []
        # neng2_bin_masks = []
        
        date_range = []
        usernames = []
        
        top_mens_tokens = []
        top_hash_tokens = []
        top_mens_attn = []
        top_hash_attn = []
        
        
        for obs in seqd:
        
            eng_counts.append(obs["engagement_counts"])
            eng_type_counts.append(obs["engagement_types"])
            eng_pm_counts.append(obs["engagement_publicmetrics"])
            neng_pm_counts.append(obs["nonengagement_publicmetrics"])
            eng_drt_counts.append(obs["engagement_drt"])
            tt_counts.append(obs["total_tweet_counts"])
            eng_ns.append(obs["engagement_nsdist"])
            mentions.append(obs["mentions_usertime"])
            hashtags.append(obs["hashtags_usertime"])
            quarter_ids.append(obs["quarter"])
            date_range.append(obs["date_range"])
            usernames.append(obs["username"])
            
            eng1_text.append(obs["eng_text_t1"]["input_ids"].unsqueeze(0))
            eng1_attn.append(obs["eng_text_t1"]["attention_mask"].unsqueeze(0))
            eng1_bin_masks.append(obs["eng_text_t1"]["bin_masks"])
            
            
            neng1_text.append(obs["neng_text_t1"]["input_ids"].unsqueeze(0))
            neng1_attn.append(obs["neng_text_t1"]["attention_mask"].unsqueeze(0))
            neng1_bin_masks.append(obs["neng_text_t1"]["bin_masks"])
            
            top_mens_tokens.append(obs["top_mentions"]["input_ids"].unsqueeze(0))
            top_hash_tokens.append(obs["top_hashes"]["input_ids"].unsqueeze(0))
            top_mens_attn.append(obs["top_mentions"]["attention_mask"].unsqueeze(0))
            top_hash_attn.append(obs["top_hashes"]["attention_mask"].unsqueeze(0))
            
            
            
            
        
        # engagement counts - 1x7
        eng_counts_norm, c_means, c_std_dev = zscore_obs(eng_counts)
        
        # engagement type counts - url,mentions,both - 1x3
        eng_type_norm,_,_ = zscore_obs(eng_type_counts)
        
        # engagement and non-engagement public metrics - retweeted,liked,shared,quote - 1x4
        eng_pm_norm,_,_ = zscore_obs(eng_pm_counts)
        neng_pm_norm,_,_ = zscore_obs(neng_pm_counts)
        
        # total_tweet_count in timestep bin - 1x1
        tt_counts_norm,_,_ = zscore_obs(tt_counts)
        tt_counts_norm = tt_counts_norm.reshape(-1,1)
        
        # engagement drt counts per stance - 1x7
        eng_drt_counts_norm,_,_ = zscore_obs(eng_drt_counts)
        
        # engaged news sources - binary vec - 1x522
        eng_ns = np.array(eng_ns)
        
        # top mentioned user accounts - binary vec - 1x5000
        mentions = np.array(mentions)
        
        # top used hashtags - binary vec - 1x5000
        hashtags = np.array(hashtags)
        
        # time features - quarter one hot encoded - 1x4
        quarter_ids_enc = one_hot_encode_quarters(quarter_ids)
        quarter_ids_enc = np.array(quarter_ids_enc)
        quarter_ids = np.array(quarter_ids)
        
        
        # text representations
        eng1_text = torch.concat(eng1_text,dim=0).to(device)
        eng1_attn = torch.concat(eng1_attn,dim=0).to(device)
        
        neng1_text = torch.concat(neng1_text,dim=0).to(device)
        neng1_attn = torch.concat(neng1_attn,dim=0).to(device)
        
        eng1 = get_text_reps_cls(eng1_text,eng1_attn,model,ttype=1)
        neng1 = get_text_reps_cls(neng1_text,neng1_attn,model,ttype=1)
        
        
        # top hashes, mentions
        
        top_hashes_toks = top_hash_tokens[-2].to(device)
        top_hashes_attn = top_hash_attn[-2].to(device)
        top_mens_toks = top_mens_tokens[-2].to(device)
        top_mens_attn = top_mens_attn[-2].to(device)
        
        
        # last time step of the input sequence
        top_hash_reps = get_hash_mens_reps(top_hashes_toks.squeeze(0),
                                           top_hashes_attn.squeeze(0),
                                           model)
        
        top_mens_reps = get_hash_mens_reps(top_mens_toks.squeeze(0),
                                           top_mens_attn.squeeze(0),
                                           model)
        
        
        
        eng1_text = None
        eng1_attn = None
        eng2_text = None
        eng2_attn = None
        neng1_text = None
        neng1_attn = None
        neng2_text = None
        neng2_attn = None
        top_hashes_toks = None
        top_hashes_attn = None
        top_mens_toks = None
        top_mens_attn = None
        
        del eng1_text
        del eng1_attn
        del eng2_text
        del eng2_attn
        del neng1_text
        del neng1_attn
        del neng2_text
        del neng2_attn
        del top_hashes_toks
        del top_hashes_attn
        del top_mens_toks
        del top_mens_attn
        
        
        
        output_dict = {"eng_counts":eng_counts_norm[:-1,:].astype(np.float16),
                       "eng_type":eng_type_norm[:-1,:].astype(np.float16),
                       "eng_pm":eng_pm_norm[:-1,:].astype(np.float16),
                       "neng_pm":neng_pm_norm[:-1,:].astype(np.float16),
                       "total_count":tt_counts_norm[:-1,:].astype(np.float16),
                       "drt_count":eng_drt_counts_norm[:-1,:].astype(np.float16),
                       "ns_feat":eng_ns[:-1,:].astype(np.float16),
                       "mentions_feat":mentions[:-1,:].astype(np.float16),
                       "hashtag_feat":hashtags[:-1,:].astype(np.float16),
                       "engagement_text_feats_1":eng1,
                       "nengagement_text_feats_1":neng1,
                       "tophashes":top_hash_reps,
                       "topmentions":top_mens_reps,
                       "input_time_feat":quarter_ids_enc[:-1,:].astype(np.float16),
                       "output_time_feat":quarter_ids_enc[-1,:].astype(np.float16),
                       "output_quarter":quarter_ids[-1],
                       "label_zscored":eng_counts_norm[-1,:].astype(np.float16),
                       "label_original":np.array(eng_counts)[-1,:].astype(np.float16),
                       "means":c_means.astype(np.float16),
                       "stddevs":c_std_dev.astype(np.float16),
                       "original_counts":np.array(eng_counts),
                       "date_ranges":date_range,
                       "username":usernames}
        
        seq_feats.append(output_dict)
        
    return seq_feats

    

In [None]:
r1_train = extract_sequence_feats(r1_train) 

with open("pickled_data/r1_train.pkl","wb") as wp:
    pickle.dump(r1_train,wp)

r1_train = None

del r1_train
    
r1_val = extract_sequence_feats(r1_val)

with open("pickled_data/r1_val.pkl","wb") as wp:
    pickle.dump(r1_val,wp)
    
r1_val = None

del r1_val

r1_test = extract_sequence_feats(r1_test)

with open("pickled_data/r1_test.pkl","wb") as wp:
    pickle.dump(r1_test,wp)

r1_test = None

del r1_test

In [None]:
r2_train = extract_sequence_feats(r2_train) 

with open("pickled_data/r2_train.pkl","wb") as wp:
    pickle.dump(r2_train,wp)
    
r2_val = extract_sequence_feats(r2_val)


with open("pickled_data/r2_val.pkl","wb") as wp:
    pickle.dump(r2_val,wp)

r2_test = extract_sequence_feats(r2_test)


with open("pickled_data/r2_test.pkl","wb") as wp:
    pickle.dump(r2_test,wp)

In [None]:
r3_train = extract_sequence_feats(r3_train) 

with open("pickled_data/r3_train.pkl","wb") as wp:
    pickle.dump(r3_train,wp)
    
r3_val = extract_sequence_feats(r3_val)

with open("pickled_data/r3_val.pkl","wb") as wp:
    pickle.dump(r3_val,wp)
    
r3_test = extract_sequence_feats(r3_test)

with open("pickled_data/r3_test.pkl","wb") as wp:
    pickle.dump(r3_test,wp)

In [None]:
r4_train = extract_sequence_feats(r4_train) 

with open("pickled_data/r4_train.pkl","wb") as wp:
    pickle.dump(r4_train,wp)
    
r4_val = extract_sequence_feats(r4_val)

with open("pickled_data/r4_val.pkl","wb") as wp:
    pickle.dump(r4_val,wp)
    
r4_test = extract_sequence_feats(r4_test)
    
with open("pickled_data/r4_test.pkl","wb") as wp:
    pickle.dump(r4_test,wp)