In [1]:
from pruning import *
from lw_retrain_utils import *
from evaluation import *
import json
import copy 
from datasets import load_dataset
import os
import gc
import shutil

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model_name = "openai-community/gpt2-medium"
model, tokenizer = load_model(model_name)

In [3]:
dataset = load_dataset("stas/openwebtext-10k", trust_remote_code=True)
calibration_pass(model=model,
                 tokenizer=tokenizer,
                 dataset=dataset,
                 sample_size=128,
                 batch_size=4,)

100%|██████████| 32/32 [00:09<00:00,  3.47it/s]


In [None]:
tokenized_dataset = tokenize_dataset(tokenizer, dataset)

os.makedirs("./saved_metrics", exist_ok=True)

training_metrics_path = "./saved_metrics/training_metrics.json"
eval_metrics_path = "./saved_metrics/eval_metrics.json"

if os.path.exists(training_metrics_path):
    with open(training_metrics_path, "r") as f:
        training_metrics = json.load(f)
else:
    training_metrics = {}

if os.path.exists(eval_metrics_path):
    with open(eval_metrics_path, "r") as f:
        eval_metrics = json.load(f)
else:
    eval_metrics = {}

num_heads = 12
mult_hidden = 2.5
embed_size = 1024
prune_model_width(model, int(mult_hidden * embed_size), num_heads, embed_size)
print(sum(t.numel() for t in model.parameters()))
param_key = f"num_heads={num_heads}_mlp_exp={mult_hidden}_embed_size={embed_size}"
gc.collect()

if param_key+"_before_training" in eval_metrics:
    print(f"Skipping evaluation for {param_key}, already exists.")
else:
    print(f"Evaluating perplexity for {param_key} before training...")
    eval_metrics[param_key+"_before_training"] = evaluate_perplexity(model, tokenizer, stride=1024).item()

    with open(eval_metrics_path, "w") as f:
        json.dump(eval_metrics, f, indent=4)
if param_key in training_metrics:
    print(f"Skipping training for {param_key}, already exists.")
else:
    print(f"Training model for {param_key}...")

    
trainer = trainer_gpt2(model, tokenizer, tokenized_dataset, batch_size=4, num_epochs=2, lr=2e-4) #, output_dir=f"./saved_models/{param_key}")
trainer.train()
    # training_metrics[param_key] = trainer.state.log_history

    # with open(training_metrics_path, "w") as f:
    #     json.dump(training_metrics, f, indent=4)
    # torch.cuda.empty_cache()
    # gc.collect()
    
if param_key in eval_metrics:
    print(f"Skipping evaluation for {param_key}, already exists.")
else:
    print(f"Evaluating perplexity for {param_key}...")
    eval_metrics[param_key] = evaluate_perplexity(model, tokenizer, stride=1024).item()

    with open(eval_metrics_path, "w") as f:
        json.dump(eval_metrics, f, indent=4)

torch.cuda.empty_cache()
gc.collect()

254104576
Evaluating perplexity for num_heads=12_mlp_exp=2.5_embed_size=1024 before training...


Token indices sequence length is longer than the specified maximum sequence length for this model (287644 > 1024). Running this sequence through the model will result in indexing errors
  0%|          | 0/281 [00:00<?, ?it/s]`loss_type=None` was set in the config but it is unrecognised.Using the default loss: `ForCausalLMLoss`.
100%|█████████▉| 280/281 [00:04<00:00, 57.01it/s]


Skipping training for num_heads=12_mlp_exp=2.5_embed_size=1024, already exists.


Step,Training Loss,Validation Loss
30,4.7195,4.164433


In [None]:
output_dir=f"./saved_models/{param_key}"

output_dir = f"./saved_models/{param_key}"
os.makedirs(output_dir, exist_ok=True)

torch.save(model, f"{output_dir}/model_lw_retrain.pth")
torch.save(model.state_dict(), f"{output_dir}/model_lw_retrain_state_dict.pth")

## Knowledge Distillation

In [None]:
import torch
from transformers import AdamW
from torch.nn import functional as F
from torch.utils.data import DataLoader
import knowledge_distillation as kd


# import knowledge_distillation  # Import the module
from evaluation import *

In [None]:
model_name = "openai-community/gpt2-medium"
teacher_model, tokenizer = load_model(model_name)
teacher_model.eval()
print('Models loaded successfully')

In [None]:
dataset = load_dataset("deven367/babylm-10M-cbt", trust_remote_code=True)
train_dataset = dataset['train']
val_dataset = dataset['valid'].select(range(500))

In [None]:

def tokenize_function(examples):
    tokenized = tokenizer(
        examples["text"],
        padding="max_length",
        truncation=True,
        max_length=256,
        return_tensors="pt"
    )

    # Create input_ids and labels for language modeling
    input_ids = tokenized["input_ids"]
    labels = input_ids.clone()

    return {
        "input_ids": input_ids,
        "attention_mask": tokenized["attention_mask"],
        "labels": labels
    }
    
print("Tokenizing dataset...")
tokenized_train_datasets = train_dataset.map(
    tokenize_function,
    batched=True,
    batch_size=32,  # Smaller batch size for mapping
    remove_columns=train_dataset.column_names,
    desc="Processing examples",
    load_from_cache_file=False  # Disable caching for debugging
)
tokenized_val_datasets = val_dataset.map(
    tokenize_function,
    batched=True,
    batch_size=32,  # Smaller batch size for mapping
    remove_columns=val_dataset.column_names,
    desc="Processing examples",
    load_from_cache_file=False  # Disable caching for debugging
)
tokenized_train_datasets.set_format("torch")
tokenized_val_datasets.set_format("torch")

In [None]:
steps, train_losses, val_losses = kd.train_kd(model, teacher_model, tokenized_train_datasets, tokenized_val_datasets, 
                                            batch_size=4, num_epochs=2, accumulation_steps=8, lr=5e-5, 
                                            temperature=1.0, device='cuda', log_interval=10, val_interval=250)

In [None]:
metrics_path = "saved_metrics/kd_metrics.json"
model_name = "width"

# Load existing metrics if the file exists, otherwise start with an empty dict
if os.path.exists(metrics_path):
    with open(metrics_path, "r") as f:
        kd_metrics = json.load(f)
else:
    kd_metrics = {}

kd_metrics.setdefault(model_name, {})[model_name] = {"steps": steps, "train_losses": train_losses, "val_losses": val_losses}

# Save updated metrics back to the file
os.makedirs("saved_metrics", exist_ok=True)
with open(metrics_path, "w") as f:
    json.dump(kd_metrics, f, indent=4)

In [None]:
evaluate_perplexity(model, tokenizer, stride=1024).item()

In [None]:
test_dataset = load_dataset("ptb_text_only", split="test")
test_dataset = test_dataset.rename_columns({"sentence": "text"})
evaluate_perplexity(model, tokenizer, test=test_dataset, stride=512)

In [None]:
from itertools import islice
from datasets import Dataset

bookcorpus = load_dataset("bookcorpus", split="train", streaming=True, trust_remote_code=True)
bookcorpus_test = list(islice(bookcorpus, 10_000))
bookcorpus_test_dataset = Dataset.from_list(bookcorpus_test)

evaluate_perplexity(model, tokenizer, test=bookcorpus_test_dataset, stride=128)

In [None]:
test_dataset = load_dataset("lambada", split="test", trust_remote_code=True)
evaluate_perplexity(model, tokenizer, test=test_dataset, stride=1024)