In [None]:
import pandas as pd
from transformers import GPT2LMHeadModel, GPT2Tokenizer, TextDataset, DataCollatorForLanguageModeling
from transformers import TrainingArguments, Trainer

# Reading the CSV file from Kaggle dataset.
df = pd.read_csv('./bob dylan corpus.csv')

# Display the first 5 rows.
df.head()

In [None]:
# Combine all lyrics into a single string, and remove extra new lines.
all_lyrics = "\n".join(df['lyrics'].dropna())
all_lyrics = all_lyrics.replace("\n\n","\n")

# Check the first 500 characters to confirm output is correct
all_lyrics[:500]

In [None]:
# Write the final cleaned text to a file for next step.
with open("all_lyrics.txt", "w") as text_file:
    text_file.write(all_lyrics)

In [None]:
# Initiate the GPT-2 pre-trained model, plus the tokenizer
model_name = "gpt2-medium"
model = GPT2LMHeadModel.from_pretrained(model_name)
tokenizer = GPT2Tokenizer.from_pretrained(model_name)

# Tokenize the lyrics and prepare dataset

# We'll set up the dataset through the tokenizer, referring to
# the file we just wrote as the basis.
train_dataset = TextDataset(
    tokenizer=tokenizer,
    file_path="./all_lyrics.txt",  # Save the all_lyrics string to a file and provide its path here
    block_size=128
)
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=False
)

# Set up training arguments; these can be modified depending on
# available architecture.
training_args = TrainingArguments(
    output_dir="./results",
    overwrite_output_dir=True,
    num_train_epochs=1,
    per_device_train_batch_size=32,
    save_steps=10_000,
    save_total_limit=2,
)

# Initiate the Trainer function and start training!
trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=train_dataset,
)

trainer.train()

In [None]:
# Use a Spongebob joke to start writing lyrics
input_text = "In a cosmic sort of way"
input_ids = tokenizer.encode(input_text, return_tensors='pt')

output = model.generate(input_ids, max_length=100, num_return_sequences=5, temperature=0.9, do_sample=True)

for i, text in enumerate(output):
    print(f"Generated Text {i+1}: {tokenizer.decode(text)}")
    print()