In [1]:
import argparse
import numpy as np
import torch
from configs import DEFAULT_MODEL_CFG, EMOTION_CATES
from model import ELMModel
from indexer import Indexer
import json

In [2]:
cfg = DEFAULT_MODEL_CFG
indexer = Indexer(cfg.n_ctx)

In [3]:
def find_tops(beta, model_path, k=200):
    # load model
    model = ELMModel(cfg, indexer.n_vocab, indexer.n_special, indexer.n_ctx, indexer, beta)
    model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
    # read weights
    ES = model.ES.weight.data
    VS = model.VS.weight.data
    EL = model.EL.weight.data
    VL = model.VL.weight.data
    # calculate bias
    bias_S = torch.matmul(VS, ES.T)
    bias_L = torch.matmul(VL, EL.T)
    # stat
    res = {}
    for i, emo in enumerate(EMOTION_CATES):
        tmp = {}
        # speaker
        tmp['bias_mean_S'] = bias_S.mean().item()
        tmp['bias_std_S'] = bias_S.std().item()
        values, indices = bias_S[:, i].topk(k)
        tmp['top_S'] = [(indexer.decode_index2text(idx), v) \
                        for v, idx in zip(values.tolist(), indices.tolist())]
        # listener
        tmp['bias_mean_L'] = bias_L.mean().item()
        tmp['bias_std_L'] = bias_L.std().item()
        values, indices = bias_L[:, i].topk(k)
        tmp['top_L'] = [(indexer.decode_index2text(idx), v) \
                        for v, idx in zip(values.tolist(), indices.tolist())]
        res[emo] = tmp
    
    return res

In [4]:
res_1 = find_tops(1.0, 'save/elm_1')
with open('save/stat_elm_1.json', 'w') as outfile:
    json.dump(res_1, outfile)

In [5]:
res_2 = find_tops(2.0, 'save/elm_2')
with open('save/stat_elm_2.json', 'w') as outfile:
    json.dump(res_2, outfile)

In [6]:
res_3 = find_tops(3.0, 'save/elm_3')
with open('save/stat_elm_3.json', 'w') as outfile:
    json.dump(res_3, outfile)

In [7]:
mean_1_S = 0.0
mean_2_S = 0.0
mean_3_S = 0.0
mean_1_L = 0.0
mean_2_L = 0.0
mean_3_L = 0.0
max_1_S = 0.0
max_2_S = 0.0
max_3_S = 0.0
max_1_L = 0.0
max_2_L = 0.0
max_3_L = 0.0
for e in EMOTION_CATES:
    mean_1_S += res_1[e]['bias_mean_S']
    mean_2_S += res_2[e]['bias_mean_S']
    mean_3_S += res_3[e]['bias_mean_S']
    mean_1_L += res_1[e]['bias_mean_L']
    mean_2_L += res_2[e]['bias_mean_L']
    mean_3_L += res_3[e]['bias_mean_L']
    max_1_S += res_1[e]['top_S'][0][1]
    max_2_S += res_2[e]['top_S'][0][1]
    max_3_S += res_3[e]['top_S'][0][1]
    max_1_L += res_1[e]['top_L'][0][1]
    max_2_L += res_2[e]['top_L'][0][1]
    max_3_L += res_3[e]['top_L'][0][1]
mean_1_S /= 32
mean_2_S /= 32
mean_3_S /= 32
mean_1_L /= 32
mean_2_L /= 32
mean_3_L /= 32
max_1_S /= 32
max_2_S /= 32
max_3_S /= 32
max_1_L /= 32
max_2_L /= 32
max_3_L /= 32

In [8]:
print(mean_1_S, mean_2_S, mean_3_S)
print(mean_1_L, mean_2_L, mean_3_L)
print(max_1_S, max_2_S, max_3_S)
print(max_1_L, max_2_L, max_3_L)

-0.010458228178322315 -0.014138422906398773 -0.014582054689526558
-0.010079051367938519 -0.014124294742941856 -0.014553489163517952
0.06316951010376215 0.059953586431220174 0.058305080397985876
0.05899635446257889 0.05655928107444197 0.05512049177195877
