In [None]:
import pandas as pd
import numpy as np
import random
import torch
import wandb
import csv
import re
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import GPT2LMHeadModel, GPT2Tokenizer
from transformers import GPT2Tokenizer, GPT2LMHeadModel, AdamW, get_linear_schedule_with_warmup
from tqdm import tqdm, trange

In [None]:
# Initialize WandB
wandb.init(project='gpt-2', entity='tomislav-krog')

In [None]:
df = pd.read_csv('/kaggle/input/lyrics/dataset.csv')
df = df[(df['Artist'] == 'alicia-keys')]

df.head()

In [None]:
len(df)

In [None]:
#Drop the songs with lyrics too long
df = df[df['Lyric'].apply(lambda x: len(x.split(' ')) < 450)]

In [None]:
len(df)

In [None]:
def preprocess_lyrics(lyrics):
    # Removing bracketed text
    pattern = r'\[.*?\]'
    lyrics = re.sub(pattern, '', lyrics)

    # Removing newline symbols
    lyrics = re.sub('\n', '', lyrics)
    
    # Removing specific parenthesized text
    pattern = r'\((chorus|CHORUS|verse|VERSE|intro|INTRO)(.*?)\)'
    lyrics = re.sub(pattern, '', lyrics)

    # Function to resolve special symbols
    def special_symbols_resolver(s):
        replacements = {
            'à': 'a', 'á': 'a', 'â': 'a', 'ã': 'a', 'ä': 'a',
            'ç': 'c',
            'ö': 'o',
            'ú': 'u', 'ü': 'u',
            'œ': 'oe',
            'Â': 'A',
            '‰': '', '™': '', '´': '', '·': '', '¦': '', '': '', '': '',
            '˜': '', '“': '', '†': '', '…': '', '′': '', '″': '', '�': '',
            'í': 'i', 'é': 'e', 'ï': 'i', 'ó': 'o', ';': ',', '‘': '\'', '’': '\'', ':': ',', 'е': 'e'
        }
        for symbol, replacement in replacements.items():
            s = s.replace(symbol, replacement)
        return s

    # Apply the special symbols resolver
    lyrics = special_symbols_resolver(lyrics)

    # Further cleaning
    replace_with_space = ['\u2005', '\u200b', '\u205f', '\xa0', '-']
    remove_list = ['\)', '\(', '–', '"', '”', '"', '\[.*\]', '.*\|.*', '—', '(VERSE)', '(CHORUS ONE)']

    for string in remove_list:
        lyrics = re.sub(string, '', lyrics)
    for string in replace_with_space:
        lyrics = re.sub(string, ' ', lyrics)

    return lyrics

In [None]:
df['Lyric'] = df['Lyric'].apply(preprocess_lyrics)

In [None]:
df.head()

In [None]:
test_set_size = int(len(df)*0.05)

# test set
test_set = df.sample(n = test_set_size)
df = df.loc[~df.index.isin(test_set.index)]
test_set = test_set.reset_index()
df = df.reset_index()

# keep last 30 words in a new column, then remove them from original column
test_set['TrueFinalLyric'] = test_set['Lyric'].str.split().str[-30:].apply(' '.join)
test_set['Lyric'] = test_set['Lyric'].str.split().str[:-30].apply(' '.join)

In [None]:
df

In [None]:
class SongLyrics(Dataset):  
    def __init__(self, control_code, truncate=False, gpt2_type="gpt2", max_length=1024):

        self.tokenizer = GPT2Tokenizer.from_pretrained(gpt2_type)
        self.lyrics = []

        for row in df['Lyric']:
          self.lyrics.append(torch.tensor(
                self.tokenizer.encode(f"<|{control_code}|>{row[:max_length]}<|endoftext|>")
            ))               
        if truncate:
            self.lyrics = self.lyrics[:20000]
        self.lyrics_count = len(self.lyrics)
        
    def __len__(self):
        return self.lyrics_count

    def __getitem__(self, item):
        return self.lyrics[item]
    
dataset = SongLyrics(df['Lyric'], truncate=True, gpt2_type="gpt2")   

In [None]:
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model = GPT2LMHeadModel.from_pretrained('gpt2')

# accumulated batch size
def pack_tensor(new_tensor, packed_tensor, max_seq_len):
    if packed_tensor is None:
        return new_tensor, True, None
    if new_tensor.size()[1] + packed_tensor.size()[1] > max_seq_len:
        return packed_tensor, False, new_tensor
    else:
        packed_tensor = torch.cat([new_tensor, packed_tensor[:, 1:]], dim=1)
        return packed_tensor, True, None

In [None]:
def train(
    dataset, model, tokenizer,
    batch_size=16, epochs=15, lr=2e-5,
    max_seq_len=400, warmup_steps=200,
    gpt2_type="gpt2", output_dir=".", output_prefix="wreckgar",
    test_mode=False, save_model_on_epoch=False,
):

    acc_steps = 100
    device = torch.device("cuda")
    model = model.cuda()
    model.train()

    optimizer = AdamW(model.parameters(), lr=lr)
    scheduler = get_linear_schedule_with_warmup(
        optimizer, num_warmup_steps=warmup_steps, num_training_steps=-1
    )

    train_dataloader = DataLoader(dataset, batch_size=1, shuffle=True)
    total_loss = 0
    accumulating_batch_count = 0
    input_tensor = None

    for epoch in range(epochs):
        print(f"Training epoch {epoch}")

        epoch_loss = 0  # Reset loss for each epoch

        for idx, entry in tqdm(enumerate(train_dataloader)):
            (input_tensor, carry_on, remainder) = pack_tensor(entry, input_tensor, 768)

            if carry_on and idx != len(train_dataloader) - 1:
                continue

            input_tensor = input_tensor.to(device)
            outputs = model(input_tensor, labels=input_tensor)
            loss = outputs[0]
            loss.backward()

            epoch_loss += loss.item()  # Accumulate the loss

            if (accumulating_batch_count % batch_size) == 0:
                optimizer.step()
                scheduler.step()
                optimizer.zero_grad()
                model.zero_grad()

            accumulating_batch_count += 1
            input_tensor = None

        # Log the average loss per epoch
        wandb.log({"epoch": epoch, "loss": epoch_loss / len(train_dataloader)})

        if save_model_on_epoch:
            torch.save(
                model.state_dict(),
                os.path.join(output_dir, f"{output_prefix}-{epoch}.pt"),
            )

        print(f"Epoch {epoch} loss: {epoch_loss / len(train_dataloader)}")

    wandb.finish()

    return model

In [None]:
model = train(dataset, model, tokenizer)

In [None]:
def generate(
    model,
    tokenizer,
    prompt,
    entry_count=10,
    entry_length=30, #maximum number of words
    top_p=0.8,
    temperature=1.,
):
    model.eval()
    generated_num = 0
    generated_list = []

    filter_value = -float("Inf")

    with torch.no_grad():

        for entry_idx in trange(entry_count):

            entry_finished = False
            generated = torch.tensor(tokenizer.encode(prompt)).unsqueeze(0)

            for i in range(entry_length):
                outputs = model(generated, labels=generated)
                loss, logits = outputs[:2]
                logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0)

                sorted_logits, sorted_indices = torch.sort(logits, descending=True)
                cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

                sorted_indices_to_remove = cumulative_probs > top_p
                sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
                    ..., :-1
                ].clone()
                sorted_indices_to_remove[..., 0] = 0

                indices_to_remove = sorted_indices[sorted_indices_to_remove]
                logits[:, indices_to_remove] = filter_value

                next_token = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
                generated = torch.cat((generated, next_token), dim=1)

                if next_token in tokenizer.encode("<|endoftext|>"):
                    entry_finished = True

                if entry_finished:

                    generated_num = generated_num + 1

                    output_list = list(generated.squeeze().numpy())
                    output_text = tokenizer.decode(output_list)
                    generated_list.append(output_text)
                    break
            
            if not entry_finished:
              output_list = list(generated.squeeze().numpy())
              output_text = f"{tokenizer.decode(output_list)}<|endoftext|>" 
              generated_list.append(output_text)
                
    return generated_list

In [None]:
def text_generation(test_data):
  generated_lyrics = []
  for i in range(len(test_data)):
    x = generate(model.to('cpu'), tokenizer, test_data['Lyric'][i], entry_count=1)
    generated_lyrics.append(x)
  return generated_lyrics

In [None]:
# generate the lyrics
generated_lyrics = text_generation(test_set)

In [None]:
test_set['Lyric'][1]

In [None]:
test_set['TrueFinalLyric'][1]

In [None]:
generated_lyrics[1]