In [1]:
import pandas as pd
import numpy as np
import torch
from config import basic_conf as conf
from libs import ModelManager as mm
from config.constants import HyperParamKey, LoadingKey

import logging

import matplotlib.pyplot as plt
from importlib import reload

%matplotlib inline

In [7]:
# in this example we init logger with level = INFO and see that the info logs get outputted
conf.init_logger(logging.WARNING, logfile=None)
logger = logging.getLogger('__main__')
mgr = mm.ModelManager(mode='notebook')

In [9]:
acc_dict = {}
for cur_genre in conf.GENRE_LIST:
    mgr.load_data(mm.loaderRegister.MNLI, genre=cur_genre)
    hparams={
        HyperParamKey.LR: 0.01,
        HyperParamKey.SCHEDULER_GAMMA: 0.09,
        HyperParamKey.RNN_HIDDEN_SIZE: 100,
        HyperParamKey.DROPOUT_FC: 0.25,
        HyperParamKey.DROPOUT_RNN: 0.25,
    }
    mgr.hparams.update(hparams)
    mgr.new_model(mm.modelRegister.NLIRNN, label=cur_genre)
    mgr.load_model(which_model=LoadingKey.LOAD_BEST)
    acc = mgr.model.eval_model(mgr.dataloader.loaders['val'])[0]
    print("intra genre acc for %s = %s" % (cur_genre, acc))
    
    for cross_genre in conf.GENRE_LIST:
        if cross_genre != cur_genre:
            print("loading data for %s ... " % cross_genre)
            mgr.load_data(mm.loaderRegister.MNLI, genre=cross_genre)
            acc = mgr.model.eval_model(mgr.dataloader.loaders['val'])[0]
            
            if cur_genre in acc_dict.keys():
                acc_dict[cur_genre][cross_genre] = acc
            else:
                acc_dict[cur_genre] = {cross_genre: acc}
acc_dict

intra genre acc for fiction = 58.81168177240685
loading data for telephone ... 
loading data for slate ... 
loading data for government ... 
loading data for travel ... 
intra genre acc for telephone = 57.91044776119403
loading data for fiction ... 
loading data for slate ... 
loading data for government ... 
loading data for travel ... 
intra genre acc for slate = 49.70059880239521
loading data for fiction ... 
loading data for telephone ... 
loading data for government ... 
loading data for travel ... 
intra genre acc for government = 61.12204724409449
loading data for fiction ... 
loading data for telephone ... 
loading data for slate ... 
loading data for travel ... 
intra genre acc for travel = 55.90631364562118
loading data for fiction ... 
loading data for telephone ... 
loading data for slate ... 
loading data for government ... 


{'fiction': {'telephone': 53.233830845771145,
  'slate': 47.20558882235529,
  'government': 52.16535433070866,
  'travel': 50.10183299389002},
 'telephone': {'fiction': 53.57502517623364,
  'slate': 49.40119760479042,
  'government': 54.330708661417326,
  'travel': 52.13849287169043},
 'slate': {'fiction': 53.071500503524675,
  'telephone': 51.44278606965174,
  'government': 54.03543307086614,
  'travel': 51.425661914460285},
 'government': {'fiction': 55.085599194360526,
  'telephone': 52.33830845771144,
  'slate': 49.800399201596804,
  'travel': 51.62932790224033},
 'travel': {'fiction': 49.34541792547835,
  'telephone': 51.343283582089555,
  'slate': 47.20558882235529,
  'government': 51.968503937007874}}

In [13]:
for cur_genre in acc_dict:
    my_acc = []
    for cross_genre in acc_dict[cur_genre]:
        my_acc.append(acc_dict[cur_genre][cross_genre])
    print("%s: %s" % (cur_genre, np.mean(my_acc)))

fiction: 50.67665174818128
telephone: 52.36135607853296
slate: 52.493845389625704
government: 52.213408688977275
travel: 49.96569856673277
