In [None]:
#!/usr/bin/env python3
import math 
import numpy as np
import pandas as pd
import random
import time
import datetime
import matplotlib.pyplot as plt
import json
import re
import tensorflow as tf
import torch
import torch.nn as nn
import torch.nn.functional as F 
import torch.optim as optim
from transformers import BertModel
from torch.utils.data import DataLoader, Dataset, TensorDataset,random_split, RandomSampler, SequentialSampler
from transformers import BertTokenizer, BertForSequenceClassification, AdamW, BertConfig
from transformers import DistilBertModel, DistilBertTokenizer
from transformers import get_linear_schedule_with_warmup
from sklearn.metrics import classification_report
from torch.utils.data import IterableDataset
from sklearn.model_selection import cross_val_score, train_test_split, GridSearchCV

device = 'cuda' if torch.cuda.is_available() else 'cpu'
torch.cuda.set_device(1)
torch.cuda.current_device()

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

def flat_accuracy(preds, labels):
    pred_flat = np.argmax(preds, axis=1).flatten()
    labels_flat = labels.flatten()
    return np.sum(pred_flat == labels_flat) / len(labels_flat)

def format_time(elapsed):
    # Round to the nearest second.
    elapsed_rounded = int(round((elapsed)))    
    # Format as hh:mm:ss
    return str(datetime.timedelta(seconds=elapsed_rounded))

def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='binary')
    acc = accuracy_score(labels, preds)
    return {
        'accuracy': acc,
        'f1': f1,
        'precision': precision,
        'recall': recall
    }

def report_average(report_list):
    r_list = list()
    for report in report_list:
        splited = [' '.join(x.split()) for x in report.split('\n\n')]
        header = [x for x in splited[0].split(' ')]
        data = np.array(splited[1].split(' ')).reshape(-1, len(header) + 1)
        data = np.delete(data, 0, 1).astype(float)
        df = pd.DataFrame(data, columns=header)
        r_list.append(df)
    tmp = pd.DataFrame()
    for df in r_list:
        tmp = tmp.add(df, fill_value=0)           
    report_ave =  tmp/len(r_list)
    return(report_ave)


class CustomDataset(Dataset):
    def __init__(self,filename,name):
        #self.tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-cased')
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)
        #self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)
        if name == 'covid':
            self.df = pd.read_csv(filename,delimiter='\t',encoding='utf-8')
            #self.train_df = pd.read_csv("/home/joao/COVID19Tweet-master/train.tsv",delimiter='\t',encoding='utf-8')  
            #self.val_df = pd.read_csv("/home/joao/COVID19Tweet-master/valid.tsv",delimiter='\t',encoding='utf-8')   
            self.df = self.df.rename(columns={'Text': 'sentence'})
            self.df = self.df.rename(columns={'Label': 'label'})
            self.df['label'].replace('INFORMATIVE', 1)
            self.df['label'] = self.df['label'].replace('INFORMATIVE', 1)
            self.df['label'].replace('UNINFORMATIVE', 0)
            self.df['label'] = self.df['label'].replace('UNINFORMATIVE', 0)
        if name == 'crisislext26':
            self.usr_features = pd.DataFrame(columns=['tweet_id','usr_id','usr_vrf','num_followers','num_friends','num_tweets'])
            self.listOfSeries = []
            with open("/home/joao/crisisLexT26.json") as json_file:
                for line in json_file:
                    data = json.loads(line)
                    tweet_id  = data['id_str']
                    user_obj = data['user']
                    usr_id = user_obj['id_str']
                    usr_vrf = int(user_obj['verified'])#.astype(int)
                    num_followers = user_obj['followers_count']
                    num_followers = math.log10(1 + num_followers)
                    num_friends = user_obj['friends_count']
                    num_friends = math.log10(1 + num_friends)
                    num_tweets = user_obj['statuses_count']
                    num_tweets = math.log10(1 + num_tweets)
                    self.listOfSeries.append(pd.Series([int(tweet_id), usr_id, usr_vrf, num_followers, num_friends, num_tweets ], index=self.usr_features.columns ))
            self.usr_features = self.usr_features.append(self.listOfSeries , ignore_index=True)
            self.usr_features = self.usr_features.astype(int)
            self.df = pd.read_csv("/home/joao/crisisLexT26.csv", encoding='utf-8')
            self.df = self.df.rename(columns={'Tweet ID': 'tweet_id'})
            self.df = self.df.drop([' Information Source', ' Information Type' ], axis=1)
            #Relabelling the columns titles to remove white spaces
            self.df = self.df.rename(columns={' Tweet Text': 'sentence'})
            self.df = self.df.rename(columns={' Informativeness': 'label'})
            self.fd = pd.DataFrame(columns=list(['sentence', 'usr_vrf', 'num_followers', 'num_friends', 'num_tweets', 'label']))
            series = []
            for index, row in self.usr_features.iterrows():
                tweet_id = row['tweet_id']
                usr_vrf, num_followers, num_friends, num_tweets = row['usr_vrf'],row['num_followers'],row['num_friends'],row['num_tweets']
                idx = self.df[self.df['tweet_id']==tweet_id].index.values.astype(int)[0]
                tweet_id, sentence, label = self.df.loc[idx]
                series.append(pd.Series([ sentence , usr_vrf, num_followers, num_friends, num_tweets, label ],index=list(['sentence', 'usr_vrf', 'num_followers', 'num_friends', 'num_tweets', 'label'])))                
            self.fd = self.fd.append(series,ignore_index = True )
            self.df = self.fd
            self.df = self.df.reset_index(drop=True)                             
            self.df = self.df[self.df.label!= 'Not related']
            self.df = self.df[self.df.label!= 'Not applicable']
            self.df['label'].replace('Related and informative', 1)
            self.df['label'] = self.df['label'].replace('Related and informative', 1)
            self.df['label'].replace('Related - but not informative', 0)
            self.df['label'] = self.df['label'].replace('Related - but not informative', 0)
            self.df = self.df.reset_index(drop=True)                       
        if name == 'crisislext6':
            self.df = pd.read_csv("/home/joao/crisisLexT6.csv", encoding='utf-8')
            self.df = self.df.rename(columns={' tweet': 'sentence'})
            self.df = self.df.rename(columns={' label': 'label'})
            self.df['label'].replace('on-topic', 1)
            self.df['label'] = self.df['label'].replace('on-topic', 1)
            self.df['label'].replace('off-topic', 0)
            self.df['label'] = self.df['label'].replace('off-topic', 0)
        self.df = self.df[['sentence','label']]
        self.df['nchars'] = np.log10(1 + self.df['sentence'].str.len())
        self.df['nwords'] = np.log10(1 + self.df['sentence'].str.split().str.len())
        self.df['bhash'] = self.df["sentence"].str.contains(pat = '#',flags=re.IGNORECASE, regex = True).astype(int) 
        self.df['nhash'] = self.df["sentence"].str.count('#') 
        self.df['blink']  = self.df["sentence"].str.contains(pat = 'http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\(\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+', flags=re.IGNORECASE, regex = True) .astype(int)
        self.df['nlink'] = self.df["sentence"].str.count(pat = 'http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\(\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+', flags=re.IGNORECASE) 
        self.df['bat'] = self.df["sentence"].str.contains(pat = '@',flags=re.IGNORECASE, regex = True).astype(int) 
        self.df['nat'] = self.df["sentence"].str.count(pat = '@') 
        self.df['rt'] = self.df["sentence"].str.contains(pat = '@rt|rt@',flags=re.IGNORECASE, regex = True).astype(int) 
        #df['phone'] = df["sentence"].str.contains(pat = '\(?([0-9]{3})\)?([ .-]?)([0-9]{3})\2([0-9]{4})',flags=re.IGNORECASE, regex = True).astype(int) 
        self.df['dlex'] = self.df["sentence"].apply(self.lexical_diversity)
        self.df["sentence"] = self.df["sentence"].str.lower()
        ## List of  US slangs.
        slangList = ['ASAP','BBIAB','BBL','BBS','BF','BFF','BFFL','BRB','CYA','DS','FAQ','FB','FITBLR','FLBP','FML','FTFY','FTW','FYI','G2G','GF','GR8','GTFO','HBIC','HML','HRU','HTH','IDK','IGHT','IMO','IMHO','IMY','IRL','ISTG','JK','JMHO','KTHX','L8R','LMAO','LMFAO','LMK','LOL','MWF','NM','NOOB','NP','NSFW','OOAK','OFC','OMG','ORLY','OTOH','RN','ROFL','RUH','SFW','SOML','SOZ','STFU','TFTI','TIL','TMI','TTFN','TTYL','TWSS','U','W/','WB','W/O','WYD','WTH','WTF','WYM','WYSIWYG','Y','YMMV','YW','YWA']
        slangList = [x.lower() for x in slangList]
        #happy emojis
        happy_emojis = [':\)', ';\)', '\(:']
        #sad emojis
        sad_emojis = [':\(', ';\(', '\):']
        punctuation = ['.',',','...','?','!',':',';']    
        #','-','+','*','_','=','/','','%',' &','{','}','[',']','(',')','
        #Checks if the sentence contains slang
        mask = self.df.iloc[:, 0].str.contains(r'\b(?:{})\b'.format('|'.join(slangList)))
        df1 = self.df[~mask]
        self.df['slang'] = mask.astype(int) 
        #Checks if the sentence contains happy emojis
        mask = self.df.iloc[:, 0].str.contains(r'\b(?:{})\b'.format('|'.join(happy_emojis)), regex = True)
        df1 = self.df[~mask]
        self.df['hemojis'] = mask.astype(int) 
        #Checks if the sentence contains happy emojis
        mask = self.df.iloc[:, 0].str.contains(r'\b(?:{})\b'.format('|'.join(sad_emojis)), regex = True)
        df1 = self.df[~mask]
        self.df['semojis'] = mask.astype(int) 
        self.hand_features =  self.df[['nchars', 'nwords','bhash','nhash','blink','nlink','bat','nat','rt','slang','dlex']]
        self.hand_features_DF = pd.DataFrame(self.hand_features)
        #################
        self.df['sentence'] = self.df['sentence'].str.replace(r'http(\S)+', r'')
        self.df['sentence'] = self.df['sentence'].str.replace(r'http(\S)+', r'')
        self.df['sentence'] = self.df['sentence'].str.replace(r'http ...', r'')
        self.df['sentence'] = self.df['sentence'].str.replace(r'(RT|rt)[ ]*@[ ]*[\S]+',r'')
        self.df['sentence'] = self.df['sentence'].str.replace(r'@[\S]+',r'')
        self.df['sentence'] = self.df['sentence'].str.replace(r'_[\S]?',r'')
        self.df['sentence'] = self.df['sentence'].str.replace(r'[ ]{2, }',r' ')
        self.df['sentence'] = self.df['sentence'].str.replace(r'&amp;?',r'and')
        self.df['sentence'] = self.df['sentence'].str.replace(r'&lt;',r'<')
        self.df['sentence'] = self.df['sentence'].str.replace(r'&gt;',r'>')
        self.df['sentence'] = self.df['sentence'].str.replace(r'([\w\d]+)([^\w\d ]+)', r'\1 \2')
        self.df['sentence'] = self.df['sentence'].str.replace(r'([^\w\d ]+)([\w\d]+)', r'\1 \2')
        self.df['sentence'] = self.df['sentence'].str.lower()
        self.df['sentence'] = self.df['sentence'].str.strip()
        self.sentences = self.df['sentence']
        self.labels = self.df['label'].values
        self.maxlen = 0
        if name == 'covid':
            self.maxlen = 256
        else:
            for sent in self.sentences:
                input_ids = self.tokenizer.encode(sent, add_special_tokens=True)
                self.maxlen = max(self.maxlen, len(input_ids))
    def __len__(self):
        return len(self.df)
    def __getitem__(self, idx):
        sentence = self.df.loc[idx, 'sentence']
        label = self.df.loc[idx, 'label']
        #h_features = self.hand_features_DF.loc[idx,['nchars', 'nwords','bhash','nhash','blink','nlink','bat','nat','rt','slang','dlex','usr_vrf','num_followers','num_friends','num_tweets']]
        h_features = self.hand_features_DF.loc[idx,['nchars', 'nwords','bhash','nhash','blink','nlink','bat','nat','rt','slang','dlex']]
        h_features_tensor = torch.tensor(h_features).to(device)
        tokens = self.tokenizer.tokenize(sentence)
        encoded_dict = self.tokenizer.encode_plus(tokens, add_special_tokens = True, max_length = self.maxlen, pad_to_max_length = True,return_attention_mask = True)
        tokens_ids = encoded_dict['input_ids']
        tokens_ids_tensor = torch.tensor(tokens_ids).to(device) #Converting the list to a pytorch tensor
        attn_mask = encoded_dict['attention_mask']
        attn_mask_tensor = torch.tensor(attn_mask).to(device)
        label_tensor = torch.tensor(label).to(device)
        return tokens_ids_tensor, attn_mask_tensor, label_tensor,h_features_tensor
    def lexical_diversity(self,text):
        return len(set(text.split())) / len(text.split())

def initialize_model(dataloader, epochs=10):
    bert_classifier = ourBertBinaryClassifier()
    bert_classifier.cuda()
    param_optimizer = list(bert_classifier.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
    {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
    {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
    ]
    optimizer = optim.SGD(optimizer_grouped_parameters, lr=0.001, momentum=0.9)
    #optimizer = AdamW(bert_classifier.parameters(),lr=0.0005,eps=1e-8)
    total_steps = len(dataloader) * epochs
    scheduler = get_linear_schedule_with_warmup(optimizer,num_warmup_steps=10000,num_training_steps=total_steps)
    return bert_classifier, optimizer, scheduler


class ourBertBinaryClassifier(nn.Module):
    def __init__(self, input_size=768, dimension=10, output_size=2, num_layers=1):
        super(ourBertBinaryClassifier, self).__init__()
        self.bert_layer = BertModel.from_pretrained('bert-base-uncased')
        self.hidden_layer_size = dimension
        self.lstm = nn.LSTM(768,64,1)
        self.drop = nn.Dropout(p=0.2)
        self.fc = nn.Linear(64, 2)
    def forward(self, seq, attn_masks, hand_features):
        outputs = self.bert_layer(input_ids=seq,attention_mask=attn_masks)
        last_hidden_state = outputs[0][:, 0, :]
        last_hidden_state = last_hidden_state.unsqueeze(1)
        x_, (h_n, c_n)= self.lstm(last_hidden_state)
        x_ = self.drop(x_)
        x_ = (x_[:, -1, :])
        return x_    

def train(model, train_dataloader, validation_dataloader, epochs=10, evaluation=False):
    training_stats = []
    stats = []
    total_t0 = time.time()
    best_valid_loss = float("Inf")
    for epoch_i in range(0, epochs):
        t0 = time.time()
        total_train_loss = 0
        model.train()
        for step, batch in enumerate(train_dataloader):
            b_input_ids, b_attn_mask, b_labels, h_features = tuple(t.to(device) for t in batch)
            model.zero_grad()
            logits = model(b_input_ids, b_attn_mask, h_features)
            loss = loss_fn(logits, b_labels)
            total_train_loss += loss.item()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()
        avg_train_loss = total_train_loss  / len(train_dataloader)
        training_time = format_time(time.time() - t0)
        print({'epoch': epoch_i + 1,'Training Loss': avg_train_loss,'Training Time': training_time})
        # =======================================
        #               Evaluation
        # =======================================
        if evaluation == True:
            model.eval()
            total_eval_accuracy = 0
            total_eval_fscore = 0
            total_eval_loss = 0
            predictions , true_labels = [], []
            for batch in validation_dataloader:
                b_input_ids, b_attn_mask, b_labels, h_features = tuple(t.to(device) for t in batch)
                with torch.no_grad():
                    logits = model(b_input_ids, b_attn_mask, h_features)
                loss = loss_fn(logits, b_labels)
                total_eval_loss += loss.item()
                logits = logits.detach().cpu().numpy()
                label_ids = b_labels.to('cpu').numpy()
                pred_flat = np.argmax(logits, axis=1).flatten()
                predictions = np.concatenate((predictions,pred_flat), axis=None)
                labels_flat = label_ids.flatten()
                true_labels = np.concatenate((true_labels,labels_flat), axis=None)
            avg_val_loss = total_eval_loss / len(validation_dataloader)
            validation_time = format_time(time.time() - t0)
            print({'epoch': epoch_i + 1,'Validation Loss': avg_val_loss,'Validation Time': validation_time})
            print(classification_report(true_labels,predictions))
            if best_valid_loss >  avg_val_loss:
                best_valid_loss = avg_val_loss
                state_dict = {'model_state_dict': model.state_dict(),'optimizer_state_dict': optimizer.state_dict(),'valid_loss':avg_val_loss}
                torch.save(state_dict, "/home/joao/bert_tuned_lstm.pt")
                print(f'Model saved to ==> /home/joao/bert_tuned_lstm.pt')
    print("Training complete!")



def train_and_evalCV(model, dataset, cv=10, epochs=10, batch_size = 32):
    training_stats = []
    stats = []
    total_t0 = time.time()
    train_dataloader = DataLoader(dataset,sampler = RandomSampler(dataset), batch_size = batch_size )
    kf = KFold(n_splits=cv)
    for train_index, test_index in kf.split(dataset): # For each fold...
        train_dataset = torch.utils.data.Subset(dataset,train_index)
        val_dataset  = torch.utils.data.Subset(dataset, test_index)
        train_dataloader = DataLoader(train_dataset, sampler = SequentialSampler(train_dataset),batch_size = batch_size)
        val_dataloader = DataLoader(val_dataset, sampler = SequentialSampler(val_dataset),batch_size = batch_size)    
        model, optimizer, scheduler = initialize_model(train_dataloader,epochs)
        model.cuda()   
        for epoch_i in range(0, epochs):  # For each epoch...
            #print('======== Epoch {:} / {:} ========'.format(epoch_i + 1, epochs))
            t0 = time.time()
            total_train_loss = 0
            model.train()
            for step, batch in enumerate(train_dataloader):
                b_input_ids, b_attn_mask, b_labels, h_features = tuple(t.to(device) for t in batch)
                model.zero_grad()
                logits = model(b_input_ids, b_attn_mask, h_features)
                loss = loss_fn(logits, b_labels)
                total_train_loss += loss.item()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()
                scheduler.step()
                optimizer.zero_grad()
            avg_train_loss = total_train_loss / len(train_dataloader)
            training_time = format_time(time.time() - t0)
            training_stats.append({'epoch': epoch_i + 1,'Training Loss': avg_train_loss,'Training Time': training_time})
            print({'epoch': epoch_i + 1,'Training Loss': avg_train_loss,'Training Time': training_time})
#    sss = StratifiedShuffleSplit(n_splits=cv, test_size=0.5, random_state=0)
#    skf = StratifiedKFold(n_splits=cv)
        predictions , true_labels = [], []
        model.eval()
        for step, batch in enumerate(val_dataloader):
            b_input_ids, b_attn_mask, b_labels, h_features = tuple(t.to(device) for t in batch)
            with torch.no_grad():
                logits = model(b_input_ids, b_attn_mask, h_features)
            logits = logits.detach().cpu().numpy()
            label_ids = b_labels.to('cpu').numpy()
            pred_flat = np.argmax(logits, axis=1).flatten()
            labels_flat = label_ids.flatten()
            true_labels = np.concatenate((true_labels,labels_flat))
            predictions = np.concatenate((predictions,pred_flat))
        stats.append(precision_recall_fscore_support(true_labels, predictions, average='macro'))
        print(precision_recall_fscore_support(true_labels, predictions, average='macro'))
        print(classification_report(true_labels,predictions))
    #print(stats)
    aggP,aggR,aggF = 0.0,0.0, 0.0
    for scores in stats:
        aggP+=scores[0]
        aggR+=scores[1]
        aggF+=scores[2]
    avgP = aggP/len(stats)
    avgR = aggR/len(stats)
    avgF = aggF/len(stats)
    print("P: {0:.3f}, R: {0:.3f}, F: {0:.3f} ".format(avgP, avgR, avgF))

    
# Saving and Loading Functions
def save_model(path, model, optimizer, valid_loss):
    if save_path == None:
        return
    state_dict = {'model_state_dict': model.state_dict(),
                  'optimizer_state_dict': optimizer.state_dict(),
                  'valid_loss': valid_loss}
    torch.save(state_dict, save_path)
    print(f'Model saved to ==> {save_path}')

def load_model(path, model, optimizer):
    if load_path==None:
        return
    state_dict = torch.load(load_path, map_location=device)
    print(f'Model loaded from <== {load_path}')
    model.load_state_dict(state_dict['model_state_dict'])
    optimizer.load_state_dict(state_dict['optimizer_state_dict'])
    return state_dict['valid_loss']
                 
## Method to plot the loss curve
def plot_loss_curve(df_stats, nepochs, filename):
    plt.rcParams["figure.figsize"] = (12,6)# Increase the plot size and font size.
    plt.grid()
    plt.plot(df_stats['Training Loss'], 'b-o', label="Training Loss")
    plt.title("Training Loss")
    plt.xlabel("# Epochs")
    plt.ylabel("Loss")
    plt.legend()
    plt.xticks(np.array(range(nepochs)))
    plt.savefig(filename)



def main():
    # Specify loss function
    loss_fn = nn.CrossEntropyLoss()
    #loss_fn = nn.BCELoss()
    #loss_fn = nn.BCEWithLogitsLoss
    
     
    datasets = ['covid', 'crisislext6', 'crisislext26']
  
    for data in datasets :
        if data == 'covid':
            train_dataset =  CustomDataset("/home/joao/COVID19Tweet-master/train.tsv","covid")
            val_dataset =  CustomDataset("/home/joao/COVID19Tweet-master/valid.tsv","covid")
            datasets = [train_dataset,val_dataset]
            dataset = torch.utils.data.ConcatDataset(datasets)
        if data ==  'crisislext6':
            dataset =  CustomDataset(None,"crisislext6")
            # Create a 90-10 train-validation split.
            #train_size = int(0.9 * len(dataset))
            #val_size = len(dataset) - train_size
            #train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
        if data == 'crisislext26':
            dataset =  CustomDataset(None,"crisislext26")
            # Create a 90-10 train-validation split.
            #train_size = int(0.9 * len(dataset))
            #val_size = len(dataset) - train_size
            #train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

        batch_size = 32
        # Create the DataLoaders for our training and validation sets.
        #train_dataloader = DataLoader(train_dataset,sampler = SequentialSampler(train_dataset), batch_size = batch_size )
        #val_dataloader = DataLoader(val_dataset,sampler = SequentialSampler(val_dataset), batch_size = batch_size )

        #complete data set
        dataloader = DataLoader(dataset,sampler = SequentialSampler(dataset), batch_size = batch_size )

        set_seed(42)    # Set seed for reproducibility
        our_bert_classifier, optimizer, scheduler = initialize_model(dataloader,epochs=10)
        #val_dataloader = None
        #train(our_bert_classifier, train_dataloader, val_dataloader, epochs=10, evaluation=False)

        train_and_evalCV(our_bert_classifier, dataset, cv=10, epochs=10, batch_size = 32)

  


In [None]:
###

#def evaluate(model, val_dataloader):
#    predictions , true_labels = [], []
#    model.eval()
#    for step, batch in enumerate(val_dataloader):
#        b_input_ids = batch[0].to(device)
#        b_input_mask = batch[1].to(device)
#        b_labels = batch[2].to(device)
#        with torch.no_grad():
#            loss, logits, hidden_states) = model(b_input_ids, token_type_ids=None, attention_mask=b_input_mask,labels=b_labels)
#        logits = logits.detach().cpu().numpy()
#        label_ids = b_labels.to('cpu').numpy()
#        pred_flat = np.argmax(logits, axis=1).flatten()
#        labels_flat = label_ids.flatten()
#        true_labels = np.concatenate((true_labels,labels_flat))
#        predictions = np.concatenate((predictions,pred_flat))
#    print('Classification Report:')
#    print(classification_report(true_labels, predictions, labels=[1,0], digits=4))
#    cm = confusion_matrix(true_labels, predictions, labels=[1,0])
#    ax= plt.subplot()
#    plt.imshow(cm, cmap='hot', interpolation='nearest')
#    ax.set_title('Confusion Matrix')
#    ax.set_xlabel('Predicted Labels')
#    ax.set_ylabel('True Labels')
    #sns.heatmap(cm, annot=True, ax = ax, cmap='Blues', fmt="d")
    #ax.xaxis.set_ticklabels(['FAKE', 'REAL'])
    #ax.yaxis.set_ticklabels(['FAKE', 'REAL'])
    
  
# Display floats with two decimal places.
#df_stats.set_option('Training Loss', 3)

# Create a DataFrame from our training statistics.
#df_stats = pd.DataFrame(data=stats)

# Use the 'epoch' as the row index.
#df_stats = df_stats.set_index('epoch')

# A hack to force the column headers to wrap.
#df_stats = df_stats.style.set_table_styles([dict(selector="th",props=[('max-width', '70px')])])

# Display the table.
#df_stats


class LSTM(nn.Module):
    def __init__(self, input_dim, hidden_dim, batch_size, output_dim=2,num_layers=2):
        super(LSTM, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.batch_size = batch_size
        self.num_layers = num_layers
        # Define the LSTM layer
        self.lstm = nn.LSTM(self.input_dim, self.hidden_dim, self.num_layers)
        # Define the output layer
        self.linear = nn.Linear(self.hidden_dim, output_dim)
    def init_hidden(self):
        # This is what we'll initialise our hidden state as
        return (torch.zeros(self.num_layers, self.batch_size, self.hidden_dim),
                torch.zeros(self.num_layers, self.batch_size, self.hidden_dim))
    def forward(self, input):
        # Forward pass through LSTM layer
        # shape of lstm_out: [input_size, batch_size, hidden_dim]
        # shape of self.hidden: (a, b), where a and b both 
        # have shape (num_layers, batch_size, hidden_dim).
        lstm_out, self.hidden = self.lstm(input.view(len(input), self.batch_size, -1))
        # Only take the output from the final timetep
        # Can pass on the entirety of lstm_out to the next layer if it is a seq2seq prediction
        y_pred = self.linear(lstm_out[-1].view(self.batch_size, -1))
        return y_pred.view(-1)

loss_fn = torch.nn.MSELoss(size_average=False)
 
optimiser = torch.optim.Adam(model.parameters(), lr=0.009)
hist = np.zeros(num_epochs)
 

 
model = LSTM(lstm_input_size, h1, batch_size=num_train, output_dim=output_dim, num_layers=num_layers)


for t in range(num_epochs):
    # Clear stored gradient
    model.zero_grad()
    model.hidden = model.init_hidden()
    y_pred = model(X_train)
    