In [1]:
import argparse
import numpy as np
import torch
from configs import DEFAULT_MODEL_CFG, EMOTION_CATES
from model import ELMModel
from indexer import Indexer
import json

In [2]:
cfg = DEFAULT_MODEL_CFG
indexer = Indexer(cfg.n_ctx)

In [6]:
EMOTION_CATES

['surprised',
 'excited',
 'angry',
 'proud',
 'sad',
 'annoyed',
 'grateful',
 'lonely',
 'afraid',
 'terrified',
 'guilty',
 'impressed',
 'disgusted',
 'hopeful',
 'confident',
 'furious',
 'anxious',
 'anticipating',
 'joyful',
 'nostalgic',
 'disappointed',
 'prepared',
 'jealous',
 'content',
 'devastated',
 'embarrassed',
 'caring',
 'sentimental',
 'trusting',
 'ashamed',
 'apprehensive',
 'faithful']

In [11]:
def find_tops(beta, init_std, n_emo_embd, model_path, k=200):
    # load model
    cfg.n_emo_embd = n_emo_embd
    model = ELMModel(cfg, indexer.n_vocab, indexer.n_special, indexer.n_ctx, indexer, beta, init_std)
    model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
    # read weights
    ES = model.ES.weight.data
    VS = model.VS.weight.data
    EL = model.EL.weight.data
    VL = model.VL.weight.data
    # calculate bias
    bias_S = torch.matmul(VS, ES.T)
    bias_L = torch.matmul(VL, EL.T)
    # stat
    res = {}
    for i, emo in enumerate(EMOTION_CATES):
        tmp = {}
        # speaker
        tmp['bias_mean_S'] = bias_S.mean().item()
        tmp['bias_std_S'] = bias_S.std().item()
        values, indices = bias_S[:, i].topk(k)
        tmp['top_S'] = [(indexer.decode_index2text(idx), v) \
                        for v, idx in zip(values.tolist(), indices.tolist())]
        # listener
        tmp['bias_mean_L'] = bias_L.mean().item()
        tmp['bias_std_L'] = bias_L.std().item()
        values, indices = bias_L[:, i].topk(k)
        tmp['top_L'] = [(indexer.decode_index2text(idx), v) \
                        for v, idx in zip(values.tolist(), indices.tolist())]
        res[emo] = tmp
    
    return res

In [12]:
r_b2_std002_h768 = find_tops(2.0, 0.02, 768, 'save/emo2/elm_b2_std002_h768')
r_b2_std01_h768 = find_tops(2.0, 0.1, 768, 'save/emo2/elm_b2_std01_h768')
r_b2_std002_h300 = find_tops(2.0, 0.02, 300, 'save/emo2/elm_b2_std002_h300')
r_b2_std01_h300 = find_tops(2.0, 0.1, 300, 'save/emo2/elm_b2_std01_h300')

In [None]:
with open('save/elm_b2_std002_h768.json', 'w') as outfile:
    json.dump(r_b2_std002_h768, outfile)
with open('save/elm_b2_std01_h300.json', 'w') as outfile:
    json.dump(r_b2_std01_h300, outfile)

In [16]:
r_b2_std01_h300['confident']

{'bias_mean_S': -0.03542833402752876,
 'bias_std_S': 0.15137436985969543,
 'top_S': [('ē', 0.5552480816841125),
  ('differently', 0.552210807800293),
  ('mollie', 0.5358927249908447),
  ('horizon', 0.5329807996749878),
  ('recover', 0.5270529985427856),
  ('ν', 0.5209773182868958),
  ('difference', 0.5086113810539246),
  ('eviden', 0.5032172203063965),
  ('tooth', 0.5028625130653381),
  ('dle', 0.4818674325942993),
  ('nesses', 0.4812135696411133),
  ('headmistress', 0.4790532886981964),
  ('inten', 0.4757104218006134),
  ('80s', 0.46945029497146606),
  ('era', 0.465280681848526),
  ('tire', 0.4650290012359619),
  ('tongs', 0.46486276388168335),
  ('eved', 0.45877280831336975),
  ('ress', 0.45858192443847656),
  ('immac', 0.4574316740036011),
  ('ensured', 0.4555746912956238),
  ('recourse', 0.45252951979637146),
  ('ronni', 0.4513828754425049),
  ('mole', 0.4505506455898285),
  ('broken', 0.4498981535434723),
  ('pull', 0.4497259557247162),
  ('bravest', 0.4494256377220154),
  ('succe