In [2]:
import torch
from datasets import load_dataset
from transformers import TextDataset, Trainer, TrainingArguments, DataCollatorForLanguageModeling, GPT2LMHeadModel, GPT2Tokenizer, GPT2LMHeadModel, AdamW
import os, sys, json
from nanoid import generate
import string
from torch.utils.data import Dataset, DataLoader, random_split
from torch.nn.utils.rnn import pad_sequence
from tqdm.notebook import tqdm

In [10]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
# Load the dataset
dataset = load_dataset("declip/Minecraft-Server-Chat")

"""
DatasetDict({
    train: Dataset({
        features: ['content', 'date', 'username'],
        num_rows: 2664797
    })
})
"""
# should print above
print(dataset)
dataset = dataset['train']

# # Limit to the first 10,000 rows for initial testing
dataset = dataset.select(range(10000))
print(dataset)

DatasetDict({
    train: Dataset({
        features: ['content', 'date', 'username'],
        num_rows: 2664797
    })
})
Dataset({
    features: ['content', 'date', 'username'],
    num_rows: 10000
})


In [4]:
"""clean the data
    - make content max 100 len and strip whitespace
    - remove empty content
    - anonymize usernames
    - 
"""
username_map = {}
def get_anon_username(username):
    if username not in username_map:
        username_map[username] = f"P{generate(string.digits, 5)}"
    return username_map[username]
def clean_data(data):
    data['content'] = data['content'][:100].rstrip()
    if not data['content']:
        return None
    data['username'] = get_anon_username(data['username'])
    return data


In [5]:
# # Function to format the dataset
def format_dataset(example):
    example = clean_data(example)
    if not example:
        return None
    return f"{example['username']}: {example['content']}"


# Apply the formatting function to each row
formatted_dataset = dataset.map(lambda x: {'text': format_dataset(x)}).filter(lambda x: "text" in x and x["text"] is not None)

# Concatenate all messages into a single text block, separated by newlines
all_text = "\n".join(formatted_dataset['text'])

In [6]:
# write all_text to a data/minecraft_chat.txt file
file_path = "data/minecraft_chat.txt"
os.makedirs("data", exist_ok=True)
with open(file_path, "w") as f:
    f.write(all_text)

In [19]:
# Initialize the tokenizer
tokenizer = GPT2Tokenizer.from_pretrained('distilgpt2')

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# Group lines until we reach a maximum of 1024 tokens
max_length = 1024
current_length = 0
current_group = []
grouped_lines = []

for line in tqdm(all_text.split("\n")):
    encoded_line = tokenizer.encode(line.strip())
    if current_length + len(encoded_line) > max_length:
        grouped_lines.append(current_group)
        current_group = encoded_line
        current_length = len(encoded_line)
    else:
        current_group.extend(encoded_line)
        current_length += len(encoded_line)

if current_group:
    grouped_lines.append(current_group)

# Custom Dataset
class MinecraftChatDataset(Dataset):
    def __init__(self, grouped_lines):
        self.grouped_lines = grouped_lines
    
    def __len__(self):
        return len(self.grouped_lines)
    
    def __getitem__(self, idx):
        return torch.tensor(self.grouped_lines[idx], dtype=torch.long, device=device)

# Create dataset and split into train and validation
dataset = MinecraftChatDataset(grouped_lines)
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

def collate_fn(batch):
    input_ids = pad_sequence(batch, batch_first=True, padding_value=tokenizer.pad_token_id).to(device)
    return input_ids, input_ids

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, collate_fn=collate_fn, generator=torch.Generator(device))
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False, collate_fn=collate_fn, generator=torch.Generator(device))

  0%|          | 0/10000 [00:00<?, ?it/s]

In [21]:
with torch.device("cuda"):
    # Load pre-trained GPT-2 model
    model = GPT2LMHeadModel.from_pretrained('distilgpt2')
    model.train()

    # Training parameters
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

    # Training loop
    for epoch in range(3):  # Adjust the number of epochs as needed
        for batch in tqdm(train_loader):
            optimizer.zero_grad()
            inputs, labels = batch
            outputs = model(inputs, labels=labels)
            loss = outputs.loss
            loss.backward()
            optimizer.step()
            print(f"Epoch: {epoch}, Loss: {loss.item()}")

        # Validation step (optional)
        model.eval()
        with torch.no_grad():
            val_loss = 0
            for batch in val_loader:
                inputs, labels = batch
                outputs = model(inputs, labels=labels)
                val_loss += outputs.loss.item()
            val_loss /= len(val_loader)
            print(f"Validation Loss: {val_loss}")
        model.train()

    # Save the fine-tuned model
    model.save_pretrained('fine_tuned_gpt2_minecraft')
    tokenizer.save_pretrained('fine_tuned_gpt2_minecraft')



  0%|          | 0/9 [00:00<?, ?it/s]