In [None]:
#!/usr/bin/env python3
import csv
import datetime
import gensim
from gensim import corpora
from gensim.models import TfidfModel
from gensim.parsing.preprocessing import remove_stopwords
from gensim.utils import simple_preprocess
from gensim.models import Word2Vec
from gensim.models.doc2vec import Doc2Vec, TaggedDocument
from gensim.parsing.preprocessing import remove_stopwords
import json
import logging
import math 
import nltk
from nltk.tokenize import word_tokenize
from nltk.corpus import stopwords
import numpy as np
import pandas as pd
import random
import re
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import cross_val_score, cross_validate
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, classification_report, confusion_matrix
from sklearn.neural_network import MLPClassifier
from sklearn import svm
from sklearn.svm import SVC
from sklearn.gaussian_process.kernels import RBF
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import RandomForestClassifier, AdaBoostClassifier
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.naive_bayes import GaussianNB
from sklearn.model_selection import KFold, StratifiedKFold, StratifiedShuffleSplit
from sklearn.metrics import precision_recall_fscore_support, accuracy_score, classification_report, confusion_matrix
from sklearn.feature_extraction.text import TfidfVectorizer
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, TensorDataset, random_split,SequentialSampler, IterableDataset
import transformers
from transformers import BertTokenizer, BertModel, BertForSequenceClassification, AdamW, BertConfig, get_linear_schedule_with_warmup
import warnings

nltk.download('words')
nltk.download('punkt')
stopwords.words('english')

logging.basicConfig(level=logging.ERROR)
warnings.filterwarnings('ignore')

#needs to be GLOBAL
words = set(nltk.corpus.words.words())
stop_words = set(stopwords.words('english'))

device = 'cuda' if torch.cuda.is_available() else 'cpu'
torch.cuda.set_device(1)
torch.cuda.current_device()

# Specify loss function
loss_fn = nn.CrossEntropyLoss()

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        
def flat_accuracy(preds, labels):
    pred_flat = np.argmax(preds, axis=1).flatten()
    labels_flat = labels.flatten()
    return np.sum(pred_flat == labels_flat) / len(labels_flat)

def format_time(elapsed):
    # Round to the nearest second.
    elapsed_rounded = int(round((elapsed)))    
    # Format as hh:mm:ss
    return str(datetime.timedelta(seconds=elapsed_rounded))

def remove_digit(text):
    return re.sub(r'\d+', '', text)

def remove_non_english(text):
    text = [w for w in nltk.wordpunct_tokenize(text) if w in text or not w.isalpha()]
    return ' '.join(text)

def remove_special_chars(text):
    return re.sub("(\\d|\\W)+"," ",text)    

def remove_shortwords(text):
    tokens = word_tokenize(text)
    text = [i for i in tokens if len(i) > 2]
    return ' '.join(text)

def remove_nonUTF8(data):
    return bytes(data, 'utf-8').decode('utf-8', 'ignore')

def preprocess(df):
    df['sentence'] = df['sentence'].str.replace(r'http(\S)+', r'')
    df['sentence'] = df['sentence'].str.replace(r'http(\S)+', r'')
    df['sentence'] = df['sentence'].str.replace(r'http ...', r'')
    df['sentence'] = df['sentence'].str.replace(r'(RT|rt)[ ]*@[ ]*[\S]+',r'')
    df['sentence'] = df['sentence'].str.replace(r'@[\S]+',r'')
    df['sentence'] = df['sentence'].str.replace(r'_[\S]?',r'')
    df['sentence'] = df['sentence'].str.replace(r'[ ]{2, }',r' ')
    df['sentence'] = df['sentence'].str.replace(r'&amp;?',r'and')
    df['sentence'] = df['sentence'].str.replace(r'&lt;',r'<')
    df['sentence'] = df['sentence'].str.replace(r'&gt;',r'>')
    df['sentence'] = df['sentence'].str.replace(r'([\w\d]+)([^\w\d ]+)', r'\1 \2')
    df['sentence'] = df['sentence'].str.replace(r'([^\w\d ]+)([\w\d]+)', r'\1 \2')
    df['sentence'] = df['sentence'].str.lower()
    df['sentence'] = df['sentence'].str.strip()
    df['sentence'] = df['sentence'].apply(remove_stopwords)
    df['sentence'] = df['sentence'].apply(remove_digit)
    df['sentence'] = df['sentence'].apply(remove_non_english)
    df['sentence'] = df['sentence'].apply(remove_special_chars)
    df['sentence'] = df['sentence'].apply(remove_nonUTF8)
    df['sentence'] = df['sentence'].str.replace("\'", "")
    df['sentence'] = df['sentence'].str.replace("\"", "")
    df['sentence'] = df['sentence'].apply(remove_shortwords)
    return df

def bow_features(fullsetdf):
    df = preprocess(fullsetdf)
    #df.drop(df.columns[[0]], axis=1, inplace=True)
    # Tokenize the text column to get the new column 'tokenized_text'
    df['tokenized_text'] = [simple_preprocess(line, deacc=True) for line in df['sentence']]
    vectorizer = TfidfVectorizer(analyzer = 'word', strip_accents= 'ascii',smooth_idf = True, use_idf=True,max_df = 10000, min_df = 5,  stop_words = 'english')
    X = vectorizer.fit_transform(df['sentence'])
    boW_features = pd.DataFrame(X.toarray(),columns=vectorizer.get_feature_names())
    return boW_features

def initialize_model(dataloader, epochs=10):
    model  = customBertBinaryClassifier()
    model.cuda()
    optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
    #optimizer = AdamW(model.parameters(),lr=0.0005,betas=(0.9, 0.999), eps=1e-8,weight_decay=0.01)
    #optimizer = AdamW(params, lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.01, amsgrad=False)
    total_steps = len(dataloader) * epochs
    scheduler = get_linear_schedule_with_warmup(optimizer,num_warmup_steps=10000,num_training_steps=total_steps)
    return model, optimizer, scheduler

def remove_special_chars(text):
    return re.sub("(\\d|\\W)+"," ",text)

def remove_punctuation(text):
    return text.translate(str.maketrans('', '', string.punctuation))

def remove_shortwords(text):
    tokens = word_tokenize(text)
    text = [i for i in tokens if len(i) > 2]
    return ' '.join(text)

def remove_nonUTF8(data):
    return bytes(data, 'utf-8').decode('utf-8', 'ignore')

def remove_stop_words(text):
    #stop_words = set(stopwords.words('english'))
    tokens = word_tokenize(text)
    text = [i for i in tokens if not i in stop_words]
    return ' '.join(text)

def stem_words(text):
    ps = PorterStemmer() 
    words = word_tokenize(text)
    text = [ps.stem(w) for w in words]    
    return ' '.join(text)

def remove_digit(text):
    return re.sub(r'\d+', '', text)

def remove_non_english(text):
    text = [w for w in nltk.wordpunct_tokenize(text) if w in text or not w.isalpha()]
    return ' '.join(text)

def train_and_evalCV(model, dataset, cv, epochs, batch_size):
    training_stats = []
    error_stats = []
    stats = []
    cm = []
    total_t0 = time.time()
    train_dataloader = DataLoader(dataset,sampler = SequentialSampler(dataset), batch_size = batch_size )
    kf = KFold(n_splits=cv)
    for train_index, test_index in kf.split(dataset): # For each fold...
        train_dataset = torch.utils.data.Subset(dataset,train_index)
        val_dataset  = torch.utils.data.Subset(dataset, test_index)
        train_dataloader = DataLoader(train_dataset, sampler = SequentialSampler(train_dataset),batch_size = batch_size)
        val_dataloader = DataLoader(val_dataset, sampler = SequentialSampler(val_dataset),batch_size = batch_size)    
        model, optimizer, scheduler = initialize_model(train_dataloader,epochs)
        model.cuda()
        for epoch_i in range(0, epochs):  # For each epoch...
            #print('======== Epoch {:} / {:} ========'.format(epoch_i + 1, epochs))
            t0 = time.time()
            total_train_loss = 0
            model.train()
            for step, batch in enumerate(train_dataloader):
                b_input_ids, b_attn_mask, hand_features, b_labels= tuple(t.to(device) for t in batch)
                model.zero_grad()
                logits = model(b_input_ids, b_attn_mask,hand_features)
                loss = loss_fn(logits, b_labels)
                total_train_loss += loss.item()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()
                scheduler.step()
                optimizer.zero_grad()
            avg_train_loss = total_train_loss / len(train_dataloader)
            training_time = format_time(time.time() - t0)
            training_stats.append({'epoch': epoch_i + 1,'Training Loss': avg_train_loss,'Training Time': training_time})
            print({'epoch': epoch_i + 1,'Training Loss': avg_train_loss,'Training Time': training_time})
#    sss = StratifiedShuffleSplit(n_splits=cv, test_size=0.5, random_state=0)
#    skf = StratifiedKFold(n_splits=cv)
        predictions , true_labels = [], []
        model.eval()
        total_val_loss = 0
        for step, batch in enumerate(val_dataloader):
            b_input_ids, b_attn_mask, hand_features, b_labels = tuple(t.to(device) for t in batch)
            with torch.no_grad():
                logits = model(b_input_ids, b_attn_mask, hand_features)
            loss = loss_fn(logits, b_labels)
            total_val_loss += loss.item()
            logits = logits.to('cpu').numpy()
            label_ids = b_labels.to('cpu').numpy()
            pred_flat =  np.argmax(logits, axis=1).flatten()
            labels_flat = label_ids.flatten()
            true_labels = np.concatenate((true_labels,labels_flat))
            predictions = np.concatenate((predictions,pred_flat))
        error_stats.append({'test_index':test_index,'true_labels':true_labels ,'predictions': predictions})
        stats.append(precision_recall_fscore_support(true_labels, predictions, average='macro'))
        avg_val_loss = total_val_loss / len(val_dataloader)
        print({'Validation Loss': avg_val_loss})
        print(precision_recall_fscore_support(true_labels, predictions, average='macro'))
        print(classification_report(true_labels,predictions))
    print()
    print(stats)
    print()
    aggP,aggR,aggF = 0.0,0.0, 0.0
    listP = []
    listR = []
    listF = []
    for scores in stats:
        aggP+=scores[0]
        listP.append(scores[0])
        aggR+=scores[1]
        listR.append(scores[1])
        aggF+=scores[2]
        listF.append(scores[2])
    avgP = aggP/len(stats)
    avgR = aggR/len(stats)
    avgF = aggF/len(stats)
    print("P: {:.3f}, R: {:.3f}, F: {:.3f} (+/- {:.3f}) ".format(np.mean(listP), np.mean(listR), np.mean(listF),np.std(listF)*2/100.0))
    #matrix = np.mean(cm, axis=0,dtype=int)
    #plt.figure(figsize=(16,14))     
    #sns.heatmap(matrix, xticklabels=[0,1], yticklabels=[0,1], annot=True)
    #plt.title("CONFUSION MATRIX : ")
    #plt.ylabel('True Label')
    #plt.xlabel('Predicted label')
    #plt.savefig('/home/joao/cmatrix.png')
    #plt.show()
    #print()
    #with open('/home/joao/error_stats.txt', 'w') as file:
    #    file.write(pickle.dumps(error_stats))
    
class FullDataset():
    def __init__(self,filename,name):
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)
        if name == 'crisismmd':
            self.df = pd.read_csv(filename,delimiter='\t',quoting=csv.QUOTE_NONE,error_bad_lines=False,encoding='utf-8',lineterminator="\n")
            #self.df = pd.read_csv('/home/joao/crisismmd.ORG.tsv',delimiter='\t',encoding='utf-8',lineterminator="\n")
            #self.df = pd.read_csv('/home/joao/crisismmd.subset.tsv',delimiter='\t',encoding='utf-8',lineterminator="\n")
        if name == 'covid':
            self.df = pd.read_csv(filename,delimiter='\t',quoting=csv.QUOTE_NONE,error_bad_lines=False,encoding='utf-8',lineterminator="\n")
            #self.df = pd.read_csv('/home/joao/covid.ORG.tsv',delimiter='\t',encoding='utf-8',lineterminator="\n")
            #self.df = pd.read_csv('/home/joao/covid.subset.tsv',delimiter='\t',encoding='utf-8',lineterminator="\n")
        if name ==  'crisislext6':
            self.df = pd.read_csv(filename,delimiter='\t',encoding='utf-8',lineterminator="\n")
            #self.df = pd.read_csv('/home/joao/crisislext6.ORG.tsv',delimiter='\t',encoding='utf-8',lineterminator="\n")
            #self.df = pd.read_csv('/home/joao/crisislext6.subset.tsv',delimiter='\t',encoding='utf-8',lineterminator="\n")
        if name == 'crisislext26':
            self.df = pd.read_csv(filename,delimiter='\t',encoding='utf-8',lineterminator="\n")
            #self.df = pd.read_csv('/home/joao/crisislext26.ORG.tsv',delimiter='\t',encoding='utf-8',lineterminator="\n")
            #self.df = pd.read_csv('/home/joao/crisislext26.subset.tsv',delimiter='\t',encoding='utf-8',lineterminator="\n")
        #self.df = self.df[['tweet_id','sentence','label']]
        self.sentences = self.df['sentence']
        #'Unnamed: 0'
        self.df.drop(self.df.columns[[0]], axis=1, inplace=True)
        self.labels = self.df['label'].values
        self.df['sentence'] = self.df['sentence'].str.replace(r'http(\S)+', r'')
        self.df['sentence'] = self.df['sentence'].str.replace(r'http(\S)+', r'')
        self.df['sentence'] = self.df['sentence'].str.replace(r'http ...', r'')
        self.df['sentence'] = self.df['sentence'].str.replace(r'(RT|rt)[ ]*@[ ]*[\S]+',r'')
        self.df['sentence'] = self.df['sentence'].str.replace(r'@[\S]+',r'')
        self.df['sentence'] = self.df['sentence'].str.replace(r'_[\S]?',r'')
        self.df['sentence'] = self.df['sentence'].str.replace(r'[ ]{2, }',r' ')
        self.df['sentence'] = self.df['sentence'].str.replace(r'&amp;?',r'and')
        self.df['sentence'] = self.df['sentence'].str.replace(r'&lt;',r'<')
        self.df['sentence'] = self.df['sentence'].str.replace(r'&gt;',r'>')
        self.df['sentence'] = self.df['sentence'].str.replace(r'([\w\d]+)([^\w\d ]+)', r'\1 \2')
        self.df['sentence'] = self.df['sentence'].str.replace(r'([^\w\d ]+)([\w\d]+)', r'\1 \2')
        self.df['sentence'] = self.df['sentence'].str.lower()
        self.df['sentence'] = self.df['sentence'].str.strip()
        self.df['sentence'] = self.df['sentence'].apply(remove_special_chars)
        self.df['sentence'] = self.df['sentence'].apply(remove_digit)
        self.df['sentence'] = self.df['sentence'].apply(remove_stopwords)
        self.df['sentence'] = self.df['sentence'].str.replace("\'", "")
        self.df['sentence'] = self.df['sentence'].str.replace("\"", "")
        self.df['sentence'] = self.df['sentence'].apply(remove_non_english)
        self.df['sentence'] = self.df['sentence'].apply(remove_nonUTF8)
        self.df['sentence'] = self.df['sentence'].apply(remove_shortwords)
        self.hand_crafted_features = self.df[['nchars', 'nwords','bhash','nhash','blink','nlink','bat','nat','brt','bslang','bintj','tlex']]
        self.hand_crafted_features_DF = pd.DataFrame(self.hand_crafted_features, columns = ['nchars', 'nwords','bhash','nhash','blink','nlink','bat','nat','brt','bslang','bintj','tlex']).astype(float)
        self.maxlen = 80
        #if name == 'covid':
        #    self.maxlen = 80
        #else:
        #    for sent in self.sentences:
        #        input_ids = self.tokenizer.encode(sent, add_special_tokens=True)
        #        self.maxlen = max(self.maxlen, len(input_ids))
    def __len__(self):
        return len(self.df)
    def __getitem__(self, idx):
        sentence = self.df.loc[idx, 'sentence']
        label = self.df.loc[idx, 'label']
        h_features = self.hand_crafted_features_DF.loc[idx,:]
        h_tensor = torch.tensor(h_features).to(device)
        tokens = self.tokenizer.tokenize(sentence)
        if len(tokens) == 0:
            tokens = ['']
        encoded_dict = self.tokenizer.encode_plus(tokens, add_special_tokens = True, max_length = self.maxlen, pad_to_max_length = True,return_attention_mask = True)
        tokens_ids = encoded_dict['input_ids']
        tokens_ids_tensor = torch.tensor(tokens_ids).to(device) #Converting the list to a pytorch tensor
        attn_mask = encoded_dict['attention_mask']
        attn_mask_tensor = torch.tensor(attn_mask).to(device)
        label_tensor = torch.tensor(label).to(device)
        return tokens_ids_tensor,attn_mask_tensor,h_tensor,label_tensor

class customBertBinaryClassifier(nn.Module):
    def __init__(self, B_in = 768, B_out = 768, H_in = 12, H_out = 12):
        super(customBertBinaryClassifier, self).__init__()
        self.bert_layer = BertModel.from_pretrained('bert-base-uncased')
        C_in = B_out + H_out 
        C_out = 2
        self.linear_B = nn.Linear(B_in, B_out)
        self.linear_H = nn.Linear(H_in, H_out)
        self.drop = nn.Dropout(p=0.5)
        #self.norm = nn.LayerNorm(H_in,eps = 1e-12, elementwise_affine = True)
        self.classifier = nn.Linear(C_in, C_out)
    def forward(self, seq, attn_masks, hand_features):
        outputs = self.bert_layer(input_ids=seq,attention_mask=attn_masks)
        last_hidden_state_cls = outputs[0][:, 0, :]
        logits_B = self.drop(last_hidden_state_cls.float())
        logits_B = self.linear_B(logits_B.float())
        #logits_H = self.norm(hand_features.float())
        logits_H = self.linear_H(hand_features.float())
        cat_features = torch.cat([logits_B.float(),logits_H.float()], dim=1)
        #feat = self.drop(cat_features.float())
        output = self.classifier(cat_features.float())
        return output

   
def main():
    datasets = ['covid', 'crisislext6', 'crisislext26', 'crisismmd']
    #datasets = ['covid']
    for data in datasets :
        print("=== {} ===".format(data))
        if data == 'covid':
            dataset = FullDataset('/home/joao/covid.ORG.tsv','covid')
            dataset = FullDataset('/home/joao/covid.subset.tsv','covid')
            #fullsetdf = pd.read_csv('/home/joao/covid.ORG.tsv',delimiter='\t',encoding='utf-8',lineterminator="\n")
            #subsetdf = pd.read_csv('/home/joao/covid.subset.tsv',delimiter='\t',encoding='utf-8',lineterminator="\n")
        if data ==  'crisislext6':
            dataset = FullDataset('/home/joao/crisislext6.ORG.tsv','crisislext6')
            dataset = FullDataset('/home/joao/crisislext6.subset.tsv','crisislext6')
            #fullsetdf = pd.read_csv('/home/joao/crisislext6.ORG.tsv',delimiter='\t',encoding='utf-8',lineterminator="\n")
            #subsetdf = pd.read_csv('/home/joao/crisislext6.subset.tsv',delimiter='\t',encoding='utf-8',lineterminator="\n")
        if data == 'crisislext26':
            #dataset = FullDataset('/home/joao/crisislext26.ORG.tsv','crisislext26')
            dataset = FullDataset('/home/joao/crisislext26.ORG.tsv','crisislext26')
            dataset = FullDataset('/home/joao/crisislext26.subset.tsv','crisislext26')
            #fullsetdf = pd.read_csv('/home/joao/crisislext26.ORG.tsv',delimiter='\t',encoding='utf-8',lineterminator="\n")
            #subsetdf = pd.read_csv('/home/joao/crisislext26.subset.tsv',delimiter='\t',encoding='utf-8',lineterminator="\n")
        if data == 'crisismmd':
            dataset = FullDataset('/home/joao/crisismmd.ORG.tsv','crisismmd')
            dataset = FullDataset('/home/joao/crisismmd.subset.tsv','crisismmd')
            #fullsetdf = pd.read_csv('/home/joao/crisismmd.ORG.tsv',delimiter='\t',quoting=csv.QUOTE_NONE,error_bad_lines=False, encoding='utf-8',lineterminator="\n")
            #subsetdf = pd.read_csv('/home/joao/crisismmd.subset.tsv',delimiter='\t',quoting=csv.QUOTE_NONE,error_bad_lines=False,encoding='utf-8',lineterminator="\n")
batch_size = 16
n_epochs= 20 
cv = 10
#complete data set
dataloader = DataLoader(dataset,sampler = SequentialSampler(dataset), batch_size = batch_size )
set_seed(42)    # Set seed for reproducibility
model, optimizer, scheduler = initialize_model(dataloader,epochs=batch_size)
train_and_evalCV(model, dataset, cv=cv, epochs=n_epochs, batch_size = batch_size)
###

main()

In [None]:
###
### SUBSET containing USER based features
###
class SubsetDataset():
    def __init__(self,filename,name):
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)
        if name == 'crisismmd':
            self.df = pd.read_csv(filename,delimiter='\t',quoting=csv.QUOTE_NONE,error_bad_lines=False,encoding='utf-8',lineterminator="\n")
        if name == 'covid':
            self.df = pd.read_csv(filename,delimiter='\t',quoting=csv.QUOTE_NONE,error_bad_lines=False,encoding='utf-8',lineterminator="\n")
        if name ==  'crisislext6':
            self.df = pd.read_csv(filename,delimiter='\t',encoding='utf-8',lineterminator="\n")
        if name == 'crisislext26':
            self.df = pd.read_csv(filename,delimiter='\t',encoding='utf-8',lineterminator="\n")
        self.sentences = self.df['sentence']
        #'Unnamed: 0'
        self.df.drop(self.df.columns[[0]], axis=1, inplace=True)
        self.labels = self.df['label'].values
        self.df['sentence'] = self.df['sentence'].str.replace(r'http(\S)+', r'')
        self.df['sentence'] = self.df['sentence'].str.replace(r'http(\S)+', r'')
        self.df['sentence'] = self.df['sentence'].str.replace(r'http ...', r'')
        self.df['sentence'] = self.df['sentence'].str.replace(r'(RT|rt)[ ]*@[ ]*[\S]+',r'')
        self.df['sentence'] = self.df['sentence'].str.replace(r'@[\S]+',r'')
        self.df['sentence'] = self.df['sentence'].str.replace(r'_[\S]?',r'')
        self.df['sentence'] = self.df['sentence'].str.replace(r'[ ]{2, }',r' ')
        self.df['sentence'] = self.df['sentence'].str.replace(r'&amp;?',r'and')
        self.df['sentence'] = self.df['sentence'].str.replace(r'&lt;',r'<')
        self.df['sentence'] = self.df['sentence'].str.replace(r'&gt;',r'>')
        self.df['sentence'] = self.df['sentence'].str.replace(r'([\w\d]+)([^\w\d ]+)', r'\1 \2')
        self.df['sentence'] = self.df['sentence'].str.replace(r'([^\w\d ]+)([\w\d]+)', r'\1 \2')
        self.df['sentence'] = self.df['sentence'].str.lower()
        self.df['sentence'] = self.df['sentence'].str.strip()
        self.df['sentence'] = self.df['sentence'].apply(remove_special_chars)
        self.df['sentence'] = self.df['sentence'].apply(remove_digit)
        self.df['sentence'] = self.df['sentence'].apply(remove_stopwords)
        self.df['sentence'] = self.df['sentence'].str.replace("\'", "")
        self.df['sentence'] = self.df['sentence'].str.replace("\"", "")
        self.df['sentence'] = self.df['sentence'].apply(remove_non_english)
        self.df['sentence'] = self.df['sentence'].apply(remove_nonUTF8)
        self.df['sentence'] = self.df['sentence'].apply(remove_shortwords)
        self.hand_crafted_features = self.df[['nchars', 'nwords','bhash','nhash','blink','nlink','bat','nat','brt','bslang','bintj','tlex','usr_vrf','num_followers','num_friends','num_tweets']]
        self.hand_crafted_features_DF = pd.DataFrame(self.hand_crafted_features)
        self.maxlen = 0
        if name == 'covid':
            self.maxlen = 256
        else:
            for sent in self.sentences:
                input_ids = self.tokenizer.encode(sent, add_special_tokens=True)
                self.maxlen = max(self.maxlen, len(input_ids))
    def __len__(self):
        return len(self.df)
    def __getitem__(self, idx):
        sentence = self.df.loc[idx, 'sentence']
        label = self.df.loc[idx, 'label']
        h_features = self.hand_crafted_features_DF.loc[idx,:]
        h_tensor = torch.tensor(h_features).to(device)
        tokens = self.tokenizer.tokenize(sentence)
        if len(tokens) == 0:
            tokens = ['']
        encoded_dict = self.tokenizer.encode_plus(tokens, add_special_tokens = True, max_length = self.maxlen, pad_to_max_length = True,return_attention_mask = True)
        tokens_ids = encoded_dict['input_ids']
        tokens_ids_tensor = torch.tensor(tokens_ids).to(device)
        attn_mask = encoded_dict['attention_mask']
        attn_mask_tensor = torch.tensor(attn_mask).to(device)
        label_tensor = torch.tensor(label).to(device)
        return tokens_ids_tensor,attn_mask_tensor,h_tensor,label_tensor
    
class customBertBinaryClassifier(nn.Module):
    def __init__(self, B_in = 768, B_out = 768, H_in = 16, H_out = 16):
        super(customBertBinaryClassifier, self).__init__()
        self.bert_layer = BertModel.from_pretrained('bert-base-uncased')
        C_in = B_out + H_out 
        C_out = 2
        self.linear_B = nn.Linear(B_in, B_out)
        self.linear_H = nn.Linear(H_in, H_out)
        self.drop = nn.Dropout(p=0.1)
        self.norm = nn.LayerNorm(H_in,eps = 1e-12, elementwise_affine = True)
        self.classifier = nn.Linear(C_in, C_out)
    def forward(self, seq, attn_masks, hand_features):
        outputs = self.bert_layer(input_ids=seq,attention_mask=attn_masks)
        last_hidden_state_cls = outputs[0][:, 0, :]
        logits_B = self.drop(last_hidden_state_cls.float())
        logits_B = self.linear_B(logits_B)
        logits_H = self.norm(hand_features.float())
        logits_H = self.linear_H(logits_H)
        cat_features = torch.cat([logits_B.float(),logits_H.float()], dim=1)
        #feat = self.drop(cat_features.float())
        output = self.classifier(cat_features.float())
        return output
    
def main():
    datasets = ['covid', 'crisislext6', 'crisislext26', 'crisismmd']
    #datasets = ['covid']
    for data in datasets :
        print("=== {} ===".format(data))
        if data == 'covid':
            dataset = SubsetDataset('/home/joao/covid.subset.tsv','covid')
            #fullsetdf = pd.read_csv('/home/joao/covid.ORG.tsv',delimiter='\t',encoding='utf-8',lineterminator="\n")
            #subsetdf = pd.read_csv('/home/joao/covid.subset.tsv',delimiter='\t',encoding='utf-8',lineterminator="\n")
        if data ==  'crisislext6':
            dataset = SubsetDataset('/home/joao/crisislext6.subset.tsv','crisislext6')
            #fullsetdf = pd.read_csv('/home/joao/crisislext6.ORG.tsv',delimiter='\t',encoding='utf-8',lineterminator="\n")
            #subsetdf = pd.read_csv('/home/joao/crisislext6.subset.tsv',delimiter='\t',encoding='utf-8',lineterminator="\n")
        if data == 'crisislext26':
            dataset = SubsetDataset('/home/joao/crisislext26.subset.tsv','crisislext26')
            #fullsetdf = pd.read_csv('/home/joao/crisislext26.ORG.tsv',delimiter='\t',encoding='utf-8',lineterminator="\n")
            #subsetdf = pd.read_csv('/home/joao/crisislext26.subset.tsv',delimiter='\t',encoding='utf-8',lineterminator="\n")
        if data == 'crisismmd':
            dataset = SubsetDataset('/home/joao/crisismmd.subset.tsv','crisismmd')
            #fullsetdf = pd.read_csv('/home/joao/crisismmd.ORG.tsv',delimiter='\t',quoting=csv.QUOTE_NONE,error_bad_lines=False, encoding='utf-8',lineterminator="\n")
            #subsetdf = pd.read_csv('/home/joao/crisismmd.subset.tsv',delimiter='\t',quoting=csv.QUOTE_NONE,error_bad_lines=False,encoding='utf-8',lineterminator="\n")
        batch_size = 16
        n_epochs=10
        cv =10
        #complete data set
        dataloader = DataLoader(dataset,sampler = SequentialSampler(dataset), batch_size = batch_size )
        set_seed(42)    # Set seed for reproducibility
        model, optimizer, scheduler = initialize_model(dataloader,epochs=batch_size)
        train_and_evalCV(model, dataset, cv=cv, epochs=n_epochs, batch_size = batch_size)
main()