# The Experments about LoRA and Purning based on Knowledge Distilled GPT2

In [None]:
# Imports and setup
import torch
import time, math, logging
from transformers import GPT2LMHeadModel, GPT2Tokenizer, GenerationConfig,AutoModelForCausalLM
from torch.optim import AdamW
from datasets import load_dataset
from peft import LoraConfig, get_peft_model
from peft.tuners.lora import LoraLayer
import torch.nn.utils.prune as prune
import pandas as pd
import gc
import copy
import wandb
import matplotlib.pyplot as plt
import numpy as np
from tqdm.auto import tqdm


In [None]:
# Configure logging
for h in logging.root.handlers[:]:
    logging.root.removeHandler(h)
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s", force=True)
logger = logging.getLogger(__name__)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {device}")

# Define experiment parameters
model_name = "gpt2"
max_length = 128
num_train_epochs = 2
eval_every_steps = 200
batch_size = 16
inference_batch_size = 8
num_inference_batches = 50
pruning_amount = 0.8  # 80% pruning rate
lora_r = 8  # LoRA rank
lora_alpha = 32
lora_target_modules = ["c_attn"]  # Target attention modules for LoRA
baseline_lr = 5e-5
lora_lr = 1e-4
run_inference_benchmark = True

# Initialize wandb for experiment tracking
try:
    run = wandb.init(
        project="lora-pruning-comparison-dstill-3",
        name=f"Baseline_vs_LoRA_vs_Pruned_{num_train_epochs}Ep_{pruning_amount:.0%}",
        config={
            "model_name": model_name, "max_length": max_length,
            "num_train_epochs": num_train_epochs, "batch_size": batch_size,
            "pruning_amount": pruning_amount, "lora_r": lora_r,
            "lora_alpha": lora_alpha, "lora_target_modules": lora_target_modules,
            "baseline_lr": baseline_lr, "lora_lr": lora_lr,
            "eval_every_steps": eval_every_steps, "device": str(device),
            "inference_batch_size": inference_batch_size,
            "num_inference_batches": num_inference_batches,
            "run_inference_benchmark": run_inference_benchmark
        }
    )
    logger.info("Weights & Biases initialized successfully.")
except Exception as e:
    logger.error(f"Failed to initialize Weights & Biases: {e}")
    run = None

# Load and prepare dataset
logger.info("Loading data...")
dataset = load_dataset("wikitext", "wikitext-2-raw-v1")
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
train_texts_full = [t for t in dataset["train"]["text"] if t.strip()]
val_texts_full = [t for t in dataset["validation"]["text"] if t.strip()]
test_texts_full = val_texts_full[:inference_batch_size * num_inference_batches]
logger.info(f"Data loaded. Train: {len(train_texts_full)}, Val/Test: {len(val_texts_full)}")

# Function to calculate perplexity for model evaluation
@torch.no_grad()
def compute_perplexity(model, tokenizer, texts, device, batch_size=8, max_length=128):
    model.eval()
    losses = []
    total_evaluated = 0
    for i in range(0, len(texts), batch_size):
        batch = texts[i:i+batch_size]
        if not batch: continue
        total_evaluated += len(batch)
        inputs = tokenizer(batch, padding=True, truncation=True, max_length=max_length, return_tensors="pt").to(device)
        outputs = model(**inputs, labels=inputs.input_ids)
        if hasattr(outputs, 'loss') and outputs.loss is not None:
            losses.append(outputs.loss.item() * len(batch))
    if not losses or total_evaluated == 0: return float('inf')
    avg_loss = sum(losses) / total_evaluated
    if avg_loss <= 0: return float('inf')
    return math.exp(avg_loss)

# Function for model training with tracking
def train_fixed_duration(model, tokenizer, train_texts, val_texts, num_epochs,
                         device, lr=5e-5, batch_size=8, max_length=128,
                         eval_every=500, run_label="Training"):
    optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=lr)
    start_time = time.time()
    peak_mem_overall = 0

    if device == torch.device("cuda"):
        torch.cuda.synchronize()
        torch.cuda.reset_peak_memory_stats(device)

    total_expected_steps = (len(train_texts) // batch_size) * num_epochs
    step = 0
    all_val_ppl = []
    steps_log = []

    logger.info(f"--- Starting {run_label} ---")
    logger.info(f"Epochs: {num_epochs}, LR: {lr}, Batch Size: {batch_size}, Eval Every: {eval_every}")

    for epoch in range(num_epochs):
        model.train()
        logger.info(f"[{run_label} Epoch {epoch+1}/{num_epochs}] Starting...")
        progress_bar = tqdm(range(0, len(train_texts), batch_size), desc=f"Epoch {epoch+1}/{num_epochs}", leave=False)

        for i in progress_bar:
            batch = train_texts[i:i+batch_size]
            if not batch: continue

            inputs = tokenizer(batch, padding=True, truncation=True, max_length=max_length, return_tensors="pt").to(device)
            outputs = model(**inputs, labels=inputs.input_ids)
            loss = outputs.loss

            if loss is None or torch.isnan(loss):
                logger.warning(f"[{run_label} Step {step}] Loss is None or NaN. Skipping step.")
                step += 1
                continue

            loss.backward()
            torch.nn.utils.clip_grad_norm_(filter(lambda p: p.requires_grad, model.parameters()), 1.0)
            optimizer.step()
            optimizer.zero_grad(set_to_none=True)
            step += 1

            progress_bar.set_postfix({'loss': f"{loss.item():.3f}"})

            is_last_batch_overall = (epoch == num_epochs - 1 and i + batch_size >= len(train_texts))
            if step % eval_every == 0 or is_last_batch_overall:
                current_loss_val = loss.item() if loss is not None else float('nan')
                val_ppl = compute_perplexity(model, tokenizer, val_texts, device, batch_size=batch_size, max_length=max_length)
                if not math.isinf(val_ppl) and not math.isnan(val_ppl):
                    all_val_ppl.append(val_ppl)
                    steps_log.append(step)

                elapsed_h = (time.time() - start_time) / 3600
                current_peak_mem = 0
                if device == torch.device("cuda"):
                    current_peak_mem = torch.cuda.max_memory_allocated(device) / 1024**2
                    peak_mem_overall = max(peak_mem_overall, current_peak_mem)

                logger.info(f"[{run_label} Step {step}/{total_expected_steps}] val_ppl={val_ppl:.2f} loss={current_loss_val:.3f} time={elapsed_h:.2f}h peak_mem={current_peak_mem:.1f}MB")

                if run:
                     wandb.log({
                         f"{run_label}/Step": step, f"{run_label}/Val Perplexity": val_ppl,
                         f"{run_label}/Loss": current_loss_val, f"{run_label}/Elapsed Hours": elapsed_h,
                         f"{run_label}/Current Peak Mem MB": current_peak_mem
                     }, step=step)
                model.train()

        logger.info(f"[{run_label} Epoch {epoch+1}/{num_epochs}] Completed.")

    total_h = (time.time() - start_time) / 3600
    final_mem = peak_mem_overall
    final_ppl = all_val_ppl[-1] if all_val_ppl else float('inf')

    logger.info(f"--- Finished {run_label} ---")
    return {
        "steps": step, "total_hours": total_h, "peak_mem_mb": final_mem,
        "final_ppl": final_ppl, "steps_log": steps_log, "ppl_log": all_val_ppl
    }

# Function to apply pruning to model parameters
def apply_pruning(model_to_prune, amount):
    parameters_to_prune = []
    target_modules = set()
    logger.debug("Starting parameter search for pruning...")
    for name, module in model_to_prune.named_modules():
         if isinstance(module, LoraLayer):
             logger.debug(f"Found LoraLayer: {name}")
             for key in ['lora_A', 'lora_B']:
                if hasattr(module, key):
                    sub_module_dict = getattr(module, key)
                    if isinstance(sub_module_dict, torch.nn.ModuleDict):
                         for adapter_name, sub_module in sub_module_dict.items():
                             if hasattr(sub_module, 'weight'):
                                logger.debug(f"   Found weight in {key}.{adapter_name}. Adding.")
                                parameters_to_prune.append((sub_module, 'weight'))
                                target_modules.add(sub_module)
                    elif isinstance(sub_module_dict, torch.nn.Module):
                        if hasattr(sub_module_dict, 'weight'):
                            logger.debug(f"   Found weight directly in {key}. Adding.")
                            parameters_to_prune.append((sub_module_dict, 'weight'))
                            target_modules.add(sub_module_dict)

    if not parameters_to_prune:
        logger.warning("Could not find any LoRA parameters to prune!")
        return 0.0
    else:
        logger.info(f"Identified {len(parameters_to_prune)} parameter tensors for pruning.")
        prune.global_unstructured(parameters_to_prune, pruning_method=prune.L1Unstructured, amount=amount)
        logger.info("Pruning mask applied. Making pruning permanent...")
        final_total_params = 0; final_zero_params = 0; pruned_module_count = 0
        for module in target_modules:
            is_currently_pruned = prune.is_pruned(module)
            if is_currently_pruned:
                 try:
                     mask = getattr(module, 'weight_mask', None); orig_param = getattr(module, 'weight_orig', None)
                     if mask is not None and orig_param is not None:
                          param_element_count = orig_param.nelement(); param_zero_count = torch.sum(mask == 0).item()
                          final_total_params += param_element_count; final_zero_params += param_zero_count
                     elif hasattr(module, 'weight'):
                          param = getattr(module, 'weight'); final_total_params += param.nelement(); final_zero_params += torch.sum(param == 0).item()
                 except AttributeError:
                       if hasattr(module, 'weight'):
                           param = getattr(module, 'weight'); final_total_params += param.nelement(); final_zero_params += torch.sum(param == 0).item()
                 prune.remove(module, 'weight'); pruned_module_count += 1
            elif hasattr(module, 'weight'):
                 param = getattr(module, 'weight'); final_total_params += param.nelement()
        logger.info(f"Removed pruning buffer from {pruned_module_count} modules.")
        if final_total_params > 0:
            final_sparsity = 100. * float(final_zero_params) / float(final_total_params)
            logger.info(f"Pruning made permanent. Calculated sparsity: {final_sparsity:.2f}%")
            return final_sparsity
        else:
            logger.info("Pruning made permanent (no parameters found or counted)."); return 0.0

# Function for inference benchmarking
@torch.no_grad()
def benchmark_inference(model, tokenizer, texts, device, batch_size=8, max_length=128, num_batches=50, generation=False):
    model.eval()
    latencies = []
    total_samples = 0
    logger.info(f"--- Starting Inference Benchmark (Generation: {generation}) ---")
    generation_config = GenerationConfig(max_new_tokens=5, pad_token_id=tokenizer.pad_token_id) if generation else None

    for i in tqdm(range(0, min(len(texts), batch_size * num_batches), batch_size), desc="Inference", leave=False):
        batch = texts[i:i+batch_size]
        if not batch: continue
        inputs = tokenizer(batch, padding=True, truncation=True, max_length=max_length, return_tensors="pt").to(device)
        batch_samples = inputs['input_ids'].shape[0]
        total_samples += batch_samples

        start_time = time.perf_counter()
        if generation:
            _ = model.generate(**inputs, generation_config=generation_config)
        else:
            _ = model(**inputs)
        if device == torch.device("cuda"): torch.cuda.synchronize()
        end_time = time.perf_counter()
        latencies.append(end_time - start_time)

    avg_latency_batch = np.mean(latencies) if latencies else 0
    throughput_samples_sec = total_samples / sum(latencies) if latencies else 0
    avg_latency_sample = (sum(latencies) / total_samples) * 1000 if total_samples > 0 else 0

    logger.info(f"--- Finished Inference Benchmark ---")
    return {
        "avg_inference_latency_ms_per_sample": avg_latency_sample,
        "avg_inference_throughput_samples_sec": throughput_samples_sec
    }

lora_cfg = LoraConfig(
    r=lora_r, lora_alpha=lora_alpha, target_modules=lora_target_modules,
    lora_dropout=0.05, bias="none", task_type="CAUSAL_LM"
)

all_results = {}
trained_models = {}


##  Train LoRA model(distilled model) with pre-pruning (pruning applied BEFORE training)

In [None]:
import torch
import time, math, logging
from transformers import GPT2LMHeadModel, GPT2Tokenizer, GenerationConfig,AutoModelForCausalLM
from torch.optim import AdamW
from datasets import load_dataset
from peft import LoraConfig, get_peft_model
from peft.tuners.lora import LoraLayer
import torch.nn.utils.prune as prune
import pandas as pd
import gc
import copy
import wandb
import matplotlib.pyplot as plt
import numpy as np
from tqdm.auto import tqdm
import os

logger.info("\n===== Starting LoRA Pre-Pruned on Distilled Model =====")
try:
    model_distilled_base = AutoModelForCausalLM.from_pretrained("./distilled_model").to(device)
    for param in model_distilled_base.parameters(): param.requires_grad = False

    lora_model_distilled = get_peft_model(model_distilled_base, lora_cfg)
    lora_model_distilled.print_trainable_parameters()

    logger.info(f"Applying {pruning_amount:.1%} pruning to distilled LoRA model BEFORE training...")
    apply_pruning(lora_model_distilled, pruning_amount)
    logger.info("Pruned distilled LoRA model.")

    distilled_lora_res = train_fixed_duration(
        lora_model_distilled, tokenizer, train_texts_full, val_texts_full,
        num_epochs=num_train_epochs, device=device, lr=lora_lr,
        batch_size=batch_size, max_length=max_length,
        eval_every=eval_every_steps, run_label="Distilled LoRA Pre-Pruned"
    )
    logger.info(f"Distilled LoRA Pre-Pruned results: {distilled_lora_res}")

    if 'final_ppl' in distilled_lora_res:
        all_results["Distilled LoRA Pre-Pruned"] = distilled_lora_res
        trained_models["Distilled LoRA Pre-Pruned"] = copy.deepcopy(lora_model_distilled)
        if run:
            wandb.log({f"Summary/Distilled LoRA Pre-Pruned/{k}": v for k, v in distilled_lora_res.items() if 'log' not in k})
except Exception as e:
    logger.error(f"Failed to train LoRA Pre-Pruned on Distilled model: {e}")


#===============================================   
distilled_lora_model_path = "./distilled_lora_prepruned"

os.makedirs(distilled_lora_model_path, exist_ok=True)
lora_model_distilled.save_pretrained(distilled_lora_model_path)
logger.info(f"Distilled LoRA Pre-Pruned model saved to {distilled_lora_model_path}")    
    
del model_distilled_base, lora_model_distilled; gc.collect(); torch.cuda.empty_cache() if device == torch.device("cuda") else None


if run_inference_benchmark:
    logger.info("\n===== Starting Inference Benchmarks =====")
    for label, model in trained_models.items():
        if model is not None:
            logger.info(f"--- Benchmarking Inference for: {label} ---")
            inference_res_fwd = benchmark_inference(
                model, tokenizer, test_texts_full, device,
                batch_size=inference_batch_size, max_length=max_length,
                num_batches=num_inference_batches, generation=False
            )
            inference_res_gen = benchmark_inference(
                model, tokenizer, test_texts_full, device,
                batch_size=inference_batch_size, max_length=max_length,
                num_batches=num_inference_batches // 2,
                generation=True
            )
            all_results[label].update({
                 "fwd_pass_latency_ms": inference_res_fwd["avg_inference_latency_ms_per_sample"],
                 "fwd_pass_throughput": inference_res_fwd["avg_inference_throughput_samples_sec"],
                 "gen_latency_ms": inference_res_gen["avg_inference_latency_ms_per_sample"],
                 "gen_throughput": inference_res_gen["avg_inference_throughput_samples_sec"]
            })
            logger.info(f"Inference Results for {label}: Fwd Latency={inference_res_fwd['avg_inference_latency_ms_per_sample']:.2f}ms, Gen Latency={inference_res_gen['avg_inference_latency_ms_per_sample']:.2f}ms")
            if run:
                wandb.log({
                    f"Summary/{label}/Fwd Pass Latency ms": inference_res_fwd["avg_inference_latency_ms_per_sample"],
                    f"Summary/{label}/Fwd Pass Throughput": inference_res_fwd["avg_inference_throughput_samples_sec"],
                    f"Summary/{label}/Gen Latency ms": inference_res_gen["avg_inference_latency_ms_per_sample"],
                    f"Summary/{label}/Gen Throughput": inference_res_gen["avg_inference_throughput_samples_sec"]
                })
        else:
            logger.warning(f"Skipping inference benchmark for {label} as training failed.")
        if model is not None: del model
        gc.collect(); torch.cuda.empty_cache() if device == torch.device("cuda") else None
    trained_models = {}

logger.info("\n===== Generating Learning Curve Plot =====")
plt.figure(figsize=(10, 6))
plot_data_logged = False
for label, results in all_results.items():
    if "steps_log" in results and "ppl_log" in results and results["steps_log"] and results["ppl_log"]:
        steps_axis = np.array(results["steps_log"])
        ppl_values = np.array(results["ppl_log"])
        mask = np.isfinite(ppl_values)
        if np.any(mask):
             plt.plot(steps_axis[mask], ppl_values[mask], marker='.', linestyle='-', label=label)
             plot_data_logged = True
        else:
             logger.warning(f"No valid PPL data points to plot for {label}.")

if plot_data_logged:
    plt.xlabel("Training Steps")
    plt.ylabel("Validation Perplexity")
    plt.title("Perplexity vs. Training Steps")
    plt.legend()
    plt.grid(True)
    min_ppl_overall = min(min(r["ppl_log"]) for r in all_results.values() if r and r.get("ppl_log")) if any(r and r.get("ppl_log") for r in all_results.values()) else 7
    plt.ylim(bottom=max(5, min(min_ppl_overall - 0.5, 7))) 

    if run:
        try:
            wandb.log({"Learning Curve": wandb.Image(plt)})
            logger.info("Learning curve plot logged to Weights & Biases.")
        except Exception as e:
            logger.error(f"Failed to log plot to wandb: {e}")
    plt.close()
else:
     logger.warning("No valid data found to generate learning curve plot.")

# Generate final comparison tables and statistics

In [None]:
logger.info("\n===== Final Benchmark Comparison =====")
results_list = []
indices = []

if "Baseline" in all_results:
    results_list.append(all_results["Baseline"])
    indices.append("Baseline")
if "LoRA Standard" in all_results:
    results_list.append(all_results["LoRA Standard"])
    indices.append("LoRA Standard")
if "LoRA Pre-Pruned" in all_results:
    results_list.append(all_results["LoRA Pre-Pruned"])
    indices.append(f"LoRA Pre-Pruned ({pruning_amount:.0%})")
    
if "Distilled LoRA Pre-Pruned" in all_results:
    results_list.append(all_results["Distilled LoRA Pre-Pruned"])
    indices.append("Distilled LoRA Pre-Pruned")

if results_list:
    df = pd.DataFrame(results_list, index=indices)
    df_display = df.drop(columns=['ppl_log', 'steps_log'], errors='ignore')

    cols_to_rename = {
        "steps": "Total Steps", "total_hours": "Total Hours (h)",
        "peak_mem_mb": "Peak GPU Mem (MB)", "final_ppl": "Final Val PPL"
    }
    if run_inference_benchmark:
        cols_to_rename.update({
            "fwd_pass_latency_ms": "Fwd Latency (ms)", "fwd_pass_throughput": "Fwd TP (samples/s)",
            "gen_latency_ms": "Gen Latency (ms)", "gen_throughput": "Gen TP (samples/s)"
        })

    df_display.rename(columns=cols_to_rename, inplace=True)

    format_map = {
        "Total Steps": '{:,.0f}', "Total Hours (h)": '{:.2f}',
        "Peak GPU Mem (MB)": '{:,.1f}', "Final Val PPL": '{:.2f}',
        "Fwd Latency (ms)": '{:.1f}', "Fwd TP (samples/s)": '{:.1f}',
        "Gen Latency (ms)": '{:.1f}', "Gen TP (samples/s)": '{:.1f}'
    }
    for col, fmt in format_map.items():
        if col in df_display.columns:
            try:
                df_display[col] = df_display[col].map(lambda x: fmt.format(x) if pd.notnull(x) else 'N/A')
            except (TypeError, ValueError):
                 logger.warning(f"Could not format column {col}. Skipping formatting.")


    logger.info("\nComparison DataFrame:\n%s", df_display.to_string())

    if run:
        try:
            df_log = df_display.reset_index().rename(columns={'index': 'Method'})
            wandb.log({"Comparison Table": wandb.Table(dataframe=df_log)})
            logger.info("Comparison table logged to Weights & Biases.")
        except Exception as e:
            logger.error(f"Failed to log DataFrame to Weights & Biases: {e}")

else:
    logger.error("No successful benchmark runs to compare.")

if run:
    wandb.finish()
    logger.info("Weights & Biases run finished.")

logger.info("\n===== Script Finished =====")


In [1]:
import torch
import time
import numpy as np
import logging
from transformers import GPT2LMHeadModel, GPT2Tokenizer
from peft import LoraConfig, get_peft_model
from datasets import load_dataset

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)
logger.info("Loading wikitext-2 dataset...")
dataset = load_dataset("wikitext", "wikitext-2-raw-v1")

2025-05-05 23:40:28,764 - INFO - Loading wikitext-2 dataset...


In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {device}")
base_name = "gpt2"
tokenizer = GPT2Tokenizer.from_pretrained(base_name)
tokenizer.pad_token = tokenizer.eos_token

distilled_model = GPT2LMHeadModel.from_pretrained( "./saved_models/distilled_lora_prepruned").to(device)

2025-05-05 23:42:54,195 - INFO - Using device: cuda


In [7]:
def benchmark_inference(model, tokenizer, dataset, batch_size, max_length=128, num_batches=100):
    import time, torch, numpy as np
    device = next(model.parameters()).device
    test_texts = [t for t in dataset["test"]["text"] if t.strip()]

    model.eval()
    infer_times, infer_mems = [], []
    infer_thrpts, infer_perps = [], []

    with torch.no_grad():
        for i in range(num_batches):
            batch = test_texts[i*batch_size:(i+1)*batch_size]
            if not batch:
                break

            inputs = tokenizer(batch, padding=True, truncation=True, max_length=max_length, return_tensors="pt").to(device)

            start = time.time()
            outputs = model(**inputs, labels=inputs.input_ids)
            loss = outputs.loss
            perp = torch.exp(loss).item()
            elapsed = time.time() - start
            mem = torch.cuda.memory_allocated(device) / 1024**2

            infer_times.append(elapsed)
            infer_mems.append(mem)
            infer_thrpts.append(batch_size / elapsed)
            infer_perps.append(perp)

    return {
        "time":       (np.mean(infer_times),   np.std(infer_times)),
        "memory":     (np.mean(infer_mems),    np.std(infer_mems)),
        "throughput": (np.mean(infer_thrpts),  np.std(infer_thrpts)),
        "perplexity": (np.mean(infer_perps),   np.std(infer_perps))
    }


In [12]:
max_length = 128
for bs in [8, 16, 32]:
    logger.info(f"\nRunning benchmark for Distilled LoRA model with batch_size={bs}, max_length={max_length}")
    logger.info("Configuration:")
    logger.info(f"  • Model:  Distilled LoRA on GPT-2")
    logger.info(f"  • Batch size:   {bs}")
    logger.info(f"  • Max length:   {max_length}")

    infer_stats = benchmark_inference(distilled_model, tokenizer, dataset, batch_size=bs)

    # Inference results
    i = infer_stats
    logger.info("\nInference:")
    logger.info(f"  Average time per batch:    {i['time'][0]:.4f} ± {i['time'][1]:.4f} seconds")
    logger.info(f"  Average memory usage:      {i['memory'][0]:.2f} ± {i['memory'][1]:.2f} MB")
    logger.info(f"  Average throughput:        {i['throughput'][0]:.2f} ± {i['throughput'][1]:.2f} samples/second")
    logger.info(f"  Average perplexity:        {i['perplexity'][0]:.4f} ± {i['perplexity'][1]:.4f}")

2025-05-05 23:51:41,950 - INFO - 
Running benchmark for Distilled LoRA model with batch_size=8, max_length=128
2025-05-05 23:51:41,952 - INFO - Configuration:
2025-05-05 23:51:41,952 - INFO -   • Model:  Distilled LoRA on GPT-2
2025-05-05 23:51:41,953 - INFO -   • Batch size:   8
2025-05-05 23:51:41,954 - INFO -   • Max length:   128
2025-05-05 23:51:49,327 - INFO - 
Inference:
2025-05-05 23:51:49,328 - INFO -   Average time per batch:    0.0673 ± 0.0170 seconds
2025-05-05 23:51:49,329 - INFO -   Average memory usage:      1077.80 ± 73.01 MB
2025-05-05 23:51:49,329 - INFO -   Average throughput:        149.14 ± 126.51 samples/second
2025-05-05 23:51:49,330 - INFO -   Average perplexity:        11.5527 ± 7.8329
2025-05-05 23:51:49,330 - INFO - 
Running benchmark for Distilled LoRA model with batch_size=16, max_length=128
2025-05-05 23:51:49,331 - INFO - Configuration:
2025-05-05 23:51:49,331 - INFO -   • Model:  Distilled LoRA on GPT-2
2025-05-05 23:51:49,331 - INFO -   • Batch size:   