# LM

In [2]:
args_list = [
        "--data", "/home/baihe/datasets/LM_data/wikitext-103/",
        "--dataset", "wt103",
        "--split", "valid",
        "--batch_size", "64",
        "--tgt_len", "150",
        "--cuda", 
        "--work_dir", "/home/baihe/projects/Dynasparse-transformer/wiki103/0710/base_trans",
      ]

In [3]:
# coding: utf-8
import argparse
import time
import math
import os, sys

import torch

from data_utils import get_lm_corpus
from mem_transformer import MemTransformerLM
from utils.exp_utils import get_logger

parser = argparse.ArgumentParser(description='PyTorch Transformer Language Model')
parser.add_argument('--data', type=str, default='../data/wikitext-103',
                    help='location of the data corpus')
parser.add_argument('--dataset', type=str, default='wt103',
                    choices=['wt103', 'lm1b', 'enwik8', 'text8'],
                    help='dataset name')
parser.add_argument('--split', type=str, default='all',
                    choices=['all', 'valid', 'test'],
                    help='which split to evaluate')
parser.add_argument('--batch_size', type=int, default=10,
                    help='batch size')
parser.add_argument('--tgt_len', type=int, default=5,
                    help='number of tokens to predict')
parser.add_argument('--ext_len', type=int, default=0,
                    help='length of the extended context')
parser.add_argument('--mem_len', type=int, default=0,
                    help='length of the retained previous heads')
parser.add_argument('--clamp_len', type=int, default=-1,
                    help='max positional embedding index')
parser.add_argument('--cuda', action='store_true',
                    help='use CUDA')
parser.add_argument('--work_dir', type=str, required=True,
                    help='path to the work_dir')
parser.add_argument('--no_log', action='store_true',
                    help='do not log the eval result')
parser.add_argument('--same_length', action='store_true',
                    help='set same length attention with masking')
parser.add_argument('--sega', action='store_true',
                    help='sega or not')
parser.add_argument('--sparse_mode', type=str, default='none',
                    help='spare mode for longformer')
args = parser.parse_args(args_list)
assert args.ext_len >= 0, 'extended context length must be non-negative'
args.sent_eos=False
if 'eos' in args.sparse_mode:
    args.sent_eos=True
args.compressed_mem = False
if 'compress' in args.sparse_mode:
    args.compressed_mem=True
device = torch.device("cuda" if args.cuda else "cpu")

args.work_dir = '{}-{}'.format(args.work_dir, args.dataset)
# Get logger
logging = get_logger(os.path.join(args.work_dir, 'log.txt'),
                     log_=not args.no_log)

# Load dataset
corpus = get_lm_corpus(args.data, args.dataset, sega=args.sega, sent_eos=args.sent_eos)
ntokens = len(corpus.vocab)

va_iter = corpus.get_iterator('valid', args.batch_size, args.tgt_len,
    device=device, ext_len=args.ext_len)
te_iter = corpus.get_iterator('test', args.batch_size, args.tgt_len,
    device=device, ext_len=args.ext_len)

# Load the best saved model.
with open(os.path.join(args.work_dir, 'model.pt'), 'rb') as f:
    model = torch.load(f)
model.backward_compatible()
model = model.to(device)

logging('Evaluating with bsz {} tgt_len {} ext_len {} mem_len {} clamp_len {}'.format(
       args.batch_size, args.tgt_len, args.ext_len, args.mem_len, args.clamp_len))

model.reset_length(args.tgt_len, args.ext_len, args.mem_len)
if args.clamp_len > 0:
    model.clamp_len = args.clamp_len
if args.same_length:
    model.same_length = True

###############################################################################
# Evaluation code
###############################################################################

def get_all_props(model, data, target, *mems):
    if not mems: mems = model.init_mems()

    tgt_len = target.size(0)
    hidden, new_mems = model._forward(data, mems=mems)

    pred_hid = hidden[-tgt_len:]
    probs = model.crit.get_all_props(pred_hid.view(-1, pred_hid.size(-1)), target.view(-1))

    if new_mems is None:
        return [probs]
    else:
        return [probs] + new_mems

def get_nll(model, data, target, *mems):
    if not mems: mems = model.init_mems()

    tgt_len = target.size(0)
    hidden, new_mems = model._forward(data, mems=mems)

    pred_hid = hidden[-tgt_len:]
    nll = model.crit(pred_hid.view(-1, pred_hid.size(-1)), target.view(-1))
    nll = nll.view(tgt_len, -1)
    if new_mems is None:
        return [nll]
    else:
        return [nll] + new_mems
def evaluate(eval_iter):
    # Turn on evaluation mode which disables dropout.
    model.eval()
    all_words, all_probs, all_targets = [],[],[]
    start_time = time.time()
    with torch.no_grad():
        mems = tuple()
        for idx, (data, target, seq_len) in enumerate(eval_iter):
            ret = get_nll(model,data, target, *mems)
            loss = ret[0]#, ret[1]
            # all_probs.append(loss)
            # all_targets.append(target)
            all_probs.extend(loss.tolist())
            all_targets.extend(target.view(-1).tolist())
            mems = ret[1:]
        total_time = time.time() - start_time
    logging('Time : {:.2f}s, {:.2f}ms/segment'.format(
            total_time, 1000 * total_time / (idx+1)))
    return all_probs, all_targets#all_words, 

all_probs, all_targets = evaluate(va_iter)
print('done')

# all_probs = torch.cat(all_probs).cpu().numpy()
# all_targets = torch.cat(all_targets).cpu().numpy()


Loading cached dataset...
Evaluating with bsz 64 tgt_len 150 ext_len 0 mem_len 0 clamp_len -1
	nonzero()
Consider using one of the following signatures instead:
	nonzero(*, bool as_tuple) (Triggered internally at  /pytorch/torch/csrc/utils/python_arg_parser.cpp:766.)
  indices_i = mask_i.nonzero().squeeze()
Time : 3.43s, 149.31ms/segment
done


In [4]:
import numpy as np

In [55]:
nll_threshold = np.sort(all_probs)[int(len(all_targets)*0.5)]

In [56]:
nll_over_threshold = np.array(all_targets)[np.array(all_probs)>nll_threshold]

In [57]:
len(nll_over_threshold)

108767

In [60]:
np.sort(nll_over_threshold)[90000:]

array([  4185,   4187,   4187, ..., 263386, 263697, 264664])

In [48]:
corpus.vocab.get_symbols(nll_over_threshold)

['Security',
 'features',
 'and',
 'to',
 'adheres',
 'Biggs',
 'both',
 '<eos>',
 'also',
 ',',
 'was',
 'southeast',
 'to',
 'the',
 'and',
 'a',
 'January',
 'the',
 'in',
 'of',
 'the',
 'gammarus',
 're',
 'Angel',
 'Benchmark',
 'think',
 'is',
 'Myanmar',
 'has',
 'the',
 '"',
 'the',
 ',',
 ',',
 '=',
 'post',
 'to',
 ',',
 'near',
 'the',
 'used',
 'For',
 'ruins',
 'layout',
 'and',
 '9',
 "'t",
 'Beckham',
 '.',
 'before',
 'rope',
 'following',
 'backstage',
 'of',
 '.',
 'character',
 ',',
 'create',
 '8',
 'Columbus',
 'Winston',
 'others',
 'been',
 'that',
 'September',
 'gamers',
 'the',
 '"',
 'a',
 'New',
 'Jack',
 'the',
 '.',
 'Kyle',
 '.',
 'former',
 'there',
 '"',
 'Burmese',
 '<eos>',
 'proven',
 '7',
 'the',
 'In',
 'Battalion',
 'chases',
 '.',
 'was',
 'the',
 'The',
 'interaction',
 'IGN',
 '.',
 'Early',
 'the',
 'Subsequent',
 '1880s',
 'the',
 'white',
 'integrating',
 'off',
 ',',
 'navigation',
 'Opera',
 'CW',
 'Early',
 'life',
 'team',
 '26',
 ',',


In [None]:
import glob, lxml, re
from lxml import etree

# clean the input
def clean(l):
    l = l.replace('<size=-1>','')
    l = l.replace('</size>','')
    l = l.replace('<br>','')
    l = l.replace('&','&amp;')
    l = l.replace('"<"','&lt;')
    l = l.replace('">"','&gt;')
    return l.rstrip(',;\n') + '\n'

# get information from the xml
def headword(class_element):
    return re.sub('[0-9#\[\] ]','',class_element.find("headword").find("b").text)

def pos(pos_element):
    return re.sub('[.#]','',pos_element.find("b").text)

def words(paragraph_element):
    return set([word.strip() for i in paragraph_element 
           if not i.text is None 
           for word in i.text.split(',') 
           if not word == ' '
           ])

def index(fn,root):
    return re.sub('[/heads.txt]','',fn + ' ') + headword(root)

# helper generator:
def pospargen(c):
    for a,b in [c[x:x+2] for x in range(len(c)-1)]:
        if a.tag == 'pos' and b.tag == 'paragraph':
            yield [a,b]

# get list of [POS, [words,in,entry]]
def pos_words(c):
    return dict([[pos(a),words(b)] for a,b in pospargen(c)])

In [None]:
roget = {}

for fn in glob.glob("./roget/heads/head*.txt"):
    with open(fn,'r',encoding="windows-1252") as f:
        xml = ['<class>']+[clean(l) for l in f.readlines()]+['</class>']
        root = etree.fromstring(''.join(xml), parser=etree.XMLParser(encoding="windows-1252"))
        roget[index(fn,root)] = pos_words(root.getchildren())

parts_of_speech = ['INT', 'VB', 'ADJ', 'N']

from collections import defaultdict

reverse_roget = defaultdict(set)
for category in roget:
    for pos in parts_of_speech:
        if pos in roget[category]:
            for word in roget[category][pos]:
                reverse_roget[word + '_' + pos].add(category)

In [3]:
import numpy as np
import pandas as pd
from copy import deepcopy

In [4]:
nlls = np.array(all_probs)

In [5]:
targets = np.array(all_targets)

In [9]:
encoded_roget_dict = {}
for k,v in roget.items():
    encoded_pos_dict = {}
    for k_pos, word_set in v.items():
        encoded_set = set()
        for word in word_set:
            if word in corpus.vocab.sym2idx:
                encoded_set.add(corpus.vocab.sym2idx[word])
        encoded_pos_dict[k_pos] = encoded_set
    encoded_roget_dict[k] = encoded_pos_dict

In [10]:
encoded_roget_dict['rog871 Regret']

{'N': {17956, 25458, 69389, 81721, 140717},
 'VB': {10901, 44930, 57544, 97448, 180515},
 'ADJ': {16318, 51645, 58464, 61463, 77647, 83160},
 'ADV': {21478, 113881, 141177},
 'INT': set()}

In [11]:
def get_ppl_freq(encoded_roget_dict, pos_tag="N",df=None):
    if df is not None:
        df[pos+'_ppl']=np.nan
        df[pos+'_freq']=np.nan
        df[pos+'_wordlist'] = ''
        df[pos+'_wordfreq'] = ''
        df[pos+'_avgfreq'] = np.nan
    all_noun = []
    noun_keys = []
    for k,v in encoded_roget_dict.items():
        if pos_tag in v:
            noun_keys.append(k)
            all_noun.append(v[pos_tag])
    encoded_roget_pos_dict = {}
    start_index = 0
    for k in noun_keys:
        current_word_set = encoded_roget_dict[k][pos]
        end_index = start_index+ len(current_word_set)
        remain_word_set = all_noun[:start_index]+all_noun[end_index:]
        noun_set = []
        for word in current_word_set:
            if word not in remain_word_set:
                noun_set.append(word)
        start_index = end_index
        encoded_roget_pos_dict[k] = noun_set
    class_ppl = {}
    class_freq = {}
    
    for k,v in encoded_roget_pos_dict.items():
        word_list = []
        wordfreq_list = []
        freq = []
        ppl = []
        for word_id in v:
            if word_id<=10000:
                continue
            word_list.append(corpus.vocab.idx2sym[word_id])
            wordfreq_list.append(str(word_id))
            indices = np.where(targets==word_id)[0]
            for index in indices:
                freq.append(word_id)
                ppl_i = nlls[index]
                ppl.append(ppl_i)
        if df is not None and len(ppl)>0:
            df[pos+'_wordlist'][df['class']==k] = ' '.join(word_list)
            df[pos+'_wordfreq'][df['class']==k] = ' '.join(wordfreq_list)
            df[pos+'_avgfreq'][df['class']==k] = np.mean(freq)
            df[pos+'_ppl'][df['class']==k] = np.mean(ppl)
            df[pos+'_freq'][df['class']==k] = len(ppl)
        class_ppl[k] = np.mean(ppl)
        class_freq[k] = len(ppl)
    return class_ppl, class_freq

In [12]:
class_df = pd.DataFrame({'class':list(encoded_roget_dict.keys())})
class_df['name'] = [' '.join(k.split()[1:]) for k in class_df['class']]
for pos in ['N','VB','ADJ']:
    get_ppl_freq(encoded_roget_dict, pos,class_df)

A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: http://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: http://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: http://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: http://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: http://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#return

In [307]:
class_df.describe(
)

Unnamed: 0,N_ppl,N_freq,N_avgfreq,VB_ppl,VB_freq,VB_avgfreq,ADJ_ppl,ADJ_freq,ADJ_avgfreq
count,770.0,770.0,770.0,493.0,493.0,493.0,604.0,604.0,604.0
mean,9.677942,5.892208,24747.458609,9.481188,3.959432,25147.212721,9.566391,4.183775,24905.110897
std,2.713695,6.758519,16042.489442,3.173086,3.677328,18292.187339,3.002031,3.907428,19113.299971
min,1.582873,1.0,10039.0,0.056076,1.0,10018.0,0.433158,1.0,10006.0
25%,8.123974,2.0,16211.25,7.837631,1.0,15192.125,7.792482,1.0,15566.0
50%,9.676776,4.0,20624.530303,9.535905,3.0,19065.25,9.648365,3.0,19937.642857
75%,11.14563,7.0,27005.5,11.144556,5.0,27997.0,11.27729,6.0,28489.5
max,21.254297,59.0,201381.0,20.632851,20.0,145572.0,19.876669,33.0,260308.0


In [138]:
import matplotlib.pylab as plt

In [154]:
index_of_n = (class_df['N_freq']>10) # &((class_df['N_ppl']<7)|(class_df['N_ppl']>11))
noun_class_df = class_df[index_of_n].sort_values(by='N_freq')

In [311]:
def print_head_tail_n_pos_class(df,pos_tag, num):
    class_df = df[(df[pos_tag+'_freq']>10)]
    print('%s hard class %d:' % (pos_tag, num))
    print(class_df.sort_values(by=pos_tag+'_ppl',ascending=False).head(num)[['name',pos_tag+'_wordlist',pos_tag+'_avgfreq',pos_tag+'_ppl']])
    print('\n%s easy class %d:' % (pos_tag, num))
    print(class_df[class_df[pos_tag+'_ppl']>0].sort_values(by=pos_tag+'_ppl').head(num)[['name',pos_tag+'_wordlist',pos_tag+'_avgfreq',pos_tag+'_ppl']])

In [312]:
 pd.options.display.max_colwidth = 30
 pd.options.display.width = 150

In [313]:
print_head_tail_n_pos_class(class_df,'N',10)

N hard class 10:
               name                     N_wordlist     N_avgfreq      N_ppl
631        Activity  dispatch zealot ado meddli...  26547.416667  11.731858
269            Love  eros fervor fondness darli...  23337.800000  11.490992
539        Evildoer  savage barbarian oppressor...  33008.428571  11.474117
101       Inclosure  ditch railing barricade co...  19198.272727  11.456427
254        Property  folkland paraphernalia fie...  22237.500000  11.442895
771       Deception  flytrap bait spoof hoax ti...  40407.083333  11.414145
447      Instrument  helm oar harness paraphern...  32706.666667  11.372717
724  Representation  likeness personification s...  16693.437500  11.289961
982        Painting  enamel holograph portraitu...  15810.250000  11.225190
74             Hope  buoyancy assumption cheer ...  21529.363636  11.187591

N easy class 10:
              name                     N_wordlist     N_avgfreq     N_ppl
1015  Unimportance  rubbish weed refuse nonent...  2028

In [314]:
print_head_tail_n_pos_class(class_df,'VB',10)

VB hard class 10:
                name                    VB_wordlist    VB_avgfreq     VB_ppl
49        Excitation  fascinate excite pique inf...  18351.705882  12.116376
658       Resentment  excite pique inflame fret ...  22577.181818  11.587866
736           Motive  lure incite fascinate urge...  19128.923077  11.433347
997        Cleanness  comb scrub weed sponge def...  23554.454545  11.263761
1029       Hindrance  hustle barricade inhibit m...  36810.066667  11.000488
850         Ejection  excrete dispatch spurt dro...  22439.071429  10.786085
251        Falsehood  deceive invent feign quibb...  36049.384615  10.581319
724   Representation  likeness illustrate statue...  22214.500000  10.402710
1011      Disclosure  snitch uncover concede div...  32359.583333  10.305008
869        Agitation  hustle ferment jerk hitch ...  31026.250000  10.251559

VB easy class 10:
            name                    VB_wordlist    VB_avgfreq    VB_ppl
559    Restraint  suppress inhibit cloister 

In [315]:
print_head_tail_n_pos_class(class_df,'ADJ',10)

ADJ hard class 10:
              name                   ADJ_wordlist   ADJ_avgfreq    ADJ_ppl
574       Dullness  pedestrian stupid stolid p...  31788.454545  12.665000
631       Activity  afoot workaday instant med...  30845.214286  12.619541
209    Ostentation  punctilious flaunting flas...  21473.454545  12.112767
183       Ugliness  forbidding grisly gaunt un...  31005.727273  12.049217
804    Drunkenness  maudlin drunken corned boo...  26872.000000  11.510492
485     Importance  weighty instant stirring c...  26125.181818  11.422877
103           Vice  lax sinister corrupt sinfu...  25509.545455  11.143114
332           Fear  apprehensive horrific trem...  17242.714286  11.101037
57     Uncleanness  beastly corrupt moldy deca...  18354.909091  10.715285
1015  Unimportance  respectable miserable scur...  18628.875000  10.449393

ADJ easy class 10:
                 name                   ADJ_wordlist   ADJ_avgfreq   ADJ_ppl
897              Land  alluvial littoral earthy m...  26614