In [1]:
from pruning import *
from lw_retrain_utils import *
from evaluation import *
import json
import copy 
from datasets import load_dataset
import os
import gc
import shutil

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model_name = "openai-community/gpt2-medium"
model, tokenizer = load_model(model_name)

In [3]:
dataset = load_dataset("stas/openwebtext-10k", trust_remote_code=True)
calibration_pass(model=model,
                 tokenizer=tokenizer,
                 dataset=dataset,
                 sample_size=128,
                 batch_size=4,)

100%|██████████| 32/32 [00:10<00:00,  2.96it/s]


In [4]:
tokenized_dataset = tokenize_dataset(tokenizer, dataset)

os.makedirs("./saved_metrics", exist_ok=True)

training_metrics_path = "./saved_metrics/training_metrics.json"
eval_metrics_path = "./saved_metrics/eval_metrics.json"

if os.path.exists(training_metrics_path):
    with open(training_metrics_path, "r") as f:
        training_metrics = json.load(f)
else:
    training_metrics = {}

if os.path.exists(eval_metrics_path):
    with open(eval_metrics_path, "r") as f:
        eval_metrics = json.load(f)
else:
    eval_metrics = {}

num_heads = 12
mult_hidden = 2.5
embed_size = 1024
prune_model_width(model, int(mult_hidden * embed_size), num_heads, embed_size)
print(sum(t.numel() for t in model.parameters()))
param_key = f"num_heads={num_heads}_mlp_exp={mult_hidden}_embed_size={embed_size}"
gc.collect()

if param_key+"_before_training" in eval_metrics:
    print(f"Skipping evaluation for {param_key}, already exists.")
else:
    print(f"Evaluating perplexity for {param_key} before training...")
    eval_metrics[param_key+"_before_training"] = evaluate_perplexity(model, tokenizer, stride=1024).item()

    with open(eval_metrics_path, "w") as f:
        json.dump(eval_metrics, f, indent=4)
if param_key in training_metrics:
    print(f"Skipping training for {param_key}, already exists.")
else:
    print(f"Training model for {param_key}...")

    
trainer = trainer_gpt2(model, tokenizer, tokenized_dataset, batch_size=4, num_epochs=2, lr=2e-4) #, output_dir=f"./saved_models/{param_key}")
trainer.train()
    # training_metrics[param_key] = trainer.state.log_history

    # with open(training_metrics_path, "w") as f:
    #     json.dump(training_metrics, f, indent=4)
    # torch.cuda.empty_cache()
    # gc.collect()
    
if param_key in eval_metrics:
    print(f"Skipping evaluation for {param_key}, already exists.")
else:
    print(f"Evaluating perplexity for {param_key}...")
    eval_metrics[param_key] = evaluate_perplexity(model, tokenizer, stride=1024).item()

    with open(eval_metrics_path, "w") as f:
        json.dump(eval_metrics, f, indent=4)

torch.cuda.empty_cache()
gc.collect()

254104576
Skipping evaluation for num_heads=12_mlp_exp=2.5_embed_size=1024, already exists.
Skipping training for num_heads=12_mlp_exp=2.5_embed_size=1024, already exists.


`loss_type=None` was set in the config but it is unrecognised.Using the default loss: `ForCausalLMLoss`.


Step,Training Loss,Validation Loss
30,4.5208,3.994396
60,3.6075,3.341475
90,3.4384,3.227257
120,3.3909,3.18897
150,3.3342,3.171892
180,3.3223,3.163071
210,3.3253,3.16122
240,3.2966,3.160393
270,3.3099,3.1603


Evaluating perplexity for num_heads=12_mlp_exp=2.5_embed_size=1024...


Token indices sequence length is longer than the specified maximum sequence length for this model (287644 > 1024). Running this sequence through the model will result in indexing errors
100%|█████████▉| 280/281 [00:06<00:00, 43.98it/s]


1007

In [5]:
output_dir=f"./saved_models/{param_key}"

output_dir = f"./saved_models/{param_key}"
os.makedirs(output_dir, exist_ok=True)

torch.save(model, f"{output_dir}/model_lw_retrain.pth")
torch.save(model.state_dict(), f"{output_dir}/model_lw_retrain_state_dict.pth")

## Knowledge Distillation

In [6]:
import torch
from transformers import AdamW
from torch.nn import functional as F
from torch.utils.data import DataLoader
import knowledge_distillation as kd


# import knowledge_distillation  # Import the module
from evaluation import *

In [7]:
model_name = "openai-community/gpt2-medium"
teacher_model, tokenizer = load_model(model_name)
teacher_model.eval()
print('Models loaded successfully')

Models loaded successfully


In [8]:
dataset = load_dataset("deven367/babylm-10M-cbt", trust_remote_code=True)
train_dataset = dataset['train']
val_dataset = dataset['valid'].select(range(500))

In [9]:

def tokenize_function(examples):
    tokenized = tokenizer(
        examples["text"],
        padding="max_length",
        truncation=True,
        max_length=256,
        return_tensors="pt"
    )

    # Create input_ids and labels for language modeling
    input_ids = tokenized["input_ids"]
    labels = input_ids.clone()

    return {
        "input_ids": input_ids,
        "attention_mask": tokenized["attention_mask"],
        "labels": labels
    }
    
print("Tokenizing dataset...")
tokenized_train_datasets = train_dataset.map(
    tokenize_function,
    batched=True,
    batch_size=32,  # Smaller batch size for mapping
    remove_columns=train_dataset.column_names,
    desc="Processing examples",
    load_from_cache_file=False  # Disable caching for debugging
)
tokenized_val_datasets = val_dataset.map(
    tokenize_function,
    batched=True,
    batch_size=32,  # Smaller batch size for mapping
    remove_columns=val_dataset.column_names,
    desc="Processing examples",
    load_from_cache_file=False  # Disable caching for debugging
)
tokenized_train_datasets.set_format("torch")
tokenized_val_datasets.set_format("torch")

Tokenizing dataset...


Processing examples: 100%|██████████| 26000/26000 [00:04<00:00, 5522.27 examples/s]
Processing examples: 100%|██████████| 500/500 [00:00<00:00, 4778.66 examples/s]


In [10]:
steps, train_losses, val_losses = kd.train_kd(model, teacher_model, tokenized_train_datasets, tokenized_val_datasets, 
                                            batch_size=4, num_epochs=2, accumulation_steps=8, lr=5e-5, 
                                            temperature=1.0, device='cuda', log_interval=10, val_interval=250)




🔄 Epoch 1/2 - Training...


Training Loss: 54.6548:   4%|▍         | 251/6500 [00:29<1:48:07,  1.04s/it]

📉 Step 250: Validation Loss = 31.7630


Training Loss: 41.0530:   8%|▊         | 502/6500 [00:58<1:17:18,  1.29it/s]

📉 Step 500: Validation Loss = 26.7116


Training Loss: 35.2317:  12%|█▏        | 752/6500 [01:26<1:12:51,  1.31it/s]

📉 Step 750: Validation Loss = 24.5806


Training Loss: 32.0679:  15%|█▌        | 1002/6500 [01:55<1:08:56,  1.33it/s]

📉 Step 1000: Validation Loss = 23.2209


Training Loss: 29.8010:  19%|█▉        | 1252/6500 [02:24<1:05:46,  1.33it/s]

📉 Step 1250: Validation Loss = 22.3189


Training Loss: 28.1610:  23%|██▎       | 1502/6500 [02:53<1:02:37,  1.33it/s]

📉 Step 1500: Validation Loss = 21.6675


Training Loss: 26.9213:  27%|██▋       | 1752/6500 [03:22<59:35,  1.33it/s]  

📉 Step 1750: Validation Loss = 21.1274


Training Loss: 25.8992:  31%|███       | 2002/6500 [03:50<55:27,  1.35it/s]  

📉 Step 2000: Validation Loss = 20.6984


Training Loss: 25.0876:  35%|███▍      | 2252/6500 [04:18<52:25,  1.35it/s]  

📉 Step 2250: Validation Loss = 20.3276


Training Loss: 24.3914:  38%|███▊      | 2502/6500 [04:47<49:18,  1.35it/s]  

📉 Step 2500: Validation Loss = 19.9740


Training Loss: 23.8050:  42%|████▏     | 2752/6500 [05:15<46:38,  1.34it/s]  

📉 Step 2750: Validation Loss = 19.7060


Training Loss: 23.3643:  46%|████▌     | 3002/6500 [05:44<43:33,  1.34it/s]

📉 Step 3000: Validation Loss = 19.4544


Training Loss: 22.9172:  50%|█████     | 3252/6500 [06:12<40:05,  1.35it/s]

📉 Step 3250: Validation Loss = 19.2378


Training Loss: 22.5428:  54%|█████▍    | 3502/6500 [06:41<37:13,  1.34it/s]

📉 Step 3500: Validation Loss = 19.0627


Training Loss: 22.1537:  58%|█████▊    | 3752/6500 [07:09<34:13,  1.34it/s]

📉 Step 3750: Validation Loss = 18.8456


Training Loss: 21.8648:  62%|██████▏   | 4002/6500 [07:37<31:01,  1.34it/s]

📉 Step 4000: Validation Loss = 18.7033


Training Loss: 21.5565:  65%|██████▌   | 4252/6500 [08:06<27:54,  1.34it/s]

📉 Step 4250: Validation Loss = 18.5751


Training Loss: 21.3122:  69%|██████▉   | 4502/6500 [08:34<24:37,  1.35it/s]

📉 Step 4500: Validation Loss = 18.4204


Training Loss: 21.0483:  73%|███████▎  | 4752/6500 [09:03<21:45,  1.34it/s]

📉 Step 4750: Validation Loss = 18.2863


Training Loss: 20.8379:  77%|███████▋  | 5002/6500 [09:31<18:33,  1.35it/s]

📉 Step 5000: Validation Loss = 18.1392


Training Loss: 20.6216:  81%|████████  | 5252/6500 [10:00<15:31,  1.34it/s]

📉 Step 5250: Validation Loss = 18.0593


Training Loss: 20.4361:  85%|████████▍ | 5502/6500 [10:28<12:24,  1.34it/s]

📉 Step 5500: Validation Loss = 17.9577


Training Loss: 20.2805:  88%|████████▊ | 5752/6500 [10:57<09:24,  1.33it/s]

📉 Step 5750: Validation Loss = 17.8454


Training Loss: 20.0955:  92%|█████████▏| 6002/6500 [11:26<06:16,  1.32it/s]

📉 Step 6000: Validation Loss = 17.7508


Training Loss: 19.9378:  96%|█████████▌| 6252/6500 [11:54<03:06,  1.33it/s]

📉 Step 6250: Validation Loss = 17.6710


                                                                           

📉 Step 6500: Validation Loss = 17.5947
✅ Epoch 1: Average Training Loss = 19.7815


                                                                           

📉 Epoch 1: Final Validation Loss = 17.5947

🔄 Epoch 2/2 - Training...


Training Loss: 26.2582:   4%|▍         | 252/6500 [00:29<1:18:37,  1.32it/s]

📉 Step 6750: Validation Loss = 17.9929


Training Loss: 20.8962:   8%|▊         | 502/6500 [00:58<1:15:25,  1.33it/s]

📉 Step 7000: Validation Loss = 17.4880


Training Loss: 19.1643:  12%|█▏        | 752/6500 [01:27<1:12:13,  1.33it/s]

📉 Step 7250: Validation Loss = 17.3049


Training Loss: 18.2701:  15%|█▌        | 1002/6500 [01:56<1:08:50,  1.33it/s]

📉 Step 7500: Validation Loss = 17.2092


Training Loss: 17.6510:  19%|█▉        | 1252/6500 [02:24<1:04:46,  1.35it/s]

📉 Step 7750: Validation Loss = 17.1283


Training Loss: 17.3499:  23%|██▎       | 1502/6500 [02:52<1:01:48,  1.35it/s]

📉 Step 8000: Validation Loss = 17.0768


Training Loss: 17.0675:  27%|██▋       | 1752/6500 [03:21<59:09,  1.34it/s]  

📉 Step 8250: Validation Loss = 17.0182


Training Loss: 16.8149:  31%|███       | 2002/6500 [03:49<55:57,  1.34it/s]  

📉 Step 8500: Validation Loss = 16.9935


Training Loss: 16.6547:  35%|███▍      | 2252/6500 [04:18<53:02,  1.33it/s]  

📉 Step 8750: Validation Loss = 16.9283


Training Loss: 16.4491:  38%|███▊      | 2502/6500 [04:46<49:47,  1.34it/s]  

📉 Step 9000: Validation Loss = 16.8845


Training Loss: 16.3223:  42%|████▏     | 2752/6500 [05:15<46:29,  1.34it/s]  

📉 Step 9250: Validation Loss = 16.8245


Training Loss: 16.2026:  46%|████▌     | 3002/6500 [05:43<43:28,  1.34it/s]

📉 Step 9500: Validation Loss = 16.7793


Training Loss: 16.1298:  50%|█████     | 3252/6500 [06:12<40:36,  1.33it/s]

📉 Step 9750: Validation Loss = 16.7223


Training Loss: 16.0448:  54%|█████▍    | 3502/6500 [06:40<37:20,  1.34it/s]

📉 Step 10000: Validation Loss = 16.6651


Training Loss: 15.9682:  58%|█████▊    | 3752/6500 [07:09<34:23,  1.33it/s]

📉 Step 10250: Validation Loss = 16.6461


Training Loss: 15.9005:  62%|██████▏   | 4002/6500 [07:37<30:54,  1.35it/s]

📉 Step 10500: Validation Loss = 16.5967


Training Loss: 15.8334:  65%|██████▌   | 4252/6500 [08:06<27:55,  1.34it/s]

📉 Step 10750: Validation Loss = 16.5677


Training Loss: 15.8102:  69%|██████▉   | 4502/6500 [08:34<24:48,  1.34it/s]

📉 Step 11000: Validation Loss = 16.5363


Training Loss: 15.7518:  73%|███████▎  | 4752/6500 [09:03<21:40,  1.34it/s]

📉 Step 11250: Validation Loss = 16.4968


Training Loss: 15.6997:  77%|███████▋  | 5002/6500 [09:31<18:28,  1.35it/s]

📉 Step 11500: Validation Loss = 16.4594


Training Loss: 15.6517:  81%|████████  | 5252/6500 [09:59<15:22,  1.35it/s]

📉 Step 11750: Validation Loss = 16.4104


Training Loss: 15.6196:  85%|████████▍ | 5502/6500 [10:28<12:16,  1.36it/s]

📉 Step 12000: Validation Loss = 16.3992


Training Loss: 15.5777:  88%|████████▊ | 5752/6500 [10:56<09:15,  1.35it/s]

📉 Step 12250: Validation Loss = 16.3332


Training Loss: 15.5297:  92%|█████████▏| 6002/6500 [11:25<06:10,  1.35it/s]

📉 Step 12500: Validation Loss = 16.3135


Training Loss: 15.5069:  96%|█████████▌| 6252/6500 [11:53<03:04,  1.34it/s]

📉 Step 12750: Validation Loss = 16.2980


                                                                           

📉 Step 13000: Validation Loss = 16.2592
✅ Epoch 2: Average Training Loss = 15.4543


                                                                           

📉 Epoch 2: Final Validation Loss = 16.2592




In [11]:
metrics_path = "saved_metrics/kd_metrics.json"
model_name = "width"

# Load existing metrics if the file exists, otherwise start with an empty dict
if os.path.exists(metrics_path):
    with open(metrics_path, "r") as f:
        kd_metrics = json.load(f)
else:
    kd_metrics = {}

kd_metrics.setdefault(model_name, {})[model_name] = {"steps": steps, "train_losses": train_losses, "val_losses": val_losses}

# Save updated metrics back to the file
os.makedirs("saved_metrics", exist_ok=True)
with open(metrics_path, "w") as f:
    json.dump(kd_metrics, f, indent=4)

In [12]:
evaluate_perplexity(model, tokenizer, stride=1024).item()

Token indices sequence length is longer than the specified maximum sequence length for this model (287644 > 1024). Running this sequence through the model will result in indexing errors
100%|█████████▉| 280/281 [00:13<00:00, 20.49it/s]


42.33681106567383

In [13]:
test_dataset = load_dataset("ptb_text_only", split="test")
test_dataset = test_dataset.rename_columns({"sentence": "text"})
evaluate_perplexity(model, tokenizer, test=test_dataset, stride=512)

 99%|█████████▉| 205/207 [00:04<00:00, 43.80it/s]


tensor(41.9827, device='cuda:0')

In [14]:
from itertools import islice
from datasets import Dataset

bookcorpus = load_dataset("bookcorpus", split="train", streaming=True, trust_remote_code=True)
bookcorpus_test = list(islice(bookcorpus, 10_000))
bookcorpus_test_dataset = Dataset.from_list(bookcorpus_test)

evaluate_perplexity(model, tokenizer, test=bookcorpus_test_dataset, stride=128)

 99%|█████████▉| 1286/1294 [00:37<00:00, 34.04it/s]


tensor(22.6066, device='cuda:0')

In [15]:
test_dataset = load_dataset("lambada", split="test", trust_remote_code=True)
evaluate_perplexity(model, tokenizer, test=test_dataset, stride=1024)

100%|█████████▉| 419/420 [00:14<00:00, 28.64it/s]


tensor(58.6653, device='cuda:0')