In [1]:
%load_ext autoreload
%autoreload 2

from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = 'all'

In [2]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = '1'

import random
import string
from collections import defaultdict
from itertools import product, chain
import numpy as np
from pattern.en import comparative

In [3]:
import torch
import torch.nn.functional as F
from torch.utils.data.dataset import Dataset
from enum import Enum
from typing import List, Optional, Union

from child_frames import frames
from utils import *

import sys
sys.path.insert(0, '/nas/xd/projects/transformers/src/transformers')

import dataclasses
import logging
import os
import sys
from dataclasses import dataclass, field
from typing import Dict, Optional

from transformers import AutoConfig, AutoModelForSequenceClassification, AutoTokenizer, EvalPrediction, GlueDataset
from transformers import GlueDataTrainingArguments as DataTrainingArguments
from transformers import HfArgumentParser, Trainer, TrainingArguments, set_seed
from transformers import RobertaForMaskedLM, RobertaTokenizer
from transformers.modeling_roberta import RobertaDoubleHeadsModel  # XD

logging.basicConfig(level=logging.ERROR)

In [31]:
tokenizer.tokenize(' shortest')

['Ġshortest']

In [4]:
A_template = "{rel_prefix} {dt} {ent0} {rel} {dt} {ent1} {rel_suffix}"
B_templates = ["{pred_prefix} {dt} {ent} {pred}", "{pred_prefix} {pred} {dt} {ent}"]
B_template = B_templates[0]
entailment_templates = [
    "{A} ? {conj} , {B} .",  # yes/no/maybe
    "{A} , so {B} ? {conj} .",
]
marker = '*'
def get_comparative(word, add_marker=False):
    compar = comparative(word)
    if add_marker:
        compar = compar.replace('more ', 'more %s ' % marker) if compar.startswith('more ') else marker + ' ' + compar
    return compar
    
def negate_sent(sent):
    assert ' is ' in sent
    neg_sents = []
    neg_sents.append(sent.replace(' is ', ' is not '))
    neg_sents.append('it is unlikely that ' + sent)
    return neg_sents

def strip_rel_id(s, lexical_rel=''):
    rel_id_span = s[s.index(':'): s.index(':') + 2]
    if lexical_rel != '': lexical_rel = ' ( ' + lexical_rel + ' )'
    return s.replace(rel_id_span, lexical_rel)
        
tag2id = {'Ġsame': 0, 'Ġopposite': 1, 'Ġunrelated': 2, 'Ġformer': 3, 'Ġlatter': 4, 'Ġanother': 5}
id2tag = {v: k for k, v in tag2id.items()}

def make_sentences(index=-1, entities=["_X", "_Z"], entity_set=string.ascii_uppercase, determiner="",
                   relation_group=[["big",], ["small"]], rand_relation_group=[["short"], ["tall", "high"]],
                   relation_prefix="", relation_suffix="", predicate_prefix="",
                   n_entity_trials=3, has_negA=True, has_negB=True, has_neutral=False, mask_types=['sent_rel'], 
                   lexical_relations=['same', 'opposite', 'unrelated'], tag_lexical_rel=False, tag_entity_rel=False):
    def form_As(relations):
        return [A_template.format(dt=determiner, ent0=ent0, ent1=ent1, rel=rel, rel_prefix=relation_prefix, rel_suffix=relation_suffix)
              for ent0, ent1, rel in [entities + relations[:1], reverse(entities) + reverse(relations)[:1]]]

    As = []
    for rel0 in relation_group[0]:
        for rel1 in relation_group[1]:
            relations = ["is %s:%d than" % (get_comparative(rel), i) for i, rel in enumerate([rel0, rel1])]
            As += form_As(relations)
    As = list(set(As))
    negAs = join_lists([negate_sent(A)[:1] for A in As]) if has_negA else []

    def form_Bs(predicates): 
        f = mask if 'entity' in mask_types else (lambda x: x)
        return [B_template.format(dt=determiner, ent=f(ent), pred=pred, pred_prefix=predicate_prefix)
              for ent, pred in zip(entities, predicates)]

    Bs, negBs = {'orig': [], 'rand': []}, {}
    for k, group in zip(['orig', 'rand'], [relation_group, rand_relation_group]):
        for rel0 in group[0]:
            for rel1 in group[1]:
                predicates = ["is %s:%d" % (get_comparative(rel), i) for i, rel in enumerate([rel0, rel1])]
                Bs[k] += form_Bs(predicates)
    for k in Bs:
        Bs[k] = list(set(Bs[k]))
        if has_negB:
            negBs[k] = join_lists([negate_sent(B)[:1] for B in Bs[k]])
            Bs[k], negBs[k] = Bs[k] + [swap_entities(negB) for negB in negBs[k]], negBs[k] + [swap_entities(B) for B in Bs[k]]
        else:
            negBs[k] = [swap_entities(B) for B in Bs[k]]

    def form_sentences(sentence_template, As, Bs, conj):
        def extract_rel_id(s): return int(s[s.index(':') + 1])
        def get_lexical_rel(rel_id_A, rel_id_B): return 'same' if rel_id_A == rel_id_B else 'opposite'
        def compare_and_tag_entity(B, A):
            entity = [e for e in entities if e in B][0]
            entity_rel = 'former' if A.strip().startswith(entity) else 'latter'
            if 'entity_rel' in mask_types: entity_rel = mask(entity_rel)
            return B.replace(entity, entity + ' ( ' + entity_rel + ' )')                    
        
        if 'sent_rel' in mask_types: conj = mask(conj)
        As_with_rel_ids = [(A, extract_rel_id(A)) for A in As]
        Bs_with_rel_ids = [(B, extract_rel_id(B)) for B in Bs]
            
        sentences = []
        for (A, rel_id_A), (B, rel_id_B) in product(As_with_rel_ids, Bs_with_rel_ids):
            lexical_rel = 'unrelated' if 'Maybe' in conj else get_lexical_rel(rel_id_A, rel_id_B)
            if lexical_rel in lexical_relations:
                if tag_entity_rel: B = compare_and_tag_entity(B, A)
                if not tag_lexical_rel: lexical_rel = ''
                elif 'lexical_rel' in mask_types: lexical_rel = mask(lexical_rel)
                sent = sentence_template.format(A=strip_rel_id(A), B=strip_rel_id(B, lexical_rel), conj=conj)
                sent = " " + " ".join(sent.split())
                sentences.append(sent)
        return sentences

    sentences = defaultdict(list)
    for entailment_template in entailment_templates[-1:]:
        for A, B, conj in [(As, Bs['orig'], 'Right'), 
                           (negAs, negBs['orig'], 'Right'), 
                           (As, negBs['orig'], 'Wrong'), 
                           (negAs, Bs['orig'], 'Wrong'),
                           (As, Bs['rand'], 'Maybe'), 
                           (negAs, negBs['rand'], 'Maybe'), 
                           (As, negBs['rand'], 'Maybe'), 
                           (negAs, Bs['rand'], 'Maybe'),
                          ]:
            sentences[conj] += form_sentences(entailment_template, A, B, conj)
    assert len(sentences['Right']) == len(sentences['Wrong']), '%d %d' % (len(sentences['Right']), len(sentences['Wrong']))
    if has_neutral: sentences['Maybe'] = random.sample(sentences['Maybe'], len(sentences['Right']))
    sentences = join_lists(sentences[k] for k in (sentences.keys() if has_neutral else ['Right', 'Wrong']))
    
    substituted_sent_groups = []
    for sent in sentences:
        sent_group = []
        for _ in range(n_entity_trials):
            e0, e1 = random.sample(entity_set, 2)
            sent_group.append(sent.replace(entities[0], e0).replace(entities[1], e1))
        substituted_sent_groups.append(sent_group)
    return sentences, substituted_sent_groups

make_sentences(has_negA=True, has_negB=True, has_neutral=True, tag_lexical_rel=True, tag_entity_rel=True, 
               mask_types=['sent_rel'])[0]

[' _X is bigger than _Z , so _Z ( latter ) is smaller ( opposite ) ? [ Right ] .',
 ' _X is bigger than _Z , so _X ( former ) is bigger ( same ) ? [ Right ] .',
 ' _X is bigger than _Z , so _X ( former ) is not smaller ( opposite ) ? [ Right ] .',
 ' _X is bigger than _Z , so _Z ( latter ) is not bigger ( same ) ? [ Right ] .',
 ' _Z is smaller than _X , so _Z ( former ) is smaller ( same ) ? [ Right ] .',
 ' _Z is smaller than _X , so _X ( latter ) is bigger ( opposite ) ? [ Right ] .',
 ' _Z is smaller than _X , so _X ( latter ) is not smaller ( same ) ? [ Right ] .',
 ' _Z is smaller than _X , so _Z ( former ) is not bigger ( opposite ) ? [ Right ] .',
 ' _X is not bigger than _Z , so _Z ( latter ) is not smaller ( opposite ) ? [ Right ] .',
 ' _X is not bigger than _Z , so _X ( former ) is not bigger ( same ) ? [ Right ] .',
 ' _X is not bigger than _Z , so _X ( former ) is smaller ( opposite ) ? [ Right ] .',
 ' _X is not bigger than _Z , so _Z ( latter ) is bigger ( same ) ? [ Ri

In [18]:
P_template = '{ent0} {rel} {ent1}'
transitive_template = '{p0} and {p1} , so {Q} ? {conj} .'
transitive_wh_QA_template = '{which} is {pred} ? {ent} .'
    
def make_transitive(entities=["_X", "_Y", "_Z"], entity_set=string.ascii_uppercase, relation_group=[["big", ], ["small", ]], 
                    n_entity_trials=3, has_negP=True, has_negQ=True, has_neutral=False, mask_types=['sent_rel']):
    def form_atoms(relations, entities, has_neg=True):
        atoms = [P_template.format(ent0=ent0, ent1=ent1, rel=rel) 
                 for ent0, ent1, rel in [entities + relations[:1], reverse(entities) + reverse(relations)[:1]]]
        if has_neg:
            neg_relations = [r.replace('is ', 'is not ') for r in relations]
            atoms += [P_template.format(ent0=ent0, ent1=ent1, rel=rel) 
                      for ent0, ent1, rel in [entities + reverse(neg_relations)[:1], reverse(entities) + neg_relations[:1]]]
        return atoms
 
    def form_sentences(transitive_template, Ps, Qs, conj):
        sentences = []
        if 'sent_rel' in mask_types: conj = mask(conj)
        for (p0, p1), Q in product(Ps, Qs):
            sent = transitive_template.format(p0=strip_rel_id(p0), p1=strip_rel_id(p1), Q=strip_rel_id(Q), conj=conj)
            sent = " " + " ".join(sent.split())
            sentences.append(sent)
        return sentences
    
    def form_all(P0_entities, P1_entities, Q_entities, neutral=False):
        P0, P1 = [], []
        for rel0 in relation_group[0]:
            for rel1 in relation_group[1]:
                relations = ["is %s:%d than" % (get_comparative(rel), i) for i, rel in enumerate([rel0, rel1])]
                P0 += form_atoms(relations, P0_entities, has_neg=has_negP)
                P1 += form_atoms(relations, P1_entities, has_neg=has_negP)
        Ps = [(p0, p1) for p0, p1 in list(product(P0, P1)) + list(product(P1, P0))]

        Qs = form_atoms(relations, Q_entities, has_neg=has_negQ)
        negQs = [swap_entities(Q, *Q_entities) for Q in Qs]
        
        for P, Q, conj in [(Ps, Qs, 'Right'), (Ps, negQs, 'Wrong')]:
            if neutral: conj = 'Maybe'
            sentences[conj] += form_sentences(transitive_template, P, Q, conj)
        return sentences
    
    e0, e1, e2 = entities
    sentences = defaultdict(list)
    form_all(P0_entities=[e0, e1], P1_entities=[e1, e2], Q_entities=[e0, e2])
    assert len(sentences['Right']) == len(sentences['Wrong']), '%d %d' % (len(sentences['Right']), len(sentences['Wrong']))
    sample_ratio = len(relation_group[0]) * len(relation_group[1])
    if sample_ratio > 1:
        for key in sentences: sentences[key] = random.sample(sentences[key], len(sentences[key]) // sample_ratio)
#     print('nRight =', len(sentences['Right']))
    if has_neutral:
        form_all(P0_entities=[e0, e1], P1_entities=[e0, e2], Q_entities=[e1, e2], neutral=True)
        sentences['Maybe'] = random.sample(sentences['Maybe'], len(sentences['Right']))
    sentences = join_lists(sentences[k] for k in (sentences.keys() if has_neutral else ['Right', 'Wrong']))
    
    substituted_sent_groups = []
    for sent in sentences:
        sent_group = []
        for _ in range(n_entity_trials):
            e0, e1, e2 = random.sample(entity_set, 3)
            sent_group.append(sent.replace(entities[0], e0).replace(entities[1], e1).replace(entities[2], e2))
        substituted_sent_groups.append(sent_group)
    return sentences, substituted_sent_groups

make_transitive(has_negP=False, has_negQ=False, has_neutral=False)

([' _X is bigger than _Y and _Y is bigger than _Z , so _X is bigger than _Z ? [ Right ] .',
  ' _X is bigger than _Y and _Y is bigger than _Z , so _Z is smaller than _X ? [ Right ] .',
  ' _X is bigger than _Y and _Z is smaller than _Y , so _X is bigger than _Z ? [ Right ] .',
  ' _X is bigger than _Y and _Z is smaller than _Y , so _Z is smaller than _X ? [ Right ] .',
  ' _Y is smaller than _X and _Y is bigger than _Z , so _X is bigger than _Z ? [ Right ] .',
  ' _Y is smaller than _X and _Y is bigger than _Z , so _Z is smaller than _X ? [ Right ] .',
  ' _Y is smaller than _X and _Z is smaller than _Y , so _X is bigger than _Z ? [ Right ] .',
  ' _Y is smaller than _X and _Z is smaller than _Y , so _Z is smaller than _X ? [ Right ] .',
  ' _Y is bigger than _Z and _X is bigger than _Y , so _X is bigger than _Z ? [ Right ] .',
  ' _Y is bigger than _Z and _X is bigger than _Y , so _Z is smaller than _X ? [ Right ] .',
  ' _Y is bigger than _Z and _Y is smaller than _X , so _X is bigge

In [10]:
model_class, tokenizer_class, shortcut = RobertaForMaskedLM, RobertaTokenizer, 'roberta-large'
# model_class = RobertaDoubleHeadsModel
model, tokenizer = None, tokenizer_class.from_pretrained(shortcut)

In [23]:
random.shuffle(frames)
# all_lines = [make_sentences(relation_group=rg, rand_relation_group=frames[(i + 1) % len(frames)], n_entity_trials=10, 
#                             has_negA=True, has_negB=True, tag_lexical_rel=True, tag_entity_rel=True,
#                             has_neutral=False, mask_types=['sent_rel'])[1] 
#              for i, rg in enumerate(frames)]
all_lines = [make_transitive(relation_group=rg, n_entity_trials=10, 
                             has_negP=False, has_negQ=False, has_neutral=False, mask_types=['sent_rel'])[1] 
             for i, rg in enumerate(frames)]
# all_lines = join_lists(all_lines)
# all_lines = join_lists(all_lines)
tokenizer.tag2id, tokenizer.id2tag = tag2id, id2tag
for k in CHILDDataset.all_lines: CHILDDataset.all_lines[k] = None
train_dataset = CHILDDataset(all_lines, tokenizer, max_noise_len=0, split_pct=[0.7, 0.3, 0.0], mode='train')
eval_dataset = CHILDDataset(all_lines, tokenizer, max_noise_len=0, split_pct=[0.7, 0.3, 0.0], mode='dev')
print('nTrain = %d, nValid = %d' % (len(train_dataset), len(eval_dataset)))

in convert_example_to_features: features.labels = [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 5143, -100, -100, -100, -100]
in convert_example_to_features: features.labels = [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 5143, -100, -100]
nTrain = 32320, nValid = 15040


In [24]:
model = model_class.from_pretrained('roberta-base', model=model)

training_args = TrainingArguments(output_dir="./models/model_name", 
    overwrite_output_dir=True, do_train=True, do_eval=True,
    per_device_train_batch_size=32, per_device_eval_batch_size=64,
    learning_rate=2e-5, num_train_epochs=3,
    logging_steps=100, eval_steps=100, save_steps=0,
    no_cuda=False, evaluate_during_training=True,
)
trainer = Trainer(model=model, args=training_args, train_dataset=train_dataset, eval_dataset=eval_dataset)
trainer.tokenizer = tokenizer
trainer.train()

HBox(children=(FloatProgress(value=0.0, description='Epoch', max=3.0, style=ProgressStyle(description_width='i…

HBox(children=(FloatProgress(value=0.0, description='Iteration', max=1010.0, style=ProgressStyle(description_w…

{'loss': 0.901, 'learning_rate': 1.9339933993399344e-05, 'epoch': 0.099, 'step': 100}


HBox(children=(FloatProgress(value=0.0, description='Evaluation', max=235.0, style=ProgressStyle(description_w…


{'eval_loss': 0.705, 'eval_acc_0': 0.5, 'eval_stat_0': 'ĠRight 0.00 0.50, ĠWrong 1.00 0.56, ', 'epoch': 0.099, 'step': 100}
{'loss': 0.73, 'learning_rate': 1.867986798679868e-05, 'epoch': 0.198, 'step': 200}


HBox(children=(FloatProgress(value=0.0, description='Evaluation', max=235.0, style=ProgressStyle(description_w…


{'eval_loss': 0.741, 'eval_acc_0': 0.5, 'eval_stat_0': 'ĠRight 1.00 0.65, ', 'epoch': 0.198, 'step': 200}
{'loss': 0.725, 'learning_rate': 1.8019801980198022e-05, 'epoch': 0.297, 'step': 300}


HBox(children=(FloatProgress(value=0.0, description='Evaluation', max=235.0, style=ProgressStyle(description_w…


{'eval_loss': 0.695, 'eval_acc_0': 0.502, 'eval_stat_0': 'ĠRight 0.42 0.51, ĠWrong 0.58 0.51, ', 'epoch': 0.297, 'step': 300}
{'loss': 0.737, 'learning_rate': 1.735973597359736e-05, 'epoch': 0.396, 'step': 400}


HBox(children=(FloatProgress(value=0.0, description='Evaluation', max=235.0, style=ProgressStyle(description_w…


{'eval_loss': 0.697, 'eval_acc_0': 0.5, 'eval_stat_0': 'ĠRight 1.00 0.54, ', 'epoch': 0.396, 'step': 400}
{'loss': 0.721, 'learning_rate': 1.66996699669967e-05, 'epoch': 0.495, 'step': 500}


HBox(children=(FloatProgress(value=0.0, description='Evaluation', max=235.0, style=ProgressStyle(description_w…


{'eval_loss': 0.705, 'eval_acc_0': 0.5, 'eval_stat_0': 'ĠRight 1.00 0.57, ', 'epoch': 0.495, 'step': 500}
{'loss': 0.71, 'learning_rate': 1.6039603960396042e-05, 'epoch': 0.594, 'step': 600}


HBox(children=(FloatProgress(value=0.0, description='Evaluation', max=235.0, style=ProgressStyle(description_w…


{'eval_loss': 0.715, 'eval_acc_0': 0.5, 'eval_stat_0': 'ĠWrong 1.00 0.60, ', 'epoch': 0.594, 'step': 600}
{'loss': 0.713, 'learning_rate': 1.537953795379538e-05, 'epoch': 0.693, 'step': 700}


HBox(children=(FloatProgress(value=0.0, description='Evaluation', max=235.0, style=ProgressStyle(description_w…


{'eval_loss': 0.698, 'eval_acc_0': 0.5, 'eval_stat_0': 'ĠWrong 1.00 0.54, ', 'epoch': 0.693, 'step': 700}
{'loss': 0.705, 'learning_rate': 1.4719471947194721e-05, 'epoch': 0.792, 'step': 800}


HBox(children=(FloatProgress(value=0.0, description='Evaluation', max=235.0, style=ProgressStyle(description_w…


{'eval_loss': 0.698, 'eval_acc_0': 0.5, 'eval_stat_0': 'ĠRight 1.00 0.55, ', 'epoch': 0.792, 'step': 800}
{'loss': 0.713, 'learning_rate': 1.405940594059406e-05, 'epoch': 0.891, 'step': 900}


HBox(children=(FloatProgress(value=0.0, description='Evaluation', max=235.0, style=ProgressStyle(description_w…


{'eval_loss': 0.695, 'eval_acc_0': 0.5, 'eval_stat_0': 'ĠWrong 1.00 0.52, ', 'epoch': 0.891, 'step': 900}
{'loss': 0.708, 'learning_rate': 1.33993399339934e-05, 'epoch': 0.99, 'step': 1000}


HBox(children=(FloatProgress(value=0.0, description='Evaluation', max=235.0, style=ProgressStyle(description_w…


{'eval_loss': 0.731, 'eval_acc_0': 0.5, 'eval_stat_0': 'ĠRight 1.00 0.63, ', 'epoch': 0.99, 'step': 1000}



HBox(children=(FloatProgress(value=0.0, description='Iteration', max=1010.0, style=ProgressStyle(description_w…

{'loss': 0.718, 'learning_rate': 1.2739273927392741e-05, 'epoch': 1.089, 'step': 1100}


HBox(children=(FloatProgress(value=0.0, description='Evaluation', max=235.0, style=ProgressStyle(description_w…


{'eval_loss': 0.702, 'eval_acc_0': 0.5, 'eval_stat_0': 'ĠWrong 1.00 0.56, ', 'epoch': 1.089, 'step': 1100}
{'loss': 0.703, 'learning_rate': 1.207920792079208e-05, 'epoch': 1.188, 'step': 1200}


HBox(children=(FloatProgress(value=0.0, description='Evaluation', max=235.0, style=ProgressStyle(description_w…


{'eval_loss': 0.702, 'eval_acc_0': 0.5, 'eval_stat_0': 'ĠWrong 1.00 0.56, ', 'epoch': 1.188, 'step': 1200}
{'loss': 0.698, 'learning_rate': 1.141914191419142e-05, 'epoch': 1.287, 'step': 1300}


HBox(children=(FloatProgress(value=0.0, description='Evaluation', max=235.0, style=ProgressStyle(description_w…


{'eval_loss': 0.698, 'eval_acc_0': 0.5, 'eval_stat_0': 'ĠWrong 1.00 0.55, ', 'epoch': 1.287, 'step': 1300}
{'loss': 0.702, 'learning_rate': 1.075907590759076e-05, 'epoch': 1.386, 'step': 1400}


HBox(children=(FloatProgress(value=0.0, description='Evaluation', max=235.0, style=ProgressStyle(description_w…


{'eval_loss': 0.693, 'eval_acc_0': 0.505, 'eval_stat_0': 'ĠRight 0.18 0.50, ĠWrong 0.82 0.50, ', 'epoch': 1.386, 'step': 1400}
{'loss': 0.706, 'learning_rate': 1.00990099009901e-05, 'epoch': 1.485, 'step': 1500}


HBox(children=(FloatProgress(value=0.0, description='Evaluation', max=235.0, style=ProgressStyle(description_w…


{'eval_loss': 0.693, 'eval_acc_0': 0.497, 'eval_stat_0': 'ĠRight 0.74 0.51, ĠWrong 0.26 0.50, ', 'epoch': 1.485, 'step': 1500}
{'loss': 0.706, 'learning_rate': 9.43894389438944e-06, 'epoch': 1.584, 'step': 1600}


HBox(children=(FloatProgress(value=0.0, description='Evaluation', max=235.0, style=ProgressStyle(description_w…


{'eval_loss': 0.693, 'eval_acc_0': 0.5, 'eval_stat_0': 'ĠRight 1.00 0.51, ĠWrong 0.00 0.50, ', 'epoch': 1.584, 'step': 1600}
{'loss': 0.7, 'learning_rate': 8.77887788778878e-06, 'epoch': 1.683, 'step': 1700}


HBox(children=(FloatProgress(value=0.0, description='Evaluation', max=235.0, style=ProgressStyle(description_w…


{'eval_loss': 0.725, 'eval_acc_0': 0.5, 'eval_stat_0': 'ĠRight 1.00 0.62, ', 'epoch': 1.683, 'step': 1700}
{'loss': 0.704, 'learning_rate': 8.11881188118812e-06, 'epoch': 1.782, 'step': 1800}


HBox(children=(FloatProgress(value=0.0, description='Evaluation', max=235.0, style=ProgressStyle(description_w…


{'eval_loss': 0.695, 'eval_acc_0': 0.5, 'eval_stat_0': 'ĠWrong 1.00 0.53, ', 'epoch': 1.782, 'step': 1800}
{'loss': 0.702, 'learning_rate': 7.458745874587459e-06, 'epoch': 1.881, 'step': 1900}


HBox(children=(FloatProgress(value=0.0, description='Evaluation', max=235.0, style=ProgressStyle(description_w…


{'eval_loss': 0.702, 'eval_acc_0': 0.5, 'eval_stat_0': 'ĠWrong 1.00 0.56, ', 'epoch': 1.881, 'step': 1900}
{'loss': 0.702, 'learning_rate': 6.798679867986799e-06, 'epoch': 1.98, 'step': 2000}


HBox(children=(FloatProgress(value=0.0, description='Evaluation', max=235.0, style=ProgressStyle(description_w…


{'eval_loss': 0.695, 'eval_acc_0': 0.5, 'eval_stat_0': 'ĠWrong 1.00 0.53, ', 'epoch': 1.98, 'step': 2000}



HBox(children=(FloatProgress(value=0.0, description='Iteration', max=1010.0, style=ProgressStyle(description_w…

{'loss': 0.698, 'learning_rate': 6.138613861386139e-06, 'epoch': 2.079, 'step': 2100}


HBox(children=(FloatProgress(value=0.0, description='Evaluation', max=235.0, style=ProgressStyle(description_w…


{'eval_loss': 0.693, 'eval_acc_0': 0.499, 'eval_stat_0': 'ĠRight 0.02 0.50, ĠWrong 0.98 0.51, ', 'epoch': 2.079, 'step': 2100}
{'loss': 0.698, 'learning_rate': 5.4785478547854785e-06, 'epoch': 2.178, 'step': 2200}


HBox(children=(FloatProgress(value=0.0, description='Evaluation', max=235.0, style=ProgressStyle(description_w…


{'eval_loss': 0.694, 'eval_acc_0': 0.5, 'eval_stat_0': 'ĠRight 1.00 0.52, ', 'epoch': 2.178, 'step': 2200}
{'loss': 0.697, 'learning_rate': 4.818481848184819e-06, 'epoch': 2.277, 'step': 2300}


HBox(children=(FloatProgress(value=0.0, description='Evaluation', max=235.0, style=ProgressStyle(description_w…


{'eval_loss': 0.699, 'eval_acc_0': 0.5, 'eval_stat_0': 'ĠRight 1.00 0.55, ', 'epoch': 2.277, 'step': 2300}
{'loss': 0.698, 'learning_rate': 4.158415841584159e-06, 'epoch': 2.376, 'step': 2400}


HBox(children=(FloatProgress(value=0.0, description='Evaluation', max=235.0, style=ProgressStyle(description_w…


{'eval_loss': 0.693, 'eval_acc_0': 0.503, 'eval_stat_0': 'ĠRight 0.54 0.50, ĠWrong 0.46 0.50, ', 'epoch': 2.376, 'step': 2400}
{'loss': 0.698, 'learning_rate': 3.4983498349834986e-06, 'epoch': 2.475, 'step': 2500}


HBox(children=(FloatProgress(value=0.0, description='Evaluation', max=235.0, style=ProgressStyle(description_w…


{'eval_loss': 0.693, 'eval_acc_0': 0.502, 'eval_stat_0': 'ĠRight 0.86 0.50, ĠWrong 0.14 0.50, ', 'epoch': 2.475, 'step': 2500}
{'loss': 0.698, 'learning_rate': 2.8382838283828383e-06, 'epoch': 2.574, 'step': 2600}


HBox(children=(FloatProgress(value=0.0, description='Evaluation', max=235.0, style=ProgressStyle(description_w…


{'eval_loss': 0.694, 'eval_acc_0': 0.5, 'eval_stat_0': 'ĠWrong 1.00 0.52, ', 'epoch': 2.574, 'step': 2600}
{'loss': 0.697, 'learning_rate': 2.1782178217821785e-06, 'epoch': 2.673, 'step': 2700}


HBox(children=(FloatProgress(value=0.0, description='Evaluation', max=235.0, style=ProgressStyle(description_w…


{'eval_loss': 0.695, 'eval_acc_0': 0.5, 'eval_stat_0': 'ĠWrong 1.00 0.53, ', 'epoch': 2.673, 'step': 2700}
{'loss': 0.698, 'learning_rate': 1.5181518151815183e-06, 'epoch': 2.772, 'step': 2800}


HBox(children=(FloatProgress(value=0.0, description='Evaluation', max=235.0, style=ProgressStyle(description_w…


{'eval_loss': 0.693, 'eval_acc_0': 0.5, 'eval_stat_0': 'ĠRight 0.05 0.50, ĠWrong 0.95 0.51, ', 'epoch': 2.772, 'step': 2800}
{'loss': 0.696, 'learning_rate': 8.580858085808581e-07, 'epoch': 2.871, 'step': 2900}


HBox(children=(FloatProgress(value=0.0, description='Evaluation', max=235.0, style=ProgressStyle(description_w…


{'eval_loss': 0.693, 'eval_acc_0': 0.499, 'eval_stat_0': 'ĠRight 0.30 0.50, ĠWrong 0.70 0.50, ', 'epoch': 2.871, 'step': 2900}
{'loss': 0.695, 'learning_rate': 1.9801980198019803e-07, 'epoch': 2.97, 'step': 3000}


HBox(children=(FloatProgress(value=0.0, description='Evaluation', max=235.0, style=ProgressStyle(description_w…


{'eval_loss': 0.694, 'eval_acc_0': 0.5, 'eval_stat_0': 'ĠWrong 1.00 0.51, ', 'epoch': 2.97, 'step': 3000}




HBox(children=(FloatProgress(value=0.0, description='Evaluation', max=235.0, style=ProgressStyle(description_w…


{'eval_loss': 0.694, 'eval_acc_0': 0.5, 'eval_stat_0': 'ĠWrong 1.00 0.51, ', 'epoch': 3.0, 'step': 3030}


TrainOutput(global_step=3030, training_loss=0.7125059162429457)

In [174]:
model.roberta.encoder.layer[0].attention.self

BertSelfAttention(
  (query): Linear(in_features=768, out_features=768, bias=True)
  (key): Linear(in_features=768, out_features=768, bias=True)
  (value): Linear(in_features=768, out_features=768, bias=True)
  (dropout): Dropout(p=0.1, inplace=False)
)

In [128]:
dataloader = trainer.get_eval_dataloader()

In [129]:
for inputs in dataloader: break

In [130]:
inputs

{'input_ids': tensor([[  0, 387,  16,  ...,   2,   0,   0],
         [  0, 534,  16,  ...,   2,   0,   0],
         [  0, 975,  16,  ...,   2,   0,   0],
         ...,
         [  0, 487,  16,  ..., 479,   2,   0],
         [  0, 673,  16,  ..., 479,   2,   0],
         [  0, 725,  16,  ..., 479,   2,   0]]),
 'attention_mask': tensor([[1, 1, 1,  ..., 1, 0, 0],
         [1, 1, 1,  ..., 1, 0, 0],
         [1, 1, 1,  ..., 1, 0, 0],
         ...,
         [1, 1, 1,  ..., 1, 1, 0],
         [1, 1, 1,  ..., 1, 1, 0],
         [1, 1, 1,  ..., 1, 1, 0]]),
 'token_type_ids': tensor([[0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         ...,
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0]]),
 'labels': tensor([[-100, -100, -100,  ..., -100, -100, -100],
         [-100, -100, -100,  ..., -100, -100, -100],
         [-100, -100, -100,  ..., -100, -100, -100],
         ...,
         [-100, -1

In [133]:
inputs = trainer._prepare_inputs(inputs, model)
loss, logits = model(**inputs)

In [134]:
probs = F.softmax((logits * (inputs['labels'] != -100).unsqueeze(-1)).sum(dim=1), dim=-1)

In [144]:
bsz, seq_len, vocab_size = (logits).size()
masks = (inputs['labels'] != -100)
n_mask = masks.sum(dim=-1)[0].item()
pred_labels = logits.masked_select(mask.unsqueeze(-1)).view(bsz, n_mask, vocab_size).argmax(dim=-1)
labels = inputs['labels'].masked_select(masks).view(bsz, n_mask)

torch.Size([64, 3])

In [132]:
inputs['labels'][0]

tensor([-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 5442,
        -100, -100, -100, -100, -100, 5483, -100, -100, 5143, -100, -100, -100,
        -100])

In [108]:
pred_labels = (logits * (inputs['labels'] != -100).unsqueeze(-1)).sum(dim=1).argmax(dim=-1)

labels = (inputs['labels'] * (inputs['labels'] != -100)).sum(dim=-1)

(pred_labels == labels).float().mean().item()

0.90625

In [35]:
print('%.1f' % 94.433)

94.4


In [33]:
for i, label_id in enumerate(pred_labels):
    print(tokenizer._convert_id_to_token(label_id.item()), probs[i, label_id].item())

ĠRight 0.6260414719581604
ĠRight 0.6174842119216919
ĠRight 0.627226710319519
ĠRight 0.6074517965316772
ĠRight 0.6034893989562988
ĠRight 0.5928120613098145
ĠMaybe 0.9999963045120239
ĠMaybe 0.9999955892562866
ĠMaybe 0.9999963045120239
ĠRight 0.6209385395050049
ĠRight 0.6153814196586609
ĠRight 0.6189284324645996
ĠRight 0.6173365712165833
ĠRight 0.6170286536216736
ĠRight 0.6148619651794434
ĠMaybe 0.9999971389770508
ĠMaybe 0.9999972581863403
ĠMaybe 0.9999969005584717
ĠRight 0.6081119179725647
ĠRight 0.6188942790031433
ĠRight 0.6106655597686768
ĠRight 0.6149485111236572
ĠRight 0.6151918172836304
ĠRight 0.6127419471740723
ĠRight 0.6444377303123474
ĠRight 0.6497629284858704
ĠRight 0.6418505311012268
ĠRight 0.6264133453369141
ĠRight 0.6273400783538818
ĠRight 0.6344654560089111
ĠRight 0.6158509254455566
ĠRight 0.6200970411300659
ĠRight 0.6208559274673462
ĠMaybe 0.9999988079071045
ĠMaybe 0.9999986886978149
ĠMaybe 0.9999988079071045
ĠMaybe 0.9999974966049194
ĠMaybe 0.9999971389770508
ĠMaybe 0.9999

In [21]:
pred_labels.bincount().topk(3)

torch.return_types.topk(
values=tensor([40, 24,  0], device='cuda:0'),
indices=tensor([5143, 5359,    0], device='cuda:0'))

In [22]:
labels.bincount().topk(3)

torch.return_types.topk(
values=tensor([24, 22, 18], device='cuda:0'),
indices=tensor([ 5359,  5143, 31273], device='cuda:0'))

In [26]:
tokenizer._convert_id_to_token(5359)

'ĠMaybe'

In [46]:
values, indices = probs.topk(5, dim=-1)

In [52]:
for top_idx in indices:
    print(top_idx.tolist())
    print(tokenizer.convert_ids_to_tokens(top_idx.tolist()))

[5143, 31273, 13984, 235, 37234]
['ĠRight', 'ĠWrong', 'Right', 'Ġright', 'ĠCorrect']
[5143, 31273, 235, 13984, 10039]
['ĠRight', 'ĠWrong', 'Ġright', 'Right', 'ĠLeft']
[5143, 31273, 235, 13984, 38103]
['ĠRight', 'ĠWrong', 'Ġright', 'Right', 'ĠRIGHT']
[5143, 31273, 235, 13984, 38103]
['ĠRight', 'ĠWrong', 'Ġright', 'Right', 'ĠRIGHT']
[5143, 31273, 235, 13984, 38103]
['ĠRight', 'ĠWrong', 'Ġright', 'Right', 'ĠRIGHT']
[31273, 5143, 235, 13984, 38103]
['ĠWrong', 'ĠRight', 'Ġright', 'Right', 'ĠRIGHT']
[31273, 5143, 235, 13984, 38103]
['ĠWrong', 'ĠRight', 'Ġright', 'Right', 'ĠRIGHT']
[5143, 31273, 235, 13984, 1593]
['ĠRight', 'ĠWrong', 'Ġright', 'Right', 'Ġwrong']
[5143, 31273, 13984, 235, 38103]
['ĠRight', 'ĠWrong', 'Right', 'Ġright', 'ĠRIGHT']
[5143, 31273, 13984, 235, 37234]
['ĠRight', 'ĠWrong', 'Right', 'Ġright', 'ĠCorrect']
[5143, 31273, 13984, 235, 37234]
['ĠRight', 'ĠWrong', 'Right', 'Ġright', 'ĠCorrect']
[5143, 31273, 235, 13984, 38103]
['ĠRight', 'ĠWrong', 'Ġright', 'Right', 'ĠRIGHT']
