In [1]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2

In [187]:
import argparse
import os
import random

import numpy as np
import torch
import torch.nn as nn
from sklearn.metrics import accuracy_score
from sklearn.utils import shuffle

from analysis import rocstories as rocstories_analysis
from datasets import rocstories
from model_pytorch import DoubleHeadModel, load_openai_pretrained_model
from opt import OpenAIAdam
from text_utils import TextEncoder
from utils import (encode_dataset, iter_data,
                   ResultLogger, make_path, np_softmax)
from loss import LMLossCompute

# Helpers

In [3]:


def transform_roc(X1, X2, X3):
    """pad and crop sequences"""
    n_batch = len(X1)
    xmb = np.zeros((n_batch, 2, n_ctx, 2), dtype=np.int32)
    mmb = np.zeros((n_batch, 2, n_ctx), dtype=np.float32)
    start = encoder['_start_']
    delimiter = encoder['_delimiter_']
    for i, (x1, x2, x3), in enumerate(zip(X1, X2, X3)):
        x12 = [start] + x1[:max_len] + [delimiter] + x2[:max_len] + [clf_token]
        x13 = [start] + x1[:max_len] + [delimiter] + x3[:max_len] + [clf_token]
        l12 = len(x12)
        l13 = len(x13)
        xmb[i, 0, :l12, 0] = x12
        xmb[i, 1, :l13, 0] = x13
        mmb[i, 0, :l12] = 1
        mmb[i, 1, :l13] = 1
    # Position information that is added to the input embeddings in the TransformerModel
    xmb[:, :, :, 1] = np.arange(n_vocab + n_special, n_vocab + n_special + n_ctx)
    return xmb, mmb


# def iter_apply(Xs, Ms, Ys):
#     # fns = [lambda x: np.concatenate(x, 0), lambda x: float(np.sum(x))]
#     logits = []
#     cost = 0
#     with torch.no_grad():
#         dh_model.eval()
#         for xmb, mmb, ymb in iter_data(Xs, Ms, Ys, n_batch=n_batch_train, truncate=False, verbose=True):
#             n = len(xmb)
#             XMB = torch.tensor(xmb, dtype=torch.long).to(device)
#             YMB = torch.tensor(ymb, dtype=torch.long).to(device)
#             MMB = torch.tensor(mmb).to(device)
#             lm_logits, clf_logits = dh_model(XMB)
#             clf_logits *= n
#             clf_losses = compute_loss_fct(XMB, YMB, MMB, clf_logits, only_return_losses=True)
#             clf_losses *= n
#             logits.append(clf_logits.to("cpu").numpy())
#             cost += clf_losses.sum().item()
#         logits = np.concatenate(logits, 0)
#     return logits, cost


def log(save_dir, desc):
    global best_score
    print("Logging")
#     tr_logits, tr_cost = iter_apply(trX[:n_valid], trM[:n_valid], trY[:n_valid])
#     va_logits, va_cost = iter_apply(vaX, vaM, vaY)
#     tr_cost = tr_cost / len(trY[:n_valid])
#     va_cost = va_cost / n_valid
#     tr_acc = accuracy_score(trY[:n_valid], np.argmax(tr_logits, 1)) * 100.
#     va_acc = accuracy_score(vaY, np.argmax(va_logits, 1)) * 100.
    logger.log(n_epochs=n_epochs, n_updates=n_updates)#, tr_cost=tr_cost, va_cost=va_cost, tr_acc=tr_acc, va_acc=va_acc)
    print('%d %d %.3f %.3f %.2f %.2f' % (n_epochs, n_updates))#, tr_cost, va_cost, tr_acc, va_acc))
#     if submit:
#         score = va_acc
#         if score > best_score:
#             best_score = score
#             path = os.path.join(save_dir, desc, 'best_params')
#             torch.save(dh_model.state_dict(), make_path(path))

def run_epoch():
    for xmb, mmb, ymb in iter_data(*shuffle(trX, trM, trYt, random_state=np.random),
                                   n_batch=n_batch_train, truncate=True, verbose=True):
        global n_updates
        dh_model.train()
        XMB = torch.tensor(xmb, dtype=torch.long).to(device)
        YMB = torch.tensor(ymb, dtype=torch.long).to(device)
        MMB = torch.tensor(mmb).to(device)
        lm_logits, _ = dh_model(XMB)
        compute_loss_fct(XMB, YMB, MMB, lm_logits)
        n_updates += 1
        if n_updates in [1000, 2000, 4000, 8000, 16000, 32000] and n_epochs == 0:
            log(save_dir, desc)



# Params

In [4]:

argmax = lambda x: np.argmax(x, 1)

pred_fns = {
    'rocstories': argmax,
}

filenames = {
    'rocstories': 'ROCStories.tsv',
}

label_decoders = {
    'rocstories': None,
}


In [207]:

parser = argparse.ArgumentParser()
parser.add_argument('--desc', type=str, help="Description")
parser.add_argument('--dataset', type=str)
parser.add_argument('--log_dir', type=str, default='log/')
parser.add_argument('--save_dir', type=str, default='save/')
parser.add_argument('--data_dir', type=str, default='data/')
parser.add_argument('--submission_dir', type=str, default='submission/')
parser.add_argument('--submit', action='store_true')
parser.add_argument('--analysis', action='store_true')
parser.add_argument('--seed', type=int, default=42)
parser.add_argument('--n_iter', type=int, default=3)
parser.add_argument('--n_batch', type=int, default=8)
parser.add_argument('--max_grad_norm', type=int, default=1)
parser.add_argument('--lr', type=float, default=6.25e-5)
parser.add_argument('--lr_warmup', type=float, default=0.002)
parser.add_argument('--n_ctx', type=int, default=512)
parser.add_argument('--n_embd', type=int, default=768)
parser.add_argument('--n_head', type=int, default=12)
parser.add_argument('--n_layer', type=int, default=12)
parser.add_argument('--embd_pdrop', type=float, default=0.1)
parser.add_argument('--attn_pdrop', type=float, default=0.1)
parser.add_argument('--resid_pdrop', type=float, default=0.1)
parser.add_argument('--clf_pdrop', type=float, default=0.1)
parser.add_argument('--l2', type=float, default=0.01)
parser.add_argument('--vector_l2', action='store_true')
parser.add_argument('--opt', type=str, default='adam')
parser.add_argument('--afn', type=str, default='gelu')
parser.add_argument('--lr_schedule', type=str, default='warmup_linear')
parser.add_argument('--encoder_path', type=str, default='model/encoder_bpe_40000.json')
parser.add_argument('--bpe_path', type=str, default='model/vocab_40000.bpe')
parser.add_argument('--n_transfer', type=int, default=12)
parser.add_argument('--lm_coef', type=float, default=0.5)
parser.add_argument('--b1', type=float, default=0.9)
parser.add_argument('--b2', type=float, default=0.999)
parser.add_argument('--e', type=float, default=1e-8)
parser.add_argument('--n_valid', type=int, default=374)


args = parser.parse_args('''
--dataset data/corpus/erotic_gutenberg.csv 
--n_batch 2 
--submit 
--n_iter 15
'''.replace('\n','').split(' '))
print(args)

Namespace(afn='gelu', analysis=False, attn_pdrop=0.1, b1=0.9, b2=0.999, bpe_path='model/vocab_40000.bpe', clf_pdrop=0.1, data_dir='data/', dataset='data/corpus/erotic_gutenberg.csv', desc=None, e=1e-08, embd_pdrop=0.1, encoder_path='model/encoder_bpe_40000.json', l2=0.01, lm_coef=0.5, log_dir='log/', lr=6.25e-05, lr_schedule='warmup_linear', lr_warmup=0.002, max_grad_norm=1, n_batch=2, n_ctx=512, n_embd=768, n_head=12, n_iter=3, n_layer=12, n_transfer=12, n_valid=374, opt='adam', resid_pdrop=0.1, save_dir='save/', seed=42, submission_dir='submission/', submit=True, vector_l2=False)


# Init

In [208]:
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)

# Constants
submit = args.submit
dataset = args.dataset
n_ctx = args.n_ctx
save_dir = args.save_dir
desc = args.desc
data_dir = args.data_dir
log_dir = args.log_dir
submission_dir = args.submission_dir

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
n_gpu = torch.cuda.device_count()
print("device", device, "n_gpu", n_gpu)

device cuda n_gpu 1


In [210]:
logger = ResultLogger(
    path=os.path.join(log_dir, '{}.jsonl'.format(desc)), **args.__dict__)

# bpe tokenizer BYTE PAIR ENCODING https://en.wikipedia.org/wiki/Byte_pair_encoding
# this is compression where we replace frequent pairs with unused byte codes
text_encoder = TextEncoder(args.encoder_path, args.bpe_path)
encoder = text_encoder.encoder
n_vocab = len(text_encoder.encoder)

# Data

In [211]:
print("Encoding dataset...")
((trX1, trX2, trX3, trY), (vaX1, vaX2, vaX3, vaY),
 (teX1, teX2, teX3)) = encode_dataset(
     *rocstories(data_dir, n_valid=args.n_valid), encoder=text_encoder)

encoder['_start_'] = len(encoder)
encoder['_delimiter_'] = len(encoder)
encoder['_classify_'] = len(encoder)
clf_token = encoder['_classify_']
n_special = 3
max_len = n_ctx // 2 - 2

n_ctx = min(
    max([
        len(x1[:max_len]) + max(len(x2[:max_len]), len(x3[:max_len]))
        for x1, x2, x3 in zip(trX1, trX2, trX3)
    ] + [
        len(x1[:max_len]) + max(len(x2[:max_len]), len(x3[:max_len]))
        for x1, x2, x3 in zip(vaX1, vaX2, vaX3)
    ] + [
        len(x1[:max_len]) + max(len(x2[:max_len]), len(x3[:max_len]))
        for x1, x2, x3 in zip(teX1, teX2, teX3)
    ]) + 3, n_ctx)
vocab = n_vocab + n_special + n_ctx

trX, trM = transform_roc(trX1, trX2, trX3)
vaX, vaM = transform_roc(vaX1, vaX2, vaX3)

if submit:
    teX, teM = transform_roc(teX1, teX2, teX3)

n_train = len(trY)
n_valid = len(vaY)
n_batch_train = args.n_batch * max(n_gpu, 1)
n_updates_total = (n_train // n_batch_train) * args.n_iter
dict(n_train=n_train, n_valid=n_valid, n_vocab=n_vocab, n_batch_train=n_batch_train, n_updates_total=n_updates_total)

  0%|                                                  | 0/2402 [00:00<?, ?it/s]

Encoding dataset...


                                                                                

{'n_batch_train': 2,
 'n_train': 2402,
 'n_updates_total': 3603,
 'n_valid': 374,
 'n_vocab': 40478}

In [212]:
trX.shape

(2402, 2, 349, 2)

# Model, loss, opt

In [11]:
# model
dh_model = DoubleHeadModel(args, clf_token, 'multiple_choice', vocab, n_ctx)

# loss, optimizer
criterion = nn.CrossEntropyLoss(reduce=False)
model_opt = OpenAIAdam(
    dh_model.parameters(),
    lr=args.lr,
    schedule=args.lr_schedule,
    warmup=args.lr_warmup,
    t_total=n_updates_total,
    b1=args.b1,
    b2=args.b2,
    e=args.e,
    l2=args.l2,
    vector_l2=args.vector_l2,
    max_grad_norm=args.max_grad_norm)
compute_loss_fct = LMLossCompute(criterion, model_opt)
## move up?
# load pretrained model
load_openai_pretrained_model(
    dh_model.transformer, n_ctx=n_ctx, n_special=n_special)

dh_model.to(device)
dh_model = nn.DataParallel(dh_model)




Loading weights...


In [12]:
n_updates = 0
n_epochs = 0
if dataset != 'stsb':
    trYt = trY

# save params
if submit:
    path = os.path.join(save_dir, desc, 'best_params')
    torch.save(dh_model.state_dict(), make_path(path))

# Run

In [13]:
# run
best_score = 0
for i in range(args.n_iter):
    print("running epoch", i)
    run_epoch()
    n_epochs += 1
    log(save_dir, desc)

  0%|                                                  | 0/1201 [00:00<?, ?it/s]

running epoch 0


 83%|█████████████████████████████████▎      | 999/1201 [07:47<01:35,  2.12it/s]
  0%|                                                   | 0/187 [00:00<?, ?it/s][A
  1%|▏                                          | 1/187 [00:00<00:21,  8.55it/s][A

Logging



  1%|▍                                          | 2/187 [00:00<00:22,  8.33it/s][A
  2%|▋                                          | 3/187 [00:00<00:22,  8.25it/s][A
  2%|▉                                          | 4/187 [00:00<00:22,  8.12it/s][A
  3%|█▏                                         | 5/187 [00:00<00:22,  8.11it/s][A
  3%|█▍                                         | 6/187 [00:00<00:22,  8.04it/s][A
  4%|█▌                                         | 7/187 [00:00<00:22,  8.01it/s][A
  4%|█▊                                         | 8/187 [00:00<00:22,  8.05it/s][A
  5%|██                                         | 9/187 [00:01<00:21,  8.15it/s][A
  5%|██▏                                       | 10/187 [00:01<00:21,  8.22it/s][A
  6%|██▍                                       | 11/187 [00:01<00:21,  8.26it/s][A
  6%|██▋                                       | 12/187 [00:01<00:21,  8.30it/s][A
  7%|██▉                                       | 13/187 [00:01<00:20,  8.38

 53%|██████████████████████▏                   | 99/187 [00:12<00:10,  8.22it/s][A
 53%|█████████████████████▉                   | 100/187 [00:12<00:10,  8.29it/s][A
 54%|██████████████████████▏                  | 101/187 [00:12<00:10,  8.27it/s][A
 55%|██████████████████████▎                  | 102/187 [00:12<00:10,  8.36it/s][A
 55%|██████████████████████▌                  | 103/187 [00:12<00:10,  8.32it/s][A
 56%|██████████████████████▊                  | 104/187 [00:12<00:09,  8.30it/s][A
 56%|███████████████████████                  | 105/187 [00:12<00:09,  8.29it/s][A
 57%|███████████████████████▏                 | 106/187 [00:13<00:09,  8.34it/s][A
 57%|███████████████████████▍                 | 107/187 [00:13<00:09,  8.10it/s][A
 58%|███████████████████████▋                 | 108/187 [00:13<00:09,  8.12it/s][A
 58%|███████████████████████▉                 | 109/187 [00:13<00:09,  8.14it/s][A
 59%|████████████████████████                 | 110/187 [00:13<00:09,  8.26i

  4%|█▌                                         | 7/187 [00:00<00:22,  8.17it/s][A
  4%|█▊                                         | 8/187 [00:00<00:21,  8.15it/s][A
  5%|██                                         | 9/187 [00:01<00:21,  8.28it/s][A
  5%|██▏                                       | 10/187 [00:01<00:21,  8.37it/s][A
  6%|██▍                                       | 11/187 [00:01<00:21,  8.25it/s][A
  6%|██▋                                       | 12/187 [00:01<00:21,  7.96it/s][A
  7%|██▉                                       | 13/187 [00:01<00:21,  8.00it/s][A
  7%|███▏                                      | 14/187 [00:01<00:21,  8.16it/s][A
  8%|███▎                                      | 15/187 [00:01<00:20,  8.25it/s][A
  9%|███▌                                      | 16/187 [00:01<00:20,  8.15it/s][A
  9%|███▊                                      | 17/187 [00:02<00:20,  8.12it/s][A
 10%|████                                      | 18/187 [00:02<00:20,  8.05i

 56%|██████████████████████▊                  | 104/187 [00:12<00:09,  8.44it/s][A
 56%|███████████████████████                  | 105/187 [00:12<00:09,  8.48it/s][A
 57%|███████████████████████▏                 | 106/187 [00:12<00:09,  8.41it/s][A
 57%|███████████████████████▍                 | 107/187 [00:12<00:09,  8.42it/s][A
 58%|███████████████████████▋                 | 108/187 [00:13<00:09,  8.43it/s][A
 58%|███████████████████████▉                 | 109/187 [00:13<00:09,  8.44it/s][A
 59%|████████████████████████                 | 110/187 [00:13<00:09,  8.28it/s][A
 59%|████████████████████████▎                | 111/187 [00:13<00:09,  7.99it/s][A
 60%|████████████████████████▌                | 112/187 [00:13<00:09,  7.82it/s][A
 60%|████████████████████████▊                | 113/187 [00:13<00:09,  7.79it/s][A
 61%|████████████████████████▉                | 114/187 [00:13<00:09,  7.99it/s][A
 61%|█████████████████████████▏               | 115/187 [00:13<00:09,  7.95i

0 1000 0.000 0.000 100.00 100.00


  1%|▏                                          | 1/187 [00:00<00:21,  8.57it/s]

Logging


  0%|                                                  | 0/1201 [00:00<?, ?it/s]

1 1201 0.000 0.000 100.00 100.00
running epoch 1


  1%|▏                                          | 1/187 [00:00<00:21,  8.62it/s]

Logging


  0%|                                                  | 0/1201 [00:00<?, ?it/s]

2 2402 0.000 0.000 100.00 100.00
running epoch 2


  1%|▏                                          | 1/187 [00:00<00:21,  8.58it/s]

Logging


                                                                                

3 3603 0.000 0.000 100.00 100.00




# Test

In [213]:

def iter_predict(Xs, Ms):
    logits = []
    with torch.no_grad():
        dh_model.eval()
        for xmb, mmb in iter_data(Xs, Ms, n_batch=n_batch_train, truncate=False, verbose=True):
            n = len(xmb)
            XMB = torch.tensor(xmb, dtype=torch.long).to(device)
#             MMB = torch.tensor(mmb).to(device)
            lm_logits, _ = dh_model(XMB)
#             print(lm_logits)
#             return lm_logits.to("cpu").numpy()
            logits.append(lm_logits.to("cpu").numpy())
#     logits = np.concatenate(logits, 0)
    logits = np.stack(logits, 0)
    return logits



In [214]:
lm_logits = iter_predict(teX[:10], teM[:10])
lm_logits.shape

                                                                                

(5, 1392, 40830)

In [233]:
encoder['_start_'], encoder['_delimiter_'], clf_token

(40478, 40479, 40480)

In [243]:
convert = [    
    ['</w><unk>', '<unk>'],
    ['.<unk>', '<unk>'],
    ['"<unk>', '<unk>'],
    ['<unk></w>', '<unk>'],
    ['<unk>.', '<unk>'],
    ['<unk>"', '<unk>'],
    ['<unk>', ''],
    ['</w>,', ','],
    ["</w>'", "'"],
    ['</w>.', '.'],
    ['</w>', ' '],
    ['"</w>', ''],
    ['_delimiter_', '\n'],
#     ['_classify_', '\n'],
    ['_start_', '\n']
    
]

In [244]:
decoder = {v:k for k,v in text_encoder.encoder.items()}
temperature = 0.5 # 1 is quite random, 0 is the most likely letter every time
for batch in range(probs.shape[0]):
    probs=np_softmax(lm_logits[batch], t=temperature) # softmax with temperature

    dist = torch.distributions.Multinomial(probs=torch.from_numpy(probs)) # make distribution
    y_encoded = dist.sample().argmax(-1).numpy() # sample
    y_text = [decoder.get(i, "<unk:{}>".format(i)) for i in y_encoded] # decode
    y_string = '\n'+''.join(y_text) # join into text

    # clean up tokens
    y_string_raw = str(y_string)
    for a, b in convert:
        y_string = y_string.replace(a, b)

    print(y_string)


"., a same, a his silk he with evcast the trees forest he the h and to he is the moral's into way.. 
1 _classify_1 " " ,;, ..,. ., .. . .,. ,,,,,,,,,, ., a same, a his blue she with same lengthened the chapel, he the h, to he, the man's into the.. 
. _classify_. 2 19.. ., ., .. ;. .  . , . .   ,,,
,;,,,... 
1 _classify_1 _classify__classify__,,,,a... 
. _classify_

" 
 

" object and is _ the emerson the _, which has us fuller resemblance impression of the love of is has envy tolerance have caused about the race race, is no yet therto been written described. elaborated. to the original. it has to convey an more description of the the has existed and decent thinminded people beings have done offer in. notions about prejudices. order with the prejudices gnance. but see here poem of the very who lost so " and his young desire strong man to the weak, a - bitter, selfish, and acious, a mother ments of innocent children and his of his children ; the was a and cold, he did him but felt, were

IndexError: index 5 is out of bounds for axis 0 with size 5