In [1]:
from pruning import *
from lw_retrain_utils import *
from evaluation import *
import json
import copy 
from datasets import load_dataset
import os
import gc
import shutil

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model_name = "openai-community/gpt2-medium"
model, tokenizer = load_model(model_name)

In [3]:
dataset = load_dataset("stas/openwebtext-10k", trust_remote_code=True)
calibration_pass(model=model,
                 tokenizer=tokenizer,
                 dataset=dataset,
                 sample_size=128,
                 batch_size=4,)

 78%|███████▊  | 25/32 [00:09<00:03,  1.93it/s]

: 

In [None]:
tokenized_dataset = tokenize_dataset(tokenizer, dataset)

os.makedirs("./saved_metrics", exist_ok=True)

training_metrics_path = "./saved_metrics/training_metrics.json"
eval_metrics_path = "./saved_metrics/eval_metrics.json"

if os.path.exists(training_metrics_path):
    with open(training_metrics_path, "r") as f:
        training_metrics = json.load(f)
else:
    training_metrics = {}

if os.path.exists(eval_metrics_path):
    with open(eval_metrics_path, "r") as f:
        eval_metrics = json.load(f)
else:
    eval_metrics = {}

num_heads = 12
mult_hidden = 2.5
embed_size = 1024
prune_model_width(model, int(mult_hidden * embed_size), num_heads, embed_size)
print(sum(t.numel() for t in model.parameters()))
param_key = f"num_heads={num_heads}_mlp_exp={mult_hidden}_embed_size={embed_size}"
gc.collect()

if param_key+"_before_training" in eval_metrics:
    print(f"Skipping evaluation for {param_key}, already exists.")
else:
    print(f"Evaluating perplexity for {param_key} before training...")
    eval_metrics[param_key+"_before_training"] = evaluate_perplexity(model, tokenizer, stride=1024).item()

    # with open(eval_metrics_path, "w") as f:
    #     json.dump(eval_metrics, f, indent=4)
if param_key in training_metrics:
    print(f"Skipping training for {param_key}, already exists.")
else:
    print(f"Training model for {param_key}...")

    
    trainer = trainer_gpt2(model, tokenizer, tokenized_dataset, batch_size=4, num_epochs=2, lr=2e-4, output_dir=f"./saved_models/{param_key}")
    trainer.train()
    training_metrics[param_key] = trainer.state.log_history

    # with open(training_metrics_path, "w") as f:
    #     json.dump(training_metrics, f, indent=4)
    torch.cuda.empty_cache()
    gc.collect()
    
if param_key in eval_metrics:
    print(f"Skipping evaluation for {param_key}, already exists.")
else:
    print(f"Evaluating perplexity for {param_key}...")
    eval_metrics[param_key] = evaluate_perplexity(model, tokenizer, stride=1024).item()

    # with open(eval_metrics_path, "w") as f:
    #     json.dump(eval_metrics, f, indent=4)

torch.cuda.empty_cache()
gc.collect()

Map:   0%|          | 0/9000 [00:00<?, ? examples/s]

Map: 100%|██████████| 9000/9000 [00:04<00:00, 2048.49 examples/s]


262469632
Evaluating perplexity for depth=20_num_heads=12_mlp_exp=3.5_embed_size=1024 before training...


100%|█████████▉| 280/281 [00:04<00:00, 57.96it/s]


Perplexity on Wikitext-2: 85.16
Training model for depth=20_num_heads=12_mlp_exp=3.5_embed_size=1024...


Step,Training Loss,Validation Loss
30,3.4844,3.121555
60,3.2623,3.037397
90,3.2137,3.007232
120,3.1955,2.989893
150,3.1473,2.985213
180,3.1365,2.980379
210,3.1447,2.979733
240,3.1177,2.979151
270,3.125,2.979301


Evaluating perplexity for depth=20_num_heads=12_mlp_exp=3.5_embed_size=1024...


100%|█████████▉| 280/281 [00:06<00:00, 41.41it/s]


Perplexity on Wikitext-2: 36.63


33