Taken from this blog post: https://towardsdatascience.com/how-to-fine-tune-gpt-2-for-text-generation-ae2ea53bc272

In [3]:
!pip install transformers

import pandas as pd
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import numpy as np
import random
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import GPT2Tokenizer, GPT2LMHeadModel, AdamW, get_linear_schedule_with_warmup
from tqdm import tqdm, trange
import torch.nn.functional as F
import csv
import os

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [6]:
### Prepare data
lyrics = pd.read_csv('all_prompts.csv')

#Drop the songs with lyrics too long (after more than 1024 tokens, does not work)
df = lyrics[lyrics['prompt'].apply(lambda x: len(x.split(' ')) < 350)]

#Create a very small test set to compare generated text with the reality
test_set = df.sample(n = 200)
df = df.loc[~df.index.isin(test_set.index)]

#Reset the indexes
test_set = test_set.reset_index()
df = df.reset_index()

#For the test set only, keep last 20 words in a new column, then remove them from original column
"""test_set['True_end_lyrics'] = test_set['Lyric'].str.split().str[-20:].apply(' '.join)
test_set['Lyric'] = test_set['Lyric'].str.split().str[:-20].apply(' '.join)"""

"test_set['True_end_lyrics'] = test_set['Lyric'].str.split().str[-20:].apply(' '.join)\ntest_set['Lyric'] = test_set['Lyric'].str.split().str[:-20].apply(' '.join)"

In [7]:
class Prompts(Dataset):  
    def __init__(self, control_code, truncate=False, gpt2_type="gpt2", max_length=1024):

        self.tokenizer = GPT2Tokenizer.from_pretrained(gpt2_type)
        self.prompt = []

        for row in df['prompt']:
          self.prompt.append(torch.tensor(
                self.tokenizer.encode(f"<|{control_code}|>{row[:max_length]}<|endoftext|>")
            ))               
        if truncate:
            self.prompt = self.prompt[:20000]
        self.prompt_count = len(self.prompt)
        
    def __len__(self):
        return self.prompt_count

    def __getitem__(self, item):
        return self.prompt[item]
    
dataset = Prompts(df['prompt'], truncate=True, gpt2_type="gpt2")      

Downloading:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/665 [00:00<?, ?B/s]

In [8]:
#Get the tokenizer and model
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model = GPT2LMHeadModel.from_pretrained('gpt2')

#Accumulated batch size (since GPT2 is so big)
def pack_tensor(new_tensor, packed_tensor, max_seq_len):
    if packed_tensor is None:
        return new_tensor, True, None
    if new_tensor.size()[1] + packed_tensor.size()[1] > max_seq_len:
        return packed_tensor, False, new_tensor
    else:
        packed_tensor = torch.cat([new_tensor, packed_tensor[:, 1:]], dim=1)
        return packed_tensor, True, None

Downloading:   0%|          | 0.00/548M [00:00<?, ?B/s]

In [9]:
def train(
    dataset, model, tokenizer,
    batch_size=16, epochs=5, lr=2e-5,
    max_seq_len=400, warmup_steps=200,
    gpt2_type="gpt2", output_dir=".", output_prefix="wreckgar",
    test_mode=False,save_model_on_epoch=False,
):
    acc_steps = 100
    device=torch.device("cuda")
    model = model.cuda()
    model.train()

    optimizer = AdamW(model.parameters(), lr=lr)
    scheduler = get_linear_schedule_with_warmup(
        optimizer, num_warmup_steps=warmup_steps, num_training_steps=-1
    )

    train_dataloader = DataLoader(dataset, batch_size=1, shuffle=True)
    loss=0
    accumulating_batch_count = 0
    input_tensor = None

    for epoch in range(epochs):

        print(f"Training epoch {epoch}")
        print(loss)
        for idx, entry in tqdm(enumerate(train_dataloader)):
            (input_tensor, carry_on, remainder) = pack_tensor(entry, input_tensor, 768)

            if carry_on and idx != len(train_dataloader) - 1:
                continue

            input_tensor = input_tensor.to(device)
            outputs = model(input_tensor, labels=input_tensor)
            loss = outputs[0]
            loss.backward()

            if (accumulating_batch_count % batch_size) == 0:
                optimizer.step()
                scheduler.step()
                optimizer.zero_grad()
                model.zero_grad()

            accumulating_batch_count += 1
            input_tensor = None
        if save_model_on_epoch:
            torch.save(
                model.state_dict(),
                os.path.join(output_dir, f"{output_prefix}-{epoch}.pt"),
            )
    return model

In [10]:
model = train(dataset, model, tokenizer)



Training epoch 0
0


20000it [23:07, 14.41it/s]


Training epoch 1
tensor(0.1988, device='cuda:0', grad_fn=<NllLossBackward0>)


3767it [04:21, 14.43it/s]


KeyboardInterrupt: ignored

In [20]:
def generate(
    model,
    tokenizer,
    prompt,
    entry_count=10,
    entry_length=30, #maximum number of words
    top_p=0.8,
    temperature=1.,
):
    model.eval()
    generated_num = 0
    generated_list = []

    filter_value = -float("Inf")

    with torch.no_grad():

        for entry_idx in trange(entry_count):

            entry_finished = False
            generated = torch.tensor(tokenizer.encode(prompt)).unsqueeze(0)
            output_prompt = ""

            for i in range(entry_length):
                outputs = model(generated, labels=generated)
                loss, logits = outputs[:2]
                logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0)

                sorted_logits, sorted_indices = torch.sort(logits, descending=True)
                cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

                sorted_indices_to_remove = cumulative_probs > top_p
                sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
                    ..., :-1
                ].clone()
                sorted_indices_to_remove[..., 0] = 0

                indices_to_remove = sorted_indices[sorted_indices_to_remove]
                logits[:, indices_to_remove] = filter_value

                next_token = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
                generated = torch.cat((generated, next_token), dim=1)

                output_prompt = output_prompt + tokenizer.decode(next_token[0][0])

                if next_token in tokenizer.encode("<|endoftext|>"):
                    entry_finished = True

                if entry_finished:

                    generated_num = generated_num + 1

                    output_list = list(generated.squeeze().numpy())
                    output_text = tokenizer.decode(output_list)
                    generated_list.append(output_text)
                    break
                
            print(output_prompt)
            
            if not entry_finished:
              output_list = list(generated.squeeze().numpy())
              output_text = f"{tokenizer.decode(output_list)}<|endoftext|>" 
              generated_list.append(output_text)
                
    return generated_list

#Function to generate multiple sentences. Test data should be a dataframe
def text_generation(test_data):
  generated_lyrics = []
  for i in range(len(test_data)):
    x = generate(model.to('cpu'), tokenizer, test_data['prompt'][i], entry_count=1)
    generated_lyrics.append(x)
  return generated_lyrics

#Run the functions to generate the lyrics
generated_lyrics = text_generation(test_set)


100%|██████████| 1/1 [00:10<00:00, 10.34s/it]





The united kingdom was formed during the seventeenth century, after a brief reign by King James II of Scotland, although it soon developed into a


100%|██████████| 1/1 [00:06<00:00,  6.35s/it]



Question: The People work in the medical field.
Question: The People work in the medical field.
Question: The People work in the


100%|██████████| 1/1 [00:06<00:00,  6.39s/it]


 Neither
Question: I'm getting very tired of the school girls playing "fatigue". The ladies are always yelling and yelling at me and often get


100%|██████████| 1/1 [00:00<00:00,  1.89it/s]


<|endoftext|>


100%|██████████| 1/1 [00:00<00:00,  5.24it/s]


<|endoftext|>


100%|██████████| 1/1 [00:25<00:00, 25.06s/it]


 Yes, Princess Diana died from cancer in 1955.

Q: Was Jeff Bridges murdered by a drunk man?
A: Yes, Jeff Bridges


100%|██████████| 1/1 [00:06<00:00,  6.63s/it]



A Air Force family is an organization based in California. It has a name in Spanish, in English, in the USA, and in Britain.


100%|██████████| 1/1 [00:00<00:00,  5.04it/s]


<|endoftext|>


100%|██████████| 1/1 [00:00<00:00,  1.64it/s]


<|endoftext|>


100%|██████████| 1/1 [00:00<00:00,  1.81it/s]


<|endoftext|>


100%|██████████| 1/1 [00:03<00:00,  3.92s/it]



Question: The chef is preparing sushi. True, False, or Neither?
Answer:<|endoftext|>


100%|██████████| 1/1 [00:34<00:00, 34.16s/it]


 no

In the early 1980s, an attempt to re-create an original recording of the Montagu Republic folk ballad through Périn


100%|██████████| 1/1 [00:07<00:00,  7.36s/it]



Question: Five people are standing by a staircase. True, False, or Neither?
Answer:
Question: Five people are standing by a


100%|██████████| 1/1 [00:12<00:00, 12.99s/it]


 The last movie in the line is True's True-Tracy reprise, on which it appears on the DVD. In the movie, True does


100%|██████████| 1/1 [00:11<00:00, 11.78s/it]


 False, or Neither?
Question: The argument is clear to True, False, or Neither?
Answer: True, False, or Neither?


100%|██████████| 1/1 [00:00<00:00,  1.17it/s]


<|endoftext|>


100%|██████████| 1/1 [00:00<00:00,  2.14it/s]


<|endoftext|>


100%|██████████| 1/1 [00:00<00:00,  4.37it/s]


<|endoftext|>


100%|██████████| 1/1 [00:01<00:00,  1.20s/it]


<|endoftext|>


100%|██████████| 1/1 [00:15<00:00, 15.59s/it]


 the prosecution is required to prove that the crime was committed before you could be charged. Therefore, when they indict you, you must also prove that you


100%|██████████| 1/1 [00:06<00:00,  6.34s/it]


 Neither.
Question: A woman is approached by a man and he tells her that he is going to kill her. She denies the allegation.



100%|██████████| 1/1 [00:24<00:00, 24.57s/it]




Q: When did Israel begin killing Jews?
A:

Q: Which Jewish street name is more likely?
A: Streets


100%|██████████| 1/1 [00:00<00:00,  2.43it/s]


<|endoftext|>


100%|██████████| 1/1 [00:24<00:00, 24.93s/it]


 Yes, Muslims should be allowed to enter the US.

Q: Should women be allowed to wear veils on the basis of their religious beliefs


100%|██████████| 1/1 [00:06<00:00,  6.92s/it]


 "Answer" is a condition of being abstinent from sexual intercourse and there are also other rules and regulations that govern this sexual activity, which also apply


100%|██████████| 1/1 [00:36<00:00, 36.47s/it]


 yes

EMT-M in Iraq completed its 12 year training project at Ulan Karabakh and also completed its training project with the United


100%|██████████| 1/1 [00:00<00:00,  2.05it/s]


<|endoftext|>


100%|██████████| 1/1 [00:05<00:00,  5.81s/it]


 Aeschylus wrote The Persians. True or False?<|endoftext|>


100%|██████████| 1/1 [00:01<00:00,  1.45s/it]


<|endoftext|>


100%|██████████| 1/1 [00:00<00:00,  5.45it/s]


<|endoftext|>


  0%|          | 0/1 [00:00<?, ?it/s]


KeyboardInterrupt: ignored

In [None]:
print(generated_lyrics)

'I'