Language model code from:
    https://github.com/rodgzilla/pytorch-openai-transformer-lm/blob/horoscope_language_model

In [1]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2

# Imports

In [2]:
import os
import pandas as pd
import pdb
import argparse
import itertools
import datetime

import numpy as np

from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
import torch.nn.functional as F

from model_pytorch import TransformerModel, LMHead, load_openai_pretrained_model, DEFAULT_CONFIG
from model_pytorch import LanguageModel
from utils import encode_dataset, flatten, iter_data, ResultLogger, make_path
from text_utils import TextEncoder
from opt import OpenAIAdam
from loss import LanguageModelingLossCompute

In [3]:

n_updates  = 0
best_score = 0

# Helpers

# Data

In [4]:

def _chunk_word_list(word_list, max_sequence_len = 50000):
    # We have to split the text into text of 100.000 characters
    # because of the parser limitations.
    word_sequences    = [[]]
    last_sequence_len = 0
    for word in word_list:
        # If the last word list has reached the maximum size
        if last_sequence_len + len(word) > max_sequence_len:
            # We transform it into a string by rejoining the words
            word_sequences[-1] = ' '.join(word_sequences[-1])
            # and then begin a new word sequence
            word_sequences.append([])
            last_sequence_len = 0
        word_sequences[-1].append(word)
        last_sequence_len += len(word)

    if type(word_sequences[-1]) == list:
        word_sequences[-1] = ' '.join(word_sequences[-1])

    return word_sequences

def load_dataset(text_encoder, window_size, path = 'data/erotic_gutenberg_dataset.csv',
                 shuffle = True, seed = 142857,
                 test_size = 0.2):
    df             = pd.read_csv(path)
    all_text       = ' '.join(df.TEXT)
    word_list      = all_text.split(' ')
    word_sequences = _chunk_word_list(word_list, )
    encoded_text   = text_encoder.encode(word_sequences)
    word_idx_list  = list(itertools.chain.from_iterable(encoded_text))
    context_list   = []
    target_list    = []

    for start_idx in range(len(word_idx_list) - window_size - 1):
        context_list.append(word_idx_list[start_idx : start_idx + window_size])
        target_list.append(word_idx_list[start_idx + window_size])

    X_train, X_val, y_train, y_val = train_test_split(
        context_list,
        target_list,
        test_size    = test_size,
        shuffle      = shuffle,
        random_state = seed
    )
    return (X_train, y_train), (X_val, y_val)

def transform_dataset(dataset, encoder, max_len, n_vocab, n_special, n_ctx):
    n_batch   = len(dataset)
    xmb       = np.zeros((n_batch, n_ctx, 2), dtype = np.int32)
    mmb       = np.zeros((n_batch, n_ctx), dtype = np.float32)
    start     = encoder.encoder['_start_']
    clf_token = encoder.encoder['_classify_']
    for i, x in enumerate(dataset):
        x_with_tokens   = [start] + x[:max_len] + [clf_token]
        l_x             = len(x_with_tokens)
        xmb[i, :l_x, 0] = x_with_tokens
        mmb[i, :l_x]    = 1
    xmb[:, :, 1] = np.arange(n_vocab + n_special, n_vocab + n_special + n_ctx)

    return xmb, mmb



## Data

In [5]:

def iter_apply(model, n_batch_train, device, compute_loss_fct, Xs, Ms, Ys, return_logits = True):
    if return_logits:
        logits = []
    cost = 0
    with torch.no_grad():
        model.eval()
        for xmb, mmb, ymb in iter_data(Xs, Ms, Ys, n_batch=n_batch_train, truncate=False, verbose=True):
            n = len(xmb)
            XMB = torch.tensor(xmb, dtype=torch.long).to(device)
            YMB = torch.tensor(ymb, dtype=torch.long).to(device)
            MMB = torch.tensor(mmb).to(device)
            lm_logits = model(XMB)
            lm_logits *= n
            lm_losses = compute_loss_fct(XMB, YMB, MMB, lm_logits, only_return_losses=True)
            lm_losses *= n
            if return_logits:
                logits.append(lm_logits.to("cpu").numpy())
            cost += lm_losses.sum().item()

    if return_logits:
        logits = np.concatenate(logits, 0)
        return logits, cost

    return cost


In [6]:

def decode_word(text_encoder, idx):
    if idx not in text_encoder.decoder:
        return '<oov>'

    word = text_encoder.decoder[idx]

    return word[:-4] if word[-4:] == '</w>' else word

def decode_sentence(text_encoder, idx_list):
    word_list = [decode_word(text_encoder, idx) for idx in idx_list]

    # Fix some weird grammer, but not all
    replace = [
        ["' ", "'"],
        [" '", "'"],
        [" ,", ","],
        [" .", "."],
        [" i ", " I "],
        [" n't", "n't"],
        [" ?", "?"],
    ]
    results2 = ' '.join(word_list)
    for a,b in replace:
        results2 = results2.replace(a, b)

    return results2

def try_on_a_sentence(model, text_encoder, sentence, window_size,
                      n_vocab, n_special, n_ctx, device,
                      final_len = 200):
    model.eval()
    start_token  = text_encoder.encoder['_start_']
    clf_token    = text_encoder.encoder['_classify_']
    encoded_text = text_encoder.encode([sentence])[0]
    while len(encoded_text) < final_len:
        # We take the last 'window_size' words of the text being generated
        # and run it through the model.
        context         = encoded_text[-window_size:]
        X_trans, X_mask = transform_dataset(
            [context],
            text_encoder,
            window_size,
            n_vocab,
            n_special,
            n_ctx
        )
        XMB                = torch.tensor(X_trans, dtype = torch.long).to(device)
        lm_logits          = model(XMB)
        
        # We truncate the resulting predictions to actual vocabulary
        # words in order to exclude special tokens and positional
        # embeddings.
        lm_logits          = lm_logits[:, : n_vocab]
        
        # We then select the logit corresponding to the 'clf_token'
        # position (last one of the sequence).
        X_trans_tensor     = torch.from_numpy(X_trans)
        clf_token_bool_idx = X_trans_tensor[0, :, 0] == clf_token
        
        # probabilistic sample so we don't get into loops
        predictions = torch.distributions.Multinomial(logits=lm_logits).sample().argmax(dim = 1)
        pred               = predictions[clf_token_bool_idx[1:]].item()
        encoded_text.append(pred)

    return decode_sentence(text_encoder, encoded_text)



## Run

In [7]:

def run_epoch(model, n_batch_train, device, compute_loss_fct, logger,
              save_dir, desc, submit, n_valid, n_epochs, X_train,
              X_train_mask, y_train, X_val, X_val_mask, y_val,
              generation_params):
    for xmb, mmb, ymb in iter_data(X_train,
                                   X_train_mask,
                                   y_train,
                                   n_batch = n_batch_train,
                                   truncate=True,
                                   verbose=True):
        global n_updates
        model.train()
        XMB        = torch.tensor(xmb, dtype=torch.long).to(device)
        YMB        = torch.tensor(ymb, dtype=torch.long).to(device)
        MMB        = torch.tensor(mmb).to(device)
        lm_logits  = model(XMB)
        compute_loss_fct(XMB, YMB, MMB, lm_logits)
        if n_updates % 500 == 0:
            log(
                model,
                n_batch_train,
                device,
                compute_loss_fct,
                logger,
                save_dir,
                desc,
                submit,
                n_valid,
                n_epochs,
                n_updates,
                X_train,
                X_train_mask,
                y_train,
                X_val,
                X_val_mask,
                y_val,
                generation_params
            )
        n_updates += 1

def log(model, n_batch_train, device, compute_loss_fct, logger,
        save_dir, desc, submit, n_valid, n_epochs, n_updates, X_train,
        X_train_mask, y_train, X_val, X_val_mask, y_val,
        generation_params):
    global best_score
    result = try_on_a_sentence(**generation_params)
    print("\n\n Base: {} \n\n Result: {}".format(generation_params['sentence'], result))
    print("\nLogging")
    tr_cost = iter_apply(
        model,
        n_batch_train,
        device,
        compute_loss_fct,
        X_train[:n_valid],
        X_train_mask[:n_valid],
        y_train[:n_valid],
        False
    )
    va_cost = iter_apply(
        model,
        n_batch_train,
        device,
        compute_loss_fct,
        X_val,
        X_val_mask,
        y_val,
        False
    )
    tr_cost = tr_cost / len(y_train[:n_valid])
    va_cost = va_cost / n_valid
    logger.log(
        n_epochs  = n_epochs,
        n_updates = n_updates,
        tr_cost   = tr_cost,
        va_cost   = va_cost
    )
    print('\n%d %d %.3f %.3f' % (n_epochs, n_updates, tr_cost, va_cost))
    if submit:
        score = va_cost
        if score > best_score:
            best_score = score
            path = os.path.join(save_dir, desc, 'best_params')
            torch.save(model.state_dict(), make_path(path))

# Params

In [8]:
# Training configuration
epochs                             = 3
n_batch_train                      = 12
window_size                        = 128
max_len                            = window_size
# General configuration
save_dir                           = 'save/'
log_dir                            = 'log/'
desc                               = 'erotic_gutenberg'
submit                             = True
args                               = DEFAULT_CONFIG
logger                             = ResultLogger(
    path = os.path.join(
        log_dir,
        '{}.jsonl'.format(desc)
    ),
    **args.__dict__
)
device                             = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
bpe_path                           = 'model/vocab_40000.bpe'
encoder_path                       = 'model/encoder_bpe_40000.json'
data_path                          = 'data/erotic_gutenberg_dataset.csv'
text_encoder                       = TextEncoder(encoder_path, bpe_path)
encoder                            = text_encoder.encoder
n_special                          = 2
n_vocab                            = len(encoder)
encoder['_start_']                 = len(encoder)
encoder['_classify_']              = len(encoder)
clf_token                          = encoder['_classify_']

n_ctx                              = window_size + n_special
total_vocab_size                   = n_vocab + n_special + n_ctx

args, dict(n_ctx=n_ctx, total_vocab_size=total_vocab_size, n_special=n_special)

({'afn': 'gelu',
  'attn_pdrop': 0.1,
  'clf_pdrop': 0.1,
  'embd_pdrop': 0.1,
  'n_embd': 768,
  'n_head': 12,
  'n_layer': 12,
  'resid_pdrop': 0.1},
 {'n_ctx': 130, 'n_special': 2, 'total_vocab_size': 40610})

# Dataset

In [9]:

(X_train, y_train), (X_val, y_val) = load_dataset(
    text_encoder,
    window_size = window_size,
    path        = data_path
)
n_train                     = len(y_train)
n_valid                     = len(y_val) // 10
n_updates_total             = (n_train // n_batch_train) * epochs

X_train_trans, X_train_mask = transform_dataset(
    X_train,
    text_encoder,
    window_size,
    n_vocab,
    n_special,
    n_ctx
)
X_val_trans, X_val_mask = transform_dataset(
    X_val,
    text_encoder,
    window_size,
    n_vocab,
    n_special,
    n_ctx
)
X_train_trans.shape, X_train_mask.shape

HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=151), HTML(value='')), layout=Layout(display=…



((1644206, 130, 2), (1644206, 130))

# Model

In [10]:
language_model = LanguageModel(
    args,
    vocab = total_vocab_size,
    n_ctx = n_ctx
)
load_openai_pretrained_model(
    language_model.transformer,
    n_ctx = n_ctx,
    n_special = n_special
)
language_model.to(device)
1

Loading weights...


1

In [11]:
save_path = 'model/{}.pkl'.format(desc)
if os.path.isfile(save_path):
    state_dict = torch.load(save_path)
    language_model.load_state_dict(state_dict)
    print('loaded', save_path)

# init opt, loss

In [12]:
model_opt = OpenAIAdam(
    params        = language_model.parameters(),
    lr            = 6.25e-5,
    schedule      = 'warmup_linear',
    warmup        = 0.002,
    t_total       = n_updates_total,
    b1            = 0.9,
    b2            = 0.999,
    e             = 1e-8,
    l2            = 0.01,
    vector_l2     = 'store_true',
    max_grad_norm = 1
)
criterion        = nn.CrossEntropyLoss(reduce = False)
compute_loss_fct = LanguageModelingLossCompute(
    lm_criterion = criterion,
    opt = model_opt
)

generation_parameters = {
    'model'        : language_model,
    'text_encoder' : text_encoder,
    'sentence'     : 'You had a great morning but your afternoon will be ruined because',
    'window_size'  : window_size,
    'n_vocab'      : n_vocab,
    'n_special'    : n_special,
    'n_ctx'        : n_ctx,
    'device'       : device,
    'final_len'    : 150
}



# run

In [None]:
for epoch in range(epochs):
    run_epoch(
        model             = language_model,
        n_batch_train     = n_batch_train,
        device            = device,
        compute_loss_fct  = compute_loss_fct,
        logger            = logger,
        save_dir          = save_dir,
        desc              = desc,
        submit            = submit,
        n_valid           = n_valid,
        n_epochs          = epoch,
        X_train           = X_train_trans,
        X_train_mask      = X_train_mask,
        y_train           = y_train,
        X_val             = X_val_trans,
        X_val_mask        = X_val_mask,
        y_val             = y_val,
        generation_params = generation_parameters
    )
    torch.save(language_model.state_dict(), save_path)
    
    ts = datetime.datetime.utcnow().strftime('%Y%m%d_%H-%M-%S')
    torch.save(language_model.state_dict(), save_path.replace('.pkl','_{}_{}.pkl'.format(epoc, ts)))

HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=137017), HTML(value='')), layout=Layout(displ…

HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=1), HTML(value='')), layout=Layout(display='i…



 Base: You had a great morning but your afternoon will be ruined because 

 Result: you had a great morning but your afternoon will be ruined because you don't give me exactly how I want it. get it? " 
 he stopped abruptly just short of a collision. " my entire evening will be ruined because of you. " 
 " merrill, wait... " 
 " the itch to get her away from you is growing ever since she went to bed last night and begged me to get you before night - that will never happen. " 
 " okay ! all right. " 
 gavin grinned. lissa snorted at that. hu ck didn't see it that way. 
 * * * 
 the dance was beautiful. I loved dancing with humans. xenides breathed deeply, savoring the intoxicating scent as he stalked kifirin. his opponent was older

Logging


HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=3425), HTML(value='')), layout=Layout(display…

In [None]:
sentences = [
    'i want you to want',
    'please help me',
    'let us run far away from',
    'rosy',
    'that unspeakable creature'
    'when can I see you',
    'I must',
    'gaze at your enhanting',
]

In [None]:
for sentence in sentences:
    generation_parameters['sentence'] = sentence
    result = try_on_a_sentence(**generation_parameters)
    print("\n\n Base: {} \n\n Result: {}".format(generation_parameters['sentence'], result))

# DEBUG check for produced string in source text

In [None]:
input_data = open(data_path).read().lower()

In [None]:
last_i = 1
while last_i>0:
    i = input_data[last_i+50:].index('their minds') + last_i+50
    print(input_data[i-10:i+50])
    last_i=i