In [2]:
from IPython import get_ipython
%load_ext autoreload
%autoreload 2

from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = 'all'

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [623]:
import sys
import os
os.environ['HF_HOME'] = '/raid/xd/.cache/torch'
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="6"

from types import MethodType
from tqdm import tqdm
from collections import defaultdict, OrderedDict, Counter
from datetime import datetime
from io import StringIO
from itertools import chain
import math
from functools import reduce, partial

%matplotlib inline
import matplotlib.pyplot as plt
import seaborn as sns

import torch
import torch.nn as nn
import torch.nn.functional as F 
from torch.utils.data.sampler import RandomSampler, Sampler, SequentialSampler
from torch.utils.data.dataloader import DataLoader

from transformers.data.data_collator import DataCollator, default_data_collator
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, pipeline
from transformers import RobertaForMaskedLM, RobertaTokenizer, GPT2LMHeadModel, GPT2Tokenizer, GPTNeoForCausalLM
from transformers import T5Tokenizer, T5ForConditionalGeneration
from transformers import HfArgumentParser, Trainer, TrainingArguments, set_seed, AdamW
# from torch.optim.lr_scheduler import CosineAnnealingLR
# from transformers.trainer_utils import EvaluationStrategy

In [1303]:
sys.path.insert(0, '/nas/xd/projects/PyFunctional')
from functional import seq
from functional.pipeline import Sequence
from fn import _
from collections import namedtuple 

In [860]:
from utils import *
from child_utils import *
from common_utils import *

In [19]:
# models = {}
# cache_dir = '/nas/xd/.cache/torch/transformers/'  # for models besides t5-3b/11b
# proxies = {'http': '192.168.50.1:1081'} 

In [None]:
# model_name = "EleutherAI/gpt-neo-1.3B"
for model_name in ['gpt2-large']:#, 'gpt2-xl', 'KoboldAI/fairseq-dense-6.7B']:
    if model_name not in models:
        with Timer(model_name):
            model = AutoModelForCausalLM.from_pretrained(model_name, cache_dir=cache_dir)  
            tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir)
            models[model_name] = model, tokenizer

In [637]:
def get_openai_model(engine):
    def forward(input_ids, attention_mask=None):
        text = tokenizer.decode(input_ids[0])
        response = openai.Completion.create(engine=engine, prompt=text, max_tokens=0, echo=True, logprobs=5)
        return Outputs(logits=response.choices[0].logprobs)
    return forward

In [661]:
for model_name in models: print(model_name)

gpt2-large
gpt2-xl
EleutherAI/gpt-neo-1.3B
EleutherAI/gpt-j-6B
KoboldAI/fairseq-dense-6.7B
KoboldAI/fairseq-dense-13B
text-babbage-001
text-curie-001
text-davinci-001
text-davinci-002


In [660]:
engines = ['text-babbage-001', 'text-curie-001', 'text-davinci-001', 'text-davinci-002']
for engine in engines:
    model = get_openai_model(engine)
    models[engine] = model, tokenizer

In [877]:
import openai

openai.api_key = open('/nas/xd/projects/openai_api_keys.txt').readlines()[-1].split()[0]
# text = 'Once upon a time'
response = openai.Completion.create(engine=engines[1], prompt=text, max_tokens=20, echo=True, logprobs=5)
print(response.choices[0].text)

adult young. Saturday -> Friday
microgram nanogram. doctor -> master
master bachelor. micrometer -> nanometer
c d. large -> huge
modern renaissance. doctor -> master
g h. child -> teenager
planet continent. h -> g
d e. microgram -> milligram
magnitude percent. large -> huge


In [13]:
prompt_token = 'Ġ!'; prompt_id = tokenizer._convert_token_to_id(prompt_token)
bop_str = 'Instruction: '; bop_id = tokenizer.encode(bop_str)[0]  # 'Inst'
eop_str = '. For example:'; eop_id = tokenizer.encode(eop_str)[2] # 'Ġexample'
bos_id = tokenizer._convert_token_to_id('Ġ->')
eos_id = tokenizer._convert_token_to_id('Ċ')


class CHILDDataset(Dataset):
    def __init__(self, input_strs, tokenizer):
        if tokenizer.pad_token is None: tokenizer.pad_token = '!'
        self.inputs = tokenizer.batch_encode_plus(input_strs, add_special_tokens=False, padding=True, return_tensors='pt')
        input_ids = self.inputs.input_ids
        self.labels = torch.ones_like(input_ids) * (-100)
        for bi in range(input_ids.size(0)):
            bop_idx = (input_ids[bi] == bop_id).nonzero().squeeze(1)
            eop_idx = (input_ids[bi] == eop_id).nonzero().squeeze(1)
            if len(bop_idx) > 0:
                assert len(bop_idx) == 1 and len(eop_idx) == 1
                bop_idx, eop_idx = bop_idx.item(), eop_idx.item()
                input_ids[bi, bop_idx: eop_idx + 2] *= -1  # use prompt embedding for prompt tokens
            
            bos_indices = (input_ids[bi] == bos_id).nonzero().squeeze(1)
            eos_indices = (input_ids[bi] == eos_id).nonzero()[-len(bos_indices):].squeeze(1)
            for i, (bos_i, eos_i) in enumerate(zip(bos_indices.tolist(), eos_indices.tolist())):
                assert eos_i > bos_i + 1
                if i >= 2: self.labels[bi, bos_i + 1: eos_i] = input_ids[bi, bos_i + 1: eos_i]

    def __len__(self):
        return len(sel f.inputs['input_ids'])

    def __getitem__(self, i):
        return {'input_ids': self.inputs['input_ids'][i],
                'attention_mask': self.inputs['attention_mask'][i],
                'labels': self.labels[i]}

In [299]:
class WrappedEmbedding(nn.Module):
    def __init__(self, 
                wte: nn.Embedding,
                prompt_id: int = None,
                prompt_len: int = 10, 
                random_range: float = 0.5,
                initialize_from_vocab: bool = True):
        super(WrappedEmbedding, self).__init__()
#         self.wte = wte
#         self.prompt_id = prompt_id
#         self.prompt_len = prompt_len
        self.__dict__.update(locals()); del self.self
        if self.prompt_id is not None:
            self.prompt_embedding = nn.parameter.Parameter(
                self.initialize_embedding(random_range, initialize_from_vocab)).to(self.wte.weight.device)
        else:
            self.prompt_embedding = nn.Embedding(self.prompt_len, self.wte.weight.size(1)).to(self.wte.weight.device)
            assert initialize_from_vocab
            self.init_prompt_embedding_()
#             self.prompt_embedding.weight.data = self.initialize_embedding(random_range, initialize_from_vocab)     
            
    def initialize_embedding(self, random_range: float = 0.5, initialize_from_vocab: bool = True):
        if initialize_from_vocab: return self.wte.weight[:self.prompt_len].clone().detach()
        return torch.FloatTensor(self.prompt_len, self.wte.weight.size(1)).uniform_(-random_range, random_range)
    
    def init_prompt_embedding_(self):
        self.prompt_embedding.weight.data[:] = self.wte.weight[:self.prompt_len]
            
    def forward(self, input_ids):
        if self.prompt_id is not None:
            input_embeds = self.wte(input_ids)
            input_embeds[input_ids == self.prompt_id] = self.prompt_embedding.expand(input_embeds.size(0), -1, -1)
        else: # adapted from cpm-2
            prompt_mask = input_ids < 0
            prompt_ids = -input_ids * prompt_mask
            assert torch.all(prompt_ids < self.prompt_len)
            p_embeds = self.prompt_embedding(prompt_ids) * prompt_mask.float().unsqueeze(-1)
            input_ids = input_ids * ~prompt_mask
            w_embeds = self.wte(input_ids) * (~prompt_mask).float().unsqueeze(-1)
            input_embeds = w_embeds + p_embeds
        return input_embeds

In [None]:
# adapted from cpm-2: https://github.com/TsinghuaAI/CPM-2-Finetune/blob/master/utils.py#L133-L164
def get_params_for_prompt_optimization(module: nn.Module):
    params = []
    for t in module.named_modules():
        if "prompt_embedding" in t[0]:
            params.append({'params': [p for p in list(t[1]._parameters.values()) if p is not None]})
    for t in module.named_parameters():
        if "prompt" not in t[0]:
            t[1].requires_grad_(False)    
    return params

def create_optimizer(model, training_args):
    from torch.nn.parallel.distributed import DistributedDataParallel as DDP
    while isinstance(model, (DDP, )): model = model.module
    we.init_prompt_embedding_()
    param_groups = get_params_for_prompt_optimization(model)
    optimizer = AdamW(param_groups, lr=training_args.learning_rate, 
                      betas=(training_args.adam_beta1, training_args.adam_beta2),eps=training_args.adam_epsilon)
    return optimizer

In [249]:
wte = model.get_input_embeddings()
if hasattr(wte, 'wte'): wte = wte.wte  # already been wrapped
we = WrappedEmbedding(wte, prompt_len=10000)
model.set_input_embeddings(we)

In [341]:
def verbalize(obj):
    if type(obj) == bool: return 'Yes' if obj else 'No'
    return str(obj)
    
def make_query_str(instruction, query):
    if instruction is None and query is None: return ''
    s = '.'
    if instruction is not None: s = s + ' ' + instruction
    if query is not None:
        if type(query) in [int, bool, str]: query = [query]
        if type(query) == dict:
    #         return '. ' + '{' + ','.join([' %s: %s' % (str(k), str(v)) for k, v in query.items()]) + ' }'
            s = s + ' ' + '{' + ','.join([' replace %s with %s' % (str(k), str(v)) for k, v in query.items()]) + ' }'
        elif type(query) in [list,]:
            s = s + ' ' + ' '.join([str(i) for i in query])
    return s

def make_example_str(example, with_instruction=False):
    instruction, l, query, ans = example
    if type(ans) not in [Sequence, list]: ans = [ans]
    ans = [verbalize(a) for a in ans]
    return '%s -> %s' % (' '.join(l) + make_query_str(instruction if with_instruction else None, query), ' '.join(ans))

def sample_rand_len(vocab, k): return sample(vocab, k=randint(1, k))

In [933]:
# def _str(l, sep=' : '):
#     if l is None: return ''
#     if isinstance(l, str) or not isinstance(l, collections.abc.Iterable): l = [l]
#     l = [e for e in l if not my_isinstance(e, Sequence)] #type(e).__name__ != 'Sequence']
#     if isinstance(l, (dict, OrderedDict)): l = [f'{k}: {v}' for k, v in l.items()]
#     return sep.join(str(i) for i in l)

# def options2str(options): return '[' + ' | '.join(options) + ']'

In [1030]:
def _str(l, sep=' '):
    if l is None: return ''
    if isinstance(l, str) or not isinstance(l, collections.abc.Iterable): l = [l]
    l = [e for e in l if not my_isinstance(e, Sequence)] #type(e).__name__ != 'Sequence']
    if isinstance(l, (dict, OrderedDict)): l = [f'{k}: {v}' for k, v in l.items()]
    return sep.join(str(i) for i in l)

def options2str(options): return '[' + ' | '.join(options) + ']'
# def options2str(options): return ' or '.join(options) + '?'

In [1460]:
def promptize(s):
#     return prompt_token * len(s.split())
    return bop_str + s + eop_str
    
def make_examples(task, nrows=4, ncols=4, full_vocab=None):
    if full_vocab is None: full_vocab = string.ascii_uppercase + string.digits
    transform_fn, vocab_fn, sample_fn, query_fn = task[:4]
    # instruction = transform_fn.__name__.replace('_', ' ')
    if vocab_fn is None: vocab_fn = lambda: full_vocab
    if query_fn is None: query_fn = lambda *_: (None, None)

    examples = []
    qa_set = set() # for dedup
    for i in range(nrows):
        vocab = vocab_fn()
        cxt = sample_fn(vocab)#, k=ncols)
        query, options = query_fn(cxt, vocab)#, ncols)
        ans = transform_fn(cxt, query)
        if (query, ans) not in qa_set:
            qa_set.add((query, ans))
            examples.append([cxt, query, options, ans])
    return examples

def make_input_str(task, examples, options_position=None):
    task += (_str,) * (4 + 3 - len(task))
    cxt2str, query2str, ans2str = task[4:]

    def example2str(example, with_instruction=False):
        cxt, query, options, ans = example
        strs = [cxt2str(cxt), query2str(query)]
        if options_position is not None: strs.insert(options_position, options2str(options))
        # strs = [options2str(options)] + strs if options_position == 'pre' else strs + [options2str(options)]
        return '. '.join(s for s in strs if s != '') + ' -> ' + ans2str(ans)

    desc = promptize(instruction) + '\n' if False else ''
    text = '\n'.join(example2str(e) for e in examples)
    text = desc + text + '\n'
    return text

In [1472]:
def ith_element(cxt, query=None): return seq(cxt).slice(1, 2)
def besides(cxt, query): return seq(cxt).difference(query)[0]
# def besides_query(cxt, vocab): return cxt.a(sample, 2), cxt.list()
def get_poset(e): return tuple([p for p in posets if e in p][0])
def special(cxt, query): return seq(cxt).group_by(get_poset).map(_[1]).find(lambda x: len(x) == 1)[0]
# def special_cxt(vocab, k=3): sample(vocab[0], k - 1) + sample(vocab[1], 1)

def after_query(r, p):
    e = r.dom().init().a(choice)
    options = r.image(e).map(beside)[0].a(sample, 2)
    return e, options

def before_query(r, p):
    # e = r.dom().tail().a(choice)
    e = choice(r.dom().init().tail().list())
    options = r.image(e).map(beside)[0].a(sample, 2)
    return e, options

def after(r, q): return r.image(q).map(next())[0]
def before(r, q): return r.image(q).map(prev())[0]
def between(r, q): 
    return r.image(q[0]).map(nexts)[0].intersection(r.image(q[1]).map(prevs)[0]).union(
        r.image(q[0]).map(prevs)[0].intersection(r.image(q[1]).map(nexts)[0]))
    
def monotone_map_cxt(vocab):
    P, p = vocab
    R = p2r(P)
    E1 = R.dom().init().tail().a(choice)
    E2 = R.image(E1).map(beside)[0].a(choice)
    return R, E1, E2

def monotone_map_query(cxt, vocab):
    P, p = vocab
    r = p2r(p)
    e1 = r.dom().init().tail().a(choice)
    options = r.image(e1).map(beside)[0]
    return (r, e1), options

def monotone_map(cxt, query, reverse=False):
    R, E1, E2 = cxt
    r, e1 = query
    return r.image(e1).map(
        seq([prev(), next()]).find(lambda f: (E2 in R.image(E1).map(f)[0]) != reverse)  # reverse = not in. too tricky
    )[0]
    
tasks = [
    (ith_element, None, partial(sample, k=3), None),
    (besides, None, partial(sample, k=3), lambda cxt, vocab: (sample(cxt, 2), cxt)),
    (special, lambda: sample(posets[1:3], 2), lambda vocab: sample(sample(vocab[0], 2) + sample(vocab[1], 1), 2 + 1), None),
    
    (after, lambda: choice(closed_posets), p2r, after_query, lambda r: ''),
    (before, lambda: choice(closed_posets), p2r, before_query, lambda r: ''),
    (between, lambda: choice(posets), p2r, lambda r, p: r.image(r.dom().init().tail().a(choice)).map(beside)[0].a(sample, 2), lambda r: ''),
    (partial(monotone_map, reverse=False), lambda: sample(posets, 2), monotone_map_cxt, monotone_map_query),
    (partial(monotone_map, reverse=True), lambda: sample(closed_posets, 2), monotone_map_cxt, monotone_map_query),
]

In [1492]:
# polygons = ['triangle', 'quadrangle', 'pentagon', 'hexagon', 'heptagon', 'octagon', 'nonagon', 'decagon',]# 'undecagon', 'dodecagon']
times_of_day = ['dawn', 'morning', 'noon', 'afternoon', 'evening', 'night',]# 'midnight']
days_of_week = ['Monday', 'Tuesday', 'Wednesday', 'Thursday', 'Friday', 'Saturday', 'Sunday']
months = ['January', 'February', 'March', 'April', 'May', 'June', 'July', 'August', 'September', 'October', 'November', 'December']
seasons = ['spring', 'summer', 'autumn', 'winter']
# ages_of_life = ['baby', 'child', 'teenager', 'young', 'adult', 'elder']
ages_of_life = ['baby', 'child', 'adolescent', 'adult']
times_of_history = ['ancient', 'medieval', 'modern', 'contemporary'] #'renaissance', 
# units_of_time = ['nanosecond', 'microsecond', 'millisecond', ][:0] + ['second', 'minute', 'hour', 'day', 'week', 'month', 'year', 'decade', 'century', 'millennium'] # first 3 multi-token
units_of_length = ['nanometer', 'micrometer', 'millimeter', 'meter', 'kilometer', 'mile']
units_of_mass = ['nanogram', 'microgram', 'milligram', 'gram', 'kilogram', 'ton']
# SI_prefixes_small = ['pico', 'nana', 'micro', 'milli', 'centi', 'deci']
# SI_prefixes_large = ['kilo', 'mega', 'giga', 'tera', 'peta', 'exa', 'zetta', 'yotta']

things = ['atom', 'molecule', 'cell', 'tissue', 'organ', 'system', 'person', 'community', 'city', 'state', 'country', 'continent', 'planet', 'star', 'galaxy', 'universe']
sizes = ['tiny', 'small', 'large', 'huge',]# 'medium', 'gigantic']
# degrees = ['bachelor', 'master', 'doctor', 'postdoc']
posets = [list(string.ascii_uppercase)[:14], list(string.ascii_lowercase)[:14], list(string.ascii_uppercase)[14:], list(string.ascii_lowercase)[14:], digits, cardinals, ordinals,
    times_of_day, days_of_week, months, seasons, ages_of_life, times_of_history, #units_of_time, 
    things, sizes]# units_of_length, units_of_mass, SI_prefixes_small, SI_prefixes_large]
closed_posets = [list(string.ascii_uppercase)[:], list(string.ascii_lowercase)[:], digits, cardinals, ordinals, 
    days_of_week, months, ]#seasons, times_of_history, ages_of_life, sizes]
open_posets = [times_of_day, ages_of_life, times_of_history, units_of_length, units_of_mass, things, sizes, ]

In [1590]:
# task = tasks[4]
# _examples = make_examples(task, nrows=16)

text = make_input_str(task, _examples, options_position=None)
bos_token, eos_token ='Ġ->', 'Ċ'
examples = text.strip().split('\n')
k_shot = 3
print(text)
inputs = tokenizer.encode_plus(text, return_tensors='pt')
input_ids = inputs.input_ids
options_ids = [[tokenizer.encode(' ' + option)[0] for option in options] 
    for cxt, query, options, ans in _examples if options if not None]
qlen = input_ids.size(1)
# tokenize without tokenization artifact -> needed for visualization, from unseal
tokens = tokenizer.tokenize(text)
tokens = list(map(tokenizer.convert_tokens_to_string, map(lambda x: [x], tokens))) 

bos_indices, eos_indices, answers, labels = locate_answers(input_ids, tokenizer)
labels[:, :bos_indices[k_shot]] = -100  # 只算k_shot个示例后的loss

Saturday -> Friday
2 -> 1
M -> L
V -> U
July -> June
I -> H
6 -> 5
k -> j
seven -> six
February -> January
Wednesday -> Tuesday
Y -> X
seventh -> sixth
y -> x
third -> second
5 -> 4



In [1592]:
losses = []
for model_name, (model, tokenizer) in models.items():
    if any(model_name.startswith(s) for s in ['gpt2-', 'KoboldAI/fairseq-dense', 'text-davinci-001', ]): continue
    if not isinstance(model, types.FunctionType): _ = model.eval()
    with Timer(model_name): outputs = model(**inputs)
    options_ids_list = [[tokenizer.encode(' ' + option)[0] for option in options] for cxt, query, options, ans in _examples]
    mask_logits_fn = partial(mask_logits, indices=bos_indices, kept_ids=options_ids_list)
    loss, _ = show_predictions(text, examples, tokenizer, outputs.logits, bos_indices, eos_indices, answers, labels,
                    mask_logits_fn=None, topk=3, loss_reduction='mean', show_range=range(k_shot, len(examples)), sep='\t')
    print(loss)
    losses.append(loss.item() if hasattr(loss, 'item') else loss)
    if model_name == 'EleutherAI/gpt-j-6B': break
print(sum(losses) / len(losses))

EleutherAI/gpt-neo-1.3B ... done 0:00:00
  U 0.014 {' V': 0.237, ' W': 0.187, ' M': 0.065} 	 V -> U
  June 0.044 {' August': 0.194, ' July': 0.109, ' June': 0.044} 	 July -> June
  H 0.056 {' I': 0.219, ' II': 0.106, ' J': 0.06} 	 I -> H
* 5 0.709 {' 5': 0.709, ' 4': 0.049, ' 6': 0.049} 	 6 -> 5
  j 0.004 {' l': 0.264, ' k': 0.129, ' 4': 0.06} 	 k -> j
* six 0.658 {' six': 0.658, ' seven': 0.117, ' eight': 0.06} 	 seven -> six
* January 0.389 {' January': 0.389, ' February': 0.376, ' jan': 0.031} 	 February -> January
* Tuesday 0.396 {' Tuesday': 0.396, ' Wednesday': 0.346, ' Thursday': 0.131} 	 Wednesday -> Tuesday
  X 0.028 {' N': 0.191, ' Z': 0.162, ' M': 0.05} 	 Y -> X
  sixth 0.404 {' seventh': 0.427, ' sixth': 0.404, ' eighth': 0.054} 	 seventh -> sixth
  x 0.062 {' y': 0.132, ' z': 0.123, ' e': 0.094} 	 y -> x
* second 0.553 {' second': 0.553, ' third': 0.394, ' two': 0.014} 	 third -> second
* 4 0.986 {' 4': 0.986, ' 3': 0.004, ' 5': 0.002} 	 5 -> 4
tensor(2.0254, grad_fn=<NllL

In [557]:
def p2r(p): p = seq(p); return p.zip(p.inits().zip(p.tails()))#.slice(1, p.len() - 1)

In [430]:
relational_functions = [prev(), next()]
rel_fns = [prevs, nexts]

In [517]:
def neighbour(direction, k=1): return lambda x: x[direction][k]
def prev(k=1): return neighbour(0, k)
def next(k=1): return neighbour(1, k)
prevs, nexts = _[0][1:], _[1][1:]
beside = lambda x: (x[0][1], x[1][1])

**TODO: read children books for more posets**  
**TODO: Prompt gpt3 to elicit the posets it knows**  
$x \to f(x)$ where $f \in \{\text{prev/next in posets of numbers/letters/months/days, antonym, hypernym, hyponym, ...}\}$  
$x \to f^2(x)$  
one poset or mixed posets  
$x, f(x).~y \to Ff^{[-1]}(y)$ one poset or mixed posets  
$x, f^k(x).~y \to Ff^{[-1]}(y)~/Ff^{[-]k}(y)$  
$x, f(f(x))~/f(f(x)), x \to f(x)$ in between, the simplest form of sequence completion  
$x, f(x) \to Gf$ where $Gf \in \{<, >\}$  
$x, f(x); y, g(y) \to Ff \stackrel{?}{=} g^{[-1]}$ where $\text{output} \in \{\text{True}, \text{False}\}$  
sort

There is a *natural* monotone map/functor $F$ between posets/sets $A$ and $B$.  Compose the computation (set operations, sorting etc.) between $A$ and $B$ with $F$ to make harder tasks.  
$P(A) ,P(B) \to F(P(A)) \setminus ~/ \cap ~/ \triangle P(B)$. Harder form of set difference/intersection.  
$P(A) \to F(\text{sorted}(P(A)))$. Harder form of sorting.


In [385]:
n_total, n_valid = 192, 64
n_train = n_total - n_valid

input_strs = [make_input_str(tasks[4], nrows=4, ncols=5) for __ in range(n_total)]
for s in sample(input_strs, 3): print(s)

Instruction: replace with the other. For example:
G H G G G -> H G H H H
I I I I M -> M M M M I
A A F A A -> F F A F F
9 9 9 I I -> I I I 9 9

Instruction: replace with the other. For example:
V Q Q V V -> Q V V Q Q
G L L G L -> L G G L G
G 2 2 2 G -> 2 G G G 2
I I Z Z Z -> Z Z I I I

Instruction: replace with the other. For example:
R H H H R -> H R R R H
B 9 9 B B -> 9 B B 9 9
D 2 2 2 D -> 2 D D D 2
A A A A W -> W W W W A



In [368]:
sum(s.count('Yes') for s in input_strs)

370

In [322]:
train_dataset = CHILDDataset(input_strs[:-n_valid], tokenizer)
eval_dataset = CHILDDataset(input_strs[-n_valid:], tokenizer)

In [121]:
if n_total == 1:
    inputs = tokenizer.encode_plus(text, return_tensors='pt')
    inputs = prepare_inputs(inputs, model.device)
    outputs = model(**inputs, output_attentions=False)

    # assert inputs.input_ids.size(0) == 1
    input_ids = inputs.input_ids
    logits = outputs.logits

    bsz = input_ids.size(0); assert bsz == 1
    labels = torch.ones_like(input_ids) * (-100)
    for bi in range(bsz):
        bos_indices = (input_ids[bi] == bos_id).nonzero().squeeze(1)
        eos_indices = (input_ids[bi] == eos_id).nonzero()[-nrows:].squeeze(1)
        for i, (example, bos_i, eos_i) in enumerate(zip(examples, bos_indices.tolist(), eos_indices.tolist())):
            print(' ' + make_example_str(example))
            ans_ids = input_ids[bi, bos_i + 1: eos_i]
            if i >= 2: labels[bi, bos_i: eos_i - 1] = ans_ids
            ans_prob_dist = logits[bi, bos_i: eos_i - 1].softmax(-1)
            ans_probs = ans_prob_dist[torch.arange(ans_prob_dist.size(0)), ans_ids]
            ans_tokens = tokenizer.convert_ids_to_tokens(ans_ids)
            for ans_id, ans_token, ans_prob, dist in zip(ans_ids, ans_tokens, numpy(ans_probs, decimals=3), ans_prob_dist):
                top1_correct = (dist.argmax() == ans_id).item()
                print(('*' if top1_correct else ' ') + ans_token, ans_prob, 
                      show_topk(*dist.topk(5), indices_fn=tokenizer.convert_ids_to_tokens)) 
    loss = nn.CrossEntropyLoss()(logits.view(-1, logits.size(-1)), labels.view(-1))
    loss

In [329]:
training_args = TrainingArguments(output_dir="./models/model_name", 
    overwrite_output_dir=True, do_train=True, do_eval=True,
    per_device_train_batch_size=16, per_device_eval_batch_size=16,
    weight_decay=0.01, adam_beta2=0.98, adam_epsilon=1e-6,
    lr_scheduler_type='constant', learning_rate=5e-3, num_train_epochs=4,
    logging_strategy ='epoch', evaluation_strategy ='epoch', save_steps=0,
    no_cuda=True, report_to='none',  # to avoid report to wandb
)

In [330]:
trainer = Trainer(model, training_args, train_dataset=train_dataset, eval_dataset=eval_dataset,
                  optimizers=(create_optimizer(model, training_args), None))

In [333]:
trainer.place_model_on_device

True

In [None]:
def get_prev(elem):
    i, v = elem
    return _l[i - 1] if i > 0 else None

false = lambda *_: False
true  = lambda *_: True

In [None]:
Element = namedtuple('Element', 'index value')
_l = 'A B C B'.split()
n = len(_l)
# l = [Element._make(e) for e in enumerate(l)]
l = seq(_l)
l = l.enumerate().map(Element._make)

In [6]:
l.map(lambda x: {'B': 'D'}.get(x, x))

l.filter(lambda x: get_prev(x) == 'B').select(_.value)

find_fn = _.index == 1
l.filter(find_fn).select(_.value).map(lower)

find_fn = _.value == 'C'
l.filter(find_fn).select(_.index)

# move x to first
update_filter = _.value == 'C'
get_new = lambda x: -1
l.map(lambda x: Element(update_fn(x, 'index'), x.value)).order_by(_.index).select(_.value)

# swap first and last
update_filter = true
get_new = lambda x: {0: n - 1, n - 1: 0}.get(x.index, x.index)
l.map(lambda x: Element(update_fn(x, 'index'), x.value)).order_by(_.index).select(_.value)

# get inbetween == drop_while + take_while?

# update by index to its prev
update_filter = _.index == 1
get_new = lambda x: get_prev(x)
def update_fn(x, update_field): return get_new(x) if update_filter(x) else getattr(x, update_field)
l.map(lambda x: Element(x.index, update_fn(x, 'value')))

# if two adjacent elements by indices are equal
l.filter(lambda x: x.index in [0, 1]).select(_.value).distinct().len() == 1

seq('A B C B C'.split()).group_by(_).select(_[1]).flatten()

# count occurance till current
seq('A B A C B A'.split()).inits().reverse().tail().map(lambda x: x.filter(_ == x.last()).len())

# find special
seq('A B A A'.split()).count_by_value().filter(_[1] == 1).select(_[0])

# generalized find special
seq('A A B C C D D'.split()).group_by(_).map(lambda x: (x[0], len(x[1]))).filter(_[1] == 1).select(_[0])