In [1]:
%load_ext autoreload
%autoreload 2

from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = 'all'

In [2]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = '1'

import random
import string
from collections import defaultdict
from itertools import product, chain
import numpy as np
from pattern.en import comparative

In [3]:
import torch
import torch.nn.functional as F
from torch.utils.data.dataset import Dataset
from enum import Enum
from typing import List, Optional, Union

from child_frames import frames
from utils import *

import sys
sys.path.insert(0, '/nas/xd/projects/transformers/src/transformers')

import dataclasses
import logging
import os
import sys
from dataclasses import dataclass, field
from typing import Dict, Optional

from transformers import AutoConfig, AutoModelForSequenceClassification, AutoTokenizer, EvalPrediction, GlueDataset
from transformers import GlueDataTrainingArguments as DataTrainingArguments
from transformers import HfArgumentParser, Trainer, TrainingArguments, set_seed
from transformers import RobertaForMaskedLM, RobertaTokenizer
from transformers.data.datasets.glue import Split
logging.basicConfig(level=logging.ERROR)

In [58]:
tokenizer.tokenize(' latter')

['Ġlatter']

In [61]:
A_template = "{rel_prefix} {dt} {ent0} {rel} {dt} {ent1} {rel_suffix}"
B_templates = ["{pred_prefix} {dt} {ent} {pred}", "{pred_prefix} {pred} {dt} {ent}"]
B_template = B_templates[0]
entailment_templates = [
    "{A} ? {conj} , {B} .",  # yes/no/maybe
    "{A} , so {B} ? {conj} .",
]

def negate_sent(sent):
    assert ' is ' in sent
    neg_sents = []
    neg_sents.append(sent.replace(' is ', ' is not '))
    neg_sents.append('it is unlikely that ' + sent)
    return neg_sents

# def synonym(x, y): return x == y
# def antonym(x, y): return x != y
# def arbitrary(x, y): return True

def make_sentences(index=-1, orig_sentence='', entities=["X", "Z"], determiner="",
                   relation_group=[["big",], ["small"]], rand_relation_group=[["short"], ["tall", "high"]],
                   relation_prefix="", relation_suffix="", predicate_prefix="",
                   n_entity_trials=3, has_negA=True, has_negB=True, has_neutral=False, predict_relation=True, 
                   lexical_relations=['same', 'opposite'], tag_lexical_rel=False, tag_entity=False, entity_set=string.ascii_uppercase):
    def form_As(relations):
        return [A_template.format(dt=determiner, ent0=ent0, ent1=ent1, rel=rel, rel_prefix=relation_prefix, rel_suffix=relation_suffix)
              for ent0, ent1, rel in [entities + relations[:1], reverse(entities) + reverse(relations)[:1]]]

    As = []
    for rel0 in relation_group[0]:
        for rel1 in relation_group[1]:
            relations = ["is %s:%d than" % (comparative(rel), i) for i, rel in enumerate([rel0, rel1])]
            As += form_As(relations)
    As = list(set(As))
    negAs = join_lists([negate_sent(A)[:1] for A in As]) if has_negA else []

    def form_Bs(predicates): 
        f = (lambda x: x) if predict_relation else mask
        return [B_template.format(dt=determiner, ent=f(ent), pred=pred, pred_prefix=predicate_prefix)
              for ent, pred in zip(entities, predicates)]

    Bs, negBs = {'orig': [], 'rand': []}, {}
    for k, group in zip(['orig', 'rand'], [relation_group, rand_relation_group]):
        for rel0 in group[0]:
            for rel1 in group[1]:
                predicates = ["is %s:%d" % (comparative(rel), i) for i, rel in enumerate([rel0, rel1])]
                Bs[k] += form_Bs(predicates)
    for k in Bs:
        Bs[k] = list(set(Bs[k]))
        if has_negB:
            negBs[k] = join_lists([negate_sent(B)[:1] for B in Bs[k]])
            Bs[k], negBs[k] = Bs[k] + [swap_entities(negB) for negB in negBs[k]], negBs[k] + [swap_entities(B) for B in Bs[k]]
        else:
            negBs[k] = [swap_entities(B) for B in Bs[k]]

#     def form_sentences(sentence_template, As, Bs, conj, lexical_rel_filter=arbitrary):
    def form_sentences(sentence_template, As, Bs, conj):
        def extract_rel_id(s): return int(s[s.index(':') + 1])
        def get_lexical_rel(rel_id_A, rel_id_B): return 'same' if rel_id_A == rel_id_B else 'opposite'
        def strip_rel_id(s, lexical_rel_tag=''):
            rel_id_span = s[s.index(':'): s.index(':') + 2]
            return s.replace(rel_id_span, lexical_rel_tag)
        def compare_and_tag_entity(B, A):
            entity = [e for e in entities if e in B][0]
            entity_tag = 'former' if A.strip().startswith(entity) else 'latter'
            return B.replace(entity, entity + ' ( ' + entity_tag + ' )')                    
        
        if predict_relation: conj = mask(conj)
        As_with_rel_ids = [(A, extract_rel_id(A)) for A in As]
        Bs_with_rel_ids = [(B, extract_rel_id(B)) for B in Bs]
            
        sentences = []
        for (A, rel_id_A), (B, rel_id_B) in product(As_with_rel_ids, Bs_with_rel_ids):
            lexical_rel = get_lexical_rel(rel_id_A, rel_id_B)
            if lexical_rel in lexical_relations:
                lexical_rel_tag = ' ( ' + lexical_rel + ' )' if tag_lexical_rel else ''
                if tag_entity: B = compare_and_tag_entity(B, A)
                sent = sentence_template.format(A=strip_rel_id(A), B=strip_rel_id(B, lexical_rel_tag=lexical_rel_tag), conj=conj)
                sent = " " + " ".join(sent.split())
                sentences.append(sent)
        return sentences
#         return [" " + " ".join(sentence_template.format(A=strip_rel_id(A), B=strip_rel_id(B), conj=conj).split()) 
#                 for A, B in product(As, Bs) if lexical_rel_filter(extract_rel_id(A), extract_rel_id(B))]

    sentences = defaultdict(list)
    for entailment_template in entailment_templates[-1:]:
        for A, B, conj in [(As, Bs['orig'], 'Right'), 
                           (negAs, negBs['orig'], 'Right'), 
                           (As, negBs['orig'], 'Wrong'), 
                           (negAs, Bs['orig'], 'Wrong'),
                           (As, Bs['rand'], 'Maybe'), 
                           (negAs, negBs['rand'], 'Maybe'), 
                           (As, negBs['rand'], 'Maybe'), 
                           (negAs, Bs['rand'], 'Maybe'),
                          ]:
            sentences[conj] += form_sentences(entailment_template, A, B, conj)
    assert len(sentences['Right']) == len(sentences['Wrong']), '%d %d' % (len(sentences['Right']), len(sentences['Wrong']))
    if has_neutral: sentences['Maybe'] = random.sample(sentences['Maybe'], len(sentences['Right']))
    sentences = join_lists(sentences[k] for k in (sentences.keys() if has_neutral else ['Right', 'Wrong']))
    
    substituted_sent_groups = []
    for sent in sentences:
        sent_group = []
        for _ in range(n_entity_trials):
            e0, e1 = random.sample(entity_set, 2)
            sent_group.append(sent.replace(entities[0], e0).replace(entities[1], e1))
        substituted_sent_groups.append(sent_group)
    return sentences, substituted_sent_groups

# make_sentences(has_negA=True, has_negB=True, tag_lexical_rel=True, tag_entity=True)

([' X is bigger than Z , so Z ( latter ) is smaller ( opposite ) ? [ Right ] .',
  ' X is bigger than Z , so X ( former ) is bigger ( same ) ? [ Right ] .',
  ' X is bigger than Z , so X ( former ) is not smaller ( opposite ) ? [ Right ] .',
  ' X is bigger than Z , so Z ( latter ) is not bigger ( same ) ? [ Right ] .',
  ' Z is smaller than X , so Z ( former ) is smaller ( same ) ? [ Right ] .',
  ' Z is smaller than X , so X ( latter ) is bigger ( opposite ) ? [ Right ] .',
  ' Z is smaller than X , so X ( latter ) is not smaller ( same ) ? [ Right ] .',
  ' Z is smaller than X , so Z ( former ) is not bigger ( opposite ) ? [ Right ] .',
  ' X is not bigger than Z , so Z ( latter ) is not smaller ( opposite ) ? [ Right ] .',
  ' X is not bigger than Z , so X ( former ) is not bigger ( same ) ? [ Right ] .',
  ' X is not bigger than Z , so X ( former ) is smaller ( opposite ) ? [ Right ] .',
  ' X is not bigger than Z , so Z ( latter ) is bigger ( same ) ? [ Right ] .',
  ' Z is not s

In [5]:
class CHILDDataset(Dataset):
    all_lines = {Split.train: None, Split.dev: None, Split.test: None}
    
    def __init__(self, all_lines, tokenizer, max_seq_len=None, max_noise_len=0, split_pct=[0.7, 0.3, 0.0], mode=Split.train):
        if isinstance(mode, str): mode = Split[mode]
        if CHILDDataset.all_lines[mode] is None:
            random.shuffle(all_lines)
            n_dev = int(round(len(all_lines) * split_pct[1]))
            n_test = int(round(len(all_lines) * split_pct[2]))
            n_train = len(all_lines) - n_dev - n_test
            
            def flatten(lines):
                if len(lines) > 0 and type(lines[0]) == list: lines = join_lists(lines)
                return join_lists(lines) if len(lines) > 0 and type(lines[0]) == list else lines
            
            CHILDDataset.all_lines[Split.train] = flatten(all_lines[:n_train])
            CHILDDataset.all_lines[Split.dev] = flatten(all_lines[n_train: n_train + n_dev])
            CHILDDataset.all_lines[Split.test] = flatten(all_lines[n_train + n_dev:])

        examples = []
        for i, line in enumerate(CHILDDataset.all_lines[mode]):
            t1, t2, is_next_label = self.split_sent(line)
            tokens_a = rejoin_masked_tokens(tokenizer.tokenize(t1))
            tokens_b = rejoin_masked_tokens(tokenizer.tokenize(t2)) if t2 is not None else None
            example = InputExample(guid=i, tokens_a=tokens_a, tokens_b=tokens_b, is_next=is_next_label)
            examples.append(example)

        if max_seq_len is None:
            max_seq_len = max([len(example.tokens_a) + len(example.tokens_b) + 3
                if example.tokens_b is not None else len(example.tokens_a) + 2
                for example in examples])

        self.features = [convert_example_to_features(example, max_seq_len, tokenizer, max_noise_len=max_noise_len)
             for example in examples]

    def split_sent(self, line):
        label = 0
        if "|||" in line:
            t1, t2 = [t.strip() for t in line.split("|||")]
            assert len(t1) > 0 and len(t2) > 0, "%d %d" % (len(t1), len(t2))
        else:
            # assert self.one_sent
            t1, t2 = line.strip(), None
        return t1, t2, label

    def __len__(self):
        return len(self.features)

    def __getitem__(self, i):
        return self.features[i]

In [63]:
model_class, tokenizer_class, shortcut = RobertaForMaskedLM, RobertaTokenizer, 'roberta-large'
model, tokenizer = None, tokenizer_class.from_pretrained(shortcut)

In [67]:
random.shuffle(frames)
all_lines = [make_sentences(relation_group=rg, rand_relation_group=frames[(i + 1) % len(frames)], n_entity_trials=10, 
                            has_negA=True, has_negB=True, tag_lexical_rel=True, tag_entity=True,
                            has_neutral=False, predict_relation=True)[1] 
             for i, rg in enumerate(frames)]
all_lines = join_lists(all_lines)
# all_lines = join_lists(all_lines)
for k in CHILDDataset.all_lines: CHILDDataset.all_lines[k] = None
train_dataset = CHILDDataset(all_lines, tokenizer, max_noise_len=0, split_pct=[0.7, 0.3, 0.0], mode='train')
eval_dataset = CHILDDataset(all_lines, tokenizer, max_noise_len=0, split_pct=[0.7, 0.3, 0.0], mode='dev')
print('nTrain = %d, nValid = %d' % (len(train_dataset), len(eval_dataset)))

nTrain = 35280, nValid = 15120


In [None]:
model = model_class.from_pretrained('roberta-base', model=model)

training_args = TrainingArguments(output_dir="./models/model_name", 
    overwrite_output_dir=True, do_train=True, do_eval=True,
    per_device_train_batch_size=32, per_device_eval_batch_size=64,
    learning_rate=2e-5, num_train_epochs=3,
    logging_steps=100, eval_steps=100, save_steps=0,
    evaluate_during_training=True,
)
trainer = Trainer(model=model, args=training_args, train_dataset=train_dataset, eval_dataset=eval_dataset)
trainer.tokenizer = tokenizer
trainer.train()

HBox(children=(FloatProgress(value=0.0, description='Epoch', max=3.0, style=ProgressStyle(description_width='i…

HBox(children=(FloatProgress(value=0.0, description='Iteration', max=1103.0, style=ProgressStyle(description_w…

{'loss': 0.856, 'learning_rate': 1.9395587790873378e-05, 'epoch': 0.091, 'step': 100}


HBox(children=(FloatProgress(value=0.0, description='Evaluation', max=237.0, style=ProgressStyle(description_w…


{'eval_loss': 0.703, 'eval_acc': 0.47, 'eval_stat': 'ĠRight 0.94 0.53, ĠWrong 0.06 0.51, ', 'epoch': 0.091, 'step': 100}
{'loss': 0.73, 'learning_rate': 1.8791175581746754e-05, 'epoch': 0.181, 'step': 200}


HBox(children=(FloatProgress(value=0.0, description='Evaluation', max=237.0, style=ProgressStyle(description_w…


{'eval_loss': 0.716, 'eval_acc': 0.471, 'eval_stat': 'ĠRight 1.00 0.57, ', 'epoch': 0.181, 'step': 200}
{'loss': 0.721, 'learning_rate': 1.818676337262013e-05, 'epoch': 0.272, 'step': 300}


HBox(children=(FloatProgress(value=0.0, description='Evaluation', max=237.0, style=ProgressStyle(description_w…

In [13]:
dataloader = trainer.get_eval_dataloader()

In [14]:
for inputs in dataloader: break

In [15]:
inputs

{'input_ids': tensor([[  0, 530,  16,  ..., 479,   2,   0],
         [  0, 791,  16,  ..., 479,   2,   0],
         [  0, 574,  16,  ..., 479,   2,   0],
         ...,
         [  0, 975,  16,  ...,   2,   0,   0],
         [  0, 574,  16,  ...,   2,   0,   0],
         [  0, 495,  16,  ..., 479,   2,   0]]),
 'attention_mask': tensor([[1, 1, 1,  ..., 1, 1, 0],
         [1, 1, 1,  ..., 1, 1, 0],
         [1, 1, 1,  ..., 1, 1, 0],
         ...,
         [1, 1, 1,  ..., 1, 0, 0],
         [1, 1, 1,  ..., 1, 0, 0],
         [1, 1, 1,  ..., 1, 1, 0]]),
 'token_type_ids': tensor([[0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         ...,
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0]]),
 'labels': tensor([[-100, -100, -100,  ..., -100, -100, -100],
         [-100, -100, -100,  ..., -100, -100, -100],
         [-100, -100, -100,  ..., -100, -100, -100],
         ...,
         [-100, -1

In [16]:
inputs = trainer._prepare_inputs(inputs, model)
loss, logits = model(**inputs)

In [29]:
probs = F.softmax((logits * (inputs['labels'] != -100).unsqueeze(-1)).sum(dim=1), dim=-1)

In [17]:
pred_labels = (logits * (inputs['labels'] != -100).unsqueeze(-1)).sum(dim=1).argmax(dim=-1)

labels = (inputs['labels'] * (inputs['labels'] != -100)).sum(dim=-1)

(pred_labels == labels).float().mean().item()

0.71875

In [35]:
print('%.1f' % 94.433)

94.4


In [33]:
for i, label_id in enumerate(pred_labels):
    print(tokenizer._convert_id_to_token(label_id.item()), probs[i, label_id].item())

ĠRight 0.6260414719581604
ĠRight 0.6174842119216919
ĠRight 0.627226710319519
ĠRight 0.6074517965316772
ĠRight 0.6034893989562988
ĠRight 0.5928120613098145
ĠMaybe 0.9999963045120239
ĠMaybe 0.9999955892562866
ĠMaybe 0.9999963045120239
ĠRight 0.6209385395050049
ĠRight 0.6153814196586609
ĠRight 0.6189284324645996
ĠRight 0.6173365712165833
ĠRight 0.6170286536216736
ĠRight 0.6148619651794434
ĠMaybe 0.9999971389770508
ĠMaybe 0.9999972581863403
ĠMaybe 0.9999969005584717
ĠRight 0.6081119179725647
ĠRight 0.6188942790031433
ĠRight 0.6106655597686768
ĠRight 0.6149485111236572
ĠRight 0.6151918172836304
ĠRight 0.6127419471740723
ĠRight 0.6444377303123474
ĠRight 0.6497629284858704
ĠRight 0.6418505311012268
ĠRight 0.6264133453369141
ĠRight 0.6273400783538818
ĠRight 0.6344654560089111
ĠRight 0.6158509254455566
ĠRight 0.6200970411300659
ĠRight 0.6208559274673462
ĠMaybe 0.9999988079071045
ĠMaybe 0.9999986886978149
ĠMaybe 0.9999988079071045
ĠMaybe 0.9999974966049194
ĠMaybe 0.9999971389770508
ĠMaybe 0.9999

In [21]:
pred_labels.bincount().topk(3)

torch.return_types.topk(
values=tensor([40, 24,  0], device='cuda:0'),
indices=tensor([5143, 5359,    0], device='cuda:0'))

In [22]:
labels.bincount().topk(3)

torch.return_types.topk(
values=tensor([24, 22, 18], device='cuda:0'),
indices=tensor([ 5359,  5143, 31273], device='cuda:0'))

In [26]:
tokenizer._convert_id_to_token(5359)

'ĠMaybe'

In [46]:
values, indices = probs.topk(5, dim=-1)

In [52]:
for top_idx in indices:
    print(top_idx.tolist())
    print(tokenizer.convert_ids_to_tokens(top_idx.tolist()))

[5143, 31273, 13984, 235, 37234]
['ĠRight', 'ĠWrong', 'Right', 'Ġright', 'ĠCorrect']
[5143, 31273, 235, 13984, 10039]
['ĠRight', 'ĠWrong', 'Ġright', 'Right', 'ĠLeft']
[5143, 31273, 235, 13984, 38103]
['ĠRight', 'ĠWrong', 'Ġright', 'Right', 'ĠRIGHT']
[5143, 31273, 235, 13984, 38103]
['ĠRight', 'ĠWrong', 'Ġright', 'Right', 'ĠRIGHT']
[5143, 31273, 235, 13984, 38103]
['ĠRight', 'ĠWrong', 'Ġright', 'Right', 'ĠRIGHT']
[31273, 5143, 235, 13984, 38103]
['ĠWrong', 'ĠRight', 'Ġright', 'Right', 'ĠRIGHT']
[31273, 5143, 235, 13984, 38103]
['ĠWrong', 'ĠRight', 'Ġright', 'Right', 'ĠRIGHT']
[5143, 31273, 235, 13984, 1593]
['ĠRight', 'ĠWrong', 'Ġright', 'Right', 'Ġwrong']
[5143, 31273, 13984, 235, 38103]
['ĠRight', 'ĠWrong', 'Right', 'Ġright', 'ĠRIGHT']
[5143, 31273, 13984, 235, 37234]
['ĠRight', 'ĠWrong', 'Right', 'Ġright', 'ĠCorrect']
[5143, 31273, 13984, 235, 37234]
['ĠRight', 'ĠWrong', 'Right', 'Ġright', 'ĠCorrect']
[5143, 31273, 235, 13984, 38103]
['ĠRight', 'ĠWrong', 'Ġright', 'Right', 'ĠRIGHT']
