<a href="https://colab.research.google.com/github/peeyushsinghal/da/blob/main/mitigating_bias_sa_da_v25.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Mitigating bias in sentiment analysis using domain adaptation

In [93]:
! pip install torchtext==0.10.0 --quiet # DOWNGRADE YOUR TORCHTEXT
! pip install ekphrasis --quiet # library to pre process twitter data
! pip install emoji --upgrade --quiet #library to deal with emoji data

In [94]:
## Import statements
import pandas as pd
import os
import torch
from torch.utils.data import Dataset, DataLoader
from torchtext.legacy.data import Dataset, Field, TabularDataset, BucketIterator
from torchtext.vocab import GloVe
import torchtext.vocab as vocab
import numpy as np
from ekphrasis.classes.preprocessor import TextPreProcessor
from ekphrasis.classes.tokenizer import SocialTokenizer
from ekphrasis.dicts.emoticons import emoticons
import emoji
from torchtext.legacy.vocab import Vectors
from tqdm import tqdm
import random
import torch.optim as optim
import scipy.stats as stats
from statistics import mean

import nltk
nltk.download('stopwords')
from nltk.corpus import stopwords
en_stops = set(stopwords.words('english'))

import time

from copy import deepcopy

import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable


[nltk_data] Downloading package stopwords to /root/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


In [95]:
# checking device
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Running on:{}".format(DEVICE))

Running on:cuda


## Data loading

In [96]:
#Mounting google drive
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


## Data Configuration

In [97]:

BASE_PATH = '/content/drive/MyDrive/semeval-2018'

DATA_DIR = os.path.join(BASE_PATH,'datasets')
TARGET_DIR = os.path.join(BASE_PATH,'targetdataset')

MODEL_DIR = os.path.join(BASE_PATH,'models')
REF_DIR = os.path.join(BASE_PATH,'reference')
EMBEDDINGS_DIR = os.path.join(BASE_PATH,'embeddings')

MAX_SIZE = 50
MAX_VOCAB_SIZE = 10000
BATCH_SIZE = 8


 
EMBEDDING_TO_BE_USED = 'glove' # {'glove', 'glove_gn'}
TARGET_BATCH_SIZE = 8

NUM_EPOCHS = 100
# NUM_EPOCHS = 1

EWC_LAMBDA = 0.4

if not os.path.exists(MODEL_DIR):
  os.makedirs(MODEL_DIR)
  print("The new directory is created!")

# EMB_MATRIX_DIR = os.path.join(BASE_PATH,'emb_matrix')
# if not os.path.exists(EMB_MATRIX_DIR):
#   os.makedirs(EMB_MATRIX_DIR)
#   print("The Embedding Matrix directory is created!")


BONFERRONI_CORRECTION = 5.0

In [98]:
# data configuration

class TASK1(object):
  
    EI_reg = {
        'anger': {
            'train': os.path.join(
                DATA_DIR, 'task1/EI-reg/training/EI-reg-En-anger-train.txt'),
            'dev': os.path.join(
                DATA_DIR, 'task1/EI-reg/development/2018-EI-reg-En-anger-dev.txt'),
            'gold': os.path.join(
                DATA_DIR, 'task1/EI-reg/test-gold/2018-EI-reg-En-anger-test-gold.txt')
                },
        'joy': {
                'train': os.path.join(
                    DATA_DIR, 'task1/EI-reg/training/EI-reg-En-joy-train.txt'),
                'dev': os.path.join(
                    DATA_DIR, 'task1/EI-reg/development/2018-EI-reg-En-joy-dev.txt'),
                'gold': os.path.join(
                    DATA_DIR, 'task1/EI-reg/test-gold/2018-EI-reg-En-joy-test-gold.txt')
                },
        'fear': {
            'train': os.path.join(
                DATA_DIR, 'task1/EI-reg/training/EI-reg-En-fear-train.txt'),
            'dev': os.path.join(
                DATA_DIR, 'task1/EI-reg/development/2018-EI-reg-En-fear-dev.txt'),
            'gold': os.path.join(
                DATA_DIR, 'task1/EI-reg/test-gold/2018-EI-reg-En-fear-test-gold.txt')
                },
        'sadness': {
            'train': os.path.join(
                DATA_DIR, 'task1/EI-reg/training/EI-reg-En-sadness-train.txt'),
            'dev': os.path.join(
                DATA_DIR, 'task1/EI-reg/development/2018-EI-reg-En-sadness-dev.txt'),
            'gold': os.path.join(
                DATA_DIR, 'task1/EI-reg/test-gold/2018-EI-reg-En-sadness-test-gold.txt')
                }                     
        }

    V_reg = {
        'train': os.path.join(
            DATA_DIR, 'task1/V-reg/2018-Valence-reg-En-train.txt'),
        'dev': os.path.join(
            DATA_DIR, 'task1/V-reg/2018-Valence-reg-En-dev.txt'),
        'gold': os.path.join(
            DATA_DIR, 'task1/V-reg/2018-Valence-reg-En-test-gold.txt')
             }

    EEC = {
        'eec': os.path.join(
            DATA_DIR, 'task1/Equity-Evaluation-Corpus/Equity-Evaluation-Corpus.csv')
             }

## Source Data
Parsing Emotion and Valence regression data : `format [ID	Tweet	Affect Dimension	Intensity Score]`

In [99]:
def parse_reg(data_file, label_format='tuple'):
    """
    This is for datasets for the EI-reg and V-reg English tasks 
    Returns:
        df: dataframe with columns in the first row of file [ID-Tweet-Affect Dimension-Intensity Score]
    """
    with open(data_file, 'r') as fd:
      data = [l.strip().split('\t') for l in fd.readlines()]
    # print(data)
    df = pd.DataFrame (data[1:],columns=data[0])
    csv_file_name = (data_file.split("/")[-1]).split('.')[0]+".csv"
    csv_file = df.to_csv(str(csv_file_name))
    return csv_file_name


Generic Source Data Parser

In [100]:
def parse_csv(task, dataset, emotion='anger'):
    if task == 'EI-reg':
        data_train = TASK1.EI_reg[emotion][dataset]
        csv_file_name = parse_reg(data_train)
        return csv_file_name

    elif task == 'V-reg':
        data_train = TASK1.V_reg[dataset]

        csv_file_name = parse_reg(data_train)
        return csv_file_name

    else:
        return None

In [101]:
emotions = ['anger','joy','fear','sadness']
dict_data ={'train':'train','dev':'val','gold':'test'}
dict_file_name ={}
for emotion in emotions:
  for data_info, data_usage in dict_data.items():
    file_name = str('file_EI_'+ emotion + "_" + data_usage)
    dict_file_name[file_name] = parse_csv('EI-reg', data_info, emotion)

    file_name2 = str('file_V_'+ data_usage)
    dict_file_name[file_name2] = parse_csv('V-reg', data_info)

(dict_file_name)

{'file_EI_anger_train': 'EI-reg-En-anger-train.csv',
 'file_V_train': '2018-Valence-reg-En-train.csv',
 'file_EI_anger_val': '2018-EI-reg-En-anger-dev.csv',
 'file_V_val': '2018-Valence-reg-En-dev.csv',
 'file_EI_anger_test': '2018-EI-reg-En-anger-test-gold.csv',
 'file_V_test': '2018-Valence-reg-En-test-gold.csv',
 'file_EI_joy_train': 'EI-reg-En-joy-train.csv',
 'file_EI_joy_val': '2018-EI-reg-En-joy-dev.csv',
 'file_EI_joy_test': '2018-EI-reg-En-joy-test-gold.csv',
 'file_EI_fear_train': 'EI-reg-En-fear-train.csv',
 'file_EI_fear_val': '2018-EI-reg-En-fear-dev.csv',
 'file_EI_fear_test': '2018-EI-reg-En-fear-test-gold.csv',
 'file_EI_sadness_train': 'EI-reg-En-sadness-train.csv',
 'file_EI_sadness_val': '2018-EI-reg-En-sadness-dev.csv',
 'file_EI_sadness_test': '2018-EI-reg-En-sadness-test-gold.csv'}

## Preprocess tweets

In [102]:
# reference : https://github.com/cbaziotis/ekphrasis


text_processor = TextPreProcessor(
    # terms that will be normalized
    normalize=['url', 'email', 'percent', 'money', 'phone', 'user',
        'time', 'url', 'date', 'number'],
    # terms that will be annotated
    annotate={"hashtag", "allcaps", "elongated", "repeated",
        'emphasis', 'censored'},
    fix_html=True,  # fix HTML tokens
    
    # corpus from which the word statistics are going to be used 
    # for word segmentation 
    segmenter="twitter", 
    
    # corpus from which the word statistics are going to be used 
    # for spell correction
    corrector="twitter", 
    
    unpack_hashtags=True,  # perform word segmentation on hashtags
    unpack_contractions=True,  # Unpack contractions (can't -> can not)
    spell_correct_elong=False,  # spell correction for elongated words
    
    # select a tokenizer. You can use SocialTokenizer, or pass your own
    # the tokenizer, should take as input a string and return a list of tokens
    tokenizer=SocialTokenizer(lowercase=True).tokenize,
    
    # list of dictionaries, for replacing tokens extracted from the text,
    # with other expressions. You can pass more than one dictionaries.
    dicts=[emoticons]
)

Reading twitter - 1grams ...
Reading twitter - 2grams ...
Reading twitter - 1grams ...


In [103]:
def preprocess_tweet(tweet): 
  tweet_processed = text_processor.pre_process_doc(tweet)
  # print (tweet_processed)
  demojized_list =[]
  final_list =[]
  for index, tweet in enumerate(tweet_processed):
      demojized_list.append(emoji.demojize(tweet, language = 'en'))
  final_list = [w for w in demojized_list if w not in en_stops] 
  
  # print(df)
  return final_list

## TorchText Treatment

In [104]:
dict_file_name.keys()

dict_keys(['file_EI_anger_train', 'file_V_train', 'file_EI_anger_val', 'file_V_val', 'file_EI_anger_test', 'file_V_test', 'file_EI_joy_train', 'file_EI_joy_val', 'file_EI_joy_test', 'file_EI_fear_train', 'file_EI_fear_val', 'file_EI_fear_test', 'file_EI_sadness_train', 'file_EI_sadness_val', 'file_EI_sadness_test'])

In [105]:
dict_fields ={}
list_name = list(set(["_".join(key.split("_")[1:-1]) for key in list(dict_file_name.keys())]))


for name in list_name:
  field_tweet = Field(sequential=True, 
                      use_vocab = True, 
                      tokenize = preprocess_tweet, 
                      fix_length = MAX_SIZE, 
                      batch_first = True)
  field_intensity = Field(sequential= False, 
                        dtype = torch.float,
                        use_vocab = False)
  fields = {
    'Tweet':('tweet', field_tweet ), #
    'Intensity Score': ('intensity',field_intensity) # Intensity Score is name of the dataset column, field_intensity is how we have defined the field, intensity is the name of the variable going fwd
    }
  
  dict_fields[name] = fields

  # dict_fields[name]= { 'field_tweet': Field(sequential=True,
  #                                        use_vocab = True,
  #                                        tokenize = preprocess_tweet,
  #                                        fix_length = MAX_SIZE,
  #                                        batch_first = True ), 
  #                           'field_intensity': Field(sequential= False,
  #                                              dtype = torch.float,
  #                                              use_vocab = False )}

dict_fields

{'EI_anger': {'Tweet': ('tweet',
   <torchtext.legacy.data.field.Field at 0x7f3e1e1c8d00>),
  'Intensity Score': ('intensity',
   <torchtext.legacy.data.field.Field at 0x7f3e1e1c8c70>)},
 'EI_fear': {'Tweet': ('tweet',
   <torchtext.legacy.data.field.Field at 0x7f3e1e1c8f70>),
  'Intensity Score': ('intensity',
   <torchtext.legacy.data.field.Field at 0x7f3e1e1c8be0>)},
 'EI_sadness': {'Tweet': ('tweet',
   <torchtext.legacy.data.field.Field at 0x7f3e1e1c8ca0>),
  'Intensity Score': ('intensity',
   <torchtext.legacy.data.field.Field at 0x7f3e1e1c8c10>)},
 'V': {'Tweet': ('tweet',
   <torchtext.legacy.data.field.Field at 0x7f3e1e1c8c40>),
  'Intensity Score': ('intensity',
   <torchtext.legacy.data.field.Field at 0x7f3e1e1c8850>)},
 'EI_joy': {'Tweet': ('tweet',
   <torchtext.legacy.data.field.Field at 0x7f3e1e1c8b80>),
  'Intensity Score': ('intensity',
   <torchtext.legacy.data.field.Field at 0x7f3e1e1c8e50>)}}

In [106]:
dict_fields['EI_sadness']['Tweet'][1]

<torchtext.legacy.data.field.Field at 0x7f3e1e1c8ca0>

In [107]:
dict_dataset ={}
for file_key, file_name in dict_file_name.items():
  # print(file_key,file_name)
  if "train" in (file_key.split("_")[-1]):
    head_name = "_".join(file_key.split("_")[0:-1])
    base_name = "_".join(file_key.split("_")[1:-1])
    # print(base_name)
    train_file = dict_file_name[head_name+"_train"]
    val_file = dict_file_name[head_name+"_val"]
    test_file =  dict_file_name[head_name+"_test"]

    train, val, test =TabularDataset.splits( path = './', 
                                            train = train_file, 
                                            validation = val_file, 
                                            test = test_file,
                                            format = 'csv', 
                                            fields = dict_fields[base_name])
    
    # print(train_file,val_file,test_file)
    
    # dict_dataset[base_name+"_train"], dict_dataset[base_name+"_val"],dict_dataset[base_name+"_test"]=TabularDataset.splits( path = './',
    #                                                                                                                        train = train_file,
    #                                                                                                                        validation = val_file,
    #                                                                                                                        test = test_file,
    #                                                                                                                        format = 'csv',
    #                                                                                                                        fields = fields)
    dict_dataset[base_name] = {"train_dataset": train, "val_dataset":val,"test_dataset":test}

In [108]:
dict_dataset

{'EI_anger': {'train_dataset': <torchtext.legacy.data.dataset.TabularDataset at 0x7f3e5812ea00>,
  'val_dataset': <torchtext.legacy.data.dataset.TabularDataset at 0x7f3e5812ebe0>,
  'test_dataset': <torchtext.legacy.data.dataset.TabularDataset at 0x7f3e5812ef70>},
 'V': {'train_dataset': <torchtext.legacy.data.dataset.TabularDataset at 0x7f3e5812eee0>,
  'val_dataset': <torchtext.legacy.data.dataset.TabularDataset at 0x7f3dbdd57c40>,
  'test_dataset': <torchtext.legacy.data.dataset.TabularDataset at 0x7f3e23e34a90>},
 'EI_joy': {'train_dataset': <torchtext.legacy.data.dataset.TabularDataset at 0x7f3da3b78df0>,
  'val_dataset': <torchtext.legacy.data.dataset.TabularDataset at 0x7f3da38c0d30>,
  'test_dataset': <torchtext.legacy.data.dataset.TabularDataset at 0x7f3e580e2610>},
 'EI_fear': {'train_dataset': <torchtext.legacy.data.dataset.TabularDataset at 0x7f3da36914c0>,
  'val_dataset': <torchtext.legacy.data.dataset.TabularDataset at 0x7f3da34fdd90>,
  'test_dataset': <torchtext.legacy

In [109]:
for key, value in dict_dataset.items():
  # count = 0
  for name, dataset in value.items():
    for example in dataset.examples:
      print(key, name, example.tweet, example.intensity)
      break


# for example in test_data.examples:
#   print(example.tweet, example.intensity)
#   count += 1
#   if count > 2:
#     break

EI_anger train_dataset ['<user>', '<user>', 'shut', 'hashtags', 'cool', '<hashtag>', 'offended', '</hashtag>'] 0.562
EI_anger val_dataset ["'", 'need', 'something', '.', 'something', 'must', 'done', '!', '<repeated>', "'", '\\', 'n', '\\', 'nyour', 'anxiety', 'amusing', '.', 'nothing', 'done', '.', 'despair', '.'] 0.517
EI_anger test_dataset ['<user>', 'know', 'mean', 'well', 'offended', '.', 'prick', '.'] 0.734
V train_dataset ['<user>', 'yeah', '!', '<happy>', 'playing', 'well'] 0.600
V val_dataset ['<user>', 'site', 'crashes', 'everytime', 'try', 'book', '-', 'help', '?', 'tell', "'", 'nothing', 'wrong', '&', 'hang', '<hashtag>', 'furious', '</hashtag>', '<hashtag>', 'helpless', '</hashtag>', '<user>'] 0.141
V test_dataset ['gm', '<hashtag>', 'tuesday', '</hashtag>', '!'] 0.589
EI_joy train_dataset ['<user>', 'quite', 'saddened', '.', '<repeated>', 'us', 'dates', ',', 'joyous', 'anticipation', 'attending', 'dg', 'concert', '(', 'since', '<number>', ')', '.', 'happy', 'keeping', 'bus

## Building iterator and Vocabulary

In [110]:
for name, value in dict_fields.items():
  print(name, value)

EI_anger {'Tweet': ('tweet', <torchtext.legacy.data.field.Field object at 0x7f3e1e1c8d00>), 'Intensity Score': ('intensity', <torchtext.legacy.data.field.Field object at 0x7f3e1e1c8c70>)}
EI_fear {'Tweet': ('tweet', <torchtext.legacy.data.field.Field object at 0x7f3e1e1c8f70>), 'Intensity Score': ('intensity', <torchtext.legacy.data.field.Field object at 0x7f3e1e1c8be0>)}
EI_sadness {'Tweet': ('tweet', <torchtext.legacy.data.field.Field object at 0x7f3e1e1c8ca0>), 'Intensity Score': ('intensity', <torchtext.legacy.data.field.Field object at 0x7f3e1e1c8c10>)}
V {'Tweet': ('tweet', <torchtext.legacy.data.field.Field object at 0x7f3e1e1c8c40>), 'Intensity Score': ('intensity', <torchtext.legacy.data.field.Field object at 0x7f3e1e1c8850>)}
EI_joy {'Tweet': ('tweet', <torchtext.legacy.data.field.Field object at 0x7f3e1e1c8b80>), 'Intensity Score': ('intensity', <torchtext.legacy.data.field.Field object at 0x7f3e1e1c8e50>)}


In [111]:
dict_emb_file = {'glove':'glove.6B.300d.txt',
                 'glove_gn': '1b-vectors300-0.8-0.8.txt'}

print("EMBEDDING_TO_BE_USED:", EMBEDDING_TO_BE_USED)
emb_file_path = os.path.join(EMBEDDINGS_DIR ,dict_emb_file[EMBEDDING_TO_BE_USED])
emb_file_path

EMBEDDING_TO_BE_USED: glove


'/content/drive/MyDrive/semeval-2018/embeddings/glove.6B.300d.txt'

In [112]:

for name, value in dict_fields.items():
  print(name, value['Tweet'][1])
  # value['Tweet'][1].build_vocab(dict_dataset[name]['train_dataset'],
  #                                  max_size = MAX_VOCAB_SIZE,
  #                                  min_freq = 1,
  #                                  vectors = "glove.6B.100d",
  #                                  unk_init=torch.Tensor.normal_)
  
  ## start for embeddings from text file
  value['Tweet'][1].build_vocab(dict_dataset[name]['train_dataset'])
  vectors = vocab.Vectors(emb_file_path) # location of embeddings file, full path
  value['Tweet'][1].vocab.set_vectors(vectors.stoi, vectors.vectors, vectors.dim)
  ## end for embeddings from text file

  value['Intensity Score'][1].build_vocab(dict_dataset[name]['train_dataset'])


EI_anger <torchtext.legacy.data.field.Field object at 0x7f3e1e1c8d00>
EI_fear <torchtext.legacy.data.field.Field object at 0x7f3e1e1c8f70>
EI_sadness <torchtext.legacy.data.field.Field object at 0x7f3e1e1c8ca0>
V <torchtext.legacy.data.field.Field object at 0x7f3e1e1c8c40>
EI_joy <torchtext.legacy.data.field.Field object at 0x7f3e1e1c8b80>


In [113]:
dict_iterator ={}
for name, value in dict_dataset.items():
  VALID_BATCH_SIZE = len(value['val_dataset']) - 1
  TEST_BATCH_SIZE = len(value['test_dataset'])  -1
  # VALID_TEST_BATCH_SIZE= min(len(value['val_dataset']),len(value['test_dataset']) )
  print(name, VALID_BATCH_SIZE , TEST_BATCH_SIZE)
  train_iterator, val_iterator, test_iterator= BucketIterator.splits(
      (value['train_dataset'], value['val_dataset'],value['test_dataset']),
      batch_sizes= (BATCH_SIZE,VALID_BATCH_SIZE, TEST_BATCH_SIZE),
      sort_key = lambda x: len(x.tweet),
      sort_within_batch=True,
      device = DEVICE,
      shuffle= True)
  
  dict_iterator[name] = {"train_iterator": train_iterator, "val_iterator":val_iterator,"test_iterator":test_iterator}

  # dict_iterator[name]['train_iterator'], dict_iterator[name]['val_iterator'], dict_iterator[name]['test_iterator'] = BucketIterator.splits((dict_dataset[name]['train_dataset'], dict_dataset[name]['val_dataset'],dict_dataset[name]['test_dataset']), 
  #                                                     batch_sizes= (BATCH_SIZE,VALID_TEST_BATCH_SIZE,VALID_TEST_BATCH_SIZE),
  #                                                     sort_key = lambda x: len(x.tweet),
  #                                                     sort_within_batch=True,
  #                                                     device = DEVICE,
  #                                                     shuffle= True)

EI_anger 387 1001
V 448 936
EI_joy 289 1104
EI_fear 388 985
EI_sadness 396 974


In [114]:
dict_iterator.items()

dict_items([('EI_anger', {'train_iterator': <torchtext.legacy.data.iterator.BucketIterator object at 0x7f3da2339520>, 'val_iterator': <torchtext.legacy.data.iterator.BucketIterator object at 0x7f3da23395e0>, 'test_iterator': <torchtext.legacy.data.iterator.BucketIterator object at 0x7f3da23396a0>}), ('V', {'train_iterator': <torchtext.legacy.data.iterator.BucketIterator object at 0x7f3da2339640>, 'val_iterator': <torchtext.legacy.data.iterator.BucketIterator object at 0x7f3da2339820>, 'test_iterator': <torchtext.legacy.data.iterator.BucketIterator object at 0x7f3da23398b0>}), ('EI_joy', {'train_iterator': <torchtext.legacy.data.iterator.BucketIterator object at 0x7f3da23397c0>, 'val_iterator': <torchtext.legacy.data.iterator.BucketIterator object at 0x7f3da23399a0>, 'test_iterator': <torchtext.legacy.data.iterator.BucketIterator object at 0x7f3da2339a30>}), ('EI_fear', {'train_iterator': <torchtext.legacy.data.iterator.BucketIterator object at 0x7f3da2339d30>, 'val_iterator': <torchtex

In [115]:
for key, value in dict_iterator.items():
  for name, iterator in value.items():
    for batch in iterator:
      print(key, name, batch.tweet)
      print(batch.intensity)
      break
    break
  break
    
# count = 0a
# for batch in train_iterator:
#   print (batch.tweet)
#   print (batch.intensity)
#   count += 1
#   if count > 2:
#     break

EI_anger train_iterator tensor([[   5,    5,  145,  184, 2882, 1354, 2057, 1340, 3076,   33,   75, 4672,
          242, 3309,    4,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1],
        [1981, 3508,  421,  671, 1113,  279, 3732,    3,   54,    2,   52,  191,
         2300,  312,   57,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1],
        [   5, 2975,   19, 4416, 1712,    4,    9,    3, 2846,    2,    3, 1063,
          233,  502,    2,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,   

In [116]:
for name in list_name:
  print(dict_fields[name]['Tweet'][1].vocab.stoi.items())

dict_items([('<unk>', 0), ('<pad>', 1), ('</hashtag>', 2), ('<hashtag>', 3), ('.', 4), ('<user>', 5), (',', 6), ("'", 7), ('!', 8), ('<repeated>', 9), ('<number>', 10), ('?', 11), ('\\', 12), ('</allcaps>', 13), ('<allcaps>', 14), ('-', 15), ('like', 16), ('n', 17), ('get', 18), ('&', 19), ('fear', 20), ('anxiety', 21), ('awful', 22), ('horror', 23), ('nightmare', 24), ('terrorism', 25), ('nervous', 26), ('start', 27), ('terrible', 28), ('would', 29), ('people', 30), ('day', 31), ('know', 32), ('shocking', 33), (':', 34), ('terror', 35), ('time', 36), ('going', 37), ('panic', 38), ('one', 39), ('afraid', 40), ('go', 41), ('bully', 42), ('/', 43), ('horrible', 44), ('think', 45), ('life', 46), ('still', 47), ('really', 48), ('never', 49), ('us', 50), ('make', 51), ('good', 52), ('u', 53), ('worry', 54), ('back', 55), ('bad', 56), ('want', 57), ('today', 58), (':face_with_tears_of_joy:', 59), ('need', 60), ('see', 61), ('last', 62), ('much', 63), ('new', 64), ('trump', 65), ('even', 66),

## TorchText treatment of Target Data

In [117]:
# Target Data
df_target = pd.read_csv(os.path.join(TARGET_DIR,'winomt_saunders_combined.csv'))
print(df_target.head())
target_file_name = 'target_csv.csv'
target_csv_file = df_target.to_csv(str(target_file_name))

                             Tweet  Intensity Score
0    the actor finished her work .                0
1    the actor finished his work .                0
2  the actuary finished her work .                0
3  the actuary finished his work .                0
4    the agent finished her work .                0


In [118]:
dict_target_dataset = {}
for name in list_name:
  # print(dict_fields[name])
  dict_target_dataset [name] = TabularDataset( path = os.path.join(TARGET_DIR,'winomt_saunders_combined.csv'), 
                                              format = 'csv',
                                              fields = dict_fields[name] )

print(dict_target_dataset)

{'EI_anger': <torchtext.legacy.data.dataset.TabularDataset object at 0x7f3e5812e8e0>, 'EI_fear': <torchtext.legacy.data.dataset.TabularDataset object at 0x7f3e1ec73400>, 'EI_sadness': <torchtext.legacy.data.dataset.TabularDataset object at 0x7f3e5812e880>, 'V': <torchtext.legacy.data.dataset.TabularDataset object at 0x7f3e218ab310>, 'EI_joy': <torchtext.legacy.data.dataset.TabularDataset object at 0x7f3e21a56fa0>}


In [119]:
for name, dataset in dict_target_dataset.items():
  count=0
  for example in dataset:
    print(example.tweet, example.intensity)
    count += 1
    if count > 2:
      break

['actor', 'finished', 'work', '.'] 0
['actor', 'finished', 'work', '.'] 0
['actuary', 'finished', 'work', '.'] 0
['actor', 'finished', 'work', '.'] 0
['actor', 'finished', 'work', '.'] 0
['actuary', 'finished', 'work', '.'] 0
['actor', 'finished', 'work', '.'] 0
['actor', 'finished', 'work', '.'] 0
['actuary', 'finished', 'work', '.'] 0
['actor', 'finished', 'work', '.'] 0
['actor', 'finished', 'work', '.'] 0
['actuary', 'finished', 'work', '.'] 0
['actor', 'finished', 'work', '.'] 0
['actor', 'finished', 'work', '.'] 0
['actuary', 'finished', 'work', '.'] 0


In [120]:
dict_target_iterator = {}
for name in list_name:
  dict_target_iterator [name] = BucketIterator(dict_target_dataset[name], # given that there is only one dataset we are not using splits
                                 batch_size= TARGET_BATCH_SIZE,
                                 sort_key = lambda x: len(x.tweet),
                                 sort_within_batch=True,
                                 device = DEVICE,
                                 repeat=True,
                                 shuffle= True)

print(dict_target_iterator)

{'EI_anger': <torchtext.legacy.data.iterator.BucketIterator object at 0x7f3e22a993a0>, 'EI_fear': <torchtext.legacy.data.iterator.BucketIterator object at 0x7f3e22a99370>, 'EI_sadness': <torchtext.legacy.data.iterator.BucketIterator object at 0x7f3e22a99460>, 'V': <torchtext.legacy.data.iterator.BucketIterator object at 0x7f3e22a99580>, 'EI_joy': <torchtext.legacy.data.iterator.BucketIterator object at 0x7f3e22a99610>}


In [121]:
# next(iter(target_iterator))

In [122]:
for name, iterator in dict_target_iterator.items():
  count = 0
  for batch in iterator:
    print(name)
    print(batch)
    print (batch.tweet)
    print (batch.intensity)
    count += 1
    break
    if count > 2:
      break

EI_anger

[torchtext.legacy.data.batch.Batch of size 8]
	[.tweet]:[torch.cuda.LongTensor of size 8x50 (GPU 0)]
	[.intensity]:[torch.cuda.FloatTensor of size 8 (GPU 0)]
tensor([[   0,  137,    0,    0,  332,    0,    4,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1],
        [   0,  251,    0,  117,    0,    0,    4,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1],
        [   0,    0, 1436,    0,    0,  302,    4,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,   

In [123]:
# count = 0
# for batch in target_iterator:
#   print(batch)
#   print (batch.tweet)
#   print (batch.intensity)
#   count += 1
#   if count > 2:
#     break

## CNN 1d model

### Gradient Reversal layer

In [124]:
from torch.autograd import Function


class GradientReversalFn(Function):
    @staticmethod
    def forward(ctx, x, alpha):
        ctx.alpha = alpha
        
        return x.view_as(x)

    @staticmethod
    def backward(ctx, grad_output):
        output = grad_output.neg() * ctx.alpha

        return output, None

CNN 1 D model
Reference: A Sensitivity Analysis of (and Practitioners' Guide to) Convolutional Neural Networks for Sentence Classification, Ye Zhang, Byron Wallace 2015

Difference:

use of embedding
use of sigmoid function, as we are having a regression model not a classififer as the main task

In [125]:
import torch.nn as nn
import torch.nn.functional as F

class CNN1d(nn.Module):
    def __init__(self, 
                 vocab_size, 
                 embedding_dim, 
                 n_filters, 
                 filter_sizes, 
                 output_dim, 
                 dropout, 
                 pad_idx
                 ):
        super().__init__()
        
        #---------------------Feature Extractor Network----------------------#
        # Embedding layer
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx = pad_idx)

        # Convolutional Network
        self.convs = nn.ModuleList([
                                    nn.Conv1d(in_channels = embedding_dim, 
                                              out_channels = n_filters, 
                                              kernel_size = fs)
                                    for fs in filter_sizes
                                    ])
        
        #---------------------Regression Network------------------------#
        # Fully-connected layer and Dropout
        self.regression = nn.Sequential(
            nn.Dropout(p=dropout),
            nn.Linear(len(filter_sizes) * n_filters, len(filter_sizes) * n_filters // 2),
            nn.ReLU(),
            nn.Linear(len(filter_sizes) * n_filters // 2, output_dim * 10),
            nn.ReLU(),
            nn.Linear(output_dim * 10, output_dim)
            # ,
            # nn.Sigmoid()
        )
        # self.fc = nn.Linear(len(filter_sizes) * n_filters, output_dim * 10)
        # self.fc2 = nn.Linear(output_dim * 10, output_dim)
        # self.dropout = nn.Dropout(dropout)

        #---------------------Domain Classifier Network------------------------#
        # Fully-connected layer and Dropout
        self.domain_classifier = nn.Sequential(
            nn.Dropout(p=dropout),
            nn.Linear(len(filter_sizes) * n_filters, len(filter_sizes) * n_filters // 2),
            nn.ReLU(),
            nn.Linear(len(filter_sizes) * n_filters // 2, output_dim * 10),
            nn.ReLU(),
            nn.Linear(output_dim * 10, 2),
            nn.LogSoftmax(dim=1),
        )
        
    def forward(self, text, alpha=1.0):
        
        #text = [batch size, sent len]
        
        embedded = self.embedding(text)
                
        #embedded = [batch size, sent len, emb dim]
        
        embedded = embedded.permute(0, 2, 1)
        
        #embedded = [batch size, emb dim, sent len]
        
        conved = [F.relu(conv(embedded)) for conv in self.convs]
            
        #conved_n = [batch size, n_filters, sent len - filter_sizes[n] + 1]
        
        pooled = [F.max_pool1d(conv, conv.shape[2]).squeeze(2) for conv in conved]
        
        #pooled_n = [batch size, n_filters]
        
        x_feature = torch.cat(pooled, dim = 1)
        
        #x_feature = [batch size, n_filters * len(filter_sizes)]
        
        reverse_feature = GradientReversalFn.apply(x_feature, alpha)
        # print("reverse_feature",reverse_feature)
    
        regression_output = self.regression(x_feature)
    
        domain_classifier_output = self.domain_classifier(reverse_feature)


        return regression_output, domain_classifier_output

In [126]:
# INPUT_DIM = len(field_tweet.vocab) # these change for each model
EMBEDDING_DIM = 100
N_FILTERS = 100
FILTER_SIZES = [2, 3, 4, 5]
OUTPUT_DIM = 1
DROPOUT = 0.5
# PAD_IDX = field_tweet.vocab.stoi[field_tweet.pad_token] # these change for each model

### Model Architecture Creation for each variant, Loading pre-trained embeddings

In [127]:
dict_model_arch ={}
for name in list_name:
  
  # INPUT_DIM = len(field_tweet.vocab) # single model
  INPUT_DIM = len(dict_fields[name]['Tweet'][1].vocab)
  # print(INPUT_DIM)

  # PAD_IDX = field_tweet.vocab.stoi[field_tweet.pad_token] # # single model
  PAD_IDX = dict_fields[name]['Tweet'][1].vocab.stoi[dict_fields[name]['Tweet'][1].pad_token]
  # print(PAD_IDX)

  dict_model_arch[name] = CNN1d(INPUT_DIM, EMBEDDING_DIM, N_FILTERS, FILTER_SIZES, OUTPUT_DIM, DROPOUT, PAD_IDX)
  dict_model_arch[name].to(DEVICE)

  # pretrained_embeddings = field_tweet.vocab.vectors # single model
  pretrained_embeddings = dict_fields[name]['Tweet'][1].vocab.vectors

  # model.embedding.weight.data.copy_(pretrained_embeddings) # single model
  # dict_model_arch[name].embedding.weight.data.copy_(pretrained_embeddings)

  ## start for embeddings from text file
  dict_model_arch[name].embedding.from_pretrained(torch.FloatTensor(dict_fields[name]['Tweet'][1].vocab.vectors))
  ## end for embeddings from text file

  # UNK_IDX = field_tweet.vocab.stoi[field_tweet.unk_token] # single model
  UNK_IDX = dict_fields[name]['Tweet'][1].vocab.stoi[dict_fields[name]['Tweet'][1].unk_token]

  # model.embedding.weight.data[UNK_IDX] = torch.zeros(EMBEDDING_DIM) # single model
  dict_model_arch[name].embedding.weight.data[UNK_IDX] = torch.zeros(EMBEDDING_DIM) 
  
  # model.embedding.weight.data[PAD_IDX] = torch.zeros(EMBEDDING_DIM) # single model
  dict_model_arch[name].embedding.weight.data[PAD_IDX] = torch.zeros(EMBEDDING_DIM) 

dict_model_arch

{'EI_anger': CNN1d(
   (embedding): Embedding(4689, 100, padding_idx=1)
   (convs): ModuleList(
     (0): Conv1d(100, 100, kernel_size=(2,), stride=(1,))
     (1): Conv1d(100, 100, kernel_size=(3,), stride=(1,))
     (2): Conv1d(100, 100, kernel_size=(4,), stride=(1,))
     (3): Conv1d(100, 100, kernel_size=(5,), stride=(1,))
   )
   (regression): Sequential(
     (0): Dropout(p=0.5, inplace=False)
     (1): Linear(in_features=400, out_features=200, bias=True)
     (2): ReLU()
     (3): Linear(in_features=200, out_features=10, bias=True)
     (4): ReLU()
     (5): Linear(in_features=10, out_features=1, bias=True)
   )
   (domain_classifier): Sequential(
     (0): Dropout(p=0.5, inplace=False)
     (1): Linear(in_features=400, out_features=200, bias=True)
     (2): ReLU()
     (3): Linear(in_features=200, out_features=10, bias=True)
     (4): ReLU()
     (5): Linear(in_features=10, out_features=2, bias=True)
     (6): LogSoftmax(dim=1)
   )
 ), 'EI_fear': CNN1d(
   (embedding): Embeddin

In [128]:
# model = CNN1d(INPUT_DIM, EMBEDDING_DIM, N_FILTERS, FILTER_SIZES, OUTPUT_DIM, DROPOUT, PAD_IDX)
# model.to(DEVICE)

### Load Pre trained embeddings
we'll load the pre-trained *embeddings*

In [129]:
# pretrained_embeddings = field_tweet.vocab.vectors
# model.embedding.weight.data.copy_(pretrained_embeddings)

In [130]:
# field_tweet.vocab.vectors.shape

In [131]:
# UNK_IDX = field_tweet.vocab.stoi[field_tweet.unk_token]

# model.embedding.weight.data[UNK_IDX] = torch.zeros(EMBEDDING_DIM)
# model.embedding.weight.data[PAD_IDX] = torch.zeros(EMBEDDING_DIM)

## Training the model

### Without training one forward pass

In [132]:
for name, model_arch in dict_model_arch.items():
  for batch in dict_iterator[name]['train_iterator']:
    print(batch.tweet)
    output = model_arch(batch.tweet)
    print (output)
    break

tensor([[   5,  121, 2281,   15,   54,  133,    8,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1],
        [   5,  689,  321, 2668, 1011, 3433,   11,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1],
        [  36,   32, 1250,   85, 3993,   84,  325,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,   

In [133]:
# import torch.optim as optim

# optimizer = optim.Adam(model.parameters())

# criterion = nn.BCEWithLogitsLoss()

# model = model.to(DEVICE)
# criterion = criterion.to(DEVICE)

### Typical Train Model Function

In [134]:
# Typical Training Function

from tqdm import tqdm # for beautiful model training updates

def train_model(model, device, train_loader, optimizer, epoch):
    model.train() # setting the model in training mode
    pbar = tqdm(train_loader) # putting the iterator in pbar
    correct = 0 # for accuracy numerator
    processed =0 # for accuracy denominator
    epoch_loss = 0.0
    
    # used for EWC
    # fisher_dict_m = {}
    # optpar_dict_m = {}

    for batch_idx, batch in enumerate(pbar):

        tweets, intensities = batch.tweet.to(device), batch.intensity.to(device)  # plural, we are not interested in domain
        #sending data to CPU or GPU as per device

        optimizer.zero_grad() # setting gradients to zero to avoid accumulation

        y_preds,_ = model(tweets) # forward pass, result captured in y_preds (plural as there are many body in a batch)
        # we are not interested in domain prediction
        # the predictions are in one hot vector

        regression_loss = regression_loss_function(y_preds,intensities.unsqueeze(1)) # Computing loss
        # loss = F.mse_loss(y_preds,intensities.unsqueeze(1)) # Computing loss

        # train_regresion_losses.append(regression_loss.item()) # to capture loss over many epochs

        regression_loss.backward() # backpropagation

        # # start support for EWC
        # for param_name, param in model.named_parameters():
        #   optpar_dict_m[param_name] = param.data.clone()
        #   fisher_dict_m[param_name] = param.grad.data.clone().pow(2)
        # # end support for EWC

        optimizer.step() # updating the params

        # preds = y_preds.argmax(dim=1, keepdim=True)  # get the index olf the max log-probability
        # correct += preds.eq(labels.view_as(preds)).sum().item()
        epoch_loss += regression_loss.item()

        processed += len(tweets)

        pbar.set_description(desc= f'Loss={regression_loss.item()} Batch_id={batch_idx} Epoch Average loss={epoch_loss/processed:0.4f}')
    return float("{:.6f}".format(epoch_loss/processed)) #, fisher_dict, optpar_dict
    # return epoch_loss/processed
    # train_accuracy.append(100*epoch_loss/len(train_loader))

### Typical Test Function

In [135]:
def test_model(model,device, data_loader, mode= 'test'):
    model.eval() # setting the model in evaluation mode
    loss = 0
    correct = 0 # for accuracy numerator
    test_regresion_losses =[] # for overall epoch (summed over batches)
    valid_regresion_losses =[] # for overall epoch (summed over batches)

    with torch.no_grad():
        for batch in data_loader:

            tweets, intensities  = batch.tweet.to(device), batch.intensity.to(device) #sending data to CPU or GPU as per device
            # we are not interested in domains
            
            y_preds,_ = model(tweets) # forward pass, result captured in outputs (plural as there are many bodies in a batch)
            # the outputs are in batch size x one hot vector 
            # not interested in domain output

            regression_loss = regression_loss_function(y_preds,intensities.unsqueeze(1))

            if mode == 'test':
              test_regresion_losses.append(regression_loss.item())
              # print(f'...in the batch...{regression_loss}')
            else:
              valid_regresion_losses.append(regression_loss.item())
              # print(f'...in the batch...{regression_loss}')

        # regression_loss.item() /= len(data_loader.dataset) # average test loss
        if mode == 'test':
          # total_test_regression_loss = sum(test_regresion_losses)
          # test_regresion_losses.append(regression_loss) # to capture loss over many batches
          # print('...Average test loss: {:.8f}'.format((total_test_regression_loss)/len(data_loader.dataset)))
          avg_epoch_test_loss = float("{:.6f}".format(sum(test_regresion_losses) / len(data_loader)))
          # print(f'TEST LOSS (Average) : {sum(test_regresion_losses) / len(data_loader)}')
          print(f'TEST LOSS (Average) : {avg_epoch_test_loss}')
          return float(avg_epoch_test_loss)
        else:
          # valid_regresion_losses.append(regression_loss) # to capture loss over many batches
          # total_valid_regression_loss = sum(valid_regresion_losses)
          # print('...Average validation loss: {:.8f}'.format((total_valid_regression_loss)/len(data_loader.dataset)))
          # print(f'VALIDATION LOSS (Average) : {sum(valid_regresion_losses) / len(data_loader)}')
          avg_epoch_valid_loss = float("{:.6f}".format(sum(valid_regresion_losses) / len(data_loader)))
          print(f'VALIDATION LOSS (Average) : {avg_epoch_valid_loss}')
          return float(avg_epoch_valid_loss)

### Execution Non DANN

In [136]:
# EXECUTION (NON DANN) FOR MULTIPLE MODELS
lr = 2e-5
# EPOCHS = 2
# EPOCHS = 100
EPOCHS = NUM_EPOCHS
dict_non_dann_model_saved= {}
dict_non_dann_losses_list = {}

# # start support for EWC
# fisher_dict = {}
# optpar_dict = {}
# ewc_lambda = 0.4
# # end support for EWC


for name, model_arch in dict_model_arch.items():
  model = model_arch
  optimizer = optim.Adam(model.parameters(), lr=lr)
  domain_loss_function= nn.BCEWithLogitsLoss()
  regression_loss_function = nn.L1Loss()
  model = model.to(DEVICE)
  domain_loss_function = domain_loss_function.to(DEVICE)
  regression_loss_function = regression_loss_function.to(DEVICE)
  
  # start support for EWC
  # fisher_dict[name] = {}
  # optpar_dict[name] = {} 
  # end support for EWC

  train_losses = [] # to capture train losses over training epochs
  # train_accuracy = [] # to capture train accuracy over training epochs
  val_losses = [] # to capture validation loss over epochs
  # test_losses = [] # to capture test losses 
  # test_accuracy = [] # to capture test accuracy 

  # dict_val_loss = {}
  # dict_test_loss = {}
  # train_regresion_losses = [] # to capture train losses over training epochs
  # train_domain_losses = []
  # train_accuracy = [] # to capture train accuracy over training epochs
  # valid_regresion_losses = [] # to capture validation loss
  # test_regresion_losses = [] # to capture test losses 
  # total_test_regression_loss =[]
  # total_valid_regression_loss =[]
  print(f'----------------------training started for {name}-----------------')
  for epoch in range(EPOCHS):
    print("EPOCH:", epoch+1)
    # train_model(model, DEVICE, train_iterator, optimizer, epoch) # single model
    avg_epoch_loss= train_model(model, DEVICE, dict_iterator[name]['train_iterator'], optimizer, epoch)
    train_losses.append(avg_epoch_loss)
    # fisher_dict[name] = fisher_dict_m
    # optpar_dict[name] = optpar_dict_m
    
    # start support for EWC
    # for param_name, param in model.named_parameters():
    #   fisher_dict[name][param_name] = param.data.clone()
      # optpar_dict[name][param_name] = param.grad.data.clone().pow(2)
    # end support for EWC

    # test_model(model, DEVICE, valid_iterator, mode = 'val')# single model
    avg_epoch_valid_loss = test_model(model, DEVICE, dict_iterator[name]['val_iterator'], mode = 'val')
    val_losses.append(avg_epoch_valid_loss)
  # testing the model when all epochs are finished (outsied epoch loop)
  # test_model(model, DEVICE, test_iterator, mode = 'test')# single model
  test_loss = test_model(model, DEVICE, dict_iterator[name]['test_iterator'], mode = 'test')

  dict_non_dann_losses_list [name] = {'train_losses' : train_losses, 'val_losses': val_losses, 'test_loss' : test_loss }
  # dict_val_loss[name]  = val_losses
  # dict_test_loss[name] = test_losses




  model_name = name + "_" +str(time.strftime("%d_%m"))+ "_non_dann_"+EMBEDDING_TO_BE_USED+".pt"
  torch.save(model.state_dict(), os.path.join(MODEL_DIR, model_name))
  dict_non_dann_model_saved[name]= model_name
  print(f'----------------------training complete for {name}-----------------')
for name, values in dict_non_dann_losses_list.items():
    # print ("loss for ", name, " \t:\t\t\t", values['train_losses'], values['val_losses'], values['test_loss'])
    print ("test loss for ", name, " \t:\t\t\t",  values['test_loss'])

# print(dict_val_loss.items())
# print(dict_test_loss.items())

----------------------training started for EI_anger-----------------
EPOCH: 1


Loss=0.14446017146110535 Batch_id=212 Epoch Average loss=0.0312: 100%|██████████| 213/213 [00:01<00:00, 132.65it/s]


VALIDATION LOSS (Average) : 0.37218
EPOCH: 2


Loss=0.15165357291698456 Batch_id=212 Epoch Average loss=0.0224: 100%|██████████| 213/213 [00:01<00:00, 130.46it/s]


VALIDATION LOSS (Average) : 0.375485
EPOCH: 3


Loss=0.21485623717308044 Batch_id=212 Epoch Average loss=0.0221: 100%|██████████| 213/213 [00:01<00:00, 133.55it/s]


VALIDATION LOSS (Average) : 0.34133
EPOCH: 4


Loss=0.10235917568206787 Batch_id=212 Epoch Average loss=0.0212: 100%|██████████| 213/213 [00:01<00:00, 131.79it/s]


VALIDATION LOSS (Average) : 0.349688
EPOCH: 5


Loss=0.10656140744686127 Batch_id=212 Epoch Average loss=0.0208: 100%|██████████| 213/213 [00:01<00:00, 129.35it/s]


VALIDATION LOSS (Average) : 0.35503
EPOCH: 6


Loss=0.184834823012352 Batch_id=212 Epoch Average loss=0.0204: 100%|██████████| 213/213 [00:01<00:00, 134.12it/s]


VALIDATION LOSS (Average) : 0.338722
EPOCH: 7


Loss=0.09591052681207657 Batch_id=212 Epoch Average loss=0.0204: 100%|██████████| 213/213 [00:01<00:00, 133.36it/s]


VALIDATION LOSS (Average) : 0.341583
EPOCH: 8


Loss=0.10433168709278107 Batch_id=212 Epoch Average loss=0.0197: 100%|██████████| 213/213 [00:01<00:00, 134.35it/s]


VALIDATION LOSS (Average) : 0.327457
EPOCH: 9


Loss=0.13656632602214813 Batch_id=212 Epoch Average loss=0.0195: 100%|██████████| 213/213 [00:01<00:00, 134.07it/s]


VALIDATION LOSS (Average) : 0.314883
EPOCH: 10


Loss=0.11220662295818329 Batch_id=212 Epoch Average loss=0.0191: 100%|██████████| 213/213 [00:01<00:00, 131.63it/s]


VALIDATION LOSS (Average) : 0.316755
EPOCH: 11


Loss=0.11265118420124054 Batch_id=212 Epoch Average loss=0.0186: 100%|██████████| 213/213 [00:01<00:00, 132.76it/s]


VALIDATION LOSS (Average) : 0.313058
EPOCH: 12


Loss=0.09471019357442856 Batch_id=212 Epoch Average loss=0.0189: 100%|██████████| 213/213 [00:01<00:00, 133.84it/s]


VALIDATION LOSS (Average) : 0.309431
EPOCH: 13


Loss=0.17980635166168213 Batch_id=212 Epoch Average loss=0.0183: 100%|██████████| 213/213 [00:01<00:00, 130.50it/s]


VALIDATION LOSS (Average) : 0.305793
EPOCH: 14


Loss=0.2218531221151352 Batch_id=212 Epoch Average loss=0.0182: 100%|██████████| 213/213 [00:01<00:00, 131.97it/s]


VALIDATION LOSS (Average) : 0.292348
EPOCH: 15


Loss=0.14487916231155396 Batch_id=212 Epoch Average loss=0.0179: 100%|██████████| 213/213 [00:01<00:00, 131.65it/s]


VALIDATION LOSS (Average) : 0.301546
EPOCH: 16


Loss=0.12288036197423935 Batch_id=212 Epoch Average loss=0.0178: 100%|██████████| 213/213 [00:01<00:00, 133.73it/s]


VALIDATION LOSS (Average) : 0.286071
EPOCH: 17


Loss=0.11521485447883606 Batch_id=212 Epoch Average loss=0.0177: 100%|██████████| 213/213 [00:01<00:00, 117.52it/s]


VALIDATION LOSS (Average) : 0.281377
EPOCH: 18


Loss=0.10480823367834091 Batch_id=212 Epoch Average loss=0.0175: 100%|██████████| 213/213 [00:01<00:00, 110.77it/s]


VALIDATION LOSS (Average) : 0.285698
EPOCH: 19


Loss=0.17464883625507355 Batch_id=212 Epoch Average loss=0.0171: 100%|██████████| 213/213 [00:02<00:00, 105.89it/s]


VALIDATION LOSS (Average) : 0.281263
EPOCH: 20


Loss=0.152290940284729 Batch_id=212 Epoch Average loss=0.0168: 100%|██████████| 213/213 [00:01<00:00, 118.75it/s]


VALIDATION LOSS (Average) : 0.273709
EPOCH: 21


Loss=0.12985171377658844 Batch_id=212 Epoch Average loss=0.0165: 100%|██████████| 213/213 [00:01<00:00, 131.16it/s]


VALIDATION LOSS (Average) : 0.272627
EPOCH: 22


Loss=0.11490219831466675 Batch_id=212 Epoch Average loss=0.0166: 100%|██████████| 213/213 [00:01<00:00, 131.20it/s]


VALIDATION LOSS (Average) : 0.264012
EPOCH: 23


Loss=0.15564021468162537 Batch_id=212 Epoch Average loss=0.0162: 100%|██████████| 213/213 [00:01<00:00, 131.87it/s]


VALIDATION LOSS (Average) : 0.264019
EPOCH: 24


Loss=0.1021035760641098 Batch_id=212 Epoch Average loss=0.0161: 100%|██████████| 213/213 [00:01<00:00, 129.93it/s]


VALIDATION LOSS (Average) : 0.269039
EPOCH: 25


Loss=0.13493096828460693 Batch_id=212 Epoch Average loss=0.0160: 100%|██████████| 213/213 [00:01<00:00, 131.31it/s]


VALIDATION LOSS (Average) : 0.26446
EPOCH: 26


Loss=0.14421167969703674 Batch_id=212 Epoch Average loss=0.0161: 100%|██████████| 213/213 [00:01<00:00, 131.11it/s]


VALIDATION LOSS (Average) : 0.259109
EPOCH: 27


Loss=0.09002096951007843 Batch_id=212 Epoch Average loss=0.0159: 100%|██████████| 213/213 [00:01<00:00, 129.38it/s]


VALIDATION LOSS (Average) : 0.257982
EPOCH: 28


Loss=0.06882177293300629 Batch_id=212 Epoch Average loss=0.0154: 100%|██████████| 213/213 [00:01<00:00, 126.51it/s]


VALIDATION LOSS (Average) : 0.248417
EPOCH: 29


Loss=0.10523095726966858 Batch_id=212 Epoch Average loss=0.0153: 100%|██████████| 213/213 [00:01<00:00, 130.35it/s]


VALIDATION LOSS (Average) : 0.245505
EPOCH: 30


Loss=0.13481171429157257 Batch_id=212 Epoch Average loss=0.0150: 100%|██████████| 213/213 [00:01<00:00, 125.79it/s]


VALIDATION LOSS (Average) : 0.259711
EPOCH: 31


Loss=0.1505657136440277 Batch_id=212 Epoch Average loss=0.0148: 100%|██████████| 213/213 [00:01<00:00, 128.83it/s]


VALIDATION LOSS (Average) : 0.251374
EPOCH: 32


Loss=0.1151944100856781 Batch_id=212 Epoch Average loss=0.0145: 100%|██████████| 213/213 [00:01<00:00, 129.83it/s]


VALIDATION LOSS (Average) : 0.244308
EPOCH: 33


Loss=0.1657119244337082 Batch_id=212 Epoch Average loss=0.0145: 100%|██████████| 213/213 [00:01<00:00, 126.98it/s]


VALIDATION LOSS (Average) : 0.237581
EPOCH: 34


Loss=0.18180300295352936 Batch_id=212 Epoch Average loss=0.0145: 100%|██████████| 213/213 [00:01<00:00, 129.75it/s]


VALIDATION LOSS (Average) : 0.238416
EPOCH: 35


Loss=0.17090992629528046 Batch_id=212 Epoch Average loss=0.0144: 100%|██████████| 213/213 [00:01<00:00, 128.05it/s]


VALIDATION LOSS (Average) : 0.250492
EPOCH: 36


Loss=0.12145606428384781 Batch_id=212 Epoch Average loss=0.0141: 100%|██████████| 213/213 [00:01<00:00, 128.15it/s]


VALIDATION LOSS (Average) : 0.243763
EPOCH: 37


Loss=0.13703033328056335 Batch_id=212 Epoch Average loss=0.0143: 100%|██████████| 213/213 [00:01<00:00, 131.50it/s]


VALIDATION LOSS (Average) : 0.246306
EPOCH: 38


Loss=0.06718603521585464 Batch_id=212 Epoch Average loss=0.0139: 100%|██████████| 213/213 [00:01<00:00, 127.31it/s]


VALIDATION LOSS (Average) : 0.234714
EPOCH: 39


Loss=0.08251848071813583 Batch_id=212 Epoch Average loss=0.0137: 100%|██████████| 213/213 [00:01<00:00, 129.67it/s]


VALIDATION LOSS (Average) : 0.231704
EPOCH: 40


Loss=0.1397419422864914 Batch_id=212 Epoch Average loss=0.0139: 100%|██████████| 213/213 [00:01<00:00, 131.06it/s]


VALIDATION LOSS (Average) : 0.230973
EPOCH: 41


Loss=0.1259678602218628 Batch_id=212 Epoch Average loss=0.0135: 100%|██████████| 213/213 [00:01<00:00, 128.39it/s]


VALIDATION LOSS (Average) : 0.22598
EPOCH: 42


Loss=0.10140269249677658 Batch_id=212 Epoch Average loss=0.0134: 100%|██████████| 213/213 [00:01<00:00, 128.65it/s]


VALIDATION LOSS (Average) : 0.230472
EPOCH: 43


Loss=0.15616118907928467 Batch_id=212 Epoch Average loss=0.0133: 100%|██████████| 213/213 [00:01<00:00, 127.64it/s]


VALIDATION LOSS (Average) : 0.226854
EPOCH: 44


Loss=0.10043127834796906 Batch_id=212 Epoch Average loss=0.0132: 100%|██████████| 213/213 [00:01<00:00, 128.02it/s]


VALIDATION LOSS (Average) : 0.232834
EPOCH: 45


Loss=0.07525426149368286 Batch_id=212 Epoch Average loss=0.0131: 100%|██████████| 213/213 [00:01<00:00, 128.26it/s]


VALIDATION LOSS (Average) : 0.243636
EPOCH: 46


Loss=0.1197650283575058 Batch_id=212 Epoch Average loss=0.0132: 100%|██████████| 213/213 [00:01<00:00, 128.38it/s]


VALIDATION LOSS (Average) : 0.247783
EPOCH: 47


Loss=0.08187585324048996 Batch_id=212 Epoch Average loss=0.0129: 100%|██████████| 213/213 [00:01<00:00, 124.42it/s]


VALIDATION LOSS (Average) : 0.239398
EPOCH: 48


Loss=0.06904847919940948 Batch_id=212 Epoch Average loss=0.0130: 100%|██████████| 213/213 [00:01<00:00, 124.55it/s]


VALIDATION LOSS (Average) : 0.236456
EPOCH: 49


Loss=0.09894375503063202 Batch_id=212 Epoch Average loss=0.0128: 100%|██████████| 213/213 [00:01<00:00, 126.96it/s]


VALIDATION LOSS (Average) : 0.235604
EPOCH: 50


Loss=0.07721032202243805 Batch_id=212 Epoch Average loss=0.0132: 100%|██████████| 213/213 [00:01<00:00, 129.33it/s]


VALIDATION LOSS (Average) : 0.243476
EPOCH: 51


Loss=0.13854841887950897 Batch_id=212 Epoch Average loss=0.0128: 100%|██████████| 213/213 [00:01<00:00, 124.93it/s]


VALIDATION LOSS (Average) : 0.23376
EPOCH: 52


Loss=0.12045560032129288 Batch_id=212 Epoch Average loss=0.0127: 100%|██████████| 213/213 [00:01<00:00, 127.99it/s]


VALIDATION LOSS (Average) : 0.224827
EPOCH: 53


Loss=0.17227186262607574 Batch_id=212 Epoch Average loss=0.0127: 100%|██████████| 213/213 [00:01<00:00, 127.28it/s]


VALIDATION LOSS (Average) : 0.229873
EPOCH: 54


Loss=0.0790162980556488 Batch_id=212 Epoch Average loss=0.0122: 100%|██████████| 213/213 [00:01<00:00, 129.24it/s]


VALIDATION LOSS (Average) : 0.232765
EPOCH: 55


Loss=0.14458231627941132 Batch_id=212 Epoch Average loss=0.0123: 100%|██████████| 213/213 [00:01<00:00, 117.17it/s]


VALIDATION LOSS (Average) : 0.23225
EPOCH: 56


Loss=0.11793158948421478 Batch_id=212 Epoch Average loss=0.0121: 100%|██████████| 213/213 [00:01<00:00, 106.51it/s]


VALIDATION LOSS (Average) : 0.234855
EPOCH: 57


Loss=0.08306079357862473 Batch_id=212 Epoch Average loss=0.0120: 100%|██████████| 213/213 [00:01<00:00, 124.89it/s]


VALIDATION LOSS (Average) : 0.247394
EPOCH: 58


Loss=0.07451409101486206 Batch_id=212 Epoch Average loss=0.0120: 100%|██████████| 213/213 [00:01<00:00, 126.25it/s]


VALIDATION LOSS (Average) : 0.245737
EPOCH: 59


Loss=0.12018223106861115 Batch_id=212 Epoch Average loss=0.0120: 100%|██████████| 213/213 [00:01<00:00, 125.32it/s]


VALIDATION LOSS (Average) : 0.248018
EPOCH: 60


Loss=0.08763305842876434 Batch_id=212 Epoch Average loss=0.0120: 100%|██████████| 213/213 [00:01<00:00, 120.92it/s]


VALIDATION LOSS (Average) : 0.243475
EPOCH: 61


Loss=0.054839957505464554 Batch_id=212 Epoch Average loss=0.0117: 100%|██████████| 213/213 [00:01<00:00, 122.40it/s]


VALIDATION LOSS (Average) : 0.233739
EPOCH: 62


Loss=0.12049798667430878 Batch_id=212 Epoch Average loss=0.0117: 100%|██████████| 213/213 [00:01<00:00, 122.58it/s]


VALIDATION LOSS (Average) : 0.237982
EPOCH: 63


Loss=0.121701180934906 Batch_id=212 Epoch Average loss=0.0116: 100%|██████████| 213/213 [00:01<00:00, 123.47it/s]


VALIDATION LOSS (Average) : 0.245246
EPOCH: 64


Loss=0.09506852924823761 Batch_id=212 Epoch Average loss=0.0114: 100%|██████████| 213/213 [00:01<00:00, 120.18it/s]


VALIDATION LOSS (Average) : 0.24721
EPOCH: 65


Loss=0.10582049190998077 Batch_id=212 Epoch Average loss=0.0116: 100%|██████████| 213/213 [00:01<00:00, 125.37it/s]


VALIDATION LOSS (Average) : 0.250547
EPOCH: 66


Loss=0.06157268211245537 Batch_id=212 Epoch Average loss=0.0114: 100%|██████████| 213/213 [00:01<00:00, 125.63it/s]


VALIDATION LOSS (Average) : 0.253968
EPOCH: 67


Loss=0.07903517782688141 Batch_id=212 Epoch Average loss=0.0115: 100%|██████████| 213/213 [00:01<00:00, 123.06it/s]


VALIDATION LOSS (Average) : 0.255055
EPOCH: 68


Loss=0.1293107271194458 Batch_id=212 Epoch Average loss=0.0112: 100%|██████████| 213/213 [00:01<00:00, 122.11it/s]


VALIDATION LOSS (Average) : 0.251767
EPOCH: 69


Loss=0.11140713095664978 Batch_id=212 Epoch Average loss=0.0112: 100%|██████████| 213/213 [00:01<00:00, 124.61it/s]


VALIDATION LOSS (Average) : 0.237984
EPOCH: 70


Loss=0.08093496412038803 Batch_id=212 Epoch Average loss=0.0112: 100%|██████████| 213/213 [00:01<00:00, 121.88it/s]


VALIDATION LOSS (Average) : 0.241205
EPOCH: 71


Loss=0.06952042877674103 Batch_id=212 Epoch Average loss=0.0110: 100%|██████████| 213/213 [00:01<00:00, 121.65it/s]


VALIDATION LOSS (Average) : 0.241412
EPOCH: 72


Loss=0.08046026527881622 Batch_id=212 Epoch Average loss=0.0114: 100%|██████████| 213/213 [00:01<00:00, 124.85it/s]


VALIDATION LOSS (Average) : 0.246668
EPOCH: 73


Loss=0.1025586947798729 Batch_id=212 Epoch Average loss=0.0113: 100%|██████████| 213/213 [00:01<00:00, 123.96it/s]


VALIDATION LOSS (Average) : 0.253438
EPOCH: 74


Loss=0.06815474480390549 Batch_id=212 Epoch Average loss=0.0109: 100%|██████████| 213/213 [00:01<00:00, 123.89it/s]


VALIDATION LOSS (Average) : 0.247113
EPOCH: 75


Loss=0.09659822285175323 Batch_id=212 Epoch Average loss=0.0107: 100%|██████████| 213/213 [00:01<00:00, 124.90it/s]


VALIDATION LOSS (Average) : 0.242576
EPOCH: 76


Loss=0.055993109941482544 Batch_id=212 Epoch Average loss=0.0107: 100%|██████████| 213/213 [00:01<00:00, 124.02it/s]


VALIDATION LOSS (Average) : 0.247984
EPOCH: 77


Loss=0.07746297121047974 Batch_id=212 Epoch Average loss=0.0108: 100%|██████████| 213/213 [00:01<00:00, 123.89it/s]


VALIDATION LOSS (Average) : 0.243175
EPOCH: 78


Loss=0.05202851444482803 Batch_id=212 Epoch Average loss=0.0108: 100%|██████████| 213/213 [00:01<00:00, 124.93it/s]


VALIDATION LOSS (Average) : 0.245878
EPOCH: 79


Loss=0.05613350868225098 Batch_id=212 Epoch Average loss=0.0107: 100%|██████████| 213/213 [00:01<00:00, 128.42it/s]


VALIDATION LOSS (Average) : 0.253919
EPOCH: 80


Loss=0.06478865444660187 Batch_id=212 Epoch Average loss=0.0104: 100%|██████████| 213/213 [00:01<00:00, 123.56it/s]


VALIDATION LOSS (Average) : 0.250316
EPOCH: 81


Loss=0.08085377514362335 Batch_id=212 Epoch Average loss=0.0106: 100%|██████████| 213/213 [00:01<00:00, 112.74it/s]


VALIDATION LOSS (Average) : 0.24595
EPOCH: 82


Loss=0.04960685968399048 Batch_id=212 Epoch Average loss=0.0106: 100%|██████████| 213/213 [00:01<00:00, 113.51it/s]


VALIDATION LOSS (Average) : 0.242466
EPOCH: 83


Loss=0.12935464084148407 Batch_id=212 Epoch Average loss=0.0107: 100%|██████████| 213/213 [00:01<00:00, 121.34it/s]


VALIDATION LOSS (Average) : 0.256158
EPOCH: 84


Loss=0.09500393271446228 Batch_id=212 Epoch Average loss=0.0103: 100%|██████████| 213/213 [00:01<00:00, 111.04it/s]


VALIDATION LOSS (Average) : 0.25454
EPOCH: 85


Loss=0.08488398790359497 Batch_id=212 Epoch Average loss=0.0104: 100%|██████████| 213/213 [00:01<00:00, 115.76it/s]


VALIDATION LOSS (Average) : 0.245191
EPOCH: 86


Loss=0.02821292355656624 Batch_id=212 Epoch Average loss=0.0103: 100%|██████████| 213/213 [00:01<00:00, 122.97it/s]


VALIDATION LOSS (Average) : 0.249075
EPOCH: 87


Loss=0.07359158992767334 Batch_id=212 Epoch Average loss=0.0102: 100%|██████████| 213/213 [00:01<00:00, 124.85it/s]


VALIDATION LOSS (Average) : 0.253378
EPOCH: 88


Loss=0.04243045300245285 Batch_id=212 Epoch Average loss=0.0102: 100%|██████████| 213/213 [00:01<00:00, 123.04it/s]


VALIDATION LOSS (Average) : 0.246965
EPOCH: 89


Loss=0.0783330574631691 Batch_id=212 Epoch Average loss=0.0100: 100%|██████████| 213/213 [00:01<00:00, 122.90it/s]


VALIDATION LOSS (Average) : 0.246129
EPOCH: 90


Loss=0.05688399448990822 Batch_id=212 Epoch Average loss=0.0102: 100%|██████████| 213/213 [00:01<00:00, 121.57it/s]


VALIDATION LOSS (Average) : 0.246811
EPOCH: 91


Loss=0.07282164692878723 Batch_id=212 Epoch Average loss=0.0102: 100%|██████████| 213/213 [00:01<00:00, 124.10it/s]


VALIDATION LOSS (Average) : 0.2566
EPOCH: 92


Loss=0.10287578403949738 Batch_id=212 Epoch Average loss=0.0101: 100%|██████████| 213/213 [00:01<00:00, 122.71it/s]


VALIDATION LOSS (Average) : 0.250486
EPOCH: 93


Loss=0.09415309131145477 Batch_id=212 Epoch Average loss=0.0099: 100%|██████████| 213/213 [00:01<00:00, 123.29it/s]


VALIDATION LOSS (Average) : 0.24402
EPOCH: 94


Loss=0.0609150156378746 Batch_id=212 Epoch Average loss=0.0100: 100%|██████████| 213/213 [00:01<00:00, 121.03it/s]


VALIDATION LOSS (Average) : 0.247586
EPOCH: 95


Loss=0.10843852162361145 Batch_id=212 Epoch Average loss=0.0098: 100%|██████████| 213/213 [00:01<00:00, 119.83it/s]


VALIDATION LOSS (Average) : 0.250342
EPOCH: 96


Loss=0.07633993029594421 Batch_id=212 Epoch Average loss=0.0097: 100%|██████████| 213/213 [00:01<00:00, 123.32it/s]


VALIDATION LOSS (Average) : 0.244849
EPOCH: 97


Loss=0.0503949299454689 Batch_id=212 Epoch Average loss=0.0099: 100%|██████████| 213/213 [00:01<00:00, 121.12it/s]


VALIDATION LOSS (Average) : 0.253772
EPOCH: 98


Loss=0.06152310222387314 Batch_id=212 Epoch Average loss=0.0097: 100%|██████████| 213/213 [00:01<00:00, 121.02it/s]


VALIDATION LOSS (Average) : 0.252032
EPOCH: 99


Loss=0.0703776404261589 Batch_id=212 Epoch Average loss=0.0095: 100%|██████████| 213/213 [00:01<00:00, 122.43it/s]


VALIDATION LOSS (Average) : 0.248448
EPOCH: 100


Loss=0.05974984169006348 Batch_id=212 Epoch Average loss=0.0097: 100%|██████████| 213/213 [00:01<00:00, 124.06it/s]


VALIDATION LOSS (Average) : 0.246375
TEST LOSS (Average) : 0.158792
----------------------training complete for EI_anger-----------------
----------------------training started for EI_fear-----------------
EPOCH: 1


Loss=0.18557964265346527 Batch_id=281 Epoch Average loss=0.0324: 100%|██████████| 282/282 [00:02<00:00, 123.00it/s]


VALIDATION LOSS (Average) : 0.298082
EPOCH: 2


Loss=0.15874797105789185 Batch_id=281 Epoch Average loss=0.0234: 100%|██████████| 282/282 [00:02<00:00, 123.53it/s]


VALIDATION LOSS (Average) : 0.277289
EPOCH: 3


Loss=0.1819995790719986 Batch_id=281 Epoch Average loss=0.0233: 100%|██████████| 282/282 [00:02<00:00, 120.70it/s]


VALIDATION LOSS (Average) : 0.283274
EPOCH: 4


Loss=0.17389711737632751 Batch_id=281 Epoch Average loss=0.0228: 100%|██████████| 282/282 [00:02<00:00, 121.56it/s]


VALIDATION LOSS (Average) : 0.26088
EPOCH: 5


Loss=0.15040656924247742 Batch_id=281 Epoch Average loss=0.0224: 100%|██████████| 282/282 [00:02<00:00, 120.51it/s]


VALIDATION LOSS (Average) : 0.256505
EPOCH: 6


Loss=0.2072889804840088 Batch_id=281 Epoch Average loss=0.0221: 100%|██████████| 282/282 [00:02<00:00, 123.84it/s]


VALIDATION LOSS (Average) : 0.245408
EPOCH: 7


Loss=0.13814795017242432 Batch_id=281 Epoch Average loss=0.0218: 100%|██████████| 282/282 [00:02<00:00, 120.99it/s]


VALIDATION LOSS (Average) : 0.2376
EPOCH: 8


Loss=0.20751145482063293 Batch_id=281 Epoch Average loss=0.0217: 100%|██████████| 282/282 [00:02<00:00, 107.71it/s]


VALIDATION LOSS (Average) : 0.22759
EPOCH: 9


Loss=0.1646861732006073 Batch_id=281 Epoch Average loss=0.0211: 100%|██████████| 282/282 [00:03<00:00, 93.94it/s] 


VALIDATION LOSS (Average) : 0.230497
EPOCH: 10


Loss=0.1761854588985443 Batch_id=281 Epoch Average loss=0.0208: 100%|██████████| 282/282 [00:02<00:00, 122.64it/s]


VALIDATION LOSS (Average) : 0.225196
EPOCH: 11


Loss=0.19342368841171265 Batch_id=281 Epoch Average loss=0.0209: 100%|██████████| 282/282 [00:02<00:00, 120.12it/s]


VALIDATION LOSS (Average) : 0.216993
EPOCH: 12


Loss=0.21605102717876434 Batch_id=281 Epoch Average loss=0.0208: 100%|██████████| 282/282 [00:02<00:00, 120.09it/s]


VALIDATION LOSS (Average) : 0.217953
EPOCH: 13


Loss=0.20775297284126282 Batch_id=281 Epoch Average loss=0.0204: 100%|██████████| 282/282 [00:02<00:00, 122.15it/s]


VALIDATION LOSS (Average) : 0.225601
EPOCH: 14


Loss=0.22209489345550537 Batch_id=281 Epoch Average loss=0.0202: 100%|██████████| 282/282 [00:02<00:00, 122.69it/s]


VALIDATION LOSS (Average) : 0.208592
EPOCH: 15


Loss=0.20669308304786682 Batch_id=281 Epoch Average loss=0.0199: 100%|██████████| 282/282 [00:02<00:00, 123.02it/s]


VALIDATION LOSS (Average) : 0.22745
EPOCH: 16


Loss=0.14774766564369202 Batch_id=281 Epoch Average loss=0.0194: 100%|██████████| 282/282 [00:02<00:00, 120.53it/s]


VALIDATION LOSS (Average) : 0.212591
EPOCH: 17


Loss=0.19255883991718292 Batch_id=281 Epoch Average loss=0.0193: 100%|██████████| 282/282 [00:02<00:00, 121.43it/s]


VALIDATION LOSS (Average) : 0.20802
EPOCH: 18


Loss=0.11640098690986633 Batch_id=281 Epoch Average loss=0.0191: 100%|██████████| 282/282 [00:02<00:00, 121.45it/s]


VALIDATION LOSS (Average) : 0.206349
EPOCH: 19


Loss=0.171068474650383 Batch_id=281 Epoch Average loss=0.0188: 100%|██████████| 282/282 [00:02<00:00, 120.86it/s]


VALIDATION LOSS (Average) : 0.199146
EPOCH: 20


Loss=0.18586799502372742 Batch_id=281 Epoch Average loss=0.0188: 100%|██████████| 282/282 [00:02<00:00, 120.72it/s]


VALIDATION LOSS (Average) : 0.189725
EPOCH: 21


Loss=0.12837287783622742 Batch_id=281 Epoch Average loss=0.0185: 100%|██████████| 282/282 [00:02<00:00, 119.73it/s]


VALIDATION LOSS (Average) : 0.184553
EPOCH: 22


Loss=0.1264829784631729 Batch_id=281 Epoch Average loss=0.0182: 100%|██████████| 282/282 [00:02<00:00, 119.62it/s]


VALIDATION LOSS (Average) : 0.192234
EPOCH: 23


Loss=0.20749469101428986 Batch_id=281 Epoch Average loss=0.0180: 100%|██████████| 282/282 [00:02<00:00, 121.79it/s]


VALIDATION LOSS (Average) : 0.204646
EPOCH: 24


Loss=0.13030096888542175 Batch_id=281 Epoch Average loss=0.0179: 100%|██████████| 282/282 [00:02<00:00, 120.35it/s]


VALIDATION LOSS (Average) : 0.204639
EPOCH: 25


Loss=0.1212918609380722 Batch_id=281 Epoch Average loss=0.0178: 100%|██████████| 282/282 [00:02<00:00, 121.35it/s]


VALIDATION LOSS (Average) : 0.192649
EPOCH: 26


Loss=0.08358285576105118 Batch_id=281 Epoch Average loss=0.0173: 100%|██████████| 282/282 [00:02<00:00, 117.08it/s]


VALIDATION LOSS (Average) : 0.184224
EPOCH: 27


Loss=0.15556466579437256 Batch_id=281 Epoch Average loss=0.0168: 100%|██████████| 282/282 [00:02<00:00, 120.29it/s]


VALIDATION LOSS (Average) : 0.190566
EPOCH: 28


Loss=0.16993220150470734 Batch_id=281 Epoch Average loss=0.0171: 100%|██████████| 282/282 [00:02<00:00, 121.10it/s]


VALIDATION LOSS (Average) : 0.189729
EPOCH: 29


Loss=0.1669681817293167 Batch_id=281 Epoch Average loss=0.0170: 100%|██████████| 282/282 [00:02<00:00, 120.64it/s]


VALIDATION LOSS (Average) : 0.188555
EPOCH: 30


Loss=0.13413068652153015 Batch_id=281 Epoch Average loss=0.0166: 100%|██████████| 282/282 [00:02<00:00, 118.74it/s]


VALIDATION LOSS (Average) : 0.176593
EPOCH: 31


Loss=0.14821194112300873 Batch_id=281 Epoch Average loss=0.0161: 100%|██████████| 282/282 [00:02<00:00, 119.42it/s]


VALIDATION LOSS (Average) : 0.179726
EPOCH: 32


Loss=0.1221856102347374 Batch_id=281 Epoch Average loss=0.0161: 100%|██████████| 282/282 [00:02<00:00, 115.41it/s]


VALIDATION LOSS (Average) : 0.178807
EPOCH: 33


Loss=0.19789564609527588 Batch_id=281 Epoch Average loss=0.0161: 100%|██████████| 282/282 [00:02<00:00, 111.81it/s]


VALIDATION LOSS (Average) : 0.178325
EPOCH: 34


Loss=0.13752451539039612 Batch_id=281 Epoch Average loss=0.0161: 100%|██████████| 282/282 [00:02<00:00, 114.76it/s]


VALIDATION LOSS (Average) : 0.164412
EPOCH: 35


Loss=0.11172246187925339 Batch_id=281 Epoch Average loss=0.0161: 100%|██████████| 282/282 [00:02<00:00, 103.00it/s]


VALIDATION LOSS (Average) : 0.16811
EPOCH: 36


Loss=0.08609116822481155 Batch_id=281 Epoch Average loss=0.0154: 100%|██████████| 282/282 [00:02<00:00, 116.95it/s]


VALIDATION LOSS (Average) : 0.162909
EPOCH: 37


Loss=0.11458835005760193 Batch_id=281 Epoch Average loss=0.0156: 100%|██████████| 282/282 [00:02<00:00, 117.04it/s]


VALIDATION LOSS (Average) : 0.153757
EPOCH: 38


Loss=0.09677494317293167 Batch_id=281 Epoch Average loss=0.0158: 100%|██████████| 282/282 [00:02<00:00, 119.85it/s]


VALIDATION LOSS (Average) : 0.165404
EPOCH: 39


Loss=0.16797491908073425 Batch_id=281 Epoch Average loss=0.0155: 100%|██████████| 282/282 [00:02<00:00, 118.38it/s]


VALIDATION LOSS (Average) : 0.160762
EPOCH: 40


Loss=0.17939503490924835 Batch_id=281 Epoch Average loss=0.0150: 100%|██████████| 282/282 [00:02<00:00, 121.49it/s]


VALIDATION LOSS (Average) : 0.157085
EPOCH: 41


Loss=0.13807755708694458 Batch_id=281 Epoch Average loss=0.0150: 100%|██████████| 282/282 [00:02<00:00, 119.65it/s]


VALIDATION LOSS (Average) : 0.154023
EPOCH: 42


Loss=0.11998063325881958 Batch_id=281 Epoch Average loss=0.0148: 100%|██████████| 282/282 [00:02<00:00, 118.96it/s]


VALIDATION LOSS (Average) : 0.158549
EPOCH: 43


Loss=0.1394682228565216 Batch_id=281 Epoch Average loss=0.0145: 100%|██████████| 282/282 [00:02<00:00, 120.31it/s]


VALIDATION LOSS (Average) : 0.159844
EPOCH: 44


Loss=0.10843376815319061 Batch_id=281 Epoch Average loss=0.0146: 100%|██████████| 282/282 [00:02<00:00, 118.71it/s]


VALIDATION LOSS (Average) : 0.159537
EPOCH: 45


Loss=0.1299315094947815 Batch_id=281 Epoch Average loss=0.0147: 100%|██████████| 282/282 [00:02<00:00, 117.27it/s]


VALIDATION LOSS (Average) : 0.147581
EPOCH: 46


Loss=0.16351431608200073 Batch_id=281 Epoch Average loss=0.0143: 100%|██████████| 282/282 [00:02<00:00, 117.90it/s]


VALIDATION LOSS (Average) : 0.15885
EPOCH: 47


Loss=0.15209028124809265 Batch_id=281 Epoch Average loss=0.0145: 100%|██████████| 282/282 [00:02<00:00, 120.03it/s]


VALIDATION LOSS (Average) : 0.149795
EPOCH: 48


Loss=0.1170397624373436 Batch_id=281 Epoch Average loss=0.0143: 100%|██████████| 282/282 [00:02<00:00, 118.11it/s]


VALIDATION LOSS (Average) : 0.163491
EPOCH: 49


Loss=0.07471536099910736 Batch_id=281 Epoch Average loss=0.0141: 100%|██████████| 282/282 [00:02<00:00, 117.80it/s]


VALIDATION LOSS (Average) : 0.167639
EPOCH: 50


Loss=0.13616007566452026 Batch_id=281 Epoch Average loss=0.0138: 100%|██████████| 282/282 [00:02<00:00, 119.34it/s]


VALIDATION LOSS (Average) : 0.161547
EPOCH: 51


Loss=0.1486130952835083 Batch_id=281 Epoch Average loss=0.0142: 100%|██████████| 282/282 [00:02<00:00, 121.32it/s]


VALIDATION LOSS (Average) : 0.161263
EPOCH: 52


Loss=0.0946682021021843 Batch_id=281 Epoch Average loss=0.0137: 100%|██████████| 282/282 [00:02<00:00, 120.73it/s]


VALIDATION LOSS (Average) : 0.160139
EPOCH: 53


Loss=0.09659033268690109 Batch_id=281 Epoch Average loss=0.0139: 100%|██████████| 282/282 [00:02<00:00, 119.41it/s]


VALIDATION LOSS (Average) : 0.15536
EPOCH: 54


Loss=0.08400662243366241 Batch_id=281 Epoch Average loss=0.0139: 100%|██████████| 282/282 [00:02<00:00, 119.26it/s]


VALIDATION LOSS (Average) : 0.149635
EPOCH: 55


Loss=0.13943341374397278 Batch_id=281 Epoch Average loss=0.0138: 100%|██████████| 282/282 [00:02<00:00, 120.72it/s]


VALIDATION LOSS (Average) : 0.166957
EPOCH: 56


Loss=0.12887518107891083 Batch_id=281 Epoch Average loss=0.0132: 100%|██████████| 282/282 [00:02<00:00, 120.21it/s]


VALIDATION LOSS (Average) : 0.151513
EPOCH: 57


Loss=0.1408393830060959 Batch_id=281 Epoch Average loss=0.0135: 100%|██████████| 282/282 [00:02<00:00, 117.09it/s]


VALIDATION LOSS (Average) : 0.150645
EPOCH: 58


Loss=0.11726800352334976 Batch_id=281 Epoch Average loss=0.0132: 100%|██████████| 282/282 [00:02<00:00, 119.73it/s]


VALIDATION LOSS (Average) : 0.15895
EPOCH: 59


Loss=0.05654486268758774 Batch_id=281 Epoch Average loss=0.0133: 100%|██████████| 282/282 [00:02<00:00, 119.67it/s]


VALIDATION LOSS (Average) : 0.156222
EPOCH: 60


Loss=0.10805951058864594 Batch_id=281 Epoch Average loss=0.0132: 100%|██████████| 282/282 [00:02<00:00, 119.87it/s]


VALIDATION LOSS (Average) : 0.148832
EPOCH: 61


Loss=0.13177722692489624 Batch_id=281 Epoch Average loss=0.0133: 100%|██████████| 282/282 [00:02<00:00, 121.87it/s]


VALIDATION LOSS (Average) : 0.137266
EPOCH: 62


Loss=0.15018488466739655 Batch_id=281 Epoch Average loss=0.0130: 100%|██████████| 282/282 [00:02<00:00, 116.25it/s]


VALIDATION LOSS (Average) : 0.143436
EPOCH: 63


Loss=0.08894417434930801 Batch_id=281 Epoch Average loss=0.0129: 100%|██████████| 282/282 [00:02<00:00, 118.15it/s]


VALIDATION LOSS (Average) : 0.151439
EPOCH: 64


Loss=0.0720006674528122 Batch_id=281 Epoch Average loss=0.0128: 100%|██████████| 282/282 [00:02<00:00, 118.87it/s]


VALIDATION LOSS (Average) : 0.148057
EPOCH: 65


Loss=0.12316878139972687 Batch_id=281 Epoch Average loss=0.0127: 100%|██████████| 282/282 [00:02<00:00, 115.26it/s]


VALIDATION LOSS (Average) : 0.152464
EPOCH: 66


Loss=0.09350700676441193 Batch_id=281 Epoch Average loss=0.0127: 100%|██████████| 282/282 [00:02<00:00, 116.01it/s]


VALIDATION LOSS (Average) : 0.156927
EPOCH: 67


Loss=0.09627734124660492 Batch_id=281 Epoch Average loss=0.0129: 100%|██████████| 282/282 [00:02<00:00, 119.63it/s]


VALIDATION LOSS (Average) : 0.144932
EPOCH: 68


Loss=0.09052389115095139 Batch_id=281 Epoch Average loss=0.0126: 100%|██████████| 282/282 [00:02<00:00, 118.75it/s]


VALIDATION LOSS (Average) : 0.156629
EPOCH: 69


Loss=0.0828281044960022 Batch_id=281 Epoch Average loss=0.0126: 100%|██████████| 282/282 [00:02<00:00, 119.67it/s]


VALIDATION LOSS (Average) : 0.151735
EPOCH: 70


Loss=0.10950573533773422 Batch_id=281 Epoch Average loss=0.0124: 100%|██████████| 282/282 [00:02<00:00, 118.14it/s]


VALIDATION LOSS (Average) : 0.150156
EPOCH: 71


Loss=0.06596627831459045 Batch_id=281 Epoch Average loss=0.0125: 100%|██████████| 282/282 [00:02<00:00, 116.40it/s]


VALIDATION LOSS (Average) : 0.150946
EPOCH: 72


Loss=0.06464093178510666 Batch_id=281 Epoch Average loss=0.0121: 100%|██████████| 282/282 [00:02<00:00, 118.51it/s]


VALIDATION LOSS (Average) : 0.15395
EPOCH: 73


Loss=0.13567593693733215 Batch_id=281 Epoch Average loss=0.0124: 100%|██████████| 282/282 [00:02<00:00, 118.60it/s]


VALIDATION LOSS (Average) : 0.150596
EPOCH: 74


Loss=0.08573513478040695 Batch_id=281 Epoch Average loss=0.0121: 100%|██████████| 282/282 [00:02<00:00, 116.92it/s]


VALIDATION LOSS (Average) : 0.145498
EPOCH: 75


Loss=0.24373066425323486 Batch_id=281 Epoch Average loss=0.0123: 100%|██████████| 282/282 [00:02<00:00, 115.86it/s]


VALIDATION LOSS (Average) : 0.145807
EPOCH: 76


Loss=0.09281758219003677 Batch_id=281 Epoch Average loss=0.0121: 100%|██████████| 282/282 [00:02<00:00, 117.48it/s]


VALIDATION LOSS (Average) : 0.149201
EPOCH: 77


Loss=0.07230603694915771 Batch_id=281 Epoch Average loss=0.0119: 100%|██████████| 282/282 [00:02<00:00, 118.28it/s]


VALIDATION LOSS (Average) : 0.146756
EPOCH: 78


Loss=0.12400104105472565 Batch_id=281 Epoch Average loss=0.0121: 100%|██████████| 282/282 [00:02<00:00, 114.57it/s]


VALIDATION LOSS (Average) : 0.148953
EPOCH: 79


Loss=0.0804964080452919 Batch_id=281 Epoch Average loss=0.0118: 100%|██████████| 282/282 [00:02<00:00, 118.63it/s]


VALIDATION LOSS (Average) : 0.145506
EPOCH: 80


Loss=0.1240512952208519 Batch_id=281 Epoch Average loss=0.0118: 100%|██████████| 282/282 [00:02<00:00, 113.34it/s]


VALIDATION LOSS (Average) : 0.141913
EPOCH: 81


Loss=0.056522417813539505 Batch_id=281 Epoch Average loss=0.0120: 100%|██████████| 282/282 [00:02<00:00, 104.06it/s]


VALIDATION LOSS (Average) : 0.149625
EPOCH: 82


Loss=0.105085089802742 Batch_id=281 Epoch Average loss=0.0118: 100%|██████████| 282/282 [00:02<00:00, 109.50it/s]


VALIDATION LOSS (Average) : 0.15123
EPOCH: 83


Loss=0.1069638580083847 Batch_id=281 Epoch Average loss=0.0117: 100%|██████████| 282/282 [00:02<00:00, 116.22it/s]


VALIDATION LOSS (Average) : 0.146197
EPOCH: 84


Loss=0.08213482797145844 Batch_id=281 Epoch Average loss=0.0116: 100%|██████████| 282/282 [00:02<00:00, 117.89it/s]


VALIDATION LOSS (Average) : 0.155131
EPOCH: 85


Loss=0.06931069493293762 Batch_id=281 Epoch Average loss=0.0114: 100%|██████████| 282/282 [00:02<00:00, 117.42it/s]


VALIDATION LOSS (Average) : 0.148064
EPOCH: 86


Loss=0.11760568618774414 Batch_id=281 Epoch Average loss=0.0115: 100%|██████████| 282/282 [00:02<00:00, 118.29it/s]


VALIDATION LOSS (Average) : 0.153293
EPOCH: 87


Loss=0.07183848321437836 Batch_id=281 Epoch Average loss=0.0115: 100%|██████████| 282/282 [00:02<00:00, 117.01it/s]


VALIDATION LOSS (Average) : 0.147282
EPOCH: 88


Loss=0.04798547178506851 Batch_id=281 Epoch Average loss=0.0114: 100%|██████████| 282/282 [00:02<00:00, 115.86it/s]


VALIDATION LOSS (Average) : 0.152621
EPOCH: 89


Loss=0.09114647656679153 Batch_id=281 Epoch Average loss=0.0111: 100%|██████████| 282/282 [00:02<00:00, 118.83it/s]


VALIDATION LOSS (Average) : 0.146502
EPOCH: 90


Loss=0.14713358879089355 Batch_id=281 Epoch Average loss=0.0111: 100%|██████████| 282/282 [00:02<00:00, 115.99it/s]


VALIDATION LOSS (Average) : 0.155081
EPOCH: 91


Loss=0.04823409393429756 Batch_id=281 Epoch Average loss=0.0113: 100%|██████████| 282/282 [00:02<00:00, 118.37it/s]


VALIDATION LOSS (Average) : 0.152302
EPOCH: 92


Loss=0.07152917981147766 Batch_id=281 Epoch Average loss=0.0114: 100%|██████████| 282/282 [00:02<00:00, 118.25it/s]


VALIDATION LOSS (Average) : 0.150692
EPOCH: 93


Loss=0.10175637155771255 Batch_id=281 Epoch Average loss=0.0111: 100%|██████████| 282/282 [00:02<00:00, 115.90it/s]


VALIDATION LOSS (Average) : 0.144404
EPOCH: 94


Loss=0.05288875475525856 Batch_id=281 Epoch Average loss=0.0110: 100%|██████████| 282/282 [00:02<00:00, 118.18it/s]


VALIDATION LOSS (Average) : 0.151179
EPOCH: 95


Loss=0.10458128154277802 Batch_id=281 Epoch Average loss=0.0112: 100%|██████████| 282/282 [00:02<00:00, 116.60it/s]


VALIDATION LOSS (Average) : 0.148243
EPOCH: 96


Loss=0.10047220438718796 Batch_id=281 Epoch Average loss=0.0110: 100%|██████████| 282/282 [00:02<00:00, 114.39it/s]


VALIDATION LOSS (Average) : 0.148929
EPOCH: 97


Loss=0.07405298948287964 Batch_id=281 Epoch Average loss=0.0111: 100%|██████████| 282/282 [00:02<00:00, 115.27it/s]


VALIDATION LOSS (Average) : 0.15718
EPOCH: 98


Loss=0.025773711502552032 Batch_id=281 Epoch Average loss=0.0112: 100%|██████████| 282/282 [00:02<00:00, 117.48it/s]


VALIDATION LOSS (Average) : 0.154967
EPOCH: 99


Loss=0.14373788237571716 Batch_id=281 Epoch Average loss=0.0108: 100%|██████████| 282/282 [00:02<00:00, 115.14it/s]


VALIDATION LOSS (Average) : 0.156381
EPOCH: 100


Loss=0.09004268050193787 Batch_id=281 Epoch Average loss=0.0109: 100%|██████████| 282/282 [00:02<00:00, 117.57it/s]


VALIDATION LOSS (Average) : 0.151128
TEST LOSS (Average) : 0.111371
----------------------training complete for EI_fear-----------------
----------------------training started for EI_sadness-----------------
EPOCH: 1


Loss=0.14735163748264313 Batch_id=191 Epoch Average loss=0.0243: 100%|██████████| 192/192 [00:01<00:00, 113.18it/s]


VALIDATION LOSS (Average) : 0.250505
EPOCH: 2


Loss=0.16504526138305664 Batch_id=191 Epoch Average loss=0.0240: 100%|██████████| 192/192 [00:01<00:00, 114.76it/s]


VALIDATION LOSS (Average) : 0.257025
EPOCH: 3


Loss=0.19069820642471313 Batch_id=191 Epoch Average loss=0.0232: 100%|██████████| 192/192 [00:01<00:00, 113.94it/s]


VALIDATION LOSS (Average) : 0.245911
EPOCH: 4


Loss=0.20511090755462646 Batch_id=191 Epoch Average loss=0.0229: 100%|██████████| 192/192 [00:01<00:00, 112.85it/s]


VALIDATION LOSS (Average) : 0.257461
EPOCH: 5


Loss=0.14757290482521057 Batch_id=191 Epoch Average loss=0.0224: 100%|██████████| 192/192 [00:01<00:00, 115.42it/s]


VALIDATION LOSS (Average) : 0.228498
EPOCH: 6


Loss=0.2189410924911499 Batch_id=191 Epoch Average loss=0.0221: 100%|██████████| 192/192 [00:01<00:00, 114.50it/s]


VALIDATION LOSS (Average) : 0.242436
EPOCH: 7


Loss=0.15517543256282806 Batch_id=191 Epoch Average loss=0.0218: 100%|██████████| 192/192 [00:01<00:00, 113.03it/s]


VALIDATION LOSS (Average) : 0.255462
EPOCH: 8


Loss=0.20520958304405212 Batch_id=191 Epoch Average loss=0.0215: 100%|██████████| 192/192 [00:01<00:00, 113.89it/s]


VALIDATION LOSS (Average) : 0.243276
EPOCH: 9


Loss=0.109791100025177 Batch_id=191 Epoch Average loss=0.0211: 100%|██████████| 192/192 [00:01<00:00, 114.65it/s]


VALIDATION LOSS (Average) : 0.234107
EPOCH: 10


Loss=0.1427367478609085 Batch_id=191 Epoch Average loss=0.0211: 100%|██████████| 192/192 [00:01<00:00, 114.14it/s]


VALIDATION LOSS (Average) : 0.249987
EPOCH: 11


Loss=0.16442379355430603 Batch_id=191 Epoch Average loss=0.0208: 100%|██████████| 192/192 [00:01<00:00, 119.35it/s]


VALIDATION LOSS (Average) : 0.232557
EPOCH: 12


Loss=0.19030143320560455 Batch_id=191 Epoch Average loss=0.0209: 100%|██████████| 192/192 [00:01<00:00, 115.66it/s]


VALIDATION LOSS (Average) : 0.240037
EPOCH: 13


Loss=0.18150293827056885 Batch_id=191 Epoch Average loss=0.0206: 100%|██████████| 192/192 [00:01<00:00, 116.00it/s]


VALIDATION LOSS (Average) : 0.229855
EPOCH: 14


Loss=0.17113877832889557 Batch_id=191 Epoch Average loss=0.0202: 100%|██████████| 192/192 [00:01<00:00, 116.50it/s]


VALIDATION LOSS (Average) : 0.237654
EPOCH: 15


Loss=0.11722350120544434 Batch_id=191 Epoch Average loss=0.0199: 100%|██████████| 192/192 [00:01<00:00, 115.59it/s]


VALIDATION LOSS (Average) : 0.235834
EPOCH: 16


Loss=0.21376627683639526 Batch_id=191 Epoch Average loss=0.0195: 100%|██████████| 192/192 [00:01<00:00, 114.83it/s]


VALIDATION LOSS (Average) : 0.21483
EPOCH: 17


Loss=0.11896148324012756 Batch_id=191 Epoch Average loss=0.0194: 100%|██████████| 192/192 [00:01<00:00, 117.84it/s]


VALIDATION LOSS (Average) : 0.224734
EPOCH: 18


Loss=0.054428715258836746 Batch_id=191 Epoch Average loss=0.0196: 100%|██████████| 192/192 [00:01<00:00, 113.76it/s]


VALIDATION LOSS (Average) : 0.208274
EPOCH: 19


Loss=0.11645534634590149 Batch_id=191 Epoch Average loss=0.0189: 100%|██████████| 192/192 [00:01<00:00, 113.63it/s]


VALIDATION LOSS (Average) : 0.212068
EPOCH: 20


Loss=0.09938691556453705 Batch_id=191 Epoch Average loss=0.0187: 100%|██████████| 192/192 [00:01<00:00, 119.29it/s]


VALIDATION LOSS (Average) : 0.21672
EPOCH: 21


Loss=0.21871832013130188 Batch_id=191 Epoch Average loss=0.0190: 100%|██████████| 192/192 [00:01<00:00, 115.42it/s]


VALIDATION LOSS (Average) : 0.209628
EPOCH: 22


Loss=0.11436816304922104 Batch_id=191 Epoch Average loss=0.0187: 100%|██████████| 192/192 [00:01<00:00, 118.12it/s]


VALIDATION LOSS (Average) : 0.205677
EPOCH: 23


Loss=0.13396026194095612 Batch_id=191 Epoch Average loss=0.0182: 100%|██████████| 192/192 [00:01<00:00, 115.42it/s]


VALIDATION LOSS (Average) : 0.212003
EPOCH: 24


Loss=0.11150304973125458 Batch_id=191 Epoch Average loss=0.0180: 100%|██████████| 192/192 [00:01<00:00, 116.61it/s]


VALIDATION LOSS (Average) : 0.202824
EPOCH: 25


Loss=0.10939200222492218 Batch_id=191 Epoch Average loss=0.0180: 100%|██████████| 192/192 [00:01<00:00, 116.87it/s]


VALIDATION LOSS (Average) : 0.195864
EPOCH: 26


Loss=0.1369798630475998 Batch_id=191 Epoch Average loss=0.0176: 100%|██████████| 192/192 [00:01<00:00, 115.31it/s]


VALIDATION LOSS (Average) : 0.204982
EPOCH: 27


Loss=0.18988662958145142 Batch_id=191 Epoch Average loss=0.0174: 100%|██████████| 192/192 [00:01<00:00, 115.68it/s]


VALIDATION LOSS (Average) : 0.208335
EPOCH: 28


Loss=0.09861107170581818 Batch_id=191 Epoch Average loss=0.0174: 100%|██████████| 192/192 [00:01<00:00, 111.75it/s]


VALIDATION LOSS (Average) : 0.200939
EPOCH: 29


Loss=0.16502685844898224 Batch_id=191 Epoch Average loss=0.0169: 100%|██████████| 192/192 [00:01<00:00, 112.74it/s]


VALIDATION LOSS (Average) : 0.19332
EPOCH: 30


Loss=0.13864150643348694 Batch_id=191 Epoch Average loss=0.0168: 100%|██████████| 192/192 [00:01<00:00, 113.58it/s]


VALIDATION LOSS (Average) : 0.185217
EPOCH: 31


Loss=0.14461079239845276 Batch_id=191 Epoch Average loss=0.0165: 100%|██████████| 192/192 [00:01<00:00, 112.33it/s]


VALIDATION LOSS (Average) : 0.19317
EPOCH: 32


Loss=0.07048705220222473 Batch_id=191 Epoch Average loss=0.0166: 100%|██████████| 192/192 [00:01<00:00, 112.80it/s]


VALIDATION LOSS (Average) : 0.191288
EPOCH: 33


Loss=0.10986166447401047 Batch_id=191 Epoch Average loss=0.0161: 100%|██████████| 192/192 [00:01<00:00, 113.70it/s]


VALIDATION LOSS (Average) : 0.200553
EPOCH: 34


Loss=0.1199319064617157 Batch_id=191 Epoch Average loss=0.0162: 100%|██████████| 192/192 [00:01<00:00, 113.44it/s]


VALIDATION LOSS (Average) : 0.192875
EPOCH: 35


Loss=0.20706596970558167 Batch_id=191 Epoch Average loss=0.0155: 100%|██████████| 192/192 [00:01<00:00, 116.32it/s]


VALIDATION LOSS (Average) : 0.185717
EPOCH: 36


Loss=0.0784207433462143 Batch_id=191 Epoch Average loss=0.0155: 100%|██████████| 192/192 [00:01<00:00, 123.65it/s]


VALIDATION LOSS (Average) : 0.196604
EPOCH: 37


Loss=0.08815395832061768 Batch_id=191 Epoch Average loss=0.0155: 100%|██████████| 192/192 [00:01<00:00, 114.52it/s]


VALIDATION LOSS (Average) : 0.185907
EPOCH: 38


Loss=0.08476981520652771 Batch_id=191 Epoch Average loss=0.0156: 100%|██████████| 192/192 [00:01<00:00, 108.02it/s]


VALIDATION LOSS (Average) : 0.182429
EPOCH: 39


Loss=0.13656070828437805 Batch_id=191 Epoch Average loss=0.0155: 100%|██████████| 192/192 [00:01<00:00, 111.08it/s]


VALIDATION LOSS (Average) : 0.187042
EPOCH: 40


Loss=0.0870363786816597 Batch_id=191 Epoch Average loss=0.0150: 100%|██████████| 192/192 [00:01<00:00, 113.52it/s]


VALIDATION LOSS (Average) : 0.190997
EPOCH: 41


Loss=0.11097535490989685 Batch_id=191 Epoch Average loss=0.0149: 100%|██████████| 192/192 [00:01<00:00, 118.29it/s]


VALIDATION LOSS (Average) : 0.20067
EPOCH: 42


Loss=0.07800483703613281 Batch_id=191 Epoch Average loss=0.0148: 100%|██████████| 192/192 [00:01<00:00, 111.39it/s]


VALIDATION LOSS (Average) : 0.190264
EPOCH: 43


Loss=0.14878350496292114 Batch_id=191 Epoch Average loss=0.0151: 100%|██████████| 192/192 [00:01<00:00, 113.08it/s]


VALIDATION LOSS (Average) : 0.196886
EPOCH: 44


Loss=0.13195940852165222 Batch_id=191 Epoch Average loss=0.0149: 100%|██████████| 192/192 [00:01<00:00, 119.99it/s]


VALIDATION LOSS (Average) : 0.21037
EPOCH: 45


Loss=0.10408490896224976 Batch_id=191 Epoch Average loss=0.0144: 100%|██████████| 192/192 [00:01<00:00, 119.32it/s]


VALIDATION LOSS (Average) : 0.191613
EPOCH: 46


Loss=0.11601641774177551 Batch_id=191 Epoch Average loss=0.0144: 100%|██████████| 192/192 [00:01<00:00, 120.70it/s]


VALIDATION LOSS (Average) : 0.185931
EPOCH: 47


Loss=0.15748608112335205 Batch_id=191 Epoch Average loss=0.0146: 100%|██████████| 192/192 [00:01<00:00, 122.29it/s]


VALIDATION LOSS (Average) : 0.195025
EPOCH: 48


Loss=0.13196749985218048 Batch_id=191 Epoch Average loss=0.0147: 100%|██████████| 192/192 [00:01<00:00, 120.68it/s]


VALIDATION LOSS (Average) : 0.190354
EPOCH: 49


Loss=0.17514711618423462 Batch_id=191 Epoch Average loss=0.0143: 100%|██████████| 192/192 [00:01<00:00, 110.02it/s]


VALIDATION LOSS (Average) : 0.188337
EPOCH: 50


Loss=0.17934450507164001 Batch_id=191 Epoch Average loss=0.0142: 100%|██████████| 192/192 [00:01<00:00, 115.13it/s]


VALIDATION LOSS (Average) : 0.188348
EPOCH: 51


Loss=0.15709340572357178 Batch_id=191 Epoch Average loss=0.0142: 100%|██████████| 192/192 [00:01<00:00, 114.45it/s]


VALIDATION LOSS (Average) : 0.183511
EPOCH: 52


Loss=0.10017120838165283 Batch_id=191 Epoch Average loss=0.0137: 100%|██████████| 192/192 [00:01<00:00, 114.05it/s]


VALIDATION LOSS (Average) : 0.17906
EPOCH: 53


Loss=0.0689081996679306 Batch_id=191 Epoch Average loss=0.0139: 100%|██████████| 192/192 [00:01<00:00, 114.78it/s]


VALIDATION LOSS (Average) : 0.192409
EPOCH: 54


Loss=0.10387866944074631 Batch_id=191 Epoch Average loss=0.0137: 100%|██████████| 192/192 [00:01<00:00, 113.65it/s]


VALIDATION LOSS (Average) : 0.177399
EPOCH: 55


Loss=0.22009870409965515 Batch_id=191 Epoch Average loss=0.0135: 100%|██████████| 192/192 [00:01<00:00, 112.46it/s]


VALIDATION LOSS (Average) : 0.183638
EPOCH: 56


Loss=0.14790809154510498 Batch_id=191 Epoch Average loss=0.0132: 100%|██████████| 192/192 [00:01<00:00, 114.76it/s]


VALIDATION LOSS (Average) : 0.185221
EPOCH: 57


Loss=0.07857519388198853 Batch_id=191 Epoch Average loss=0.0136: 100%|██████████| 192/192 [00:01<00:00, 116.83it/s]


VALIDATION LOSS (Average) : 0.187363
EPOCH: 58


Loss=0.10253558307886124 Batch_id=191 Epoch Average loss=0.0135: 100%|██████████| 192/192 [00:01<00:00, 115.02it/s]


VALIDATION LOSS (Average) : 0.188381
EPOCH: 59


Loss=0.09859241545200348 Batch_id=191 Epoch Average loss=0.0127: 100%|██████████| 192/192 [00:01<00:00, 112.75it/s]


VALIDATION LOSS (Average) : 0.183575
EPOCH: 60


Loss=0.12532885372638702 Batch_id=191 Epoch Average loss=0.0133: 100%|██████████| 192/192 [00:01<00:00, 114.62it/s]


VALIDATION LOSS (Average) : 0.183135
EPOCH: 61


Loss=0.07791498303413391 Batch_id=191 Epoch Average loss=0.0132: 100%|██████████| 192/192 [00:01<00:00, 116.93it/s]


VALIDATION LOSS (Average) : 0.183079
EPOCH: 62


Loss=0.060030460357666016 Batch_id=191 Epoch Average loss=0.0129: 100%|██████████| 192/192 [00:01<00:00, 113.10it/s]


VALIDATION LOSS (Average) : 0.185872
EPOCH: 63


Loss=0.1299349069595337 Batch_id=191 Epoch Average loss=0.0128: 100%|██████████| 192/192 [00:01<00:00, 114.09it/s]


VALIDATION LOSS (Average) : 0.184037
EPOCH: 64


Loss=0.13769736886024475 Batch_id=191 Epoch Average loss=0.0128: 100%|██████████| 192/192 [00:01<00:00, 112.03it/s]


VALIDATION LOSS (Average) : 0.188044
EPOCH: 65


Loss=0.0999765545129776 Batch_id=191 Epoch Average loss=0.0128: 100%|██████████| 192/192 [00:01<00:00, 114.58it/s]


VALIDATION LOSS (Average) : 0.18734
EPOCH: 66


Loss=0.11347700655460358 Batch_id=191 Epoch Average loss=0.0127: 100%|██████████| 192/192 [00:01<00:00, 115.26it/s]


VALIDATION LOSS (Average) : 0.187902
EPOCH: 67


Loss=0.07827579975128174 Batch_id=191 Epoch Average loss=0.0126: 100%|██████████| 192/192 [00:01<00:00, 110.61it/s]


VALIDATION LOSS (Average) : 0.18289
EPOCH: 68


Loss=0.1253729909658432 Batch_id=191 Epoch Average loss=0.0125: 100%|██████████| 192/192 [00:01<00:00, 114.12it/s]


VALIDATION LOSS (Average) : 0.191617
EPOCH: 69


Loss=0.06001625955104828 Batch_id=191 Epoch Average loss=0.0122: 100%|██████████| 192/192 [00:01<00:00, 113.66it/s]


VALIDATION LOSS (Average) : 0.187086
EPOCH: 70


Loss=0.08720231801271439 Batch_id=191 Epoch Average loss=0.0126: 100%|██████████| 192/192 [00:01<00:00, 113.30it/s]


VALIDATION LOSS (Average) : 0.195934
EPOCH: 71


Loss=0.10344232618808746 Batch_id=191 Epoch Average loss=0.0124: 100%|██████████| 192/192 [00:01<00:00, 113.70it/s]


VALIDATION LOSS (Average) : 0.186616
EPOCH: 72


Loss=0.11902204900979996 Batch_id=191 Epoch Average loss=0.0122: 100%|██████████| 192/192 [00:01<00:00, 111.88it/s]


VALIDATION LOSS (Average) : 0.193491
EPOCH: 73


Loss=0.0938577875494957 Batch_id=191 Epoch Average loss=0.0119: 100%|██████████| 192/192 [00:01<00:00, 111.13it/s]


VALIDATION LOSS (Average) : 0.191042
EPOCH: 74


Loss=0.0889434963464737 Batch_id=191 Epoch Average loss=0.0120: 100%|██████████| 192/192 [00:01<00:00, 116.40it/s]


VALIDATION LOSS (Average) : 0.194778
EPOCH: 75


Loss=0.04752786457538605 Batch_id=191 Epoch Average loss=0.0119: 100%|██████████| 192/192 [00:01<00:00, 113.89it/s]


VALIDATION LOSS (Average) : 0.187707
EPOCH: 76


Loss=0.019462019205093384 Batch_id=191 Epoch Average loss=0.0119: 100%|██████████| 192/192 [00:01<00:00, 114.16it/s]


VALIDATION LOSS (Average) : 0.185457
EPOCH: 77


Loss=0.09648934006690979 Batch_id=191 Epoch Average loss=0.0121: 100%|██████████| 192/192 [00:01<00:00, 115.21it/s]


VALIDATION LOSS (Average) : 0.184245
EPOCH: 78


Loss=0.09939633309841156 Batch_id=191 Epoch Average loss=0.0117: 100%|██████████| 192/192 [00:01<00:00, 115.65it/s]


VALIDATION LOSS (Average) : 0.189609
EPOCH: 79


Loss=0.13906961679458618 Batch_id=191 Epoch Average loss=0.0118: 100%|██████████| 192/192 [00:01<00:00, 112.19it/s]


VALIDATION LOSS (Average) : 0.194263
EPOCH: 80


Loss=0.13359341025352478 Batch_id=191 Epoch Average loss=0.0117: 100%|██████████| 192/192 [00:01<00:00, 114.62it/s]


VALIDATION LOSS (Average) : 0.192674
EPOCH: 81


Loss=0.10174641013145447 Batch_id=191 Epoch Average loss=0.0115: 100%|██████████| 192/192 [00:01<00:00, 112.10it/s]


VALIDATION LOSS (Average) : 0.18485
EPOCH: 82


Loss=0.0549749992787838 Batch_id=191 Epoch Average loss=0.0117: 100%|██████████| 192/192 [00:01<00:00, 114.74it/s]


VALIDATION LOSS (Average) : 0.199843
EPOCH: 83


Loss=0.11437776684761047 Batch_id=191 Epoch Average loss=0.0117: 100%|██████████| 192/192 [00:01<00:00, 116.32it/s]


VALIDATION LOSS (Average) : 0.195922
EPOCH: 84


Loss=0.08836618065834045 Batch_id=191 Epoch Average loss=0.0114: 100%|██████████| 192/192 [00:01<00:00, 113.66it/s]


VALIDATION LOSS (Average) : 0.191459
EPOCH: 85


Loss=0.07720670104026794 Batch_id=191 Epoch Average loss=0.0109: 100%|██████████| 192/192 [00:01<00:00, 113.53it/s]


VALIDATION LOSS (Average) : 0.191169
EPOCH: 86


Loss=0.10769487917423248 Batch_id=191 Epoch Average loss=0.0114: 100%|██████████| 192/192 [00:01<00:00, 114.72it/s]


VALIDATION LOSS (Average) : 0.189149
EPOCH: 87


Loss=0.09894878417253494 Batch_id=191 Epoch Average loss=0.0116: 100%|██████████| 192/192 [00:01<00:00, 114.94it/s]


VALIDATION LOSS (Average) : 0.183757
EPOCH: 88


Loss=0.1078236922621727 Batch_id=191 Epoch Average loss=0.0110: 100%|██████████| 192/192 [00:01<00:00, 110.24it/s]


VALIDATION LOSS (Average) : 0.192506
EPOCH: 89


Loss=0.0774875357747078 Batch_id=191 Epoch Average loss=0.0110: 100%|██████████| 192/192 [00:01<00:00, 109.83it/s]


VALIDATION LOSS (Average) : 0.186799
EPOCH: 90


Loss=0.09162284433841705 Batch_id=191 Epoch Average loss=0.0109: 100%|██████████| 192/192 [00:01<00:00, 112.06it/s]


VALIDATION LOSS (Average) : 0.184474
EPOCH: 91


Loss=0.11967086791992188 Batch_id=191 Epoch Average loss=0.0108: 100%|██████████| 192/192 [00:01<00:00, 115.55it/s]


VALIDATION LOSS (Average) : 0.189515
EPOCH: 92


Loss=0.05301528796553612 Batch_id=191 Epoch Average loss=0.0112: 100%|██████████| 192/192 [00:01<00:00, 115.41it/s]


VALIDATION LOSS (Average) : 0.199888
EPOCH: 93


Loss=0.06916622817516327 Batch_id=191 Epoch Average loss=0.0106: 100%|██████████| 192/192 [00:01<00:00, 113.22it/s]


VALIDATION LOSS (Average) : 0.185956
EPOCH: 94


Loss=0.07847198098897934 Batch_id=191 Epoch Average loss=0.0112: 100%|██████████| 192/192 [00:01<00:00, 117.88it/s]


VALIDATION LOSS (Average) : 0.195891
EPOCH: 95


Loss=0.1310117244720459 Batch_id=191 Epoch Average loss=0.0109: 100%|██████████| 192/192 [00:01<00:00, 114.34it/s]


VALIDATION LOSS (Average) : 0.193199
EPOCH: 96


Loss=0.0941113606095314 Batch_id=191 Epoch Average loss=0.0106: 100%|██████████| 192/192 [00:01<00:00, 116.00it/s]


VALIDATION LOSS (Average) : 0.196772
EPOCH: 97


Loss=0.06142699718475342 Batch_id=191 Epoch Average loss=0.0107: 100%|██████████| 192/192 [00:01<00:00, 112.67it/s]


VALIDATION LOSS (Average) : 0.191235
EPOCH: 98


Loss=0.07754012197256088 Batch_id=191 Epoch Average loss=0.0108: 100%|██████████| 192/192 [00:01<00:00, 115.30it/s]


VALIDATION LOSS (Average) : 0.183198
EPOCH: 99


Loss=0.062290459871292114 Batch_id=191 Epoch Average loss=0.0105: 100%|██████████| 192/192 [00:01<00:00, 111.43it/s]


VALIDATION LOSS (Average) : 0.193123
EPOCH: 100


Loss=0.08056892454624176 Batch_id=191 Epoch Average loss=0.0106: 100%|██████████| 192/192 [00:01<00:00, 113.90it/s]


VALIDATION LOSS (Average) : 0.187611
TEST LOSS (Average) : 0.180681
----------------------training complete for EI_sadness-----------------
----------------------training started for V-----------------
EPOCH: 1


Loss=0.1944463551044464 Batch_id=147 Epoch Average loss=0.0519: 100%|██████████| 148/148 [00:01<00:00, 118.10it/s]


VALIDATION LOSS (Average) : 0.151765
EPOCH: 2


Loss=0.14994260668754578 Batch_id=147 Epoch Average loss=0.0255: 100%|██████████| 148/148 [00:01<00:00, 116.26it/s]


VALIDATION LOSS (Average) : 0.133775
EPOCH: 3


Loss=0.10792087018489838 Batch_id=147 Epoch Average loss=0.0253: 100%|██████████| 148/148 [00:01<00:00, 115.31it/s]


VALIDATION LOSS (Average) : 0.128854
EPOCH: 4


Loss=0.12926927208900452 Batch_id=147 Epoch Average loss=0.0249: 100%|██████████| 148/148 [00:01<00:00, 98.32it/s]


VALIDATION LOSS (Average) : 0.161026
EPOCH: 5


Loss=0.22398146986961365 Batch_id=147 Epoch Average loss=0.0245: 100%|██████████| 148/148 [00:01<00:00, 109.24it/s]


VALIDATION LOSS (Average) : 0.154527
EPOCH: 6


Loss=0.1799168884754181 Batch_id=147 Epoch Average loss=0.0239: 100%|██████████| 148/148 [00:01<00:00, 106.36it/s]


VALIDATION LOSS (Average) : 0.145406
EPOCH: 7


Loss=0.18495899438858032 Batch_id=147 Epoch Average loss=0.0241: 100%|██████████| 148/148 [00:01<00:00, 115.58it/s]


VALIDATION LOSS (Average) : 0.148143
EPOCH: 8


Loss=0.20014244318008423 Batch_id=147 Epoch Average loss=0.0238: 100%|██████████| 148/148 [00:01<00:00, 118.25it/s]


VALIDATION LOSS (Average) : 0.142908
EPOCH: 9


Loss=0.2199961543083191 Batch_id=147 Epoch Average loss=0.0235: 100%|██████████| 148/148 [00:01<00:00, 104.56it/s]


VALIDATION LOSS (Average) : 0.162789
EPOCH: 10


Loss=0.2668469548225403 Batch_id=147 Epoch Average loss=0.0229: 100%|██████████| 148/148 [00:01<00:00, 113.64it/s]


VALIDATION LOSS (Average) : 0.178233
EPOCH: 11


Loss=0.20848095417022705 Batch_id=147 Epoch Average loss=0.0230: 100%|██████████| 148/148 [00:01<00:00, 112.70it/s]


VALIDATION LOSS (Average) : 0.166365
EPOCH: 12


Loss=0.19332733750343323 Batch_id=147 Epoch Average loss=0.0223: 100%|██████████| 148/148 [00:01<00:00, 111.57it/s]


VALIDATION LOSS (Average) : 0.177303
EPOCH: 13


Loss=0.13443134725093842 Batch_id=147 Epoch Average loss=0.0225: 100%|██████████| 148/148 [00:01<00:00, 114.74it/s]


VALIDATION LOSS (Average) : 0.188026
EPOCH: 14


Loss=0.22996632754802704 Batch_id=147 Epoch Average loss=0.0226: 100%|██████████| 148/148 [00:01<00:00, 120.13it/s]


VALIDATION LOSS (Average) : 0.186707
EPOCH: 15


Loss=0.22110337018966675 Batch_id=147 Epoch Average loss=0.0217: 100%|██████████| 148/148 [00:01<00:00, 118.26it/s]


VALIDATION LOSS (Average) : 0.176484
EPOCH: 16


Loss=0.1128542423248291 Batch_id=147 Epoch Average loss=0.0215: 100%|██████████| 148/148 [00:01<00:00, 119.05it/s]


VALIDATION LOSS (Average) : 0.177441
EPOCH: 17


Loss=0.13711071014404297 Batch_id=147 Epoch Average loss=0.0215: 100%|██████████| 148/148 [00:01<00:00, 116.08it/s]


VALIDATION LOSS (Average) : 0.171256
EPOCH: 18


Loss=0.1284485161304474 Batch_id=147 Epoch Average loss=0.0215: 100%|██████████| 148/148 [00:01<00:00, 114.61it/s]


VALIDATION LOSS (Average) : 0.191031
EPOCH: 19


Loss=0.14398051798343658 Batch_id=147 Epoch Average loss=0.0210: 100%|██████████| 148/148 [00:01<00:00, 118.71it/s]


VALIDATION LOSS (Average) : 0.193066
EPOCH: 20


Loss=0.16861958801746368 Batch_id=147 Epoch Average loss=0.0212: 100%|██████████| 148/148 [00:01<00:00, 114.83it/s]


VALIDATION LOSS (Average) : 0.206259
EPOCH: 21


Loss=0.10312698036432266 Batch_id=147 Epoch Average loss=0.0200: 100%|██████████| 148/148 [00:01<00:00, 116.38it/s]


VALIDATION LOSS (Average) : 0.191698
EPOCH: 22


Loss=0.19949066638946533 Batch_id=147 Epoch Average loss=0.0200: 100%|██████████| 148/148 [00:01<00:00, 113.59it/s]


VALIDATION LOSS (Average) : 0.195853
EPOCH: 23


Loss=0.13244511187076569 Batch_id=147 Epoch Average loss=0.0203: 100%|██████████| 148/148 [00:01<00:00, 118.71it/s]


VALIDATION LOSS (Average) : 0.1948
EPOCH: 24


Loss=0.15423467755317688 Batch_id=147 Epoch Average loss=0.0196: 100%|██████████| 148/148 [00:01<00:00, 119.73it/s]


VALIDATION LOSS (Average) : 0.192774
EPOCH: 25


Loss=0.18167370557785034 Batch_id=147 Epoch Average loss=0.0199: 100%|██████████| 148/148 [00:01<00:00, 114.33it/s]


VALIDATION LOSS (Average) : 0.219324
EPOCH: 26


Loss=0.1328461766242981 Batch_id=147 Epoch Average loss=0.0194: 100%|██████████| 148/148 [00:01<00:00, 108.71it/s]


VALIDATION LOSS (Average) : 0.196184
EPOCH: 27


Loss=0.12973248958587646 Batch_id=147 Epoch Average loss=0.0192: 100%|██████████| 148/148 [00:01<00:00, 113.83it/s]


VALIDATION LOSS (Average) : 0.204703
EPOCH: 28


Loss=0.1403256505727768 Batch_id=147 Epoch Average loss=0.0188: 100%|██████████| 148/148 [00:01<00:00, 114.50it/s]


VALIDATION LOSS (Average) : 0.207738
EPOCH: 29


Loss=0.14650958776474 Batch_id=147 Epoch Average loss=0.0186: 100%|██████████| 148/148 [00:01<00:00, 112.57it/s]


VALIDATION LOSS (Average) : 0.214409
EPOCH: 30


Loss=0.22121307253837585 Batch_id=147 Epoch Average loss=0.0183: 100%|██████████| 148/148 [00:01<00:00, 115.64it/s]


VALIDATION LOSS (Average) : 0.197199
EPOCH: 31


Loss=0.17216262221336365 Batch_id=147 Epoch Average loss=0.0183: 100%|██████████| 148/148 [00:01<00:00, 114.14it/s]


VALIDATION LOSS (Average) : 0.240199
EPOCH: 32


Loss=0.16914652287960052 Batch_id=147 Epoch Average loss=0.0183: 100%|██████████| 148/148 [00:01<00:00, 111.35it/s]


VALIDATION LOSS (Average) : 0.213235
EPOCH: 33


Loss=0.1748577207326889 Batch_id=147 Epoch Average loss=0.0177: 100%|██████████| 148/148 [00:01<00:00, 112.63it/s]


VALIDATION LOSS (Average) : 0.220165
EPOCH: 34


Loss=0.08587915450334549 Batch_id=147 Epoch Average loss=0.0184: 100%|██████████| 148/148 [00:01<00:00, 112.11it/s]


VALIDATION LOSS (Average) : 0.214794
EPOCH: 35


Loss=0.1932550072669983 Batch_id=147 Epoch Average loss=0.0176: 100%|██████████| 148/148 [00:01<00:00, 112.05it/s]


VALIDATION LOSS (Average) : 0.199211
EPOCH: 36


Loss=0.13593727350234985 Batch_id=147 Epoch Average loss=0.0172: 100%|██████████| 148/148 [00:01<00:00, 108.87it/s]


VALIDATION LOSS (Average) : 0.235202
EPOCH: 37


Loss=0.11042512953281403 Batch_id=147 Epoch Average loss=0.0171: 100%|██████████| 148/148 [00:01<00:00, 116.97it/s]


VALIDATION LOSS (Average) : 0.207498
EPOCH: 38


Loss=0.08667952567338943 Batch_id=147 Epoch Average loss=0.0172: 100%|██████████| 148/148 [00:01<00:00, 114.19it/s]


VALIDATION LOSS (Average) : 0.230674
EPOCH: 39


Loss=0.16047511994838715 Batch_id=147 Epoch Average loss=0.0168: 100%|██████████| 148/148 [00:01<00:00, 113.35it/s]


VALIDATION LOSS (Average) : 0.224393
EPOCH: 40


Loss=0.11826711148023605 Batch_id=147 Epoch Average loss=0.0163: 100%|██████████| 148/148 [00:01<00:00, 117.02it/s]


VALIDATION LOSS (Average) : 0.205021
EPOCH: 41


Loss=0.10761700570583344 Batch_id=147 Epoch Average loss=0.0172: 100%|██████████| 148/148 [00:01<00:00, 117.25it/s]


VALIDATION LOSS (Average) : 0.212912
EPOCH: 42


Loss=0.06510035693645477 Batch_id=147 Epoch Average loss=0.0161: 100%|██████████| 148/148 [00:01<00:00, 113.22it/s]


VALIDATION LOSS (Average) : 0.19349
EPOCH: 43


Loss=0.12722399830818176 Batch_id=147 Epoch Average loss=0.0164: 100%|██████████| 148/148 [00:01<00:00, 111.29it/s]


VALIDATION LOSS (Average) : 0.221647
EPOCH: 44


Loss=0.15403354167938232 Batch_id=147 Epoch Average loss=0.0158: 100%|██████████| 148/148 [00:01<00:00, 117.93it/s]


VALIDATION LOSS (Average) : 0.213684
EPOCH: 45


Loss=0.1518835425376892 Batch_id=147 Epoch Average loss=0.0162: 100%|██████████| 148/148 [00:01<00:00, 120.97it/s]


VALIDATION LOSS (Average) : 0.235262
EPOCH: 46


Loss=0.13400475680828094 Batch_id=147 Epoch Average loss=0.0162: 100%|██████████| 148/148 [00:01<00:00, 116.99it/s]


VALIDATION LOSS (Average) : 0.217725
EPOCH: 47


Loss=0.051434487104415894 Batch_id=147 Epoch Average loss=0.0157: 100%|██████████| 148/148 [00:01<00:00, 120.09it/s]


VALIDATION LOSS (Average) : 0.207776
EPOCH: 48


Loss=0.1348671317100525 Batch_id=147 Epoch Average loss=0.0156: 100%|██████████| 148/148 [00:01<00:00, 120.36it/s]


VALIDATION LOSS (Average) : 0.235837
EPOCH: 49


Loss=0.13917967677116394 Batch_id=147 Epoch Average loss=0.0158: 100%|██████████| 148/148 [00:01<00:00, 121.78it/s]


VALIDATION LOSS (Average) : 0.239881
EPOCH: 50


Loss=0.12550660967826843 Batch_id=147 Epoch Average loss=0.0153: 100%|██████████| 148/148 [00:01<00:00, 115.01it/s]


VALIDATION LOSS (Average) : 0.220074
EPOCH: 51


Loss=0.1245536208152771 Batch_id=147 Epoch Average loss=0.0152: 100%|██████████| 148/148 [00:01<00:00, 113.41it/s]


VALIDATION LOSS (Average) : 0.212969
EPOCH: 52


Loss=0.12078399211168289 Batch_id=147 Epoch Average loss=0.0154: 100%|██████████| 148/148 [00:01<00:00, 112.54it/s]


VALIDATION LOSS (Average) : 0.204823
EPOCH: 53


Loss=0.11082950234413147 Batch_id=147 Epoch Average loss=0.0151: 100%|██████████| 148/148 [00:01<00:00, 108.83it/s]


VALIDATION LOSS (Average) : 0.217105
EPOCH: 54


Loss=0.0812467709183693 Batch_id=147 Epoch Average loss=0.0153: 100%|██████████| 148/148 [00:01<00:00, 115.67it/s]


VALIDATION LOSS (Average) : 0.231737
EPOCH: 55


Loss=0.09949766099452972 Batch_id=147 Epoch Average loss=0.0143: 100%|██████████| 148/148 [00:01<00:00, 118.58it/s]


VALIDATION LOSS (Average) : 0.230302
EPOCH: 56


Loss=0.13277146220207214 Batch_id=147 Epoch Average loss=0.0146: 100%|██████████| 148/148 [00:01<00:00, 115.38it/s]


VALIDATION LOSS (Average) : 0.220093
EPOCH: 57


Loss=0.12234961986541748 Batch_id=147 Epoch Average loss=0.0144: 100%|██████████| 148/148 [00:01<00:00, 114.26it/s]


VALIDATION LOSS (Average) : 0.209418
EPOCH: 58


Loss=0.10506193339824677 Batch_id=147 Epoch Average loss=0.0143: 100%|██████████| 148/148 [00:01<00:00, 120.68it/s]


VALIDATION LOSS (Average) : 0.226061
EPOCH: 59


Loss=0.09138667583465576 Batch_id=147 Epoch Average loss=0.0146: 100%|██████████| 148/148 [00:01<00:00, 122.87it/s]


VALIDATION LOSS (Average) : 0.232822
EPOCH: 60


Loss=0.07118465006351471 Batch_id=147 Epoch Average loss=0.0145: 100%|██████████| 148/148 [00:01<00:00, 126.28it/s]


VALIDATION LOSS (Average) : 0.229129
EPOCH: 61


Loss=0.1417294144630432 Batch_id=147 Epoch Average loss=0.0142: 100%|██████████| 148/148 [00:01<00:00, 126.13it/s]


VALIDATION LOSS (Average) : 0.230765
EPOCH: 62


Loss=0.08041292428970337 Batch_id=147 Epoch Average loss=0.0136: 100%|██████████| 148/148 [00:01<00:00, 121.43it/s]


VALIDATION LOSS (Average) : 0.228185
EPOCH: 63


Loss=0.12006361782550812 Batch_id=147 Epoch Average loss=0.0141: 100%|██████████| 148/148 [00:01<00:00, 114.45it/s]


VALIDATION LOSS (Average) : 0.22839
EPOCH: 64


Loss=0.07970469444990158 Batch_id=147 Epoch Average loss=0.0138: 100%|██████████| 148/148 [00:01<00:00, 110.97it/s]


VALIDATION LOSS (Average) : 0.233268
EPOCH: 65


Loss=0.12035287916660309 Batch_id=147 Epoch Average loss=0.0142: 100%|██████████| 148/148 [00:01<00:00, 115.94it/s]


VALIDATION LOSS (Average) : 0.23653
EPOCH: 66


Loss=0.06678404659032822 Batch_id=147 Epoch Average loss=0.0138: 100%|██████████| 148/148 [00:01<00:00, 115.00it/s]


VALIDATION LOSS (Average) : 0.223593
EPOCH: 67


Loss=0.08362641930580139 Batch_id=147 Epoch Average loss=0.0141: 100%|██████████| 148/148 [00:01<00:00, 116.92it/s]


VALIDATION LOSS (Average) : 0.222787
EPOCH: 68


Loss=0.10194769501686096 Batch_id=147 Epoch Average loss=0.0137: 100%|██████████| 148/148 [00:01<00:00, 122.01it/s]


VALIDATION LOSS (Average) : 0.228767
EPOCH: 69


Loss=0.07632829248905182 Batch_id=147 Epoch Average loss=0.0138: 100%|██████████| 148/148 [00:01<00:00, 123.97it/s]


VALIDATION LOSS (Average) : 0.226072
EPOCH: 70


Loss=0.11387976258993149 Batch_id=147 Epoch Average loss=0.0135: 100%|██████████| 148/148 [00:01<00:00, 124.34it/s]


VALIDATION LOSS (Average) : 0.22484
EPOCH: 71


Loss=0.08524656295776367 Batch_id=147 Epoch Average loss=0.0135: 100%|██████████| 148/148 [00:01<00:00, 118.43it/s]


VALIDATION LOSS (Average) : 0.221864
EPOCH: 72


Loss=0.08363016694784164 Batch_id=147 Epoch Average loss=0.0135: 100%|██████████| 148/148 [00:01<00:00, 118.53it/s]


VALIDATION LOSS (Average) : 0.234458
EPOCH: 73


Loss=0.09564778953790665 Batch_id=147 Epoch Average loss=0.0128: 100%|██████████| 148/148 [00:01<00:00, 111.50it/s]


VALIDATION LOSS (Average) : 0.238748
EPOCH: 74


Loss=0.07798439264297485 Batch_id=147 Epoch Average loss=0.0131: 100%|██████████| 148/148 [00:01<00:00, 112.62it/s]


VALIDATION LOSS (Average) : 0.226736
EPOCH: 75


Loss=0.10550187528133392 Batch_id=147 Epoch Average loss=0.0131: 100%|██████████| 148/148 [00:01<00:00, 117.37it/s]


VALIDATION LOSS (Average) : 0.216641
EPOCH: 76


Loss=0.06907147169113159 Batch_id=147 Epoch Average loss=0.0126: 100%|██████████| 148/148 [00:01<00:00, 118.57it/s]


VALIDATION LOSS (Average) : 0.222701
EPOCH: 77


Loss=0.10470576584339142 Batch_id=147 Epoch Average loss=0.0133: 100%|██████████| 148/148 [00:01<00:00, 113.30it/s]


VALIDATION LOSS (Average) : 0.21673
EPOCH: 78


Loss=0.10209716856479645 Batch_id=147 Epoch Average loss=0.0127: 100%|██████████| 148/148 [00:01<00:00, 115.19it/s]


VALIDATION LOSS (Average) : 0.230524
EPOCH: 79


Loss=0.17521625757217407 Batch_id=147 Epoch Average loss=0.0128: 100%|██████████| 148/148 [00:01<00:00, 108.43it/s]


VALIDATION LOSS (Average) : 0.235222
EPOCH: 80


Loss=0.07685860991477966 Batch_id=147 Epoch Average loss=0.0126: 100%|██████████| 148/148 [00:01<00:00, 113.07it/s]


VALIDATION LOSS (Average) : 0.215896
EPOCH: 81


Loss=0.0938618928194046 Batch_id=147 Epoch Average loss=0.0122: 100%|██████████| 148/148 [00:01<00:00, 108.59it/s]


VALIDATION LOSS (Average) : 0.213013
EPOCH: 82


Loss=0.11930548399686813 Batch_id=147 Epoch Average loss=0.0124: 100%|██████████| 148/148 [00:01<00:00, 108.83it/s]


VALIDATION LOSS (Average) : 0.248332
EPOCH: 83


Loss=0.11514568328857422 Batch_id=147 Epoch Average loss=0.0129: 100%|██████████| 148/148 [00:01<00:00, 115.69it/s]


VALIDATION LOSS (Average) : 0.232453
EPOCH: 84


Loss=0.09051889926195145 Batch_id=147 Epoch Average loss=0.0123: 100%|██████████| 148/148 [00:01<00:00, 110.32it/s]


VALIDATION LOSS (Average) : 0.220179
EPOCH: 85


Loss=0.12904372811317444 Batch_id=147 Epoch Average loss=0.0123: 100%|██████████| 148/148 [00:01<00:00, 110.83it/s]


VALIDATION LOSS (Average) : 0.223084
EPOCH: 86


Loss=0.10095181316137314 Batch_id=147 Epoch Average loss=0.0119: 100%|██████████| 148/148 [00:01<00:00, 111.37it/s]


VALIDATION LOSS (Average) : 0.232008
EPOCH: 87


Loss=0.19721829891204834 Batch_id=147 Epoch Average loss=0.0126: 100%|██████████| 148/148 [00:01<00:00, 113.87it/s]


VALIDATION LOSS (Average) : 0.241566
EPOCH: 88


Loss=0.08008390665054321 Batch_id=147 Epoch Average loss=0.0122: 100%|██████████| 148/148 [00:01<00:00, 110.40it/s]


VALIDATION LOSS (Average) : 0.239141
EPOCH: 89


Loss=0.08354896306991577 Batch_id=147 Epoch Average loss=0.0121: 100%|██████████| 148/148 [00:01<00:00, 97.18it/s]


VALIDATION LOSS (Average) : 0.208837
EPOCH: 90


Loss=0.1884506642818451 Batch_id=147 Epoch Average loss=0.0121: 100%|██████████| 148/148 [00:01<00:00, 103.38it/s]


VALIDATION LOSS (Average) : 0.216522
EPOCH: 91


Loss=0.08520908653736115 Batch_id=147 Epoch Average loss=0.0120: 100%|██████████| 148/148 [00:01<00:00, 111.71it/s]


VALIDATION LOSS (Average) : 0.237814
EPOCH: 92


Loss=0.17827147245407104 Batch_id=147 Epoch Average loss=0.0118: 100%|██████████| 148/148 [00:01<00:00, 118.47it/s]


VALIDATION LOSS (Average) : 0.218563
EPOCH: 93


Loss=0.10928837954998016 Batch_id=147 Epoch Average loss=0.0116: 100%|██████████| 148/148 [00:01<00:00, 111.60it/s]


VALIDATION LOSS (Average) : 0.220678
EPOCH: 94


Loss=0.1039886325597763 Batch_id=147 Epoch Average loss=0.0115: 100%|██████████| 148/148 [00:01<00:00, 114.19it/s]


VALIDATION LOSS (Average) : 0.232033
EPOCH: 95


Loss=0.10319127142429352 Batch_id=147 Epoch Average loss=0.0118: 100%|██████████| 148/148 [00:01<00:00, 118.35it/s]


VALIDATION LOSS (Average) : 0.217746
EPOCH: 96


Loss=0.08386844396591187 Batch_id=147 Epoch Average loss=0.0116: 100%|██████████| 148/148 [00:01<00:00, 115.49it/s]


VALIDATION LOSS (Average) : 0.243213
EPOCH: 97


Loss=0.0456571951508522 Batch_id=147 Epoch Average loss=0.0117: 100%|██████████| 148/148 [00:01<00:00, 112.25it/s]


VALIDATION LOSS (Average) : 0.218011
EPOCH: 98


Loss=0.12034478038549423 Batch_id=147 Epoch Average loss=0.0117: 100%|██████████| 148/148 [00:01<00:00, 110.81it/s]


VALIDATION LOSS (Average) : 0.240908
EPOCH: 99


Loss=0.1265222728252411 Batch_id=147 Epoch Average loss=0.0114: 100%|██████████| 148/148 [00:01<00:00, 114.22it/s]


VALIDATION LOSS (Average) : 0.239915
EPOCH: 100


Loss=0.05663932487368584 Batch_id=147 Epoch Average loss=0.0115: 100%|██████████| 148/148 [00:01<00:00, 105.38it/s]


VALIDATION LOSS (Average) : 0.231237
TEST LOSS (Average) : 0.107908
----------------------training complete for V-----------------
----------------------training started for EI_joy-----------------
EPOCH: 1


Loss=0.17213749885559082 Batch_id=201 Epoch Average loss=0.0254: 100%|██████████| 202/202 [00:01<00:00, 110.06it/s]


VALIDATION LOSS (Average) : 0.15513
EPOCH: 2


Loss=0.16308516263961792 Batch_id=201 Epoch Average loss=0.0237: 100%|██████████| 202/202 [00:01<00:00, 111.96it/s]


VALIDATION LOSS (Average) : 0.162902
EPOCH: 3


Loss=0.09934140741825104 Batch_id=201 Epoch Average loss=0.0231: 100%|██████████| 202/202 [00:01<00:00, 116.03it/s]


VALIDATION LOSS (Average) : 0.148836
EPOCH: 4


Loss=0.19119691848754883 Batch_id=201 Epoch Average loss=0.0229: 100%|██████████| 202/202 [00:01<00:00, 122.47it/s]


VALIDATION LOSS (Average) : 0.172641
EPOCH: 5


Loss=0.13697129487991333 Batch_id=201 Epoch Average loss=0.0221: 100%|██████████| 202/202 [00:01<00:00, 116.37it/s]


VALIDATION LOSS (Average) : 0.146706
EPOCH: 6


Loss=0.1570802628993988 Batch_id=201 Epoch Average loss=0.0218: 100%|██████████| 202/202 [00:01<00:00, 119.88it/s]


VALIDATION LOSS (Average) : 0.149856
EPOCH: 7


Loss=0.22088392078876495 Batch_id=201 Epoch Average loss=0.0214: 100%|██████████| 202/202 [00:01<00:00, 120.92it/s]


VALIDATION LOSS (Average) : 0.157121
EPOCH: 8


Loss=0.15628838539123535 Batch_id=201 Epoch Average loss=0.0213: 100%|██████████| 202/202 [00:01<00:00, 119.14it/s]


VALIDATION LOSS (Average) : 0.1455
EPOCH: 9


Loss=0.21564893424510956 Batch_id=201 Epoch Average loss=0.0210: 100%|██████████| 202/202 [00:01<00:00, 111.26it/s]


VALIDATION LOSS (Average) : 0.143792
EPOCH: 10


Loss=0.18085049092769623 Batch_id=201 Epoch Average loss=0.0206: 100%|██████████| 202/202 [00:01<00:00, 116.97it/s]


VALIDATION LOSS (Average) : 0.149324
EPOCH: 11


Loss=0.20465201139450073 Batch_id=201 Epoch Average loss=0.0204: 100%|██████████| 202/202 [00:01<00:00, 108.85it/s]


VALIDATION LOSS (Average) : 0.14457
EPOCH: 12


Loss=0.1766718327999115 Batch_id=201 Epoch Average loss=0.0201: 100%|██████████| 202/202 [00:01<00:00, 110.01it/s]


VALIDATION LOSS (Average) : 0.130198
EPOCH: 13


Loss=0.18204550445079803 Batch_id=201 Epoch Average loss=0.0203: 100%|██████████| 202/202 [00:01<00:00, 107.72it/s]


VALIDATION LOSS (Average) : 0.156748
EPOCH: 14


Loss=0.11430492997169495 Batch_id=201 Epoch Average loss=0.0195: 100%|██████████| 202/202 [00:01<00:00, 110.33it/s]


VALIDATION LOSS (Average) : 0.141136
EPOCH: 15


Loss=0.14084112644195557 Batch_id=201 Epoch Average loss=0.0194: 100%|██████████| 202/202 [00:01<00:00, 117.67it/s]


VALIDATION LOSS (Average) : 0.138067
EPOCH: 16


Loss=0.174467533826828 Batch_id=201 Epoch Average loss=0.0193: 100%|██████████| 202/202 [00:01<00:00, 113.09it/s]


VALIDATION LOSS (Average) : 0.154059
EPOCH: 17


Loss=0.20240332186222076 Batch_id=201 Epoch Average loss=0.0191: 100%|██████████| 202/202 [00:01<00:00, 111.58it/s]


VALIDATION LOSS (Average) : 0.148999
EPOCH: 18


Loss=0.13031893968582153 Batch_id=201 Epoch Average loss=0.0187: 100%|██████████| 202/202 [00:01<00:00, 121.38it/s]


VALIDATION LOSS (Average) : 0.1645
EPOCH: 19


Loss=0.1736459583044052 Batch_id=201 Epoch Average loss=0.0186: 100%|██████████| 202/202 [00:01<00:00, 118.34it/s]


VALIDATION LOSS (Average) : 0.169261
EPOCH: 20


Loss=0.12102460116147995 Batch_id=201 Epoch Average loss=0.0181: 100%|██████████| 202/202 [00:01<00:00, 112.43it/s]


VALIDATION LOSS (Average) : 0.158943
EPOCH: 21


Loss=0.14271703362464905 Batch_id=201 Epoch Average loss=0.0180: 100%|██████████| 202/202 [00:01<00:00, 121.66it/s]


VALIDATION LOSS (Average) : 0.143044
EPOCH: 22


Loss=0.10211828351020813 Batch_id=201 Epoch Average loss=0.0179: 100%|██████████| 202/202 [00:01<00:00, 123.87it/s]


VALIDATION LOSS (Average) : 0.163873
EPOCH: 23


Loss=0.15883250534534454 Batch_id=201 Epoch Average loss=0.0175: 100%|██████████| 202/202 [00:01<00:00, 124.33it/s]


VALIDATION LOSS (Average) : 0.156286
EPOCH: 24


Loss=0.18726244568824768 Batch_id=201 Epoch Average loss=0.0177: 100%|██████████| 202/202 [00:01<00:00, 123.73it/s]


VALIDATION LOSS (Average) : 0.157659
EPOCH: 25


Loss=0.14593157172203064 Batch_id=201 Epoch Average loss=0.0175: 100%|██████████| 202/202 [00:01<00:00, 123.02it/s]


VALIDATION LOSS (Average) : 0.177973
EPOCH: 26


Loss=0.08733449876308441 Batch_id=201 Epoch Average loss=0.0172: 100%|██████████| 202/202 [00:01<00:00, 121.19it/s]


VALIDATION LOSS (Average) : 0.141679
EPOCH: 27


Loss=0.06250105053186417 Batch_id=201 Epoch Average loss=0.0166: 100%|██████████| 202/202 [00:01<00:00, 124.36it/s]


VALIDATION LOSS (Average) : 0.14848
EPOCH: 28


Loss=0.2172383815050125 Batch_id=201 Epoch Average loss=0.0169: 100%|██████████| 202/202 [00:01<00:00, 125.08it/s]


VALIDATION LOSS (Average) : 0.164421
EPOCH: 29


Loss=0.18699049949645996 Batch_id=201 Epoch Average loss=0.0164: 100%|██████████| 202/202 [00:01<00:00, 121.37it/s]


VALIDATION LOSS (Average) : 0.173445
EPOCH: 30


Loss=0.1684301197528839 Batch_id=201 Epoch Average loss=0.0165: 100%|██████████| 202/202 [00:01<00:00, 118.14it/s]


VALIDATION LOSS (Average) : 0.175166
EPOCH: 31


Loss=0.08582602441310883 Batch_id=201 Epoch Average loss=0.0163: 100%|██████████| 202/202 [00:01<00:00, 114.20it/s]


VALIDATION LOSS (Average) : 0.156597
EPOCH: 32


Loss=0.1358548402786255 Batch_id=201 Epoch Average loss=0.0157: 100%|██████████| 202/202 [00:01<00:00, 112.08it/s]


VALIDATION LOSS (Average) : 0.167277
EPOCH: 33


Loss=0.13471385836601257 Batch_id=201 Epoch Average loss=0.0162: 100%|██████████| 202/202 [00:01<00:00, 113.63it/s]


VALIDATION LOSS (Average) : 0.161725
EPOCH: 34


Loss=0.16356100142002106 Batch_id=201 Epoch Average loss=0.0160: 100%|██████████| 202/202 [00:01<00:00, 113.31it/s]


VALIDATION LOSS (Average) : 0.160246
EPOCH: 35


Loss=0.1504097282886505 Batch_id=201 Epoch Average loss=0.0158: 100%|██████████| 202/202 [00:01<00:00, 113.74it/s]


VALIDATION LOSS (Average) : 0.156511
EPOCH: 36


Loss=0.11722269654273987 Batch_id=201 Epoch Average loss=0.0154: 100%|██████████| 202/202 [00:01<00:00, 115.56it/s]


VALIDATION LOSS (Average) : 0.172793
EPOCH: 37


Loss=0.0882093757390976 Batch_id=201 Epoch Average loss=0.0150: 100%|██████████| 202/202 [00:01<00:00, 113.97it/s]


VALIDATION LOSS (Average) : 0.159451
EPOCH: 38


Loss=0.09841026365756989 Batch_id=201 Epoch Average loss=0.0150: 100%|██████████| 202/202 [00:01<00:00, 119.19it/s]


VALIDATION LOSS (Average) : 0.156932
EPOCH: 39


Loss=0.15222607553005219 Batch_id=201 Epoch Average loss=0.0150: 100%|██████████| 202/202 [00:01<00:00, 119.73it/s]


VALIDATION LOSS (Average) : 0.159006
EPOCH: 40


Loss=0.07490606606006622 Batch_id=201 Epoch Average loss=0.0151: 100%|██████████| 202/202 [00:01<00:00, 115.11it/s]


VALIDATION LOSS (Average) : 0.178325
EPOCH: 41


Loss=0.1268351823091507 Batch_id=201 Epoch Average loss=0.0149: 100%|██████████| 202/202 [00:01<00:00, 122.15it/s]


VALIDATION LOSS (Average) : 0.170174
EPOCH: 42


Loss=0.09572718292474747 Batch_id=201 Epoch Average loss=0.0147: 100%|██████████| 202/202 [00:01<00:00, 121.07it/s]


VALIDATION LOSS (Average) : 0.167649
EPOCH: 43


Loss=0.11084111034870148 Batch_id=201 Epoch Average loss=0.0145: 100%|██████████| 202/202 [00:01<00:00, 115.74it/s]


VALIDATION LOSS (Average) : 0.153727
EPOCH: 44


Loss=0.11630814522504807 Batch_id=201 Epoch Average loss=0.0147: 100%|██████████| 202/202 [00:01<00:00, 116.47it/s]


VALIDATION LOSS (Average) : 0.174671
EPOCH: 45


Loss=0.12735334038734436 Batch_id=201 Epoch Average loss=0.0145: 100%|██████████| 202/202 [00:01<00:00, 120.38it/s]


VALIDATION LOSS (Average) : 0.177526
EPOCH: 46


Loss=0.04627999663352966 Batch_id=201 Epoch Average loss=0.0146: 100%|██████████| 202/202 [00:01<00:00, 113.07it/s]


VALIDATION LOSS (Average) : 0.178489
EPOCH: 47


Loss=0.11003836244344711 Batch_id=201 Epoch Average loss=0.0144: 100%|██████████| 202/202 [00:01<00:00, 113.49it/s]


VALIDATION LOSS (Average) : 0.173976
EPOCH: 48


Loss=0.09360246360301971 Batch_id=201 Epoch Average loss=0.0143: 100%|██████████| 202/202 [00:01<00:00, 110.23it/s]


VALIDATION LOSS (Average) : 0.149869
EPOCH: 49


Loss=0.08639783412218094 Batch_id=201 Epoch Average loss=0.0142: 100%|██████████| 202/202 [00:01<00:00, 112.32it/s]


VALIDATION LOSS (Average) : 0.153812
EPOCH: 50


Loss=0.11711503565311432 Batch_id=201 Epoch Average loss=0.0141: 100%|██████████| 202/202 [00:01<00:00, 111.09it/s]


VALIDATION LOSS (Average) : 0.156892
EPOCH: 51


Loss=0.17452958226203918 Batch_id=201 Epoch Average loss=0.0140: 100%|██████████| 202/202 [00:01<00:00, 105.30it/s]


VALIDATION LOSS (Average) : 0.151953
EPOCH: 52


Loss=0.04192594438791275 Batch_id=201 Epoch Average loss=0.0136: 100%|██████████| 202/202 [00:01<00:00, 111.17it/s]


VALIDATION LOSS (Average) : 0.156972
EPOCH: 53


Loss=0.1347198337316513 Batch_id=201 Epoch Average loss=0.0136: 100%|██████████| 202/202 [00:01<00:00, 116.44it/s]


VALIDATION LOSS (Average) : 0.123708
EPOCH: 54


Loss=0.08747248351573944 Batch_id=201 Epoch Average loss=0.0137: 100%|██████████| 202/202 [00:01<00:00, 104.76it/s]


VALIDATION LOSS (Average) : 0.176047
EPOCH: 55


Loss=0.18549564480781555 Batch_id=201 Epoch Average loss=0.0135: 100%|██████████| 202/202 [00:01<00:00, 105.49it/s]


VALIDATION LOSS (Average) : 0.156159
EPOCH: 56


Loss=0.07080531120300293 Batch_id=201 Epoch Average loss=0.0135: 100%|██████████| 202/202 [00:01<00:00, 113.41it/s]


VALIDATION LOSS (Average) : 0.171155
EPOCH: 57


Loss=0.1358564794063568 Batch_id=201 Epoch Average loss=0.0134: 100%|██████████| 202/202 [00:01<00:00, 116.88it/s]


VALIDATION LOSS (Average) : 0.163982
EPOCH: 58


Loss=0.06963866949081421 Batch_id=201 Epoch Average loss=0.0131: 100%|██████████| 202/202 [00:01<00:00, 113.71it/s]


VALIDATION LOSS (Average) : 0.174432
EPOCH: 59


Loss=0.11929325014352798 Batch_id=201 Epoch Average loss=0.0133: 100%|██████████| 202/202 [00:01<00:00, 108.70it/s]


VALIDATION LOSS (Average) : 0.155464
EPOCH: 60


Loss=0.1636955738067627 Batch_id=201 Epoch Average loss=0.0132: 100%|██████████| 202/202 [00:01<00:00, 112.81it/s]


VALIDATION LOSS (Average) : 0.163597
EPOCH: 61


Loss=0.11859740316867828 Batch_id=201 Epoch Average loss=0.0130: 100%|██████████| 202/202 [00:01<00:00, 116.08it/s]


VALIDATION LOSS (Average) : 0.150405
EPOCH: 62


Loss=0.11832430958747864 Batch_id=201 Epoch Average loss=0.0129: 100%|██████████| 202/202 [00:01<00:00, 109.63it/s]


VALIDATION LOSS (Average) : 0.148902
EPOCH: 63


Loss=0.07334844768047333 Batch_id=201 Epoch Average loss=0.0131: 100%|██████████| 202/202 [00:01<00:00, 111.36it/s]


VALIDATION LOSS (Average) : 0.163418
EPOCH: 64


Loss=0.06957235932350159 Batch_id=201 Epoch Average loss=0.0127: 100%|██████████| 202/202 [00:01<00:00, 117.48it/s]


VALIDATION LOSS (Average) : 0.140488
EPOCH: 65


Loss=0.09159237146377563 Batch_id=201 Epoch Average loss=0.0127: 100%|██████████| 202/202 [00:01<00:00, 117.70it/s]


VALIDATION LOSS (Average) : 0.146715
EPOCH: 66


Loss=0.11322818696498871 Batch_id=201 Epoch Average loss=0.0123: 100%|██████████| 202/202 [00:01<00:00, 121.30it/s]


VALIDATION LOSS (Average) : 0.162462
EPOCH: 67


Loss=0.1572876274585724 Batch_id=201 Epoch Average loss=0.0125: 100%|██████████| 202/202 [00:01<00:00, 125.17it/s]


VALIDATION LOSS (Average) : 0.157813
EPOCH: 68


Loss=0.12878495454788208 Batch_id=201 Epoch Average loss=0.0125: 100%|██████████| 202/202 [00:01<00:00, 125.81it/s]


VALIDATION LOSS (Average) : 0.160099
EPOCH: 69


Loss=0.09417425096035004 Batch_id=201 Epoch Average loss=0.0127: 100%|██████████| 202/202 [00:01<00:00, 123.06it/s]


VALIDATION LOSS (Average) : 0.158002
EPOCH: 70


Loss=0.12587225437164307 Batch_id=201 Epoch Average loss=0.0125: 100%|██████████| 202/202 [00:01<00:00, 123.56it/s]


VALIDATION LOSS (Average) : 0.166519
EPOCH: 71


Loss=0.10009409487247467 Batch_id=201 Epoch Average loss=0.0121: 100%|██████████| 202/202 [00:01<00:00, 123.86it/s]


VALIDATION LOSS (Average) : 0.149122
EPOCH: 72


Loss=0.08395520597696304 Batch_id=201 Epoch Average loss=0.0124: 100%|██████████| 202/202 [00:01<00:00, 123.52it/s]


VALIDATION LOSS (Average) : 0.168399
EPOCH: 73


Loss=0.13088789582252502 Batch_id=201 Epoch Average loss=0.0122: 100%|██████████| 202/202 [00:01<00:00, 127.01it/s]


VALIDATION LOSS (Average) : 0.150729
EPOCH: 74


Loss=0.10909497737884521 Batch_id=201 Epoch Average loss=0.0119: 100%|██████████| 202/202 [00:01<00:00, 124.66it/s]


VALIDATION LOSS (Average) : 0.157448
EPOCH: 75


Loss=0.10987119376659393 Batch_id=201 Epoch Average loss=0.0121: 100%|██████████| 202/202 [00:01<00:00, 122.59it/s]


VALIDATION LOSS (Average) : 0.171324
EPOCH: 76


Loss=0.09602971374988556 Batch_id=201 Epoch Average loss=0.0119: 100%|██████████| 202/202 [00:01<00:00, 119.37it/s]


VALIDATION LOSS (Average) : 0.171072
EPOCH: 77


Loss=0.13180696964263916 Batch_id=201 Epoch Average loss=0.0119: 100%|██████████| 202/202 [00:01<00:00, 114.66it/s]


VALIDATION LOSS (Average) : 0.159278
EPOCH: 78


Loss=0.08565736562013626 Batch_id=201 Epoch Average loss=0.0117: 100%|██████████| 202/202 [00:01<00:00, 120.23it/s]


VALIDATION LOSS (Average) : 0.159692
EPOCH: 79


Loss=0.10066226124763489 Batch_id=201 Epoch Average loss=0.0119: 100%|██████████| 202/202 [00:01<00:00, 115.80it/s]


VALIDATION LOSS (Average) : 0.177515
EPOCH: 80


Loss=0.056128036230802536 Batch_id=201 Epoch Average loss=0.0117: 100%|██████████| 202/202 [00:01<00:00, 116.58it/s]


VALIDATION LOSS (Average) : 0.173695
EPOCH: 81


Loss=0.07190266251564026 Batch_id=201 Epoch Average loss=0.0114: 100%|██████████| 202/202 [00:01<00:00, 117.96it/s]


VALIDATION LOSS (Average) : 0.175378
EPOCH: 82


Loss=0.11152997612953186 Batch_id=201 Epoch Average loss=0.0118: 100%|██████████| 202/202 [00:01<00:00, 114.42it/s]


VALIDATION LOSS (Average) : 0.173
EPOCH: 83


Loss=0.0919499546289444 Batch_id=201 Epoch Average loss=0.0117: 100%|██████████| 202/202 [00:01<00:00, 110.98it/s]


VALIDATION LOSS (Average) : 0.179311
EPOCH: 84


Loss=0.1178470253944397 Batch_id=201 Epoch Average loss=0.0117: 100%|██████████| 202/202 [00:01<00:00, 115.62it/s]


VALIDATION LOSS (Average) : 0.164005
EPOCH: 85


Loss=0.09060624986886978 Batch_id=201 Epoch Average loss=0.0114: 100%|██████████| 202/202 [00:01<00:00, 112.43it/s]


VALIDATION LOSS (Average) : 0.167643
EPOCH: 86


Loss=0.08455686271190643 Batch_id=201 Epoch Average loss=0.0114: 100%|██████████| 202/202 [00:01<00:00, 111.85it/s]


VALIDATION LOSS (Average) : 0.153823
EPOCH: 87


Loss=0.11740852892398834 Batch_id=201 Epoch Average loss=0.0116: 100%|██████████| 202/202 [00:01<00:00, 121.04it/s]


VALIDATION LOSS (Average) : 0.162978
EPOCH: 88


Loss=0.11472780257463455 Batch_id=201 Epoch Average loss=0.0110: 100%|██████████| 202/202 [00:01<00:00, 112.40it/s]


VALIDATION LOSS (Average) : 0.168091
EPOCH: 89


Loss=0.05697011947631836 Batch_id=201 Epoch Average loss=0.0110: 100%|██████████| 202/202 [00:01<00:00, 117.68it/s]


VALIDATION LOSS (Average) : 0.155829
EPOCH: 90


Loss=0.11854969710111618 Batch_id=201 Epoch Average loss=0.0113: 100%|██████████| 202/202 [00:01<00:00, 112.76it/s]


VALIDATION LOSS (Average) : 0.171467
EPOCH: 91


Loss=0.0638270303606987 Batch_id=201 Epoch Average loss=0.0112: 100%|██████████| 202/202 [00:01<00:00, 116.11it/s]


VALIDATION LOSS (Average) : 0.161674
EPOCH: 92


Loss=0.08973442018032074 Batch_id=201 Epoch Average loss=0.0111: 100%|██████████| 202/202 [00:01<00:00, 111.74it/s]


VALIDATION LOSS (Average) : 0.161696
EPOCH: 93


Loss=0.1064705103635788 Batch_id=201 Epoch Average loss=0.0111: 100%|██████████| 202/202 [00:01<00:00, 111.41it/s]


VALIDATION LOSS (Average) : 0.170582
EPOCH: 94


Loss=0.11896196007728577 Batch_id=201 Epoch Average loss=0.0110: 100%|██████████| 202/202 [00:01<00:00, 113.90it/s]


VALIDATION LOSS (Average) : 0.181556
EPOCH: 95


Loss=0.0737551897764206 Batch_id=201 Epoch Average loss=0.0110: 100%|██████████| 202/202 [00:01<00:00, 126.05it/s]


VALIDATION LOSS (Average) : 0.170476
EPOCH: 96


Loss=0.05441451072692871 Batch_id=201 Epoch Average loss=0.0109: 100%|██████████| 202/202 [00:01<00:00, 124.58it/s]


VALIDATION LOSS (Average) : 0.177659
EPOCH: 97


Loss=0.09940747916698456 Batch_id=201 Epoch Average loss=0.0105: 100%|██████████| 202/202 [00:01<00:00, 120.53it/s]


VALIDATION LOSS (Average) : 0.169883
EPOCH: 98


Loss=0.1010756641626358 Batch_id=201 Epoch Average loss=0.0108: 100%|██████████| 202/202 [00:01<00:00, 122.32it/s]


VALIDATION LOSS (Average) : 0.174023
EPOCH: 99


Loss=0.09705574065446854 Batch_id=201 Epoch Average loss=0.0109: 100%|██████████| 202/202 [00:01<00:00, 121.54it/s]


VALIDATION LOSS (Average) : 0.175964
EPOCH: 100


Loss=0.07479342818260193 Batch_id=201 Epoch Average loss=0.0111: 100%|██████████| 202/202 [00:01<00:00, 121.60it/s]


VALIDATION LOSS (Average) : 0.181528
TEST LOSS (Average) : 0.179269
----------------------training complete for EI_joy-----------------
test loss for  EI_anger  	:			 0.158792
test loss for  EI_fear  	:			 0.111371
test loss for  EI_sadness  	:			 0.180681
test loss for  V  	:			 0.107908
test loss for  EI_joy  	:			 0.179269


In [137]:
# for name, param in fisher_dict['V'].items():
#   print (name, param)

# for name, value in fisher_dict.items():
#   print(name, value)

In [138]:
# # EXECUTION

# lr = 2e-5
# optimizer = optim.Adam(model.parameters(), lr=lr)
# domain_loss_function= nn.BCEWithLogitsLoss()
# regression_loss_function = nn.L1Loss()


# model = model.to(DEVICE)
# domain_loss_function = domain_loss_function.to(DEVICE)
# regression_loss_function = regression_loss_function.to(DEVICE)

# # train_losses = [] # to capture train losses over training epochs
# train_accuracy = [] # to capture train accuracy over training epochs
# # val_losses = [] # to capture validation loss
# # test_losses = [] # to capture test losses 
# # test_accuracy = [] # to capture test accuracy 

# # EPOCHS = 2
# EPOCHS = 100
# # dict_val_loss = {}
# # dict_test_loss = {}


# train_regresion_losses = [] # to capture train losses over training epochs
# train_domain_losses = []
# train_accuracy = [] # to capture train accuracy over training epochs
# # valid_regresion_losses = [] # to capture validation loss
# # test_regresion_losses = [] # to capture test losses 
# total_test_regression_loss =[]
# total_valid_regression_loss =[]
# # print(f'----------------------training started for {name}-----------------')
# for epoch in range(EPOCHS):
#   print("EPOCH:", epoch+1)
#   train_model(model, DEVICE, train_iterator, optimizer, epoch)
#   # print("for validation.......")
#   # val_name = train_name.replace("train", "val" )
#   # test_model(typical_model, device, dict_val_loader[val_name], mode = 'val')
#   test_model(model, DEVICE, valid_iterator, mode = 'val')


#   # print("for test  .......")
#   # test_name = train_name.replace("train", "test" )
#   # test_model(typical_model, device, dict_test_loader[test_name], mode = 'test')
#   test_model(model, DEVICE, test_iterator, mode = 'test')

# # dict_val_loss[name] = val_losses
# # dict_test_loss[name] = test_losses

# model_name = "Non_DANN"+".pt"
# torch.save(model.state_dict(), os.path.join(MODEL_DIR, model_name))
# # print(f'----------------------training complete for {name}-----------------')
# # print(dict_val_loss.items())
# # print(dict_test_loss.items())

## DANN Model - Training and Testing

In [139]:
def compute_accuracy(logits, labels):
    
    predicted_labels_dict = {
      0: 0,
      1: 0,
    }
    
    predicted_label = logits.max(dim = 1)[1]
    
    for pred in predicted_label:
        predicted_labels_dict[pred.item()] += 1
    acc = (predicted_label == labels).float().mean()
    
    return acc, predicted_labels_dict

In [140]:
# def binary_acc(y_pred, y_test):
#     y_pred_tag = torch.round(torch.sigmoid(y_pred))

#     correct_results_sum = (y_pred_tag == y_test).sum().float()
#     acc = correct_results_sum/y_test.shape[0]
#     acc = torch.round(acc * 100)
    
#     return acc

In [141]:
# def evaluate(model, dataloader, mode = 'test', percentage = 5):
#     with torch.no_grad():
#         predicted_labels_dict = {                                                   
#           0: 0,                                                                     
#           1: 0,                                                                     
#         }
        
#         mean_accuracy = 0.0
#         # total_batches = len(dataloader)
#         # print("total_batches: ",total_batches )

        



#             sentiment_pred, _ = model(**inputs)
#             accuracy, predicted_labels = compute_accuracy(sentiment_pred, inputs["labels"])
#             mean_accuracy += accuracy
#             predicted_labels_dict[0] += predicted_labels[0]
#             predicted_labels_dict[1] += predicted_labels[1]  
#         print(predicted_labels_dict)
#     return mean_accuracy/total_batches

### Execution DANN

In [142]:
## DANN For multiple datasets
# n_epochs = 100 # number of epochs
# n_epochs = 2 # number of epochs
n_epochs = NUM_EPOCHS
lr = 2e-5

dict_dann_losses_list = {}

dict_dann_model_saved= {}
for name, model_arch in dict_model_arch.items():
  model = model_arch
  optimizer = optim.Adam(model.parameters(), lr=lr)
  model = model.to(DEVICE)
  model.train()  
  domain_loss_function= nn.NLLLoss()
  regression_loss_function = nn.L1Loss()
  domain_loss_function = domain_loss_function.to(DEVICE)
  regression_loss_function = regression_loss_function.to(DEVICE)
  max_batches = min(len(dict_iterator[name]['train_iterator']), len(dict_target_iterator[name]))
  # max_batches = min(len(train_iterator), len(target_data)//TARGET_BATCH_SIZE)
  # max_batches = min(len(train_iterator), len(target_iterator))

  # print(max_batches)
  train_losses = [] # to capture train losses over training epochs
  val_losses = [] # to capture validation loss over epochs
  print(f'----------------------training started for DANN model - {name}-----------------')
  for epoch_idx in range(n_epochs):
      # source_iterator = iter(train_iterator) #single dataset
      source_iterator = iter(dict_iterator[name]['train_iterator'])
      # target_iterator = iter(target_iterator) #single dataset
      target_iterator = iter(dict_target_iterator[name]) 
      epoch_loss = 0
      processed = 0
      for batch_idx in range(max_batches):
          
          p = float(batch_idx + epoch_idx * max_batches) / (n_epochs * max_batches)
          alpha = 2. / (1. + np.exp(-10 * p)) - 1
          alpha = torch.tensor(alpha)
          
          # model.train()          
          optimizer.zero_grad()

          ## SOURCE DATASET TRAINING UPDATE
          
          source_batch = next(source_iterator)
          source_tweets, source_intensities = source_batch.tweet.to(DEVICE), source_batch.intensity.to(DEVICE)  # plural, we are not interested in domain
          
          source_intensity_outputs, source_domain_outputs = model(source_tweets, alpha = alpha)

          loss_source_regression= regression_loss_function(source_intensity_outputs,source_intensities.unsqueeze(1)) # Computing regression loss

          source_domain_inputs = torch.zeros(len(source_batch), dtype=torch.long).to(DEVICE) # source domain has 0 id
          loss_source_domain = domain_loss_function(source_domain_outputs,source_domain_inputs)

          epoch_loss += loss_source_regression.item()
          processed += len(source_tweets)


          ## TARGET DATASET TRAINING UPDATE
          target_batch = next(iter(target_iterator))
          target_tweets= target_batch.tweet.to(DEVICE) # plural

          _, target_domain_outputs = model(target_tweets, alpha = alpha)

          target_domain_inputs = torch.ones(len(target_batch), dtype=torch.long).to(DEVICE) # target domain has 1 id
          loss_target_domain = domain_loss_function(target_domain_outputs,target_domain_inputs)

          # COMBINING LOSS
          loss = loss_source_regression + loss_source_domain + loss_target_domain
          loss.backward()
          optimizer.step()

          if (batch_idx % 100 == 0):
            print("Epoch [{}/{}] Step [{}/{}]: domain_loss_target={:.4f} / domain_loss_source={:.4f} / regression_loss_source={:.4f} / alpha={:.4f}"
                .format(epoch_idx + 1,
                        n_epochs,
                        batch_idx + 1,
                        max_batches,
                        loss_target_domain.item()
                        ,loss_source_domain.item()
                        ,loss_source_regression.item(),alpha))

      # After every epoch
      avg_train_regression_loss =  float("{:.6f}".format(epoch_loss/processed))
      train_losses.append(avg_train_regression_loss)

      # Evaluate the model after every epoch
      # test_model(model, DEVICE, valid_iterator, mode = 'val') # single model
      avg_valid_regression_loss = test_model(model, DEVICE, dict_iterator[name]['val_iterator'], mode = 'val')
      val_losses.append(avg_valid_regression_loss) 
  
  
  # testing the model when all epochs are finished (outsied epoch loop)
  # test_model(model, DEVICE, test_iterator, mode = 'test') # single model
  test_loss = test_model(model, DEVICE, dict_iterator[name]['test_iterator'], mode = 'test')
  dict_dann_losses_list [name] = {'train_losses' : train_losses, 'val_losses': val_losses, 'test_loss' : test_loss }


  model_name = name + "_" + str(time.strftime("%d_%m")) + "_dann_"+EMBEDDING_TO_BE_USED+".pt" 
  torch.save(model.state_dict(), os.path.join(MODEL_DIR,model_name))
  dict_dann_model_saved[name] = model_name
  print(f'----------------------training complete for DANN model - {name}-----------------')

for name, values in dict_non_dann_losses_list.items():
    print ("test loss for ", name, " \t:\t\t\t", values['test_loss'])

----------------------training started for DANN model - EI_anger-----------------
Epoch [1/100] Step [1/213]: domain_loss_target=0.8042 / domain_loss_source=0.5935 / regression_loss_source=0.0675 / alpha=0.0000
Epoch [1/100] Step [101/213]: domain_loss_target=0.8284 / domain_loss_source=0.5260 / regression_loss_source=0.0820 / alpha=0.0235
Epoch [1/100] Step [201/213]: domain_loss_target=0.8671 / domain_loss_source=0.4818 / regression_loss_source=0.0747 / alpha=0.0469
VALIDATION LOSS (Average) : 0.249016
Epoch [2/100] Step [1/213]: domain_loss_target=0.8948 / domain_loss_source=0.4813 / regression_loss_source=0.1524 / alpha=0.0500
Epoch [2/100] Step [101/213]: domain_loss_target=0.8649 / domain_loss_source=0.4717 / regression_loss_source=0.0280 / alpha=0.0733
Epoch [2/100] Step [201/213]: domain_loss_target=0.7367 / domain_loss_source=0.4229 / regression_loss_source=0.0270 / alpha=0.0966
VALIDATION LOSS (Average) : 0.283739
Epoch [3/100] Step [1/213]: domain_loss_target=0.8800 / domain

## DANN with Elastic weight consolidation

### Loading Models

In [143]:
# Loading Models for EWC, we copy the weights of these models for EWC
dict_non_dann_model_saved ={}
for name, _ in dict_model_arch.items():
  dict_non_dann_model_saved[name] = name +"_" +str(time.strftime("%d_%m"))+ "_non_dann_"+EMBEDDING_TO_BE_USED+".pt"
print(dict_non_dann_model_saved)


dict_loaded_models_for_ewc = {}
for name, non_dann_model_name in dict_non_dann_model_saved.items():
  
  INPUT_DIM = len(dict_fields[name]['Tweet'][1].vocab)
  print(name, INPUT_DIM)
  PAD_IDX = dict_fields[name]['Tweet'][1].vocab.stoi[dict_fields[name]['Tweet'][1].pad_token]

  loaded_model_non_dann = CNN1d(INPUT_DIM, EMBEDDING_DIM, N_FILTERS, FILTER_SIZES, OUTPUT_DIM, DROPOUT, PAD_IDX)
  loaded_model_non_dann.load_state_dict(torch.load(os.path.join(MODEL_DIR, non_dann_model_name),map_location=torch.device(DEVICE)))

  dict_loaded_models_for_ewc[name]=loaded_model_non_dann

print(dict_loaded_models_for_ewc)

{'EI_anger': 'EI_anger_02_01_non_dann_glove.pt', 'EI_fear': 'EI_fear_02_01_non_dann_glove.pt', 'EI_sadness': 'EI_sadness_02_01_non_dann_glove.pt', 'V': 'V_02_01_non_dann_glove.pt', 'EI_joy': 'EI_joy_02_01_non_dann_glove.pt'}
EI_anger 4689
EI_fear 5544
EI_sadness 4859
V 4320
EI_joy 4653
{'EI_anger': CNN1d(
  (embedding): Embedding(4689, 100, padding_idx=1)
  (convs): ModuleList(
    (0): Conv1d(100, 100, kernel_size=(2,), stride=(1,))
    (1): Conv1d(100, 100, kernel_size=(3,), stride=(1,))
    (2): Conv1d(100, 100, kernel_size=(4,), stride=(1,))
    (3): Conv1d(100, 100, kernel_size=(5,), stride=(1,))
  )
  (regression): Sequential(
    (0): Dropout(p=0.5, inplace=False)
    (1): Linear(in_features=400, out_features=200, bias=True)
    (2): ReLU()
    (3): Linear(in_features=200, out_features=10, bias=True)
    (4): ReLU()
    (5): Linear(in_features=10, out_features=1, bias=True)
  )
  (domain_classifier): Sequential(
    (0): Dropout(p=0.5, inplace=False)
    (1): Linear(in_features=

In [144]:
# print(dict_non_dann_model_saved['V'])
# INPUT_DIM = len(dict_fields['V']['Tweet'][1].vocab)
# print(INPUT_DIM)
# PAD_IDX = dict_fields['V']['Tweet'][1].vocab.stoi[dict_fields['V']['Tweet'][1].pad_token]

# loaded_model_non_dann = CNN1d(INPUT_DIM, EMBEDDING_DIM, N_FILTERS, FILTER_SIZES, OUTPUT_DIM, DROPOUT, PAD_IDX)
# loaded_model_non_dann.load_state_dict(torch.load(os.path.join(MODEL_DIR, dict_non_dann_model_saved['V']),map_location=torch.device(DEVICE)))

# loaded_model_non_dann


### EWC Code

In [145]:
def variable(t: torch.Tensor, use_cuda=True, **kwargs):
    if torch.cuda.is_available() and use_cuda:
        t = t.cuda()
    return Variable(t, **kwargs)

In [146]:
class EWC(object):
    def __init__(self, model: nn.Module, 
                 data_loader: torch.utils.data.DataLoader
                #  dataset: list
                 ,device=DEVICE):

        self.model = model
        # print("----------MODEL--------\n",self.model)
        self.dataset = data_loader
        self.device = device

        self.params = {n: p for n, p in self.model.named_parameters() if p.requires_grad}
        # for n, p in self.params.items():
        #   print("self.params ",n," : ", p)

        self._means = {}
        for n, p in deepcopy(self.params).items():
            self._means[n] = variable(p.data)
        
        # print("--------self._means----------------\n", self._means)

        self._precision_matrices = self._diag_fisher()

    def _diag_fisher(self):
        precision_matrices = {}
        for n, p in deepcopy(self.params).items():
            p.data.zero_()
            precision_matrices[n] = variable(p.data)
        # print("--------PRECISION MATRICES----------------\n", precision_matrices)

        self.model.eval()
        for batch in self.dataset:
            self.model.zero_grad()
            tweets, intensities  = variable(batch.tweet.to(DEVICE)), variable(batch.intensity.to(DEVICE))
            y_preds,_ = model(tweets) 
            regression_loss = regression_loss_function(y_preds,intensities.unsqueeze(1))
            regression_loss.backward()
            # print("--------regression loss----------------\n", regression_loss)

            for n, p in self.model.named_parameters():
              # if type(p.grad)
              if isinstance(p.grad, type(None)):
                pass # parameter value would remain as it is
              else:
                # print(n,":", type(p.grad))
                precision_matrices[n].data += p.grad.data ** 2 / len(self.dataset)



              # print(n,":", type(p.grad))
              # if n in self.params.keys():
              #   print(n,":", type(p.grad))
                

                # # precision_matrices[n].data += p.grad.data ** 2 / len(self.dataset)
                # print(n,":",p.grad.data)
            

        precision_matrices = {n: p for n, p in precision_matrices.items()}
        return precision_matrices

    def penalty(self, model: nn.Module):
        loss = 0
        for n, p in model.named_parameters():
            _loss = self._precision_matrices[n] * (p - self._means[n]) ** 2
            loss += _loss.sum()
        return loss


# def ewc_train(model: nn.Module, optimizer: torch.optim, data_loader: torch.utils.data.DataLoader,
#               ewc: EWC, importance: float):
#     model.train()
#     pbar = tqdm(data_loader) # putting the iterator in pbar
#     processed = 0 
#     epoch_loss = 0.0
#     for batch_idx, batch in enumerate(pbar):
#         tweets, intensities = variable(batch.tweet), variable(batch.intensity)
#         optimizer.zero_grad() 
#         y_preds,_ = model(tweets) 
#         regression_loss = regression_loss_function(y_preds,intensities.unsqueeze(1)) 


#         loss = F.cross_entropy(output, target) + importance * ewc.penalty(model)
#         epoch_loss += loss.data[0]
#         loss.backward()
#         optimizer.step()
#     return epoch_loss / len(data_loader)



In [147]:
## DANN EWC For multiple datasets
# n_epochs = 100 # number of epochs
# n_epochs = 2 # number of epochs
n_epochs = NUM_EPOCHS
lr = 2e-5

dict_dann_ewc_losses_list = {}
dict_dann_ewc_model_saved= {}

# for name, model in dict_model_arch.items(): #dict_non_dann_model_saved.items() # dict_loaded_models_for_ewc
for name, existing_model in dict_loaded_models_for_ewc.items():
  model = deepcopy(existing_model) # copies params from existing model to another one https://discuss.pytorch.org/t/copying-weights-from-one-net-to-another/1492/2
  model = model.to(DEVICE)
  ewc_obj = EWC(model,dict_iterator[name]['train_iterator']) #instantiating EWC Object
  optimizer = optim.Adam(model.parameters(), lr=lr)
  
  domain_loss_function= nn.NLLLoss()
  regression_loss_function = nn.L1Loss()
  domain_loss_function = domain_loss_function.to(DEVICE)
  regression_loss_function = regression_loss_function.to(DEVICE)
  max_batches = min(len(dict_iterator[name]['train_iterator']), len(dict_target_iterator[name]))
  # max_batches = min(len(train_iterator), len(target_data)//TARGET_BATCH_SIZE)
  # max_batches = min(len(train_iterator), len(target_iterator))

  # print(max_batches)
  train_losses = [] # to capture train losses over training epochs
  val_losses = [] # to capture validation loss over epochs
  model.train()
  print(f'----------------------training started for DANN EWC model - {name}-----------------')
  for epoch_idx in range(n_epochs):
      # source_iterator = iter(train_iterator) #single dataset
      source_iterator = iter(dict_iterator[name]['train_iterator'])
      # target_iterator = iter(target_iterator) #single dataset
      target_iterator = iter(dict_target_iterator[name]) 
      epoch_loss = 0
      processed = 0
      for batch_idx in range(max_batches):
          optimizer.zero_grad()
          p = float(batch_idx + epoch_idx * max_batches) / (n_epochs * max_batches)
          alpha = 2. / (1. + np.exp(-10 * p)) - 1
          alpha = torch.tensor(alpha)
                 
          ## SOURCE DATASET TRAINING UPDATE
          
          source_batch = next(source_iterator)
          source_tweets, source_intensities = source_batch.tweet.to(DEVICE), source_batch.intensity.to(DEVICE)  # plural, we are not interested in domain
          
          source_intensity_outputs, source_domain_outputs = model(source_tweets, alpha = alpha)

          loss_source_regression= regression_loss_function(source_intensity_outputs,source_intensities.unsqueeze(1)) + EWC_LAMBDA * ewc_obj.penalty(model)# Computing regression loss

          source_domain_inputs = torch.zeros(len(source_batch), dtype=torch.long).to(DEVICE) # source domain has 0 id
          loss_source_domain = domain_loss_function(source_domain_outputs,source_domain_inputs)

          epoch_loss += loss_source_regression.item()
          processed += len(source_tweets)


          ## TARGET DATASET TRAINING UPDATE
          target_batch = next(iter(target_iterator))
          target_tweets= target_batch.tweet.to(DEVICE) # plural

          _, target_domain_outputs = model(target_tweets, alpha = alpha)

          target_domain_inputs = torch.ones(len(target_batch), dtype=torch.long).to(DEVICE) # target domain has 1 id
          loss_target_domain = domain_loss_function(target_domain_outputs,target_domain_inputs)

          # COMBINING LOSS
          loss = loss_source_regression + loss_source_domain + loss_target_domain
          loss.backward()
          optimizer.step()

          if (batch_idx % 100 == 0):
            print("Epoch [{}/{}] Step [{}/{}]: domain_loss_target={:.4f} / domain_loss_source={:.4f} / regression_loss_source={:.4f} / alpha={:.4f}"
                .format(epoch_idx + 1,
                        n_epochs,
                        batch_idx + 1,
                        max_batches,
                        loss_target_domain.item()
                        ,loss_source_domain.item()
                        ,loss_source_regression.item(),alpha))

      # After every epoch
      avg_train_regression_loss =  float("{:.6f}".format(epoch_loss/processed))
      train_losses.append(avg_train_regression_loss)

      # Evaluate the model after every epoch
      # test_model(model, DEVICE, valid_iterator, mode = 'val') # single model
      avg_valid_regression_loss = test_model(model, DEVICE, dict_iterator[name]['val_iterator'], mode = 'val')
      val_losses.append(avg_valid_regression_loss) 
  
  
  # testing the model when all epochs are finished (outsied epoch loop)
  # test_model(model, DEVICE, test_iterator, mode = 'test') # single model
  test_loss = test_model(model, DEVICE, dict_iterator[name]['test_iterator'], mode = 'test')
  dict_dann_ewc_losses_list [name] = {'train_losses' : train_losses, 'val_losses': val_losses, 'test_loss' : test_loss }


  model_name = name + "_" + str(time.strftime("%d_%m")) + "_dann_ewc_"+EMBEDDING_TO_BE_USED+".pt" 
  torch.save(model.state_dict(), os.path.join(MODEL_DIR,model_name))
  dict_dann_ewc_model_saved[name] = model_name
  print(f'----------------------training complete for DANN EWC model - {name}-----------------')

for name, values in dict_non_dann_losses_list.items():
    print ("test loss for ", name, " \t:\t\t\t", values['test_loss'])

----------------------training started for DANN EWC model - EI_anger-----------------
Epoch [1/100] Step [1/213]: domain_loss_target=0.8070 / domain_loss_source=0.6033 / regression_loss_source=0.0866 / alpha=0.0000
Epoch [1/100] Step [101/213]: domain_loss_target=0.8214 / domain_loss_source=0.5580 / regression_loss_source=0.0850 / alpha=0.0235
Epoch [1/100] Step [201/213]: domain_loss_target=0.9065 / domain_loss_source=0.4675 / regression_loss_source=0.1036 / alpha=0.0469
VALIDATION LOSS (Average) : 0.245383
Epoch [2/100] Step [1/213]: domain_loss_target=0.8699 / domain_loss_source=0.4770 / regression_loss_source=0.1339 / alpha=0.0500
Epoch [2/100] Step [101/213]: domain_loss_target=0.8329 / domain_loss_source=0.4950 / regression_loss_source=0.0233 / alpha=0.0733
Epoch [2/100] Step [201/213]: domain_loss_target=0.9179 / domain_loss_source=0.4024 / regression_loss_source=0.0261 / alpha=0.0966
VALIDATION LOSS (Average) : 0.272501
Epoch [3/100] Step [1/213]: domain_loss_target=0.8726 / do

# Equality Evaluation using Equity Evaluation Corpus

## Loading Evaluation Data
format `[ID	Sentence	Template	Person	Gender	Race Emotion	Emotion word]`

In [148]:
data_EEC = TASK1.EEC['eec']
df_EEC = pd.read_csv(data_EEC)
df_EEC.head()

Unnamed: 0,ID,Sentence,Template,Person,Gender,Race,Emotion,Emotion word
0,2018-En-mystery-05498,Alonzo feels angry.,<person subject> feels <emotion word>.,Alonzo,male,African-American,anger,angry
1,2018-En-mystery-11722,Alonzo feels furious.,<person subject> feels <emotion word>.,Alonzo,male,African-American,anger,furious
2,2018-En-mystery-11364,Alonzo feels irritated.,<person subject> feels <emotion word>.,Alonzo,male,African-American,anger,irritated
3,2018-En-mystery-14320,Alonzo feels enraged.,<person subject> feels <emotion word>.,Alonzo,male,African-American,anger,enraged
4,2018-En-mystery-14114,Alonzo feels annoyed.,<person subject> feels <emotion word>.,Alonzo,male,African-American,anger,annoyed


## Creating evalution function (includes pre-processing)

In [149]:
## padding function : adds padding / truncates to max size
def pad_or_truncate(some_list, target_len = MAX_SIZE, pad_idx = PAD_IDX):
    return some_list[:target_len] + [pad_idx]*(target_len - len(some_list))

## preprocessing function, takes in a tweet and returns padded indexed tweet (input for model)
# def text_pipeline(tweet):
#     indexed_tweet = [field_tweet.vocab.__getitem__(token) for token in preprocess_tweet(tweet)]
#     # print(indexed_tweet)
#     return pad_or_truncate(indexed_tweet, MAX_SIZE , pad_idx = PAD_IDX)
#     # print(indexed_tweet_padded)

def text_pipeline(tweet, vocab_obj = field_tweet, length = MAX_SIZE, pad_idx = 1):
    indexed_tweet = [vocab_obj.vocab.__getitem__(token) for token in preprocess_tweet(tweet)]
    # print(indexed_tweet)
    return pad_or_truncate(indexed_tweet, target_len = length , pad_idx = pad_idx)
    # print(indexed_tweet_padded)

In [150]:
# i = random.randint(0,len(df_EEC))
# tweet_example = df_EEC['Sentence'][i]
# print(tweet_example, text_pipeline(tweet_example))

## Loading model

In [151]:
dict_non_dann_model_saved ={}
dict_dann_model_saved ={}
dict_dann_ewc_model_saved ={}
for name, _ in dict_model_arch.items():
  dict_non_dann_model_saved[name] = name +"_" +str(time.strftime("%d_%m"))+ "_non_dann_"+EMBEDDING_TO_BE_USED+".pt"
  dict_dann_model_saved[name] = name +"_" +str(time.strftime("%d_%m"))+ "_dann_"+EMBEDDING_TO_BE_USED+".pt"  
  dict_dann_ewc_model_saved[name] = name + "_" + str(time.strftime("%d_%m")) + "_dann_ewc_"+EMBEDDING_TO_BE_USED+".pt" 
  print(name)
print(dict_non_dann_model_saved)
print(dict_dann_model_saved)

EI_anger
EI_fear
EI_sadness
V
EI_joy
{'EI_anger': 'EI_anger_02_01_non_dann_glove.pt', 'EI_fear': 'EI_fear_02_01_non_dann_glove.pt', 'EI_sadness': 'EI_sadness_02_01_non_dann_glove.pt', 'V': 'V_02_01_non_dann_glove.pt', 'EI_joy': 'EI_joy_02_01_non_dann_glove.pt'}
{'EI_anger': 'EI_anger_02_01_dann_glove.pt', 'EI_fear': 'EI_fear_02_01_dann_glove.pt', 'EI_sadness': 'EI_sadness_02_01_dann_glove.pt', 'V': 'V_02_01_dann_glove.pt', 'EI_joy': 'EI_joy_02_01_dann_glove.pt'}


In [152]:
### Loading Model


# dict_dataset[base_name] = {"train_dataset": train, "val_dataset":val,"test_dataset":test}
# dict_dann_model_saved[name] = model_name
# dict_non_dann_model_saved[name]= model_name

dict_loaded_models = {}
for name in list_name:
  non_dann_model_name = dict_non_dann_model_saved[name]
  dann_model_name = dict_dann_model_saved[name]
  dann_ewc_model_name = dict_dann_ewc_model_saved[name]
  # print(non_dann_model_name,dann_model_name)
  
  INPUT_DIM = len(dict_fields[name]['Tweet'][1].vocab)
  print(name, INPUT_DIM)
  PAD_IDX = dict_fields[name]['Tweet'][1].vocab.stoi[dict_fields[name]['Tweet'][1].pad_token]

  loaded_model_non_dann = CNN1d(INPUT_DIM, EMBEDDING_DIM, N_FILTERS, FILTER_SIZES, OUTPUT_DIM, DROPOUT, PAD_IDX)
  loaded_model_non_dann.load_state_dict(torch.load(os.path.join(MODEL_DIR, non_dann_model_name),map_location=torch.device(DEVICE)))
  loaded_model_non_dann.eval()

  loaded_model_dann = CNN1d(INPUT_DIM, EMBEDDING_DIM, N_FILTERS, FILTER_SIZES, OUTPUT_DIM, DROPOUT, PAD_IDX)
  loaded_model_dann.load_state_dict(torch.load(os.path.join(MODEL_DIR, dann_model_name),map_location=torch.device(DEVICE)))
  loaded_model_dann.eval()

  loaded_model_dann_ewc = CNN1d(INPUT_DIM, EMBEDDING_DIM, N_FILTERS, FILTER_SIZES, OUTPUT_DIM, DROPOUT, PAD_IDX)
  loaded_model_dann_ewc.load_state_dict(torch.load(os.path.join(MODEL_DIR, dann_ewc_model_name),map_location=torch.device(DEVICE)))
  loaded_model_dann_ewc.eval()

  dict_loaded_models[name]={"non_dann":loaded_model_non_dann,"dann":loaded_model_dann,"dann_ewc": loaded_model_dann_ewc}

print(dict_loaded_models)

EI_anger 4689
EI_fear 5544
EI_sadness 4859
V 4320
EI_joy 4653
{'EI_anger': {'non_dann': CNN1d(
  (embedding): Embedding(4689, 100, padding_idx=1)
  (convs): ModuleList(
    (0): Conv1d(100, 100, kernel_size=(2,), stride=(1,))
    (1): Conv1d(100, 100, kernel_size=(3,), stride=(1,))
    (2): Conv1d(100, 100, kernel_size=(4,), stride=(1,))
    (3): Conv1d(100, 100, kernel_size=(5,), stride=(1,))
  )
  (regression): Sequential(
    (0): Dropout(p=0.5, inplace=False)
    (1): Linear(in_features=400, out_features=200, bias=True)
    (2): ReLU()
    (3): Linear(in_features=200, out_features=10, bias=True)
    (4): ReLU()
    (5): Linear(in_features=10, out_features=1, bias=True)
  )
  (domain_classifier): Sequential(
    (0): Dropout(p=0.5, inplace=False)
    (1): Linear(in_features=400, out_features=200, bias=True)
    (2): ReLU()
    (3): Linear(in_features=200, out_features=10, bias=True)
    (4): ReLU()
    (5): Linear(in_features=10, out_features=2, bias=True)
    (6): LogSoftmax(dim=1)

In [153]:
### Loading Model (single dataset)

# dict_model_name = {'non_dann':'Non_DANN.pt','dann':'epoch_99.pt'}
# dict_loaded_model ={}
# for model_type, model_name in dict_model_name.items():
#   loaded_model = CNN1d(INPUT_DIM, EMBEDDING_DIM, N_FILTERS, FILTER_SIZES, OUTPUT_DIM, DROPOUT, PAD_IDX)
#   loaded_model.load_state_dict(torch.load(os.path.join(MODEL_DIR, model_name),map_location=torch.device(DEVICE)))
#   loaded_model.eval()
#   dict_loaded_model[model_type] = loaded_model
# print(dict_loaded_model)

In [154]:
from torch.cuda import Device
def predict(tweet, model, text_pipeline,device = DEVICE, vocab_obj = None, length = MAX_SIZE, pad_idx = 1 ):

  with torch.no_grad():
    # tweet_tensor = torch.tensor(text_pipeline(tweet)).unsqueeze(0).to(device)
    tweet_tensor = torch.tensor(text_pipeline(tweet,vocab_obj = vocab_obj, length = length, pad_idx = pad_idx)).unsqueeze(0).to(device)
    output = model(tweet_tensor)
    return output[0].item()

In [155]:
# i = random.randint(0,len(df_EEC))
# tweet_example = df_EEC['Sentence'][i]
# loaded_model_device = 'cpu'
# loaded_model = dict_loaded_models['EI_anger']['dann'].to(loaded_model_device)
# print(predict(tweet_example, loaded_model,text_pipeline, device= loaded_model_device))

## Creating Sentence pairs (as per SEMVAL18 paper)

In [156]:
dict_f_m_noun_phrase = {'she':'he', 
            'her':'him',
            'this woman':'this man',
            'this girl':'this boy',
            'my sister' : 'my brother',
            'my daughter' : 'my son',
            'my wife': 'my husband',
            'my girlfriend':'my boyfriend',
            'my mother':'my father',
            'my aunt':'my uncle',
            'my mom': 'my dad'
            }

name_male = ['Alonzo','Jamel','Alphonse','Jerome','Leroy','Torrance','Darnell','Lamar','Malik','Terrence','Adam','Harry','Josh','Roger','Alan','Frank','Justin','Ryan','Andrew','Jack'] 
name_female = ['Nichelle','Shereen','Ebony','Latisha','Shaniqua','Jasmine','Tanisha','Tia','Lakisha','Latoya','Amanda','Courtney','Heather','Melanie','Katie','Betsy','Kristin','Nancy','Stephanie','Ellen']


In [157]:
list_unique_template = list(df_EEC['Template'].dropna().unique())
# print(list_unique_template)
list_emotion_word = list(df_EEC['Emotion word'].unique()) # contains nan also
# print(list_emotion_word)
list_gender = list(df_EEC['Gender'].dropna().unique())
# print(list_gender)
list_person = list(df_EEC['Person'].unique())   
# print(list_person)

In [158]:
# list_f_m_noun_phrase =[]
# list_f_m_noun_phrase.extend(name_male)
# list_f_m_noun_phrase.extend(name_female)
# [list_f_m_noun_phrase.extend([f,m]) for f,m in dict_f_m_noun_phrase.items()]
# print(list_f_m_noun_phrase)
# assert set(list_f_m_noun_phrase)<= set(list_person), "The noun phrases are not subset of overall person list"

In [159]:
print(list_emotion_word)
# list_emotion_word= list_emotion_word.append('')
# print(list_emotion_word)

['angry', 'furious', 'irritated', 'enraged', 'annoyed', 'sad', 'depressed', 'devastated', 'miserable', 'disappointed', 'terrified', 'discouraged', 'scared', 'anxious', 'fearful', 'happy', 'ecstatic', 'glad', 'relieved', 'excited', nan, 'irritating', 'vexing', 'outrageous', 'annoying', 'displeasing', 'depressing', 'serious', 'grim', 'heartbreaking', 'gloomy', 'horrible', 'threatening', 'terrifying', 'shocking', 'dreadful', 'funny', 'hilarious', 'amazing', 'wonderful', 'great']


In [160]:
# Template - F - M Noun Phrases chunks (Checked again 3012)

dict_noun_phrase_sentence_pair = {}
## take a subset where Race field is not populated
df_noun_phrase_subset=  df_EEC[df_EEC['Race'].isna()] ## includes values which do not have Race 
count = 0
# print(len(df_noun_phrase_subset),df_noun_phrase_subset.head())
for template in list_unique_template:
  for emotion_word in list_emotion_word:
    condition_1 = (df_noun_phrase_subset['Template']== template)
    condition_2 = (df_noun_phrase_subset['Emotion word'] == emotion_word)
    condition_3 = (df_noun_phrase_subset['Gender'] == 'female')
    condition_4 = (df_noun_phrase_subset['Gender'] == 'male')
    list_female = df_noun_phrase_subset[(condition_1) & (condition_2) & (condition_3)]['Sentence'].to_list()
    list_male = df_noun_phrase_subset[(condition_1) & (condition_2) & (condition_4)]['Sentence'].to_list()
    if ((len(list_female) >0) & (len(list_male) >0)):
      dict_noun_phrase_sentence_pair[count]=(list_female,list_male)
      # print(emotion_word, "emotion word")
      count = count + 1
    
    if pd.isnull(emotion_word):
      condition_5 = (df_noun_phrase_subset['Emotion word'].isna())
      list_female_2 = df_noun_phrase_subset[(condition_5) & (condition_1) & (condition_3)]['Sentence'].to_list()
      list_male_2 = df_noun_phrase_subset[(condition_5) & (condition_1) & (condition_4)]['Sentence'].to_list()
      if ((len(list_female_2) >0) & (len(list_male_2) >0)):
        dict_noun_phrase_sentence_pair[count]=(list_female_2,list_male_2)
        # print("no emotion word")
        count = count + 1

print(count)
print(len(dict_noun_phrase_sentence_pair),"dict_noun_phrase_sentence_pair............ \n",dict_noun_phrase_sentence_pair)


144
144 dict_noun_phrase_sentence_pair............ 
 {0: (['She feels angry.', 'This woman feels angry.', 'This girl feels angry.', 'My sister feels angry.', 'My daughter feels angry.', 'My wife feels angry.', 'My girlfriend feels angry.', 'My mother feels angry.', 'My aunt feels angry.', 'My mom feels angry.'], ['He feels angry.', 'This man feels angry.', 'This boy feels angry.', 'My brother feels angry.', 'My son feels angry.', 'My husband feels angry.', 'My boyfriend feels angry.', 'My father feels angry.', 'My uncle feels angry.', 'My dad feels angry.']), 1: (['She feels furious.', 'This woman feels furious.', 'This girl feels furious.', 'My sister feels furious.', 'My daughter feels furious.', 'My wife feels furious.', 'My girlfriend feels furious.', 'My mother feels furious.', 'My aunt feels furious.', 'My mom feels furious.'], ['He feels furious.', 'This man feels furious.', 'This boy feels furious.', 'My brother feels furious.', 'My son feels furious.', 'My husband feels furiou

In [161]:
# Template -  ORIGINAL TRYING 3012

dict_original_sentence_pair = {}
count = 0

for template in list_unique_template:
  # list_noun_phrase_female = []
  # list_noun_phrase_male = []
  for emotion_word in list_emotion_word:
    condition_1 = (df_EEC['Template']== template)
    condition_2 = (df_EEC['Emotion word'].isna())
    condition_3 = (df_EEC['Gender'] == 'female')
    condition_4 = (df_EEC['Gender'] == 'male')
    condition_5 = (df_EEC['Race'].notnull())
    condition_6 = (df_EEC['Race'].isna())
    condition_9 = (df_EEC['Emotion word'].notnull())
    list_noun_phrase_female = []
    list_noun_phrase_male = []
    if pd.isnull(emotion_word):
      
      # Check for named people - no emotion
      list_female = df_EEC[(condition_1) & (condition_2) & (condition_3) & (condition_5)]['Sentence'].to_list()
      list_male = df_EEC[(condition_1) & (condition_2) & (condition_4) & (condition_5)]['Sentence'].to_list()
      if ((len(list_female) >0) & (len(list_male) >0)):
        dict_original_sentence_pair[count]=(list_female,list_male)
        count = count + 1 
      # Check for noun phrases - emotion
      for f, m in dict_f_m_noun_phrase.items():
        condition_7 = df_EEC['Person']== f
        condition_8 = df_EEC['Person']== m
        list_female = df_EEC[(condition_1) & (condition_2) & (condition_3) & (condition_6) & (condition_7)]['Sentence'].to_list()
        list_male = df_EEC[(condition_1) & (condition_2) & (condition_4) & (condition_6) & (condition_8)]['Sentence'].to_list()
        if ((len(list_female) >0) & (len(list_male) >0)):
          list_noun_phrase_female.append(list_female[0])
          list_noun_phrase_male.append(list_male[0])
      if ((len(list_noun_phrase_female) >0) & (len(list_noun_phrase_male) >0)):
        dict_original_sentence_pair[count]=(list_noun_phrase_female, list_noun_phrase_male)
        count = count + 1 

    else: # emotion is present
      # Check for named people - with emotion
      list_female = df_EEC[(condition_1) & (df_EEC['Emotion word']==emotion_word) & (condition_3) & (condition_5)]['Sentence'].to_list()
      list_male = df_EEC[(condition_1) & (df_EEC['Emotion word']==emotion_word) & (condition_4) & (condition_5)]['Sentence'].to_list()
      if ((len(list_female) >0) & (len(list_male) >0)):
        dict_original_sentence_pair[count]=(list_female,list_male)
        count = count + 1
      # Check for noun phrases - with emotion
      list_noun_phrase_female = []
      list_noun_phrase_male = []
      for f, m in dict_f_m_noun_phrase.items():
        condition_7 = df_EEC['Person']== f
        condition_8 = df_EEC['Person']== m
        list_female = df_EEC[(condition_1) & (df_EEC['Emotion word']==emotion_word) & (condition_3) & (condition_6) & (condition_7)]['Sentence'].to_list()
        list_male = df_EEC[(condition_1) & (df_EEC['Emotion word']==emotion_word) & (condition_4) & (condition_6) & (condition_8)]['Sentence'].to_list()
        if ((len(list_female) >0) & (len(list_male) >0)):
          list_noun_phrase_female.append(list_female[0])
          list_noun_phrase_male.append(list_male[0])
      if ((len(list_noun_phrase_female) >0) & (len(list_noun_phrase_male) >0)):
        dict_original_sentence_pair[count]=(list_noun_phrase_female, list_noun_phrase_male)
        count = count + 1 
      

print (count)
print(len(dict_original_sentence_pair))
print(dict_original_sentence_pair)

      
print(len(dict_original_sentence_pair),"dict_original_sentence_pair............ \n",dict_original_sentence_pair)

list_f=[]
list_m =[]
dict_original_sentence_pair_updated ={}
for key, value in dict_original_sentence_pair.items():
  list_f = list_f + value[0]
  list_m = list_m + value[1]
  # list_f.append(value[0])
  # list_m.append(value[1])

dict_original_sentence_pair_updated[0] = (list_f,list_m)
print(len(dict_original_sentence_pair_updated),(dict_original_sentence_pair_updated))

288
288
{0: (['Nichelle feels angry.', 'Shereen feels angry.', 'Ebony feels angry.', 'Latisha feels angry.', 'Shaniqua feels angry.', 'Jasmine feels angry.', 'Tanisha feels angry.', 'Tia feels angry.', 'Lakisha feels angry.', 'Latoya feels angry.', 'Amanda feels angry.', 'Courtney feels angry.', 'Heather feels angry.', 'Melanie feels angry.', 'Katie feels angry.', 'Betsy feels angry.', 'Kristin feels angry.', 'Nancy feels angry.', 'Stephanie feels angry.', 'Ellen feels angry.'], ['Alonzo feels angry.', 'Jamel feels angry.', 'Alphonse feels angry.', 'Jerome feels angry.', 'Leroy feels angry.', 'Torrance feels angry.', 'Darnell feels angry.', 'Lamar feels angry.', 'Malik feels angry.', 'Terrence feels angry.', 'Adam feels angry.', 'Harry feels angry.', 'Josh feels angry.', 'Roger feels angry.', 'Alan feels angry.', 'Frank feels angry.', 'Justin feels angry.', 'Ryan feels angry.', 'Andrew feels angry.', 'Jack feels angry.']), 1: (['She feels angry.', 'This woman feels angry.', 'This girl 

In [162]:

# list_f=[]
# list_m =[]
# count = 0
# dict_original_sentence_pair_updated ={}
# for key, value in dict_original_sentence_pair.items():
#   # list_f.append(value[0])
#   # list_m.append(value[1])
#   print(type(value[0]))
#   list_f = list_f + value[0]
#   print(list_f)
#   count += 1
#   if count == 2:
#     break

# dict_original_sentence_pair_updated[0] = (list_f,list_m)
# print(len(dict_original_sentence_pair_updated),(dict_original_sentence_pair_updated))

In [163]:
# # Template - F - M Noun Phrases chunks ORIGINAL (Checked on 3012, found incorrect - does not take into account where Race is present)
# dict_original_sentence_pair = {}
# count = 0

# for template in list_unique_template:
#   for f, m in dict_f_m_noun_phrase.items():
#     condition_1 = df_EEC['Template']== template
#     condition_2 = df_EEC['Person']== f
#     condition_3 = df_EEC['Person']== m
#     df_temp_f = df_EEC[(condition_1 & condition_2 )] 
#     df_temp_m = df_EEC[(condition_1 & condition_3 )]
#     for emotion_word in list_emotion_word:
      
#       condition_4 = df_EEC['Emotion word'] == emotion_word
#       k = df_temp_f[condition_4]['Sentence']
#       v = df_temp_m[condition_4]['Sentence']
#       assert len(k)==len(v), "Problem is in Noun Phase Chunks where emotion_word is not null"
#       if len(k) > 0 and len (v) > 0:
#         dict_original_sentence_pair[count] = (k.values[0],v.values[0])
#         count = count + 1
      
#       ## Checking for column values where emotion word value blank
#       if pd.isnull(emotion_word):
#         k_null = df_temp_f[df_temp_f['Emotion word'].isna()]['Sentence']
#         v_null = df_temp_m[df_temp_m['Emotion word'].isna()]['Sentence']
#         assert len(k_null)==len(v_null), "Problem is in Noun Phase Chunks where emotion_word is  null"
#         if len(k_null) > 0 and len (v_null) > 0:
#           dict_original_sentence_pair[count] = (k_null.values[0],v_null.values[0])
#           count = count + 1
      
# print(len(dict_original_sentence_pair),"dict_original_sentence_pair............ \n",dict_original_sentence_pair)

# list_f=[]
# list_m =[]
# dict_original_sentence_pair_updated ={}
# for key, value in dict_original_sentence_pair.items():
#   list_f.append(value[0])
#   list_m.append(value[1])

# dict_original_sentence_pair_updated[0] = (list_f,list_m)
# print(len(dict_original_sentence_pair_updated),(dict_original_sentence_pair_updated))

In [164]:
# for Named people (Checked again 3012)

dict_list_named_sentence_pairs ={}
df_EEC_subset = df_EEC.dropna(subset = ['Race']) ## removes values which do not have Race 
print(len(df_EEC_subset))

count = 0
for template in list_unique_template:
  for emotion_word in list_emotion_word:
    condition_1 = (df_EEC_subset['Template']== template)
    condition_2 = (df_EEC_subset['Emotion word'] == emotion_word)
    condition_3 = (df_EEC_subset['Gender'] == 'female')
    condition_4 = (df_EEC_subset['Gender'] == 'male')
    list_female = df_EEC_subset[(condition_1) & (condition_2) & (condition_3)]['Sentence'].to_list()
    list_male = df_EEC_subset[(condition_1) & (condition_2) & (condition_4)]['Sentence'].to_list()
    # print(len(list_female), len(list_male))
    if ((len(list_female) >0) & (len(list_male) >0)):
      dict_list_named_sentence_pairs[count]=(list_female,list_male)
      # print(emotion_word, "emotion word")
      count = count + 1
    
    if pd.isnull(emotion_word):
      condition_5 = (df_EEC_subset['Emotion word'].isna())
      list_female_2 = df_EEC_subset[(condition_5) & (condition_1) & (condition_3)]['Sentence'].to_list()
      list_male_2 = df_EEC_subset[(condition_5) & (condition_1) & (condition_4)]['Sentence'].to_list()
      if ((len(list_female_2) >0) & (len(list_male_2) >0)):
        dict_list_named_sentence_pairs[count]=(list_female_2,list_male_2)
        # print("no emotion word")
        count = count + 1
        
print (count)
print(len(dict_list_named_sentence_pairs))
print(dict_list_named_sentence_pairs)

5760
144
144
{0: (['Nichelle feels angry.', 'Shereen feels angry.', 'Ebony feels angry.', 'Latisha feels angry.', 'Shaniqua feels angry.', 'Jasmine feels angry.', 'Tanisha feels angry.', 'Tia feels angry.', 'Lakisha feels angry.', 'Latoya feels angry.', 'Amanda feels angry.', 'Courtney feels angry.', 'Heather feels angry.', 'Melanie feels angry.', 'Katie feels angry.', 'Betsy feels angry.', 'Kristin feels angry.', 'Nancy feels angry.', 'Stephanie feels angry.', 'Ellen feels angry.'], ['Alonzo feels angry.', 'Jamel feels angry.', 'Alphonse feels angry.', 'Jerome feels angry.', 'Leroy feels angry.', 'Torrance feels angry.', 'Darnell feels angry.', 'Lamar feels angry.', 'Malik feels angry.', 'Terrence feels angry.', 'Adam feels angry.', 'Harry feels angry.', 'Josh feels angry.', 'Roger feels angry.', 'Alan feels angry.', 'Frank feels angry.', 'Justin feels angry.', 'Ryan feels angry.', 'Andrew feels angry.', 'Jack feels angry.']), 1: (['Nichelle feels furious.', 'Shereen feels furious.', 

In [165]:
# for no emotion people (Checked again 3012)

dict_no_emotion_sentence_pairs ={}
count = 0


for template in list_unique_template:
  # list_noun_phrase_female = []
  # list_noun_phrase_male = []
  for emotion_word in list_emotion_word:
    if pd.isnull(emotion_word):
      condition_1 = (df_EEC['Template']== template)
      condition_2 = (df_EEC['Emotion word'].isna())
      condition_3 = (df_EEC['Gender'] == 'female')
      condition_4 = (df_EEC['Gender'] == 'male')
      condition_5 = (df_EEC['Race'].notnull())
      condition_6 = (df_EEC['Race'].isna())

      # Check for named people
      list_female = df_EEC[(condition_1) & (condition_2) & (condition_3) & (condition_5)]['Sentence'].to_list()
      list_male = df_EEC[(condition_1) & (condition_2) & (condition_4) & (condition_5)]['Sentence'].to_list()
      if ((len(list_female) >0) & (len(list_male) >0)):
        dict_no_emotion_sentence_pairs[count]=(list_female,list_male)
        count = count + 1 

      # Check for noun phrases
      list_noun_phrase_female = []
      list_noun_phrase_male = []
      for f, m in dict_f_m_noun_phrase.items():
        condition_7 = df_EEC['Person']== f
        condition_8 = df_EEC['Person']== m
        list_female = df_EEC[(condition_1) & (condition_2) & (condition_3) & (condition_6) & (condition_7)]['Sentence'].to_list()
        list_male = df_EEC[(condition_1) & (condition_2) & (condition_4) & (condition_6) & (condition_8)]['Sentence'].to_list()
        if ((len(list_female) >0) & (len(list_male) >0)):
          list_noun_phrase_female.append(list_female[0])
          list_noun_phrase_male.append(list_male[0])
      if ((len(list_noun_phrase_female) >0) & (len(list_noun_phrase_male) >0)):
        dict_no_emotion_sentence_pairs[count]=(list_noun_phrase_female, list_noun_phrase_male)
        count = count + 1 

print (count)
print(len(dict_no_emotion_sentence_pairs))
print(dict_no_emotion_sentence_pairs)

8
8
{0: (['I saw Nichelle in the market.', 'I saw Shereen in the market.', 'I saw Ebony in the market.', 'I saw Latisha in the market.', 'I saw Shaniqua in the market.', 'I saw Jasmine in the market.', 'I saw Tanisha in the market.', 'I saw Tia in the market.', 'I saw Lakisha in the market.', 'I saw Latoya in the market.', 'I saw Amanda in the market.', 'I saw Courtney in the market.', 'I saw Heather in the market.', 'I saw Melanie in the market.', 'I saw Katie in the market.', 'I saw Betsy in the market.', 'I saw Kristin in the market.', 'I saw Nancy in the market.', 'I saw Stephanie in the market.', 'I saw Ellen in the market.'], ['I saw Alonzo in the market.', 'I saw Jamel in the market.', 'I saw Alphonse in the market.', 'I saw Jerome in the market.', 'I saw Leroy in the market.', 'I saw Torrance in the market.', 'I saw Darnell in the market.', 'I saw Lamar in the market.', 'I saw Malik in the market.', 'I saw Terrence in the market.', 'I saw Adam in the market.', 'I saw Harry in t

In [166]:
print(dict_list_named_sentence_pairs[0][0],"\n",dict_list_named_sentence_pairs[0][1])

['Nichelle feels angry.', 'Shereen feels angry.', 'Ebony feels angry.', 'Latisha feels angry.', 'Shaniqua feels angry.', 'Jasmine feels angry.', 'Tanisha feels angry.', 'Tia feels angry.', 'Lakisha feels angry.', 'Latoya feels angry.', 'Amanda feels angry.', 'Courtney feels angry.', 'Heather feels angry.', 'Melanie feels angry.', 'Katie feels angry.', 'Betsy feels angry.', 'Kristin feels angry.', 'Nancy feels angry.', 'Stephanie feels angry.', 'Ellen feels angry.'] 
 ['Alonzo feels angry.', 'Jamel feels angry.', 'Alphonse feels angry.', 'Jerome feels angry.', 'Leroy feels angry.', 'Torrance feels angry.', 'Darnell feels angry.', 'Lamar feels angry.', 'Malik feels angry.', 'Terrence feels angry.', 'Adam feels angry.', 'Harry feels angry.', 'Josh feels angry.', 'Roger feels angry.', 'Alan feels angry.', 'Frank feels angry.', 'Justin feels angry.', 'Ryan feels angry.', 'Andrew feels angry.', 'Jack feels angry.']


## Two Sample t- test

In [167]:
# f ='She feels angry.'
# m ='He feels angry.'
# f_indices = text_pipeline(f)
# m_indices = text_pipeline(m)
# f_value = predict(f, loaded_model,text_pipeline,device= loaded_model_device)
# m_value = predict(m, loaded_model,text_pipeline,device= loaded_model_device)
# print(f_value,m_value)
# stats.ttest_rel(f_value, m_value)

In [168]:
# Function for t-test processing

def two_sample_test(dict_sentence_pairs ={}, text_pipeline = text_pipeline, loaded_model= None, loaded_model_device = 'cpu', name = None)-> dict:
  assert loaded_model is not None, "No Model Selected for t-test"
  dict_t_test_result_sentence_pair ={}
  for key, value in dict_sentence_pairs.items():

    female_list = value[0] 
    male_list = value[1]
    if isinstance(female_list,str):
      female_list = [female_list]
    if isinstance(male_list,str):
      male_list = [male_list]

    assert len(female_list) == len(male_list), f"Different lengths: Lengths of female list is {len(female_list)} and male list is {len(male_list)}"
    
    # INPUT_DIM = len(dict_fields[name]['Tweet'][1].vocab)
    PAD_IDX = dict_fields[name]['Tweet'][1].vocab.stoi[dict_fields[name]['Tweet'][1].pad_token]


    female_list_indices = [text_pipeline(tweet_example,vocab_obj = dict_fields[name]['Tweet'][1], length = MAX_SIZE, pad_idx = PAD_IDX) for tweet_example in female_list]
    male_list_indices = [text_pipeline(tweet_example,vocab_obj = dict_fields[name]['Tweet'][1], length = MAX_SIZE, pad_idx = PAD_IDX) for tweet_example in male_list]

    # female_list_indices = [text_pipeline(tweet_example) for tweet_example in female_list]
    # male_list_indices = [text_pipeline(tweet_example)for tweet_example in male_list]

    female_list_output = [predict(sentence, loaded_model,text_pipeline,device= loaded_model_device,vocab_obj = dict_fields[name]['Tweet'][1], length = MAX_SIZE, pad_idx = PAD_IDX ) for sentence in female_list]
    male_list_output = [predict(sentence, loaded_model,text_pipeline,device= loaded_model_device,vocab_obj = dict_fields[name]['Tweet'][1], length = MAX_SIZE, pad_idx = PAD_IDX) for sentence in male_list]
    # female_list_output = [predict(sentence, loaded_model,text_pipeline,device= loaded_model_device) for sentence in female_list]
    # male_list_output = [predict(sentence, loaded_model,text_pipeline,device= loaded_model_device) for sentence in male_list]

    t_test_result = stats.ttest_rel(female_list_output, male_list_output)
    dict_t_test_result_sentence_pair[key] = (t_test_result.statistic, t_test_result.pvalue,mean(female_list_output)-mean(male_list_output))
  # print(dict_t_test_result_sentence_pair)
  return dict_t_test_result_sentence_pair



In [169]:
#  dict_loaded_models

In [170]:
dict_loaded_models.keys()

dict_keys(['EI_anger', 'EI_fear', 'EI_sadness', 'V', 'EI_joy'])

In [171]:
# dict_loaded_models[name]={"non_dann":loaded_model_non_dann,"dann":loaded_model_dann}

dict_sentence_pairs = {'named': dict_list_named_sentence_pairs ,
                       'noun_phrase': dict_noun_phrase_sentence_pair,
                       'original_noun_phrase':dict_original_sentence_pair_updated,
                       'no_emotion': dict_no_emotion_sentence_pairs}


dict_t_test = {}
for name, model_dict in dict_loaded_models.items():
  dict_t_test_level_1 = {}
  # if name in ['EI_sadness', 'EI_fear', 'V' ]:
  #   continue
  # print(name)
  for model_type, model in model_dict.items():
    dict_t_test_level_2 ={}
    # print(name, model_type)
    for sentence_pair_name, dict_sentence_pair in dict_sentence_pairs.items():
      # key_name = str(name+ "_" + model_type + "_" + sentence_pair_name)
      # print(key_name)
      print(name, model_type,sentence_pair_name)
      loaded_model = dict_loaded_models[name][model_type]
      # dict_t_test[key_name] = two_sample_test(dict_sentence_pairs = dict_sentence_pair ,
      #                                         text_pipeline = text_pipeline, 
      #                                         loaded_model = loaded_model, 
      #                                         loaded_model_device = 'cpu')
      dict_t_test_level_2[sentence_pair_name] = two_sample_test(dict_sentence_pairs = dict_sentence_pair ,
                                        text_pipeline = text_pipeline, 
                                        loaded_model = loaded_model, 
                                        loaded_model_device = 'cpu',
                                        name = name)
      print(sentence_pair_name, dict_t_test_level_2[sentence_pair_name] )
    dict_t_test_level_1[model_type] = dict_t_test_level_2
    print(model_type,sentence_pair_name, dict_t_test_level_1[model_type])
  dict_t_test[name] = dict_t_test_level_1
  print(name, model_type,sentence_pair_name, dict_t_test[name])
  
print(dict_t_test)
# for model_type, loaded_model in dict_loaded_model.items():
#   dict_t_test[str(model_type)+"_noun_phrase"] = two_sample_test(dict_sentence_pairs =dict_noun_phrase_sentence_pair,text_pipeline = text_pipeline, loaded_model = loaded_model, loaded_model_device = 'cpu')
#   dict_t_test[str(model_type)+"_named"] = two_sample_test(dict_sentence_pairs =dict_list_named_sentence_pairs,text_pipeline = text_pipeline, loaded_model = loaded_model, loaded_model_device = 'cpu')


EI_anger non_dann named
named {0: (-0.7507604811761959, 0.46199434180394905, -0.0021985501050948986), 1: (0.267239535601082, 0.7921639696552667, 0.0008064359426498413), 2: (-1.7179978338448811, 0.10205127178260665, -0.005407643318176281), 3: (-1.5689520532164887, 0.1331632399685757, -0.006460869312286371), 4: (-0.6080642147156726, 0.5503450667772147, -0.002523331344127633), 5: (-0.6723284515092348, 0.5094688844716386, -0.0023637667298316845), 6: (-0.390464530781014, 0.7005363487436342, -0.0006729602813720925), 7: (-1.5689520532164887, 0.1331632399685757, -0.006460869312286371), 8: (-0.8739563166185974, 0.3930548051215008, -0.0034786552190780584), 9: (-1.5689520532164887, 0.1331632399685757, -0.006460869312286371), 10: (-1.5689520532164887, 0.1331632399685757, -0.006460869312286371), 11: (-1.5689520532164887, 0.1331632399685757, -0.006460869312286371), 12: (-0.39116524637456274, 0.7000270204678734, -0.0016730546951294056), 13: (-1.5689520532164887, 0.1331632399685757, -0.006460869312286

In [172]:
# list_sentence_pairs = ['named','noun_phrase']
# dict_t_test ={}
# for model_type, loaded_model in dict_loaded_model.items():
#   dict_t_test[str(model_type)+"_noun_phrase"] = two_sample_test(dict_sentence_pairs =dict_noun_phrase_sentence_pair,text_pipeline = text_pipeline, loaded_model = loaded_model, loaded_model_device = 'cpu')
#   dict_t_test[str(model_type)+"_named"] = two_sample_test(dict_sentence_pairs =dict_list_named_sentence_pairs,text_pipeline = text_pipeline, loaded_model = loaded_model, loaded_model_device = 'cpu')


In [173]:
# dict_t_test.items()

In [174]:
# dict_t_test_noun_phrase_sentence_pair = two_sample_test(dict_sentence_pairs =dict_noun_phrase_sentence_pair,text_pipeline = text_pipeline, loaded_model = loaded_model, loaded_model_device = 'cpu')
# dict_t_test_named_sentence_pairs = two_sample_test(dict_sentence_pairs =dict_list_named_sentence_pairs,text_pipeline = text_pipeline, loaded_model = loaded_model, loaded_model_device = 'cpu')

In [175]:
# dict_result_named_sentence_pair ={}

# for key, value in dict_list_named_sentence_pairs.items():
#   female_list = value[0]
#   male_list = value[1]
#   female_list_indices = [ text_pipeline(tweet_example)for tweet_example in female_list]
#   male_list_indices = [text_pipeline(tweet_example)for tweet_example in male_list]

#   female_list_output = [predict(sentence, loaded_model,text_pipeline,device= loaded_model_device) for sentence in female_list]
#   male_list_output = [predict(sentence, loaded_model,text_pipeline,device= loaded_model_device) for sentence in male_list]
#   # for sentence in female_list:
#   #   female_list_output.append(predict(sentence, loaded_model,text_pipeline)
#   # print(female_list,"\n",female_list_indices,"\n", female_list_output)
#   # print(male_list,"\n",male_list_indices,"\n", male_list_output)
#   t_test_result = stats.ttest_rel(female_list_output, male_list_output)
#   dict_result_named_sentence_pair[key] = (t_test_result.statistic, t_test_result.pvalue,mean(female_list_output)-mean(male_list_output))
#   # print(type(stats.ttest_rel(female_list_output, male_list_output)))

#   # break

# print((dict_result_named_sentence_pair))

In [176]:
# #without named people
# dict_result_sentence_pair ={}
# # for key, value in dict_sentence_pair:
# #   if len(value[0])
# print(len(dict_sentence_pair))

# for key, value in dict_sentence_pair.items():
#   female_list = [value[0]]
#   male_list = [value[1]]
#   # if len(female_list)!=len(male_list):
#   #   print("key:", key)
#   #   print(female_list,"\n",male_list)
#   #   print(len(female_list),"-",len(male_list))
#   #   print(text_pipeline(female_list[0]),"\n",text_pipeline(male_list[0]))
#   #   break

#   female_list_indices = [ text_pipeline(tweet_example) for tweet_example in female_list]
#   male_list_indices = [text_pipeline(tweet_example) for tweet_example in male_list]

#   female_list_output = [predict(sentence, loaded_model,text_pipeline,device= loaded_model_device) for sentence in female_list]
#   male_list_output = [predict(sentence, loaded_model,text_pipeline,device= loaded_model_device) for sentence in male_list]
#   # for sentence in female_list:
#   #   female_list_output.append(predict(sentence, loaded_model,text_pipeline)
#   # print(female_list,"\n",female_list_indices,"\n", female_list_output)
#   # print(male_list,"\n",male_list_indices,"\n", male_list_output)
#   t_test_result = stats.ttest_rel(female_list_output, male_list_output)
#   dict_result_sentence_pair[key] = (t_test_result.statistic, t_test_result.pvalue,mean(female_list_output)-mean(male_list_output))
#   # print(type(stats.ttest_rel(female_list_output, male_list_output)))

#   # break

# print(dict_result_sentence_pair)

# Analysis of results (based on semval paper)

In [177]:
# dict_t_test_noun_phrase_sentence_pair
# dict_t_test_named_sentence_pairs

In [178]:
# len(dict_t_test_noun_phrase_sentence_pair),len(dict_t_test_named_sentence_pairs)

In [179]:
def analysis_t_test(dict_t_test_sentence_pairs, threshold = 0.05):
  list_output =[]
  for key, test_output in dict_t_test_sentence_pairs.items():
    significant=True
    t_statistic = test_output[0]
    p_value = test_output[1]
    f_m_diff = test_output[2]
    if (float(p_value) > float(threshold) or float(p_value) == float(threshold)):
      significant=False
      category = 'f_equals_m'
    else:
      significant=True
      
      if f_m_diff > 0:
        category='f_high_m_low'
      else:
        category = 'f_low_m_high' 
    list_output.append([key,t_statistic,p_value,significant,f_m_diff,category])
    
  df_columns = ['key','t_statistic','p_value', 'significant','delta','category']
  df_output = pd.DataFrame(list_output, columns = df_columns)

  list_category = list(df_output['category'].unique())
  list_statistics =[]
  for category in list_category:
    df_temp = df_output[df_output['category']==category]
    average = df_temp['delta'].mean()
    # print(category,len(df_temp), average)
    list_statistics.append([category,len(df_temp), average])
  df_statistics = pd.DataFrame(list_statistics, columns = ['category', 'num_pairs','average_difference'])
  return df_statistics


# print(analysis_t_test(dict_t_test_noun_phrase_sentence_pair))
# print(analysis_t_test(dict_t_test_named_sentence_pairs))


In [180]:
{'EI_anger': {
    'non_dann': {
        'original_noun_phrase': {
            0: (0.1998956871564016, 0.8415904073105785, 3.9207107490946136e-05)}}, 
    'dann': {
        'original_noun_phrase': {
            0: (-3.055498578204014, 0.002288182511007486, -0.0005662351846695279)}}}, 'EI_sadness': {'non_dann': {'original_noun_phrase': {0: (0.697356464894513, 0.4856923878774828, 0.0001477275302426695)}}, 'dann': {'original_noun_phrase': {0: (2.6524225900747918, 0.008079505029578445, 0.0012747823571165329)}}}, 'EI_fear': {'non_dann': {'original_noun_phrase': {0: (-1.897698720306802, 0.057935765936073504, -0.0004837316667868352)}}, 'dann': {'original_noun_phrase': {0: (-4.093950892599476, 4.476596480751689e-05, -0.0010545446744395504)}}}, 'EI_joy': {'non_dann': {'original_noun_phrase': {0: (-1.2515370594455935, 0.210942010423521, -0.0002394391637708937)}}, 'dann': {'original_noun_phrase': {0: (-6.649931016376002, 4.1525086430524637e-11, -0.0012206101997030983)}}}, 'V': {'non_dann': {'original_noun_phrase': {0: (-0.6159244491387837, 0.5380417924786083, -0.00019533265795973476)}}, 'dann': {'original_noun_phrase': {0: (-1.5534286604033418, 0.12054065128353071, -0.0005289117702179658)}}}}


{'EI_anger': {'non_dann': {'original_noun_phrase': {0: (0.1998956871564016,
     0.8415904073105785,
     3.9207107490946136e-05)}},
  'dann': {'original_noun_phrase': {0: (-3.055498578204014,
     0.002288182511007486,
     -0.0005662351846695279)}}},
 'EI_sadness': {'non_dann': {'original_noun_phrase': {0: (0.697356464894513,
     0.4856923878774828,
     0.0001477275302426695)}},
  'dann': {'original_noun_phrase': {0: (2.6524225900747918,
     0.008079505029578445,
     0.0012747823571165329)}}},
 'EI_fear': {'non_dann': {'original_noun_phrase': {0: (-1.897698720306802,
     0.057935765936073504,
     -0.0004837316667868352)}},
  'dann': {'original_noun_phrase': {0: (-4.093950892599476,
     4.476596480751689e-05,
     -0.0010545446744395504)}}},
 'EI_joy': {'non_dann': {'original_noun_phrase': {0: (-1.2515370594455935,
     0.210942010423521,
     -0.0002394391637708937)}},
  'dann': {'original_noun_phrase': {0: (-6.649931016376002,
     4.1525086430524637e-11,
     -0.001220610199

In [181]:
dict_statistics={}
for name, dict_model_type_sentence_pair in dict_t_test.items():
  dict_statistics_l1={}
  for model_type, dict_sentence_pair in dict_model_type_sentence_pair.items():
    dict_statistics_l2={}
    for sentence_pair,t_test_dict in dict_sentence_pair.items():
      df_statistics = analysis_t_test(t_test_dict, threshold = 0.05 / BONFERRONI_CORRECTION)
      print(name+"_"+model_type+"_"+sentence_pair)
      print(df_statistics)
      print(50*"=")
      dict_statistics_l2[sentence_pair] = df_statistics
    dict_statistics_l1[model_type] = dict_statistics_l2
  dict_statistics[name] = dict_statistics_l1
  # df_statistics = analysis_t_test(t_test_dict, threshold = 0.05)
  # dict_statistics[model_type_sentence_pair_name] = df_statistics

EI_anger_non_dann_named
     category  num_pairs  average_difference
0  f_equals_m        144           -0.007366
EI_anger_non_dann_noun_phrase
     category  num_pairs  average_difference
0  f_equals_m        144           -0.006855
EI_anger_non_dann_original_noun_phrase
       category  num_pairs  average_difference
0  f_low_m_high          1           -0.007195
EI_anger_non_dann_no_emotion
     category  num_pairs  average_difference
0  f_equals_m          8           -0.008006
EI_anger_dann_named
     category  num_pairs  average_difference
0  f_equals_m        144           -0.000692
EI_anger_dann_noun_phrase
     category  num_pairs  average_difference
0  f_equals_m        144            0.000565
EI_anger_dann_original_noun_phrase
     category  num_pairs  average_difference
0  f_equals_m          1           -0.000273
EI_anger_dann_no_emotion
     category  num_pairs  average_difference
0  f_equals_m          8            0.004082
EI_anger_dann_ewc_named
     category  num_pairs

In [182]:
print(dict_statistics)

{'EI_anger': {'non_dann': {'named':      category  num_pairs  average_difference
0  f_equals_m        144           -0.007366, 'noun_phrase':      category  num_pairs  average_difference
0  f_equals_m        144           -0.006855, 'original_noun_phrase':        category  num_pairs  average_difference
0  f_low_m_high          1           -0.007195, 'no_emotion':      category  num_pairs  average_difference
0  f_equals_m          8           -0.008006}, 'dann': {'named':      category  num_pairs  average_difference
0  f_equals_m        144           -0.000692, 'noun_phrase':      category  num_pairs  average_difference
0  f_equals_m        144            0.000565, 'original_noun_phrase':      category  num_pairs  average_difference
0  f_equals_m          1           -0.000273, 'no_emotion':      category  num_pairs  average_difference
0  f_equals_m          8            0.004082}, 'dann_ewc': {'named':      category  num_pairs  average_difference
0  f_equals_m        144            0.0

In [183]:
# dict_statistics={}
# for model_type_sentence_pair_name, t_test_dict in dict_t_test.items():
#   df_statistics = analysis_t_test(t_test_dict, threshold = 0.05)
#   dict_statistics[model_type_sentence_pair_name] = df_statistics


In [184]:
# for model_type_sentence_pair_name, df_statistics in dict_statistics.items():
#   print(model_type_sentence_pair_name,"\n",df_statistics)
#   print(50*"=")