In [1]:
import os
import time
import shutil
import time
import json
import random
import numpy as np
from easydict import EasyDict as edict
import argparse
from sklearn.metrics import classification_report,f1_score
import pickle
## torch packages
import torch
import torch.nn.functional as F
from torch.autograd import Variable
import torch.nn as nn

## for visulisation
import matplotlib.pyplot as plt

## custom
from select_model_input import select_model,select_input
import dataset
from label_dict import emo_label_map,label_emo_map,class_names,class_indices
from xai_emo_rec import explain_model

torch.manual_seed(0)
np.random.seed(0)
random.seed(0)
torch.backends.cudnn.deterministic = False

def get_pred_softmax(logits):
    softmax_layer = nn.Softmax(dim=1)
    return softmax_layer(logits)

def eval_model(model, val_iter, loss_fn,config,mode="train",explain=False):

    confusion = config.confusion
    per_class = config.per_class
    y_true = []
    y_pred = []
    total_epoch_loss = 0
    total_epoch_acc = 0

    if confusion:
        conf_matrix = torch.zeros(config.output_size, config.output_size)
    if per_class:
           class_correct = list(0. for i in range(config.output_size))
           class_total = list(0. for i in range(config.output_size))

    model.eval()
    with torch.no_grad():
        for idx, batch in enumerate(val_iter):
            model = model.cuda()
            text, target = select_input(batch,config)
            target = torch.autograd.Variable(target).long()

            if (target.size()[0] is not config.batch_size):
                continue
          
            if torch.cuda.is_available():
                if config.arch_name == "sl_bert" or config.arch_name=="a_bert":
                    text = [text[0].cuda(),text[1].cuda()]
                else:
                    text = text.cuda()
                target = target.cuda()

            prediction = model(text)
            correct = np.squeeze(torch.max(prediction, 1)[1].eq(target.view_as(torch.max(prediction, 1)[1])))
            pred_ind = torch.max(prediction, 1)[1].view(target.size()).data
            if mode == "explain":
                pred_softmax = get_pred_softmax(prediction)
                explain_model(model,text,target.data,batch["utterance_data_str"],pred_ind,pred_softmax)
#                 with open('./vis_data/'+str(idx)+'.p', "wb") as f:
#                     pickle.dump(vis_data_records, f)
#                 f.close()
            else:
                if confusion:
                    for t, p in zip(target.data, pred_ind):
                            conf_matrix[t.long(), p.long()] += 1

                
                if per_class:
                    for i in range(config.batch_size):
                        label = target[i]
                        class_correct[label] += correct[i].item()
                        class_total[label] += 1

                loss = loss_fn(prediction, target)
                
                num_corrects = (pred_ind == target.data).sum()
                y_true.extend(target.data.cpu().tolist())
                y_pred.extend(pred_ind.cpu().tolist())

                acc = 100.0 * num_corrects/config.batch_size
                total_epoch_loss += loss.item()
                total_epoch_acc += acc.item()

        
        if confusion:
            import seaborn as sns
            sns.heatmap(conf_matrix, annot=True,xticklabels=list(emo_label_map.keys()),yticklabels=list(emo_label_map.keys()))
            plt.show()
        if per_class:
            for i in range(config.output_size):
                print('Test Accuracy of %5s: %2d%% (%2d/%2d)' % (
                label_emo_map[i], 100 * class_correct[i] / class_total[i],
                np.sum(class_correct[i]), np.sum(class_total[i])))

    if mode != "explain": 
        f1_score_e = f1_score(y_true, y_pred, labels=class_indices,average='macro')
        return total_epoch_loss/len(val_iter), total_epoch_acc/len(val_iter),f1_score_e



def load_model(resume,model,optimizer):
    
    
    checkpoint = torch.load(resume)
    start_epoch = checkpoint['epoch']
    model.load_state_dict(checkpoint['state_dict'])
    model = model.cuda()
    model.eval()
    # optimizer.load_state_dict(checkpoint['optimizer']) ## during retrain TODO

    return model,optimizer,start_epoch
    

def call_eval(resume_path,mode,rem_epoch=10,patience=10):

    ## Load the resume model parameters  
    log_path = resume_path.replace("model_best.pth.tar","log.json")
    with open(log_path,'r') as f:
        log = json.load(f)
    f.close()
    
    ## Initialising parameters
    learning_rate = log["param"]["learning_rate"]
    batch_size = log["param"]["batch_size"]
    input_type = log["param"]["input_type"]
    arch_name = log["param"]["arch_name"]
    hidden_size = log["param"]["hidden_size"]
    embedding_length = log["param"]["embedding_length"]
    output_size = log["param"]["output_size"]
    tokenizer = log["param"]["tokenizer"]
    embedding_type = log["param"]["embedding_type"]

    ## Loading data
    print('Loading dataset')
    start_time = time.time()
    vocab_size, word_embeddings,train_iter, valid_iter ,test_iter= dataset.get_dataloader(batch_size,tokenizer,embedding_type,arch_name)
    finish_time = time.time()
    print('Finished loading. Time taken:{:06.3f} sec'.format(finish_time-start_time))

    eval_config = edict(log["param"])
    eval_config.resume_path = resume_path

    if mode == "explain":
        model = select_model(eval_config,vocab_size,word_embeddings,grad_check=False)
    else:
        model = select_model(eval_config,vocab_size,word_embeddings)
    loss_fn = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()),lr=learning_rate)
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,eval_config.step_size, gamma=0.5)

    model,optimizer,start_epoch = load_model(resume_path,model,optimizer)

    if mode == "retrain": ## retrain from checkpoint TODO
        from train import train_model
        eval_config.patience = patience
        eval_config.nepoch = rem_epoch
        eval_config.confusion = False
        eval_config.per_class = True
        
        data  = (train_iter,valid_iter,test_iter)
        model_run_time = time.strftime("%Y_%m_%d_%H_%M_%S", time.localtime())
        writer = SummaryWriter('./runs/'+input_type+"/"+arch_name+"/")
        save_home = "./save/"+input_type+"/"+arch_name+"/"+model_run_time
        
        train_model(eval_config,data,model,loss_fn,optimizer,lr_scheduler,writer,save_home)

    elif mode == "eval":
        
        print(f'Train Acc: {log["train_acc"]:.3f}%, Valid Acc: {log["valid_acc"]:.3f}%, Test Acc: {log["test_acc"]:.3f}%')

        eval_config.confusion = True
        eval_config.per_class = True

        # val_loss, val_acc = eval_model(model, valid_iter,loss_fn,eval_config) ## uncommeent if validation needed
        
        ## testing
        test_loss, test_acc,f1_score = eval_model(model, test_iter,loss_fn,eval_config,mode)
        log["f1_score"] = f1_score
        with open(log_path, 'w') as fp:
            json.dump(log, fp,indent=4)
        fp.close()

    elif mode == "explain":
        
        print(f'Train Acc: {log["train_acc"]:.3f}%, Valid Acc: {log["valid_acc"]:.3f}%, Test Acc: {log["test_acc"]:.3f}%')

        eval_config.confusion = False
        eval_config.per_class = False

        ## explaining
        eval_model(model, test_iter,loss_fn,eval_config,mode,explain=True)


The below section shows the word importances of sentimental and nostalgic!

In [None]:

call_eval("/home/ashvar/varsha/Emotion-Recognition/save/speaker+listener/bert/2020_09_27_19_49_06/model_best.pth.tar"
,"explain")


Loading dataset
Finished loading. Time taken:00.702 sec


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at

Loading Model
Train Acc: 68.540%, Valid Acc: 55.388%, Test Acc: 53.555%


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
sentimental,nostalgic (0.66),nostalgic,3.3,i found my dog ' s old toy . it brought back memories of when i first got her ##aw ##w . i had to put mine down last month . i miss her . i ' m sorry to hear that . what kind of dog was she ? border col ##lie . she was great !
,,,,


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
sentimental,sentimental (0.59),sentimental,0.02,i won a storage shed auction last monday and found my grandmother ' s old glass eye in a steamer trunk inside of the shed . really reminded me of her and how much i miss her . wow that ' s incredible . what was her an ##me ? rita may . she used to mess with me by popping it out and showing me her empty eye socket just to get a laugh while she got drunk off of bourbon . those were such good times . ha ##ha that ##s really funny . i bet she was a special person .
,,,,


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
nostalgic,nostalgic (0.85),nostalgic,2.52,"i miss high school . my girlfriend was the best ##i enjoyed highs ##cho ##ol , too . that ' s great . college was also fun ##coll ##ege was pretty laid back for me , i did a lot of online classes ."
,,,,
