In [1]:
%load_ext autoreload
%autoreload 2

from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = 'all'

In [2]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = '1'

import random
import string
from collections import defaultdict
from itertools import product, chain
import numpy as np
from pattern.en import comparative

In [3]:
import torch
import torch.nn.functional as F
from torch.utils.data.dataset import Dataset
from enum import Enum
from typing import List, Optional, Union

from child_frames import frames
from utils import *

import sys
sys.path.insert(0, '/nas/xd/projects/transformers/src/transformers')

import dataclasses
import logging
import os
import sys
from dataclasses import dataclass, field
from typing import Dict, Optional

from transformers import AutoConfig, AutoModelForSequenceClassification, AutoTokenizer, EvalPrediction, GlueDataset
from transformers import GlueDataTrainingArguments as DataTrainingArguments
from transformers import HfArgumentParser, Trainer, TrainingArguments, set_seed
from transformers import RobertaForMaskedLM, RobertaTokenizer
from transformers.data.datasets.glue import Split
logging.basicConfig(level=logging.ERROR)

In [4]:
A_template = "{rel_prefix} {dt} {ent0} {rel} {dt} {ent1} {rel_suffix}"
B_templates = ["{pred_prefix} {dt} {ent} {pred}", "{pred_prefix} {pred} {dt} {ent}"]
B_template = B_templates[0]
entailment_templates = [
    "{A} ? {conj} , {B} .",  # yes/no/maybe
    "{A} , so {B} ? {conj} .",
]

def join_lists(x): return list(chain.from_iterable(x))

def reverse(l): return list(reversed(l))

def mask(ent_str):
    tokens = ent_str.strip().split()
    if len(tokens) == 1:
        return '[ %s ]' % tokens[0]
    elif len(tokens) == 2:
        assert tokens[0] == 'the', ent_str
        return '%s [ %s ]' % (tokens[0], tokens[1])
    else:
        assert False, ent_str

def negate_sent(sent):
    assert ' is ' in sent
    neg_sents = []
    neg_sents.append(sent.replace(' is ', ' is not '))
    neg_sents.append('it is unlikely that ' + sent)
    return neg_sents

def swap_entities(sent, e0='X', e1='Z'): return sent.replace(e0, 'xx').replace(e1, e0).replace('xx', e1)

def make_sentences(index=-1, orig_sentence='', entities=["X", "Z"], determiner="",
                   relation_group=[["big", "large"], ["small"]], rand_relation_group=[["short"], ["tall", "high"]],
                   relation_prefix="", relation_suffix="", predicate_prefix="",
                   n_entity_trials=3, has_negation=True, has_neutral=True, entity_set=string.ascii_uppercase):
    def form_As(relations):
        return [A_template.format(dt=determiner, ent0=ent0, ent1=ent1, rel=rel, rel_prefix=relation_prefix, rel_suffix=relation_suffix)
              for ent0, ent1, rel in [entities + relations[:1], reverse(entities) + reverse(relations)[:1]]]

    As = []
    for rel0 in relation_group[0]:
        for rel1 in relation_group[1]:
            relations = ["is %s than" % comparative(rel) for rel in [rel0, rel1]]
            As += form_As(relations)
    As = list(set(As))
    negAs = join_lists([negate_sent(A)[:1] for A in As]) if has_negation else []

    def form_Bs(predicates): 
        return [B_template.format(dt=determiner, ent=ent, pred=pred, pred_prefix=predicate_prefix)
              for ent, pred in zip(entities, predicates)]

    Bs, negBs = {'orig': [], 'rand': []}, {}
    for k, group in zip(['orig', 'rand'], [relation_group, rand_relation_group]):
        for rel0 in group[0]:
            for rel1 in group[1]:
                predicates = ["is %s" % comparative(rel) for rel in [rel0, rel1]]
                Bs[k] += form_Bs(predicates)
    for k in Bs:
        Bs[k] = list(set(Bs[k]))
        if has_negation:
            negBs[k] = join_lists([negate_sent(B)[:1] for B in Bs[k]])
            Bs[k], negBs[k] = Bs[k] + [swap_entities(negB) for negB in negBs[k]], negBs[k] + [swap_entities(B) for B in Bs[k]]
        else:
            negBs[k] = [swap_entities(B) for B in Bs[k]]

    def form_sentences(sentence_template, As, Bs, conj):
        return [" " + " ".join(sentence_template.format(A=A, B=B, conj=conj).split()) for A, B in product(As, Bs)]

    sentences = defaultdict(list)
    for entailment_template in entailment_templates[-1:]:
        for A, B, conj in [(As, Bs['orig'], 'Right'), 
                           (negAs, negBs['orig'], 'Right'), 
                           (As, negBs['orig'], 'Wrong'), 
                           (negAs, Bs['orig'], 'Wrong'),
                           (As, Bs['rand'], 'Maybe'), 
                           (negAs, negBs['rand'], 'Maybe'), 
                           (As, negBs['rand'], 'Maybe'), 
                           (negAs, Bs['rand'], 'Maybe'),
                          ]:
            sentences[conj] += form_sentences(entailment_template, A, B, mask(conj))
    assert len(sentences['Right']) == len(sentences['Wrong']), '%d %d' % (len(sentences['Right']), len(sentences['Wrong']))
    sentences['Maybe'] = random.sample(sentences['Maybe'], len(sentences['Right']) // 2)
    sentences = join_lists(sentences[k] for k in (sentences.keys() if has_neutral else ['Right', 'Wrong']))
    
    substituted_sentences = []
    for sent in sentences:
        for _ in range(n_entity_trials):
            e0, e1 = random.sample(entity_set, 2)
            substituted_sentences.append(sent.replace(entities[0], e0).replace(entities[1], e1))
    return sentences, substituted_sentences

# make_sentences(has_negation=False, has_neutral=False)[0]

In [5]:
def rejoin_masked_tokens(tokens):
    out = []
    while len(tokens) > 0:
        token = tokens.pop(0)
        if token not in ['[', ']', 'Ġ[', 'Ġ]']:
            out.append(token)
        else:
            assert token in ['[', 'Ġ[']
            next_token = tokens.pop(0)  # the maksed word
            next_next_token = tokens.pop(0)  # "]" symbol
            assert next_next_token in [']',  'Ġ]']
            token, next_next_token = token.replace('Ġ', ''), next_next_token.replace('Ġ', '')
            out.append(token + next_token + next_next_token)
    return out

class CHILDDataset(Dataset):
    all_lines = {Split.train: None, Split.dev: None, Split.test: None}
    
    def __init__(self, all_lines, tokenizer, max_seq_len=None, split_pct=[0.7, 0.3, 0.0], max_noise_len=0, mode=Split.train):
        if isinstance(mode, str):
            try:
                mode = Split[mode]
            except KeyError:
                raise KeyError("mode is not a valid split name")
        if CHILDDataset.all_lines[mode] is None:
            random.shuffle(all_lines)
            n_dev = int(round(len(all_lines) * split_pct[1]))
            n_test = int(round(len(all_lines) * split_pct[2]))
            n_train = len(all_lines) - n_dev - n_test
            
            def flatten(lines):
                return list(chain.from_iterable(lines)) if len(lines) > 0 and type(lines[0]) == list else lines
            
            CHILDDataset.all_lines[Split.train] = flatten(all_lines[:n_train])
            CHILDDataset.all_lines[Split.dev] = flatten(all_lines[n_train: n_train + n_dev])
            CHILDDataset.all_lines[Split.test] = flatten(all_lines[n_train + n_dev:])

        examples = []
        for i, line in enumerate(CHILDDataset.all_lines[mode]):
            t1, t2, is_next_label = self.split_sent(line)
            tokens_a = rejoin_masked_tokens(tokenizer.tokenize(t1))
            tokens_b = rejoin_masked_tokens(tokenizer.tokenize(t2)) if t2 is not None else None
            example = InputExample(guid=i, tokens_a=tokens_a, tokens_b=tokens_b, is_next=is_next_label)
            examples.append(example)

        if max_seq_len is None:
            max_seq_len = max([len(example.tokens_a) + len(example.tokens_b) + 3
                if example.tokens_b is not None else len(example.tokens_a) + 2
                for example in examples])

        self.features = [convert_example_to_features(example, max_seq_len, tokenizer, max_noise_len=max_noise_len)
             for example in examples]

    def split_sent(self, line):
        label = 0
        if "|||" in line:
            t1, t2 = [t.strip() for t in line.split("|||")]
            assert len(t1) > 0 and len(t2) > 0, "%d %d" % (len(t1), len(t2))
        else:
            # assert self.one_sent
            t1, t2 = line.strip(), None
        return t1, t2, label

    def __len__(self):
        return len(self.features)

    def __getitem__(self, i):
        return self.features[i]

In [6]:
model_class, tokenizer_class, shortcut = RobertaForMaskedLM, RobertaTokenizer, 'roberta-large'
model, tokenizer = None, tokenizer_class.from_pretrained(shortcut)

INFO:transformers.tokenization_utils_base:loading file https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-vocab.json from cache at /home/xd/.cache/torch/transformers/1ae1f5b6e2b22b25ccc04c000bb79ca847aa226d0761536b011cf7e5868f0655.ef00af9e673c7160b4d41cfda1f48c5f4cba57d5142754525572a846a1ab1b9b
INFO:transformers.tokenization_utils_base:loading file https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-merges.txt from cache at /home/xd/.cache/torch/transformers/f8f83199a6270d582d6245dc100e99c4155de81c9745c6248077018fe01abcfb.70bec105b4158ed9a1747fea67a43f5dee97855c64d62b6ec3742f4cfdb5feda


In [7]:
random.shuffle(frames)
all_lines = [make_sentences(relation_group=rg, rand_relation_group=frames[(i + 1) % len(frames)], 
                            has_negation=False, has_neutral=True)[1] 
             for i, rg in enumerate(frames)]
# all_lines = join_lists(all_lines)
for k in CHILDDataset.all_lines: CHILDDataset.all_lines[k] = None
train_dataset = CHILDDataset(all_lines, tokenizer, split_pct=[0.7, 0.3, 0.0], mode='train')
eval_dataset = CHILDDataset(all_lines, tokenizer, split_pct=[0.7, 0.3, 0.0], mode='dev')

In [8]:
model = model_class.from_pretrained('roberta-base', model)

training_args = TrainingArguments(output_dir="./models/model_name", 
    overwrite_output_dir=True, do_train=True, do_eval=True,
    per_device_train_batch_size=32, per_device_eval_batch_size=64,
    learning_rate=2e-5, num_train_epochs=3,
    logging_steps=50, eval_steps=50, save_steps=1000,
    evaluate_during_training=True,
)
trainer = Trainer(model=model, args=training_args, train_dataset=train_dataset, eval_dataset=eval_dataset)
trainer.train()

INFO:transformers.configuration_utils:loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-config.json from cache at /home/xd/.cache/torch/transformers/e1a2a406b5a05063c31f4dfdee7608986ba7c6393f7f79db5e69dcd197208534.117c81977c5979de8c088352e74ec6e70f5c66096c28b61d3c50101609b39690
INFO:transformers.configuration_utils:Model config RobertaConfig {
  "architectures": [
    "RobertaForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "bos_token_id": 0,
  "eos_token_id": 2,
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-05,
  "max_position_embeddings": 514,
  "model_type": "roberta",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 1,
  "type_vocab_size": 1,
  "vocab_size": 50265
}

INFO:transformers.modeling_utils:loading weights file https://cdn.huggingface.co/roberta-base-pyt

loading state_dict took 0.992 sec


INFO:transformers.trainer:You are instantiating a Trainer but W&B is not installed. To use wandb logging, run `pip install wandb; wandb login` see https://docs.wandb.com/huggingface.
INFO:transformers.trainer:To use comet_ml logging, run `pip/conda install comet_ml` see https://www.comet.ml/docs/python-sdk/huggingface/
INFO:transformers.trainer:***** Running training *****
INFO:transformers.trainer:  Num examples = 3294
INFO:transformers.trainer:  Num Epochs = 3
INFO:transformers.trainer:  Instantaneous batch size per device = 32
INFO:transformers.trainer:  Total train batch size (w. parallel, distributed & accumulation) = 32
INFO:transformers.trainer:  Gradient Accumulation steps = 1
INFO:transformers.trainer:  Total optimization steps = 309


HBox(children=(FloatProgress(value=0.0, description='Epoch', max=3.0, style=ProgressStyle(description_width='i…

HBox(children=(FloatProgress(value=0.0, description='Iteration', max=103.0, style=ProgressStyle(description_wi…

{'loss': 1.1985808324813843, 'learning_rate': 1.6763754045307445e-05, 'epoch': 0.4854368932038835, 'step': 50}


HBox(children=(FloatProgress(value=0.0, description='Evaluation', max=22.0, style=ProgressStyle(description_wi…


{'eval_loss': 0.7780519994822416, 'eval_accuracy': 0.5875946974212473, 'epoch': 0.4854368932038835, 'step': 50}
{'loss': 0.7714587771892547, 'learning_rate': 1.3527508090614887e-05, 'epoch': 0.970873786407767, 'step': 100}


HBox(children=(FloatProgress(value=0.0, description='Evaluation', max=22.0, style=ProgressStyle(description_wi…


{'eval_loss': 1.1149881427938289, 'eval_accuracy': 0.5223958343267441, 'epoch': 0.970873786407767, 'step': 100}



HBox(children=(FloatProgress(value=0.0, description='Iteration', max=103.0, style=ProgressStyle(description_wi…

{'loss': 0.5574883329868316, 'learning_rate': 1.029126213592233e-05, 'epoch': 1.4563106796116505, 'step': 150}


HBox(children=(FloatProgress(value=0.0, description='Evaluation', max=22.0, style=ProgressStyle(description_wi…


{'eval_loss': 0.7463125004009767, 'eval_accuracy': 0.7781250016255812, 'epoch': 1.4563106796116505, 'step': 150}
{'loss': 0.30606507778167724, 'learning_rate': 7.055016181229773e-06, 'epoch': 1.941747572815534, 'step': 200}


HBox(children=(FloatProgress(value=0.0, description='Evaluation', max=22.0, style=ProgressStyle(description_wi…


{'eval_loss': 0.5870019793510437, 'eval_accuracy': 0.8660511374473572, 'epoch': 1.941747572815534, 'step': 200}



HBox(children=(FloatProgress(value=0.0, description='Iteration', max=103.0, style=ProgressStyle(description_wi…

{'loss': 0.17191166400909424, 'learning_rate': 3.818770226537217e-06, 'epoch': 2.4271844660194173, 'step': 250}


HBox(children=(FloatProgress(value=0.0, description='Evaluation', max=22.0, style=ProgressStyle(description_wi…


{'eval_loss': 0.8572778322479941, 'eval_accuracy': 0.8596590919928118, 'epoch': 2.4271844660194173, 'step': 250}
{'loss': 0.13333035230636597, 'learning_rate': 5.825242718446603e-07, 'epoch': 2.912621359223301, 'step': 300}


HBox(children=(FloatProgress(value=0.0, description='Evaluation', max=22.0, style=ProgressStyle(description_wi…


{'eval_loss': 0.880185604095459, 'eval_accuracy': 0.8617897738109935, 'epoch': 2.912621359223301, 'step': 300}




HBox(children=(FloatProgress(value=0.0, description='Evaluation', max=22.0, style=ProgressStyle(description_wi…

INFO:transformers.trainer:

Training completed. Do not forget to share your model on huggingface.co/models =)





{'eval_loss': 0.8635010434822603, 'eval_accuracy': 0.8639204556291754, 'epoch': 3.0, 'step': 309}


TrainOutput(global_step=309, training_loss=0.5107385703817637)

In [12]:
dataloader = trainer.get_eval_dataloader()



In [16]:
for inputs in dataloader: break

{'input_ids': tensor([[    0,   534,    16,    55,  1473,    87,   468, 17487, 50264,  2156,
           272,    16,    55,  1473,   479,     2,     0,     0,     0,     0,
             0],
        [    0,  1301,    16,    55,  1473,    87,   226, 17487, 50264,  2156,
           525,    16,    55,  1473,   479,     2,     0,     0,     0,     0,
             0],
        [    0,   530,    16,    55,  1473,    87,   384, 17487, 50264,  2156,
           229,    16,    55,  1473,   479,     2,     0,     0,     0,     0,
             0],
        [    0,   565,    16,    55,  1473,    87,   289, 17487, 50264,  2156,
           289,    16,    55, 37192,   479,     2,     0,     0,     0,     0,
             0],
        [    0,   791,    16,    55,  1473,    87,   274, 17487, 50264,  2156,
           274,    16,    55, 37192,   479,     2,     0,     0,     0,     0,
             0],
        [    0,   975,    16,    55,  1473,    87,   256, 17487, 50264,  2156,
           256,    16,    55, 37

In [20]:
inputs = trainer._prepare_inputs(inputs, model)
outputs = model(**inputs)

In [22]:
loss, logits = outputs[:2]

In [44]:
probs = F.softmax((logits * (inputs['labels'] != -100).unsqueeze(-1)).sum(dim=1), dim=-1)

In [77]:
pred_labels = (logits * (inputs['labels'] != -100).unsqueeze(-1)).sum(dim=1).argmax(dim=-1)

labels = (inputs['labels'] * (inputs['labels'] != -100)).sum(dim=-1)

(pred_labels == labels).float().mean().item()

0.875

In [46]:
values, indices = probs.topk(5, dim=-1)

In [52]:
for top_idx in indices:
    print(top_idx.tolist())
    print(tokenizer.convert_ids_to_tokens(top_idx.tolist()))

[5143, 31273, 13984, 235, 37234]
['ĠRight', 'ĠWrong', 'Right', 'Ġright', 'ĠCorrect']
[5143, 31273, 235, 13984, 10039]
['ĠRight', 'ĠWrong', 'Ġright', 'Right', 'ĠLeft']
[5143, 31273, 235, 13984, 38103]
['ĠRight', 'ĠWrong', 'Ġright', 'Right', 'ĠRIGHT']
[5143, 31273, 235, 13984, 38103]
['ĠRight', 'ĠWrong', 'Ġright', 'Right', 'ĠRIGHT']
[5143, 31273, 235, 13984, 38103]
['ĠRight', 'ĠWrong', 'Ġright', 'Right', 'ĠRIGHT']
[31273, 5143, 235, 13984, 38103]
['ĠWrong', 'ĠRight', 'Ġright', 'Right', 'ĠRIGHT']
[31273, 5143, 235, 13984, 38103]
['ĠWrong', 'ĠRight', 'Ġright', 'Right', 'ĠRIGHT']
[5143, 31273, 235, 13984, 1593]
['ĠRight', 'ĠWrong', 'Ġright', 'Right', 'Ġwrong']
[5143, 31273, 13984, 235, 38103]
['ĠRight', 'ĠWrong', 'Right', 'Ġright', 'ĠRIGHT']
[5143, 31273, 13984, 235, 37234]
['ĠRight', 'ĠWrong', 'Right', 'Ġright', 'ĠCorrect']
[5143, 31273, 13984, 235, 37234]
['ĠRight', 'ĠWrong', 'Right', 'Ġright', 'ĠCorrect']
[5143, 31273, 235, 13984, 38103]
['ĠRight', 'ĠWrong', 'Ġright', 'Right', 'ĠRIGHT']
