In [1]:
import re
import numpy as np
import pandas as pd
from tqdm import tqdm
from datasets import Dataset
import mlflow
import torch
from transformers import Trainer, TrainingArguments
from transformer_lens import utils, HookedTransformer
from transformers import (
    GPT2Tokenizer, 
    GPT2LMHeadModel, 
    DataCollatorForLanguageModeling, 
    Trainer, 
    TrainingArguments
)
import mlflow
from transformers.integrations import MLflowCallback

from create_dataset import make_rows_from_chat

file_name = "artifacts/input_text.txt"

with open(file_name, "r") as file_read:
    chat = file_read.read()

chat_rows = make_rows_from_chat(chat)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
def make_generator(data, tokenizer) :
    for x in data : 
        tokens = tokenizer(x, truncation=True, padding=True)['input_ids']
        for i in range(1,len(tokens), 1) :
            x, y = tokens[:i], tokens[i]
            yield {"text":tokenizer.decode(x), "target":tokenizer.decode(y)}
ds = Dataset.from_generator(make_generator,gen_kwargs={'data':chat_rows[-3000:],
'tokenizer':tokenizer})
ds

Dataset({
    features: ['text', 'target'],
    num_rows: 62125
})

In [3]:
# Initialize the model
model = GPT2LMHeadModel.from_pretrained('gpt2')

for param in model.parameters():
    param.requires_grad = False

for param, name in zip(model.parameters(),model.named_parameters()) :
    if not re.match("^transformer.h.11.",name[0]) :
        continue
    param.requires_grad = True

input_text = "Hello, would you mind"
input_ids = tokenizer.encode(input_text, return_tensors='pt')
def next_token(input_text):
    input_ids = tokenizer.encode(input_text, return_tensors='pt')
    with torch.no_grad():
        outputs = model(input_ids)

    # Extract logits for the last token
    logits = outputs.logits
    last_token_logits = logits[0, -1, :]

    # Apply softmax to get probabilities
    probabilities = torch.softmax(last_token_logits, dim=-1)

    # Find the predicted token id
    predicted_token_id = torch.argmax(probabilities).item()

    # Decode the predicted token to get the word
    predicted_token = tokenizer.decode([predicted_token_id])
    return predicted_token

print(input_text)
for i in range(15):
    add = next_token(input_text)
    input_text += add
    print(input_text)

Hello, would you mind
Hello, would you mind if
Hello, would you mind if I
Hello, would you mind if I could
Hello, would you mind if I could take
Hello, would you mind if I could take a
Hello, would you mind if I could take a moment
Hello, would you mind if I could take a moment to
Hello, would you mind if I could take a moment to explain
Hello, would you mind if I could take a moment to explain to
Hello, would you mind if I could take a moment to explain to you
Hello, would you mind if I could take a moment to explain to you how
Hello, would you mind if I could take a moment to explain to you how I
Hello, would you mind if I could take a moment to explain to you how I got
Hello, would you mind if I could take a moment to explain to you how I got here
Hello, would you mind if I could take a moment to explain to you how I got here?


In [9]:
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

def tokenize_function(examples):
    return tokenizer(examples['text'], truncation=True, padding=True, return_tensors='pt')

tokenized_datasets = ds.map(tokenize_function, batched=True)

split_ds = tokenized_datasets.train_test_split(test_size=0.1)
train_ds = split_ds["train"]
test_ds = split_ds["test"]

model.to("mps")

training_args = TrainingArguments(
    output_dir='./results',
    overwrite_output_dir=True,
    num_train_epochs=5,
    per_device_train_batch_size=8,  # Adjust batch size as needed
    save_steps=10_000,
    save_total_limit=2,
    learning_rate=0.001,
    logging_dir='./logs',
    logging_steps=200,
    report_to="none"  # Disable reporting to W&B
)

# Set up MLflow
mlflow.set_tracking_uri("./outputs")  # Replace with your tracking URI
mlflow.set_experiment("gpt2_fine_tuning")

# Initialize Trainer with MLflow callback
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    data_collator=data_collator,
    callbacks=[MLflowCallback()],
)

# Start training
with mlflow.start_run():
    trainer.train()
    eval_metrics = trainer.evaluate(eval_dataset=test_ds)
    for key, value in eval_metrics.items():
        mlflow.log_metric(key, value)
    mlflow.pyfunc.log_model(model)

  0%|          | 10/17475 [1:48:35<3161:06:24, 651.59s/it]
  0%|          | 0/34945 [1:17:01<?, ?it/s]
  1%|          | 200/34945 [26:45<678:57:20, 70.35s/it]  

{'loss': 4.9893, 'grad_norm': 0.6418070197105408, 'learning_rate': 0.0009942767205608815, 'epoch': 0.03}


  1%|          | 400/34945 [1:33:09<21:10:35,  2.21s/it]   

{'loss': 4.0992, 'grad_norm': 0.9630653858184814, 'learning_rate': 0.000988553441121763, 'epoch': 0.06}


  2%|▏         | 600/34945 [2:41:10<38:41:33,  4.06s/it]   

{'loss': 4.2117, 'grad_norm': nan, 'learning_rate': 0.0009828301616826441, 'epoch': 0.09}


  2%|▏         | 800/34945 [2:51:25<27:51:54,  2.94s/it]

{'loss': 0.0, 'grad_norm': nan, 'learning_rate': 0.0009771068822435256, 'epoch': 0.11}


  2%|▏         | 862/34945 [3:22:02<148:56:08, 15.73s/it]  

KeyboardInterrupt: 

In [8]:
import os
os.environ["PYTORCH_MPS_HIGH_WATERMARK_RATIO"]="0.1"

In [11]:
# see how it changed
model.to("cpu")
input_text = "Hello, would you mind"
print(input_text)
for i in range(15):
    add = next_token(input_text)
    input_text += add
    print(input_text)

Hello, would you mind
Hello, would you mind!
Hello, would you mind!!
Hello, would you mind!!!
Hello, would you mind!!!!
Hello, would you mind!!!!!
Hello, would you mind!!!!!!
Hello, would you mind!!!!!!!
Hello, would you mind!!!!!!!!
Hello, would you mind!!!!!!!!!
Hello, would you mind!!!!!!!!!!
Hello, would you mind!!!!!!!!!!!
Hello, would you mind!!!!!!!!!!!!
Hello, would you mind!!!!!!!!!!!!!
Hello, would you mind!!!!!!!!!!!!!!
Hello, would you mind!!!!!!!!!!!!!!!
