In [12]:
import torch
from datasets import load_dataset
from transformers import TextDataset, Trainer, TrainingArguments, DataCollatorForLanguageModeling,GPT2Tokenizer, GPT2LMHeadModel
import os, sys, json
from nanoid import generate
import string

In [13]:
# Load the dataset
dataset = load_dataset("declip/Minecraft-Server-Chat")

"""
DatasetDict({
    train: Dataset({
        features: ['content', 'date', 'username'],
        num_rows: 2664797
    })
})
"""
# should print above
print(dataset)
dataset = dataset['train']

# # Limit to the first 10,000 rows for initial testing
dataset = dataset.select(range(10000))
print(dataset)

DatasetDict({
    train: Dataset({
        features: ['content', 'date', 'username'],
        num_rows: 2664797
    })
})
Dataset({
    features: ['content', 'date', 'username'],
    num_rows: 10000
})


In [14]:
"""clean the data
    - make content max 100 len and strip whitespace
    - remove empty content
    - anonymize usernames
    - 
"""
username_map = {}
def get_anon_username(username):
    if username not in username_map:
        username_map[username] = f"P{generate(string.digits, 5)}"
    return username_map[username]
def clean_data(data):
    data['content'] = data['content'][:100].rstrip()
    if not data['content']:
        return None
    data['username'] = get_anon_username(data['username'])
    return data


In [15]:
# # Function to format the dataset
def format_dataset(example):
    example = clean_data(example)
    if not example:
        return None
    return f"{example['username']}: {example['content']}"


# Apply the formatting function to each row
formatted_dataset = dataset.map(lambda x: {'text': format_dataset(x)}).filter(lambda x: "text" in x and x["text"] is not None)

# Concatenate all messages into a single text block, separated by newlines
all_text = "\n".join(formatted_dataset['text'])

In [16]:
# write all_text to a data/minecraft_chat.txt file
os.makedirs("data", exist_ok=True)
with open("data/minecraft_chat.txt", "w") as f:
    f.write(all_text)

In [None]:
raise Exception("Done")

In [6]:
# Initialize GPT-2 model
print("Initializing model...")
model = GPT2LMHeadModel.from_pretrained("gpt2")
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

Initializing model...


In [7]:

# Create data collator
print("Creating data collator...")
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=False,
)


Creating data collator...


In [8]:

# Create dataset
print("Creating dataset...")
train_dataset = TextDataset(
        tokenizer=tokenizer,
        file_path="data/minecraft_chat.txt",
        block_size=1024
    )


Creating dataset...




In [9]:
# Define training arguments
print("Defining training arguments...")
training_args = TrainingArguments(
    output_dir="./data/gpt2-minecraft",
    overwrite_output_dir=True,
    num_train_epochs=3,
    per_device_train_batch_size=128,
    # save_steps=10_000,
    # save_total_limit=2,
    prediction_loss_only=True,
    logging_dir="./data/logs",
    learning_rate=2e-4,
)

# Initialize the Trainer
print("Initializing trainer...")
trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=train_dataset,
)
# Train the model
print("Training the model...")
trainer.train()

Defining training arguments...
Initializing trainer...
Training the model...


  0%|          | 0/3 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
# Function to generate a response
def generate_response(prompt):
    inputs = tokenizer.encode(prompt, return_tensors='pt').to(model.device)
    # use beam search to generate the response
    outputs = model.generate(inputs, max_length=150, num_return_sequences=1, num_beams=5, do_sample=True, temperature=0.7, top_k=50, top_p=0.95, pad_token_id=tokenizer.eos_token_id)
        
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return response

# Example usage
prompt = f"{get_anon_username("Player1")}: Hey, how do you build a house?\n"
response = generate_response(prompt)
print(response)