In [1]:
from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments, TrainerCallback
from datasets import load_dataset
import torch
import matplotlib.pyplot as plt
import time

In [2]:
# Step 1: Load the dataset
dataset = load_dataset("wikitext", "wikitext-2-raw-v1")
print(f"dataset shape : {dataset.shape}")

dataset shape : {'test': (4358, 1), 'train': (36718, 1), 'validation': (3760, 1)}


In [3]:
# Step 2: Load pre-trained GPT-2 model and tokenizer 
model_name = "gpt2"
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Tokenizer padding
tokenizer.pad_token = tokenizer.eos_token



In [4]:
# Test generation before fine-tuning
def generate_text_before(prompt, max_length=50):
    inputs = tokenizer(prompt, return_tensors="pt")
    outputs = model.generate(inputs["input_ids"], max_length=max_length, num_return_sequences=1)
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

In [5]:
prompt = "A PhD student is excited to join Huawei research team"
print("Before fine-tuning:")
print(generate_text_before(prompt))

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


Before fine-tuning:
A PhD student is excited to join Huawei research team and work with them on a new project.

"We are excited to be working with Huawei on a new project that will allow us to develop a new wireless technology that will enable us to deliver


In [6]:
# Step 3: Tokenize the text
def tokenize_function(examples):
    #1024 is set as the max_length to utilizes the full context window of GPT-2,
    #  which is better for understanding long sequences of text, however it uses much memory comparing to 512
    tokenized = tokenizer(examples["text"], padding="max_length", truncation=True, max_length=512)
    #Use input_ids as labels for training
    tokenized["labels"] = tokenized["input_ids"].copy()
    return tokenized

In [7]:
#For effeciency and simplicity the text columns is removed
tokenized_datasets = dataset.map(tokenize_function, batched=True, remove_columns=["text"])
tokenized_datasets.set_format("torch", columns=["input_ids", "attention_mask", "labels"])

In [8]:
# Custom callback to log losses
class LossLogger(TrainerCallback):
    def __init__(self):
        self.train_losses = []
        self.eval_losses = []

    def on_log(self, args, state, control, logs=None, **kwargs):
        if logs is not None:
            # Log training loss
            if 'loss' in logs:
                self.train_losses.append(logs['loss'])

    def on_evaluate(self, args, state, control, metrics=None, **kwargs):
        if metrics is not None:
            # Log evaluation loss
            if 'eval_loss' in metrics:
                self.eval_losses.append(metrics['eval_loss'])

# Initialize the loss logger
loss_logger = LossLogger()

In [9]:
# Step 5: Training arguments for fine-tuning
training_args = TrainingArguments(
    output_dir="./gpt2-finetuned",
    overwrite_output_dir=True,
    evaluation_strategy="steps",
    learning_rate=5e-5,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    num_train_epochs=3,#7
    logging_dir="./logs",
    logging_steps=50,
    eval_steps=50,
    save_total_limit=2,
    fp16=torch.cuda.is_available(),  # Enables mixed precision if GPU supports it
    save_steps=500
)



In [10]:
# Step 6: Initialize the Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"]
)

  self.scaler = torch.cuda.amp.GradScaler(**kwargs)


In [11]:
def print_gpu_memory_usage():
    if torch.cuda.is_available():
        print(f"Allocated: {torch.cuda.memory_allocated() / (1024 ** 2):.2f} MB")
        print(f"Cached: {torch.cuda.memory_reserved() / (1024 ** 2):.2f} MB")

In [12]:
# Step 7: Fine-tune the model

# Start timing
start_time = time.perf_counter()
print_gpu_memory_usage()
trainer.train()
print_gpu_memory_usage()
# End timing
end_time = time.perf_counter()

# Calculate training time
training_time = end_time - start_time
print(f"Training Time: {training_time:.2f} seconds")

Allocated: 487.47 MB
Cached: 542.00 MB


  0%|          | 0/27540 [00:00<?, ?it/s]

  attn_output = torch.nn.functional.scaled_dot_product_attention(


{'loss': 2.0776, 'grad_norm': 1.0430114269256592, 'learning_rate': 4.9923747276688455e-05, 'epoch': 0.01}


  0%|          | 0/940 [00:00<?, ?it/s]

{'eval_loss': 0.4658660590648651, 'eval_runtime': 294.258, 'eval_samples_per_second': 12.778, 'eval_steps_per_second': 3.194, 'epoch': 0.01}
{'loss': 0.3767, 'grad_norm': 1.5032342672348022, 'learning_rate': 4.983297022512709e-05, 'epoch': 0.01}


  0%|          | 0/940 [00:00<?, ?it/s]

{'eval_loss': 0.4543038606643677, 'eval_runtime': 295.1228, 'eval_samples_per_second': 12.74, 'eval_steps_per_second': 3.185, 'epoch': 0.01}
{'loss': 0.4624, 'grad_norm': 1.493654489517212, 'learning_rate': 4.9742193173565725e-05, 'epoch': 0.02}


  0%|          | 0/940 [00:00<?, ?it/s]

{'eval_loss': 0.4504323899745941, 'eval_runtime': 292.9769, 'eval_samples_per_second': 12.834, 'eval_steps_per_second': 3.208, 'epoch': 0.02}
{'loss': 0.4658, 'grad_norm': 1.2388556003570557, 'learning_rate': 4.9651416122004356e-05, 'epoch': 0.02}


  0%|          | 0/940 [00:00<?, ?it/s]

{'eval_loss': 0.44561636447906494, 'eval_runtime': 294.1831, 'eval_samples_per_second': 12.781, 'eval_steps_per_second': 3.195, 'epoch': 0.02}
{'loss': 0.4, 'grad_norm': 1.8811702728271484, 'learning_rate': 4.9560639070442995e-05, 'epoch': 0.03}


  0%|          | 0/940 [00:00<?, ?it/s]

{'eval_loss': 0.44519656896591187, 'eval_runtime': 293.9984, 'eval_samples_per_second': 12.789, 'eval_steps_per_second': 3.197, 'epoch': 0.03}
{'loss': 0.4432, 'grad_norm': 0.1982816904783249, 'learning_rate': 4.946986201888163e-05, 'epoch': 0.03}


  0%|          | 0/940 [00:00<?, ?it/s]

{'eval_loss': 0.4427814185619354, 'eval_runtime': 294.0213, 'eval_samples_per_second': 12.788, 'eval_steps_per_second': 3.197, 'epoch': 0.03}
{'loss': 0.4133, 'grad_norm': 0.31772828102111816, 'learning_rate': 4.9379084967320265e-05, 'epoch': 0.04}


  0%|          | 0/940 [00:00<?, ?it/s]

{'eval_loss': 0.44387614727020264, 'eval_runtime': 293.9832, 'eval_samples_per_second': 12.79, 'eval_steps_per_second': 3.197, 'epoch': 0.04}
{'loss': 0.4972, 'grad_norm': 0.20903582870960236, 'learning_rate': 4.9288307915758896e-05, 'epoch': 0.04}


  0%|          | 0/940 [00:00<?, ?it/s]

{'eval_loss': 0.4424620270729065, 'eval_runtime': 293.9568, 'eval_samples_per_second': 12.791, 'eval_steps_per_second': 3.198, 'epoch': 0.04}
{'loss': 0.4466, 'grad_norm': 1.5317144393920898, 'learning_rate': 4.9197530864197535e-05, 'epoch': 0.05}


  0%|          | 0/940 [00:00<?, ?it/s]

{'eval_loss': 0.44207483530044556, 'eval_runtime': 293.8847, 'eval_samples_per_second': 12.794, 'eval_steps_per_second': 3.199, 'epoch': 0.05}
{'loss': 0.4653, 'grad_norm': 1.5647974014282227, 'learning_rate': 4.9106753812636166e-05, 'epoch': 0.05}


  0%|          | 0/940 [00:00<?, ?it/s]

{'eval_loss': 0.44046053290367126, 'eval_runtime': 294.6865, 'eval_samples_per_second': 12.759, 'eval_steps_per_second': 3.19, 'epoch': 0.05}
{'loss': 0.5072, 'grad_norm': 0.8568047285079956, 'learning_rate': 4.90159767610748e-05, 'epoch': 0.06}


  0%|          | 0/940 [00:00<?, ?it/s]

{'eval_loss': 0.43968459963798523, 'eval_runtime': 293.8319, 'eval_samples_per_second': 12.796, 'eval_steps_per_second': 3.199, 'epoch': 0.06}
{'loss': 0.449, 'grad_norm': 0.6725171804428101, 'learning_rate': 4.8925199709513436e-05, 'epoch': 0.07}


  0%|          | 0/940 [00:00<?, ?it/s]

{'eval_loss': 0.44085580110549927, 'eval_runtime': 292.2246, 'eval_samples_per_second': 12.867, 'eval_steps_per_second': 3.217, 'epoch': 0.07}
{'loss': 0.5358, 'grad_norm': 0.741547703742981, 'learning_rate': 4.8834422657952074e-05, 'epoch': 0.07}


  0%|          | 0/940 [00:00<?, ?it/s]

{'eval_loss': 0.43882766366004944, 'eval_runtime': 294.6885, 'eval_samples_per_second': 12.759, 'eval_steps_per_second': 3.19, 'epoch': 0.07}
{'loss': 0.3852, 'grad_norm': 1.320408821105957, 'learning_rate': 4.8743645606390706e-05, 'epoch': 0.08}


  0%|          | 0/940 [00:00<?, ?it/s]

{'eval_loss': 0.4392436444759369, 'eval_runtime': 293.9373, 'eval_samples_per_second': 12.792, 'eval_steps_per_second': 3.198, 'epoch': 0.08}
{'loss': 0.4922, 'grad_norm': 1.5941028594970703, 'learning_rate': 4.865286855482934e-05, 'epoch': 0.08}


  0%|          | 0/940 [00:00<?, ?it/s]

{'eval_loss': 0.4377095103263855, 'eval_runtime': 292.2422, 'eval_samples_per_second': 12.866, 'eval_steps_per_second': 3.217, 'epoch': 0.08}
{'loss': 0.464, 'grad_norm': 1.3240922689437866, 'learning_rate': 4.8562091503267976e-05, 'epoch': 0.09}


  0%|          | 0/940 [00:00<?, ?it/s]

{'eval_loss': 0.43750420212745667, 'eval_runtime': 293.5146, 'eval_samples_per_second': 12.81, 'eval_steps_per_second': 3.203, 'epoch': 0.09}
{'loss': 0.4144, 'grad_norm': 1.5798345804214478, 'learning_rate': 4.8471314451706614e-05, 'epoch': 0.09}


  0%|          | 0/940 [00:00<?, ?it/s]

{'eval_loss': 0.43705129623413086, 'eval_runtime': 295.9506, 'eval_samples_per_second': 12.705, 'eval_steps_per_second': 3.176, 'epoch': 0.09}
{'loss': 0.4295, 'grad_norm': 1.4635580778121948, 'learning_rate': 4.8380537400145245e-05, 'epoch': 0.1}


  0%|          | 0/940 [00:00<?, ?it/s]

{'eval_loss': 0.43603038787841797, 'eval_runtime': 296.8854, 'eval_samples_per_second': 12.665, 'eval_steps_per_second': 3.166, 'epoch': 0.1}
{'loss': 0.3733, 'grad_norm': 0.6076154708862305, 'learning_rate': 4.828976034858388e-05, 'epoch': 0.1}


  0%|          | 0/940 [00:00<?, ?it/s]

{'eval_loss': 0.43775519728660583, 'eval_runtime': 296.7946, 'eval_samples_per_second': 12.669, 'eval_steps_per_second': 3.167, 'epoch': 0.1}
{'loss': 0.4738, 'grad_norm': 0.19454355537891388, 'learning_rate': 4.8198983297022515e-05, 'epoch': 0.11}


  0%|          | 0/940 [00:00<?, ?it/s]

{'eval_loss': 0.43598562479019165, 'eval_runtime': 294.8921, 'eval_samples_per_second': 12.75, 'eval_steps_per_second': 3.188, 'epoch': 0.11}
{'loss': 0.4164, 'grad_norm': 1.638607382774353, 'learning_rate': 4.8108206245461154e-05, 'epoch': 0.11}


  0%|          | 0/940 [00:00<?, ?it/s]

{'eval_loss': 0.4365951716899872, 'eval_runtime': 293.9611, 'eval_samples_per_second': 12.791, 'eval_steps_per_second': 3.198, 'epoch': 0.11}
{'loss': 0.4652, 'grad_norm': 0.1091577410697937, 'learning_rate': 4.8017429193899785e-05, 'epoch': 0.12}


  0%|          | 0/940 [00:00<?, ?it/s]

{'eval_loss': 0.4363793134689331, 'eval_runtime': 293.6937, 'eval_samples_per_second': 12.802, 'eval_steps_per_second': 3.201, 'epoch': 0.12}
{'loss': 0.5012, 'grad_norm': 1.4645557403564453, 'learning_rate': 4.792665214233842e-05, 'epoch': 0.13}


  0%|          | 0/940 [00:00<?, ?it/s]

{'eval_loss': 0.4357193410396576, 'eval_runtime': 295.3178, 'eval_samples_per_second': 12.732, 'eval_steps_per_second': 3.183, 'epoch': 0.13}
{'loss': 0.4357, 'grad_norm': 0.6720056533813477, 'learning_rate': 4.7835875090777055e-05, 'epoch': 0.13}


  0%|          | 0/940 [00:00<?, ?it/s]

{'eval_loss': 0.4347008168697357, 'eval_runtime': 295.4059, 'eval_samples_per_second': 12.728, 'eval_steps_per_second': 3.182, 'epoch': 0.13}
{'loss': 0.4622, 'grad_norm': 0.8318189382553101, 'learning_rate': 4.774509803921569e-05, 'epoch': 0.14}


  0%|          | 0/940 [00:00<?, ?it/s]

{'eval_loss': 0.434119313955307, 'eval_runtime': 295.1456, 'eval_samples_per_second': 12.739, 'eval_steps_per_second': 3.185, 'epoch': 0.14}
{'loss': 0.4996, 'grad_norm': 1.236153483390808, 'learning_rate': 4.7654320987654325e-05, 'epoch': 0.14}


  0%|          | 0/940 [00:00<?, ?it/s]

{'eval_loss': 0.4338391125202179, 'eval_runtime': 294.7708, 'eval_samples_per_second': 12.756, 'eval_steps_per_second': 3.189, 'epoch': 0.14}
{'loss': 0.5093, 'grad_norm': 1.193713903427124, 'learning_rate': 4.7563543936092956e-05, 'epoch': 0.15}


  0%|          | 0/940 [00:00<?, ?it/s]

{'eval_loss': 0.4341387152671814, 'eval_runtime': 294.9721, 'eval_samples_per_second': 12.747, 'eval_steps_per_second': 3.187, 'epoch': 0.15}
{'loss': 0.581, 'grad_norm': 1.6505205631256104, 'learning_rate': 4.7472766884531595e-05, 'epoch': 0.15}


  0%|          | 0/940 [00:00<?, ?it/s]

{'eval_loss': 0.4345899224281311, 'eval_runtime': 295.1211, 'eval_samples_per_second': 12.741, 'eval_steps_per_second': 3.185, 'epoch': 0.15}
{'loss': 0.4391, 'grad_norm': 1.704192042350769, 'learning_rate': 4.7381989832970226e-05, 'epoch': 0.16}


  0%|          | 0/940 [00:00<?, ?it/s]

{'eval_loss': 0.43403932452201843, 'eval_runtime': 295.6185, 'eval_samples_per_second': 12.719, 'eval_steps_per_second': 3.18, 'epoch': 0.16}
{'loss': 0.398, 'grad_norm': 0.246282696723938, 'learning_rate': 4.729121278140886e-05, 'epoch': 0.16}


  0%|          | 0/940 [00:00<?, ?it/s]

{'eval_loss': 0.433960884809494, 'eval_runtime': 295.1897, 'eval_samples_per_second': 12.738, 'eval_steps_per_second': 3.184, 'epoch': 0.16}
{'loss': 0.4218, 'grad_norm': 1.2182611227035522, 'learning_rate': 4.7200435729847496e-05, 'epoch': 0.17}


  0%|          | 0/940 [00:00<?, ?it/s]

{'eval_loss': 0.43404486775398254, 'eval_runtime': 295.3823, 'eval_samples_per_second': 12.729, 'eval_steps_per_second': 3.182, 'epoch': 0.17}
{'loss': 0.4649, 'grad_norm': 0.7265605330467224, 'learning_rate': 4.7109658678286135e-05, 'epoch': 0.17}


  0%|          | 0/940 [00:00<?, ?it/s]

{'eval_loss': 0.4332168400287628, 'eval_runtime': 289.9793, 'eval_samples_per_second': 12.966, 'eval_steps_per_second': 3.242, 'epoch': 0.17}
{'loss': 0.4316, 'grad_norm': 0.39429306983947754, 'learning_rate': 4.7018881626724766e-05, 'epoch': 0.18}


  0%|          | 0/940 [00:00<?, ?it/s]

{'eval_loss': 0.43327122926712036, 'eval_runtime': 290.271, 'eval_samples_per_second': 12.953, 'eval_steps_per_second': 3.238, 'epoch': 0.18}
{'loss': 0.4207, 'grad_norm': 0.20005400478839874, 'learning_rate': 4.69281045751634e-05, 'epoch': 0.19}


  0%|          | 0/940 [00:00<?, ?it/s]

{'eval_loss': 0.4344761073589325, 'eval_runtime': 290.26, 'eval_samples_per_second': 12.954, 'eval_steps_per_second': 3.238, 'epoch': 0.19}
{'loss': 0.4201, 'grad_norm': 0.47727611660957336, 'learning_rate': 4.6837327523602036e-05, 'epoch': 0.19}


  0%|          | 0/940 [00:00<?, ?it/s]

{'eval_loss': 0.43385419249534607, 'eval_runtime': 290.3171, 'eval_samples_per_second': 12.951, 'eval_steps_per_second': 3.238, 'epoch': 0.19}
{'loss': 0.4768, 'grad_norm': 3.1601099967956543, 'learning_rate': 4.674655047204067e-05, 'epoch': 0.2}


  0%|          | 0/940 [00:00<?, ?it/s]

{'eval_loss': 0.434255450963974, 'eval_runtime': 291.8445, 'eval_samples_per_second': 12.884, 'eval_steps_per_second': 3.221, 'epoch': 0.2}
{'loss': 0.4856, 'grad_norm': 3.4885709285736084, 'learning_rate': 4.6655773420479306e-05, 'epoch': 0.2}


  0%|          | 0/940 [00:00<?, ?it/s]

{'eval_loss': 0.43378737568855286, 'eval_runtime': 296.3672, 'eval_samples_per_second': 12.687, 'eval_steps_per_second': 3.172, 'epoch': 0.2}
{'loss': 0.3944, 'grad_norm': 0.18796782195568085, 'learning_rate': 4.656499636891794e-05, 'epoch': 0.21}


  0%|          | 0/940 [00:00<?, ?it/s]

{'eval_loss': 0.4336297810077667, 'eval_runtime': 291.2641, 'eval_samples_per_second': 12.909, 'eval_steps_per_second': 3.227, 'epoch': 0.21}
{'loss': 0.369, 'grad_norm': 1.1330397129058838, 'learning_rate': 4.6474219317356576e-05, 'epoch': 0.21}


  0%|          | 0/940 [00:00<?, ?it/s]

{'eval_loss': 0.4336042106151581, 'eval_runtime': 290.6313, 'eval_samples_per_second': 12.937, 'eval_steps_per_second': 3.234, 'epoch': 0.21}
{'loss': 0.4048, 'grad_norm': 0.5846811532974243, 'learning_rate': 4.638344226579521e-05, 'epoch': 0.22}


  0%|          | 0/940 [00:00<?, ?it/s]

{'eval_loss': 0.43322551250457764, 'eval_runtime': 294.075, 'eval_samples_per_second': 12.786, 'eval_steps_per_second': 3.196, 'epoch': 0.22}


KeyboardInterrupt: 

In [None]:
# Access logged losses
train_losses = loss_logger.train_losses
eval_losses = loss_logger.eval_losses

# Plot the losses
plt.figure(figsize=(12, 6))
plt.plot(train_losses, label='Training Loss', color='blue')
plt.plot(range(len(eval_losses)), eval_losses, label='Evaluation Loss', color='orange', marker='o')
plt.xlabel('Steps')
plt.ylabel('Loss')
plt.title('Training and Evaluation Losses')
plt.legend()
plt.show()

In [None]:
# Save the fine-tuned model
model.save_pretrained("./gpt2-finetuned")
tokenizer.save_pretrained("./gpt2-finetuned")

In [17]:
# Load the fine-tuned GPT-2 model and tokenizer
fine_tuned_model_path = "./gpt2-finetuned"  # Path to your fine-tuned model

tokenizer = AutoTokenizer.from_pretrained(fine_tuned_model_path)
model = AutoModelForCausalLM.from_pretrained(fine_tuned_model_path)

In [18]:
# Test generation after fine-tuning
def generate_text_after(prompt, max_length=50):
    inputs = tokenizer(prompt, return_tensors="pt")
    outputs = model.generate(inputs["input_ids"], max_length=max_length, num_return_sequences=1, attention_mask=inputs["attention_mask"], eos_token_id=None)
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

In [None]:
print("After fine-tuning:")
prompt = "A PhD student is excited to join Huawei research team"
print(generate_text_after(prompt))