# Important Note:
1. While the batched version of all generation functions are runnable (and more efficient), since they're using pre-sequence padding, you may get worse result from using them. We recommend you to set `batch_size` as 1 if you would like to run the code without any further modification / optimization.
2. The rhyming-constrained functions are also runnable, but they're much less efficient than simply generating limericks without constraints and filtering the limericks in post-processing, also the rhyming words are limited by the words provided by `pronouncing` package.

In [1]:
# Start by installing required libraries (mainly Transformers)
!pip install transformers==4.17.0
!pip install scikit-learn
!pip install hydra-core
!pip install pronouncing

Defaulting to user installation because normal site-packages is not writeable
Defaulting to user installation because normal site-packages is not writeable
Defaulting to user installation because normal site-packages is not writeable
Defaulting to user installation because normal site-packages is not writeable


In [None]:
# Only needed when running in colab
from google.colab import drive
drive.mount("/content/drive/", force_remount=True)

In [None]:
!git clone https://{your_own_token}@github.com/coderalo/11785-automatic-poetry-generation.git

In [2]:
import copy
import glob
import json
import math
import numpy as np
import os
import pronouncing
import random
import shutil
import string as string_utils
import sys
import tempfile
import torch
import torch.optim as optim
import tqdm.notebook as tqdm
import yaml

from hydra import compose
from hydra import initialize_config_dir
from omegaconf import OmegaConf
from torch.utils.data import Dataset, DataLoader
from transformers import AutoModelForCausalLM
from transformers import GPT2LMHeadModel
from transformers import GPT2Model
from transformers import GPT2Tokenizer

In [3]:
%load_ext autoreload
%autoreload 2

# sys.path.append("/content/11785-automatic-poetry-generation/")

# from src.dataset import merge_lines, reorder, reverse_line
# from src.dataset import LimerickDataset
# from src.utils import load_dataset, get_tokenizer

from dataset import merge_lines, reorder, reverse_line
from dataset import LimerickDataset
from utils import load_dataset, get_tokenizer

In [4]:
def get_input_ids(
        prompt,
        tokenizer,
        use_bos,
        reverse,
        add_line_token
):
    """
    Arguments:
        prompt: str
        tokenizer: the tokenizer used to generate tokens
        use_bos: bool, use <BOS> token as the beginning of the prompt or not
        reverse: bool, revert the word order or not
        add_line_token: bool, add the <LINE> token at the end of prompt or not
    Return:
        input_ids: torch.LongTensor
    """
    prompt = prompt.strip()
    if add_line_token:
        if prompt != "" and prompt[-6:] != "<LINE>":
            prompt += " <LINE>"
    if use_bos and prompt[:5] != "<BOS>":
        prompt = "<BOS> " + prompt

    if reverse is True:
        input_ids = reverse_line(
            input_ids=tokenizer(prompt, return_tensors="np").input_ids[0],
            use_bos=use_bos,
            tokenizer=tokenizer,
            reverse_last_line=True)
        input_ids = torch.tensor(input_ids).reshape(1, -1)
    else:
        input_ids = tokenizer(prompt, return_tensors="pt").input_ids

    return input_ids

In [5]:
def batch_decode(
        outputs,
        tokenizer,
        use_bos,
        reverse,
        reverse_last_line
):
    """
    Arguments:
        outputs: List of torch.LongTensor
        tokenizer: the tokenizer used to decode tokens to words
        use_bos: bool, whether the <BOS> token is used or not
        reverse: bool, whether the tokens are in reverse order or not
    """
    if reverse is True:
        reversed = []
        for output in outputs:
            output = torch.tensor(
                reverse_line(
                    input_ids=output.cpu().numpy(),
                    use_bos=use_bos,
                    tokenizer=tokenizer,
                    reverse_last_line=reverse_last_line)
                ).reshape(-1)
            reversed.append(output)
        outputs = torch.stack(reversed)
    else:
        outputs = torch.stack(outputs)

    outputs = tokenizer.batch_decode(outputs.cpu(), skip_special_tokens=False)

    return outputs

In [6]:
def count_lines(prompt):
    return len(prompt.strip().split("<LINE>")) - 1


def lengths_to_mask(lengths, dtype, device, position="pos"):
    max_len = lengths.max().item()
    if position == "pos":
        mask = torch.arange(
            max_len,
            dtype=lengths.dtype,
            device=lengths.device)
        mask = mask.expand(len(lengths), max_len)
        mask = (mask < lengths.unsqueeze(1))
    else:
        mask = torch.arange(
            max_len - 1, -1, -1,
            dtype=lengths.dtype,
            device=lengths.device)
        mask = mask.expand(len(lengths), max_len)
        mask = (mask < lengths.unsqueeze(1))

    mask = mask.clone().detach()
    mask = mask.to(dtype=dtype, device=device)
    
    return mask

In [7]:
def generate_lines(
        model,
        tokenizer,
        config,
        prompts,
        generate_params,
        num_generation,
        batch_size,
        add_line_token
):
    """
    Generate / finish one line of the limerick. The prompts should be in the 
    correct word order (you don't need to revert the words before passing into
    the function)
    """
    use_bos = config.data.use_bos
    reverse = config.data.reverse
    order = config.data.order

    """
    Step 1:
        concat the input ids into a large tensor; notice that the prompts
        are in variable lengths, thus we need to pad **before** the prompt,
        and generate the attention mask accordingly
    """
    full_input_ids = []
    num_lines = []
    for prompt in prompts:
        num_lines = count_lines(prompt)
        input_ids = get_input_ids(
            prompt=prompt,
            tokenizer=tokenizer,
            use_bos=use_bos,
            reverse=reverse,
            add_line_token=add_line_token)
        input_ids = input_ids.repeat(num_generation, 1)
        full_input_ids.append(input_ids)

    # generate attention mask
    lengths = []
    for input_ids in full_input_ids:
        lengths += [input_ids.shape[1]] * input_ids.shape[0]
    lengths = torch.tensor(lengths, dtype=torch.long)
    full_attention_mask = lengths_to_mask(lengths, torch.long, "cpu", "pre")

    # pad the input ids
    max_seq_len = max([input_ids.shape[1] for input_ids in full_input_ids])
    full_input_ids = [
        torch.cat([
            torch.full(
                (input_ids.shape[0], max_seq_len - input_ids.shape[1]),
                fill_value=tokenizer.eos_token_id, dtype=torch.long
            ),
            input_ids
        ], dim=1)
        for input_ids in full_input_ids]
    full_input_ids = torch.cat(full_input_ids, dim=0)

    num_batches = math.ceil(full_input_ids.shape[0] / batch_size)

    # assume that a line cannot be longer than 30 tokens
    tmp_params = copy.deepcopy(generate_params)
    if "max_length" in tmp_params:
        tmp_params.pop("max_length")
    tmp_params["max_new_tokens"] = 30

    # Step 2: pass the batch into model to get generation output
    outputs = []
    for i in range(num_batches):
    # for i in tqdm.trange(num_batches, leave=False):
        input_ids = full_input_ids[i * batch_size: (i + 1) * batch_size]
        input_ids = input_ids.to(device=config.device)
        attention_mask = \
            full_attention_mask[i * batch_size: (i + 1) * batch_size]
        attention_mask = attention_mask.to(device=config.device)
        with torch.no_grad():
            output = model.generate(
                input_ids, **tmp_params,
                attention_mask=attention_mask,
                pad_token_id=tokenizer.eos_token_id)
            output = torch.unbind(output)
            outputs.extend(output)
    
    # Step 3: convert the generation result back to strings
    outputs = batch_decode(
        outputs=outputs,
        tokenizer=tokenizer,
        use_bos=use_bos,
        reverse=reverse,
        reverse_last_line=False)

    clean_outputs = []
    for output in outputs:
        new_num_lines = count_lines(output)
        if new_num_lines < num_lines + 1:
            continue
        output = output.strip().split(" <LINE> ")[:num_lines + 1]
        output = " <LINE> ".join(output) + " <LINE>"
        # clean up the prepended tokens
        output = output.replace("<|endoftext|>", "").strip()
        clean_outputs.append(output)
  
    return clean_outputs

In [8]:
def generate_new_lines(
        model,
        tokenizer,
        config,
        prompts,
        generate_params,
        num_generation,
        batch_size
):
    return generate_lines(
        model=model,
        tokenizer=tokenizer,
        config=config,
        prompts=prompts,
        generate_params=generate_params,
        num_generation=num_generation,
        batch_size=batch_size,
        add_line_token=True)
    

def finish_lines(
        model,
        tokenizer,
        config,
        prompts,
        generate_params,
        num_generation,
        batch_size
):
    return generate_lines(
        model=model,
        tokenizer=tokenizer,
        config=config,
        prompts=prompts,
        generate_params=generate_params,
        num_generation=num_generation,
        batch_size=batch_size,
        add_line_token=False)

In [9]:
def generate_limericks(
        model,
        tokenizer,
        config,
        prompts,
        generate_params,
        num_generation=10,
        batch_size=1,
        add_line_token=True,
):
    use_bos = config.data.use_bos
    reverse = config.data.reverse
    order = config.data.order

    """
    Step 1:
        concat the input ids into a large tensor; notice that the prompts
        are in variable lengths, thus we need to pad **before** the prompts,
        and generate the attention mask accordingly
    """
    full_input_ids = []
    num_lines = []
    for prompt in prompts:
        num_lines = count_lines(prompt)
        input_ids = get_input_ids(
            prompt=prompt,
            tokenizer=tokenizer,
            use_bos=use_bos,
            reverse=reverse,
            add_line_token=add_line_token)
        input_ids = input_ids.repeat(num_generation, 1)
        full_input_ids.append(input_ids)

    # generate attention mask
    lengths = []
    for input_ids in full_input_ids:
        lengths += [input_ids.shape[1]] * input_ids.shape[0]
    lengths = torch.tensor(lengths, dtype=torch.long)
    full_attention_mask = lengths_to_mask(lengths, torch.long, "cpu", "pre")

    # pad the input ids
    max_seq_len = max([input_ids.shape[1] for input_ids in full_input_ids])
    full_input_ids = [
        torch.cat([
            torch.full(
                (input_ids.shape[0], max_seq_len - input_ids.shape[1]),
                fill_value=tokenizer.eos_token_id, dtype=torch.long
            ),
            input_ids
        ], dim=1)
        for input_ids in full_input_ids]
    full_input_ids = torch.cat(full_input_ids, dim=0)

    num_batches = math.ceil(full_input_ids.shape[0] / batch_size)

    # Step 2: pass the batch into model to get generation output
    outputs = []
    for i in range(num_batches):
    # for i in tqdm.trange(num_batches, leave=False):
        input_ids = full_input_ids[i * batch_size: (i + 1) * batch_size]
        input_ids = input_ids.to(device=config.device)
        attention_mask = \
            full_attention_mask[i * batch_size: (i + 1) * batch_size]
        attention_mask = attention_mask.to(device=config.device)
        with torch.no_grad():
            output = model.generate(
                input_ids, **generate_params,
                attention_mask=attention_mask,
                pad_token_id=tokenizer.eos_token_id)
            output = torch.unbind(output)
            outputs.extend(output)

    # Step 3: convert the generation result back to strings
    outputs = batch_decode(
        outputs=outputs,
        tokenizer=tokenizer,
        use_bos=use_bos,
        reverse=reverse,
        reverse_last_line=False)
    clean_outputs = []

    for output in outputs:
        new_num_lines = count_lines(output)
        if new_num_lines < 5:
            continue
        output = output.strip().split(" <LINE> ")[:5]
        output = " <LINE> ".join(output) + " <LINE>"
        # clean up the prepended tokens
        output = output.replace("<|endoftext|>", "").strip()
        clean_outputs.append(output)

    return clean_outputs

In [10]:
def generate_limericks_two_stage(
        standard_lm,
        reverse_lm,
        standard_tokenizer,
        reverse_tokenizer,
        standard_config,
        reverse_config,
        prompts,
        generate_params,
        num_generation_1=10,
        num_generation_2=1,
        batch_size=64,
):

    first_lines = finish_lines(
        model=standard_lm,
        tokenizer=standard_tokenizer,
        config=standard_config,
        prompts=prompts,
        generate_params=generate_params,
        num_generation=num_generation_1,
        batch_size=batch_size)

    limericks = generate_limericks(
        model=reverse_lm,
        tokenizer=reverse_tokenizer,
        config=reverse_config,
        prompts=first_lines,
        generate_params=generate_params,
        num_generation=num_generation_2,
        batch_size=batch_size)

    return limericks

In [11]:
def get_last_words(prompt):
    prompt = prompt.split(' ')
    
    words = []
    for i, word in enumerate(prompt):
        if word == "<LINE>":
            words.append(prompt[i - 1])

    return words


def get_current_rhymes(prompt, tokenizer, allow_repetition=False):
    num_lines = count_lines(prompt)
    words = get_last_words(prompt)

    try:
        if num_lines in [0, 2]:  # first A or first B
            return [], []
        elif num_lines in [1, 4]:  # 2nd and 3rd A in AABBA
            if num_lines == 1:
                words = [words[0]]
            else:
                words = [words[0], words[1]]
        elif num_lines == 3:
            words = [words[2]]
    except Exception:
        words = []
        rhyme_tokens, rhymes = [], []
        return rhyme_tokens, rhymes

    rhymes = set()
    for word in words:
        rhymes.update(pronouncing.rhymes(word))
    if not allow_repetition:
        for word in words:
            if word in rhymes:
                rhymes.remove(word)
    rhymes = list(rhymes)

    if rhymes != []:
        rhyme_tokens = [
            rhyme[::-1] for rhyme in tokenizer(rhymes)['input_ids']]
    else:
        rhyme_tokens = []

    return rhyme_tokens, rhymes

In [12]:
def pad_tokens(tokens, tokenizer, max_len):
    padded_tokens = [
        tokens_ + [tokenizer.pad_token_id] * (max_len - len(tokens_))
        for tokens_ in tokens]
    attention_mask = [
        [1.] * len(tokens_) + [0.] * (max_len - len(tokens_))
        for tokens_ in tokens]

    padded_tokens = torch.tensor(padded_tokens, dtype=torch.long)
    attention_mask = torch.tensor(attention_mask, dtype=torch.float)

    return padded_tokens, attention_mask

In [13]:
def get_rhyming_word_score(
        reverse_lm,
        tokenizer,
        config,
        prompts,
        rhymes,
        temperature,
        batch_size=64
):
    """
    Step 1: 
        generate input ids for each prompts (not concatenated now)
        also collect the max rhyme (tokens) len for next step
    """
    lengths, max_rhyme_len = [], 0
    input_ids_list = []
    for prompt, rhymes_ in zip(prompts, rhymes):
        input_ids = get_input_ids(
            prompt=prompt,
            tokenizer=tokenizer,
            use_bos=config.data.use_bos,
            reverse=True,
            add_line_token=True)
        
        # [l_0, ..., l_0, l_1, ..., l_1, ...]
        lengths.extend([input_ids.shape[1]] * len(rhymes_))
        input_ids = input_ids.repeat(len(rhymes_), 1)
        input_ids_list.append(input_ids)
 
        rhyme_len = max([len(rhyme) for rhyme in rhymes_])
        max_rhyme_len = max(max_rhyme_len, rhyme_len)

    """
    Step 2:
        generate input ids for each rhyme word list to concat with prompts
        the attention mask is generated to calculate the scores later
    """
    padded_rhymes_list = []
    rhyme_masks = []
    for rhymes_ in rhymes:
        padded_rhymes, attention_mask = \
            pad_tokens(rhymes_, tokenizer, max_rhyme_len)
        padded_rhymes_list.append(padded_rhymes)
        rhyme_masks.append(attention_mask)

    padded_rhymes = torch.cat(padded_rhymes_list, dim=0)
    rhyme_masks = torch.cat(rhyme_masks, dim=0)

    """
    Step 3:
        concat the input ids of prompts with rhyme words
        also need to pad them to the same length for batching
    """
    input_ids_list = [
        torch.cat([input_ids, padded_rhymes], dim=1)
        for input_ids, padded_rhymes in
        zip(input_ids_list, padded_rhymes_list)]

    max_seq_len = max([input_ids.shape[1] for input_ids in input_ids_list])
    input_ids_list = [
        torch.cat(
            [
                input_ids,
                torch.full(
                    (input_ids.shape[0], max_seq_len - input_ids.shape[1]),
                    fill_value=tokenizer.pad_token_id,
                    dtype=torch.long, device="cpu")
            ], dim=1)
        for input_ids in input_ids_list]

    full_input_ids = torch.cat(input_ids_list, dim=0)
    num_examples = full_input_ids.shape[0]
    num_batches = math.ceil(num_examples / batch_size)

    lengths = torch.tensor(lengths, dtype=torch.long)
    total_lengths = lengths + max_rhyme_len
    attention_masks = lengths_to_mask(total_lengths, torch.float, "cpu")

    """
    Step 4:
        pass the batches into model to get logits, which then are converted
        into log probs and aggregated to get the final scores
    """
    full_scores = []
    for i in tqdm.trange(num_batches, leave=False):
        input_ids = full_input_ids[i * batch_size: (i + 1) * batch_size]
        attention_mask = attention_masks[i * batch_size: (i + 1) * batch_size]
        input_ids = input_ids.to(device=config.device)
        attention_mask = attention_mask.to(device=config.device)
  
        batch_lengths = lengths[i * batch_size: (i + 1) * batch_size]
        batch_padded_rhymes = \
            padded_rhymes[i * batch_size: (i + 1) * batch_size]
        batch_rhyme_masks = rhyme_masks[i * batch_size: (i + 1) * batch_size]

        batch_padded_rhymes = batch_padded_rhymes.to(device=config.device)
        batch_rhyme_masks = batch_rhyme_masks.to(device=config.device)
        
        with torch.no_grad():
            logits = reverse_lm(
                input_ids=input_ids,
                attention_mask=attention_mask)['logits']

            # [batch_size, max_rhyme_len]
            offsets = (torch.arange(0, input_ids.shape[0]) * max_seq_len)
            offsets = offsets.reshape(-1, 1).repeat(1, max_rhyme_len)
            indices = (offsets + batch_lengths.reshape(-1, 1)).reshape(-1)
            indices = indices.to(device=config.device)

            # [batch_size * max_seq_len, vocab_size]
            logits = logits.reshape(-1, logits.shape[-1])
            # [batch_size * max_rhyme_len, vocab_size]
            logits = torch.index_select(logits, 0, indices)
            # [batch_size, max_rhyme_len, vocab_size]
            logits = logits.reshape(input_ids.shape[0], max_rhyme_len, -1)

            log_probs = F.softmax(logits, -1)
            # [batch_size, max_rhyme_len]
            scores = torch.gather(
                log_probs, 2,
                batch_padded_rhymes.unsqueeze(2)).squeeze()
            scores = torch.sum(scores * batch_rhyme_masks, dim=1)
            scores = scores.cpu().numpy()

            full_scores.append(scores)

    scores = np.concatenate(full_scores, axis=0)

    """
    Step 5:
        split the final results back into array for each prompt
    """
    probs_list, anchor = [], 0
    for rhymes_ in rhymes:
        probs = scores[anchor: anchor + len(rhymes_)]
        probs /= np.sum(probs)
        probs_list.append(probs)
        anchor += len(rhymes_)

    return probs_list

In [14]:
def attach_next_rhyming_word(
        reverse_lm,
        tokenizer,
        config,
        prompts,
        num_samples,
        weighted,
        temperature=None,
        batch_size=64
):
    prompts_with_next_word = [None for _ in prompts]
    prompts_with_rhymes, prompts_without_rhymes = [], []
    for idx, prompt in enumerate(prompts):
        tokens, words = get_current_rhymes(prompt, tokenizer)
        if tokens != []:
            prompts_with_rhymes.append([idx, prompt, tokens, words])
        else:
            prompts_without_rhymes.append([idx, prompt])

    if weighted and prompts_with_rhymes != []:
        probs_list = get_rhyming_word_score(
            reverse_lm=reverse_lm,
            tokenizer=tokenizer,
            config=config,
            prompts=[p[1] for p in prompts_with_rhymes],
            rhymes=[p[2] for p in prompts_with_rhymes],
            temperature=(1.0 if temperature is None else temperature),
            batch_size=batch_size)
        torch.cuda.empty_cache()
    else:
        probs_list = [
            np.ones(len(p[3])) / len(p[3])
            for p in prompts_with_rhymes]

    for prompt_info, probs in zip(prompts_with_rhymes, probs_list):
        idx, prompt, _, words = prompt_info
        samples = np.random.choice(len(words), num_samples, p=probs)
        prompts_with_next_word[prompt_info[0]] = \
            [f"{prompt} {words[s]}" for s in samples]

    for idx, prompt in prompts_without_rhymes:
        prompts_with_next_word[idx] = [prompt] * num_samples

    prompts_with_next_word = [
        prompt for prompts in prompts_with_next_word
        for prompt in prompts]

    return prompts_with_next_word

In [15]:
def generate_limericks_with_rhyming(
        reverse_lm,
        tokenizer,
        config,
        prompts,
        generate_params,
        weighted,
        num_generation=10,
        batch_size=10
):
    
    limericks = []
    prompt = ""

    prompts = generate_new_lines(
        model=reverse_lm,
        tokenizer=tokenizer,
        config=config,
        prompts=prompts,
        generate_params=generate_params,
        num_generation=num_generation,
        batch_size=batch_size)
  
    for prompt in prompts:
        print(prompt)
    
    for _ in range(4):
        new_prompts = attach_next_rhyming_word(
            reverse_lm=reverse_lm,
            tokenizer=tokenizer,
            config=config,
            prompts=prompts,
            num_samples=1,
            weighted=weighted,
            temperature=1.0)
        prompts = finish_lines(
            model=reverse_lm,
            tokenizer=tokenizer,
            config=config,
            prompts=new_prompts,
            generate_params=generate_params,
            num_generation=1,
            batch_size=batch_size)
        
        for prompt in prompts:
            print(prompt)
        
    return prompts

In [16]:
def generate_limericks_two_stage_with_rhyming(
        standard_lm,
        reverse_lm,
        standard_tokenizer,
        reverse_tokenizer,
        standard_config,
        reverse_config,
        prompts,
        generate_params,
        weighted,
        num_generation_1=10,
        num_generation_2=1,
        batch_size=1,
):
    lines = finish_lines(
        model=standard_lm,
        tokenizer=standard_tokenizer,
        config=standard_config,
        prompts=prompts,
        generate_params=generate_params,
        num_generation=num_generation_1,
        batch_size=batch_size)
    
    for _ in range(4):
        lines = attach_next_rhyming_word(
            reverse_lm=reverse_lm,
            tokenizer=reverse_tokenizer,
            config=reverse_config,
            prompts=lines,
            num_samples=1,
            weighted=weighted,
            temperature=1.0)
        lines = finish_lines(
            model=reverse_lm,
            tokenizer=reverse_tokenizer,
            config=reverse_config,
            prompts=lines,
            generate_params=generate_params,
            num_generation=1,
            batch_size=batch_size)

    return lines

In [21]:
def load_model(exp_dir, tmp_root="/scratch/sthilaga/GPT2_Poem_Generator/content/test/"):
    config = OmegaConf.create(yaml.safe_load(open(exp_dir + "/config.yaml")))
    tokenizer = GPT2Tokenizer.from_pretrained(f"{exp_dir}/tokenizer")

    if not os.path.exists(tmp_root):
        print("Path does not exist! :(")
        # os.makedirs(tmp_root, exist_ok=True)
    tmp_dir = tempfile.mkdtemp(dir=tmp_root)
    states = torch.load(f"{exp_dir}/best-model.ckpt")
    
    model = GPT2LMHeadModel.from_pretrained("gpt2")
    model.resize_token_embeddings(len(tokenizer))
    model = model.cuda()
    model.load_state_dict(states['model_state_dict'])
    model.save_pretrained(tmp_dir)
    new_model = AutoModelForCausalLM.from_pretrained(tmp_dir)
    new_model = new_model.cuda()

    return config, tokenizer, new_model

## Example of one-stage generation

In [22]:
exp_dir = f"/scratch/sthilaga/GPT2_Poem_Generator/config/reverse-gpt2"
config, tokenizer, model = load_model(exp_dir)

In [23]:
generate_params = {
    "do_sample": True,
    "max_length": 100,
}

results = []
for _ in range(100):
    results.append(
        generate_limericks(
            model,
            tokenizer,
            config,
            [""],
            generate_params,
            num_generation=1,
            batch_size=1,
            add_line_token=True)[0])

for res in results:
    print(res)

<BOS> who's living in paris or moma <LINE> she's determined to live on a chammy <LINE> and so does his life <LINE> the needs of his wife <LINE> she says he sleeps out through the mammy <LINE>
<BOS> in the '60s are all et vie-cee <LINE> and we marched in, with much too much glee <LINE> early time that we fought <LINE> where they once looted and burned <LINE> to the tealas we fled; we were free <LINE>
<BOS> it's plural (a word that is called a) <LINE> aorto-ment? a kind of aawa <LINE> by the french u.k. <LINE> all the experts agree <LINE> not a home vacation in nova <LINE>
<BOS> here's an ogerman, auger? right for it <LINE> he's some better, no longer fight for it <LINE> just be sure, for a guy <LINE> an officer's eye <LINE> for that. that's why he will fight for it <LINE>
<BOS> i was probed, on the head of this leafy <LINE> that on earth. and this new christian theory <LINE> and a true young ahexis <LINE> have purged on my poshysys <LINE> don't need not be duped. i be wary <LINE>
<BOS> 

## Example of two-stage generation

In [None]:
standard_exp_dir = "/content/drive/MyDrive/11-785-final/ckpt/bos-gpt2"
reverse_exp_dir = "/content/drive/MyDrive/11-785-final/ckpt/reverse-bos-gpt2"

standard_config, standard_tokenizer, standard_model = \
    load_model(standard_exp_dir)
reverse_config, reverse_tokenizer, reverse_model = \
    load_model(reverse_exp_dir)

In [None]:
generate_params = {
    "do_sample": True,
    "max_length": 100,
}

results = []
for _ in range(100):
    results.append(
        generate_limericks_two_stage(
            standard_model,
            reverse_model,
            standard_config,
            reverse_config,
            standard_tokenizer,
            reverse_tokenizer,
            [""],
            generate_params=generate_params,
            num_generation_1=1,
            num_generation_2=1,
            batch_size=1)[0])

for res in results:
    print(res)