# 04 - Teacher Output Caching

**Thesis Section Reference:** Chapter 3.8 - Knowledge Distillation Methods

This notebook caches teacher model outputs for knowledge distillation:
1. **KD1 (Logit-based):** Cache soft logits for both SST-2 and SQuAD
2. **KD2 (Sequence-level):** Generate and cache teacher answers for SQuAD
3. **KD3 (Feature-based):** Cache hidden states (memory-safe chunked saving)

## Memory Management Notes
- Uses fp32 for MPS stability
- Gradient checkpointing enabled
- Periodic cache clearing
- Automatic fallback to smaller teacher if OOM

In [1]:
# Standard setup
import os
import sys
import json
import gc
from pathlib import Path

import torch

ROOT_DIR = Path.cwd().parent if Path.cwd().name == "notebooks" else Path.cwd()
sys.path.insert(0, str(ROOT_DIR / "src"))

from dotenv import load_dotenv
load_dotenv(ROOT_DIR / ".env")

from config import load_config
from utils_seed import set_seed

config = load_config(str(ROOT_DIR / "configs" / "experiment.yaml"))
config.ensure_dirs()

SEED = config.get_seeds()[0]
set_seed(SEED)

# Device setup
if torch.backends.mps.is_available():
    DEVICE = torch.device("mps")
elif torch.cuda.is_available():
    DEVICE = torch.device("cuda")
else:
    DEVICE = torch.device("cpu")

print(f"Mode: {'FAST' if config.fast_mode else 'FULL'}")
print(f"Device: {DEVICE}")
print(f"Seed: {SEED}")

Mode: FAST
Device: mps
Seed: 42


In [2]:
# Set up paths
DATA_DIR = ROOT_DIR / "results" / "processed_data"
CACHE_DIR = ROOT_DIR / "results" / "teacher_cache"
CACHE_DIR.mkdir(parents=True, exist_ok=True)

print(f"Data directory: {DATA_DIR}")
print(f"Cache directory: {CACHE_DIR}")

Data directory: /Users/pjere/Workshop/thesis-exp/results/processed_data
Cache directory: /Users/pjere/Workshop/thesis-exp/results/teacher_cache


In [3]:
# Check what already exists
cache_status = {
    "sst2_logits": (CACHE_DIR / "sst2_logits.pt").exists(),
    "squad_logits": (CACHE_DIR / "squad_logits.pt").exists(),
    "squad_answers": (CACHE_DIR / "squad_teacher_answers.json").exists(),
    "hidden_states": any(CACHE_DIR.glob("hidden_states_*.pt"))
}

print("Cache status:")
for name, exists in cache_status.items():
    status = "✓ exists" if exists else "✗ needs generation"
    print(f"  {name}: {status}")

Cache status:
  sst2_logits: ✗ needs generation
  squad_logits: ✗ needs generation
  squad_answers: ✗ needs generation
  hidden_states: ✗ needs generation


In [4]:
# Load teacher model with fallback
from transformers import AutoModelForCausalLM, AutoTokenizer

def load_teacher_model(config, device):
    """Load teacher with automatic fallback to smaller model."""
    # Try primary teacher first
    primary = os.getenv("TEACHER_PRIMARY", config.teacher.primary)
    fallback = os.getenv("TEACHER_FALLBACK", config.teacher.local_fallback)
    
    for model_name in [primary, fallback]:
        try:
            print(f"Attempting to load: {model_name}")
            
            tokenizer = AutoTokenizer.from_pretrained(
                model_name,
                trust_remote_code=True,
                cache_dir=str(ROOT_DIR / "hf_cache")
            )
            
            if tokenizer.pad_token is None:
                tokenizer.pad_token = tokenizer.eos_token
            
            model = AutoModelForCausalLM.from_pretrained(
                model_name,
                trust_remote_code=True,
                dtype=torch.float32,  # fp32 for MPS stability
                cache_dir=str(ROOT_DIR / "hf_cache"),
                low_cpu_mem_usage=True
            )
            
            model = model.to(device)
            model.eval()
            
            # Enable gradient checkpointing for memory
            if hasattr(model, 'gradient_checkpointing_enable'):
                model.gradient_checkpointing_enable()
            
            print(f"✓ Successfully loaded: {model_name}")
            return model, tokenizer, model_name
            
        except Exception as e:
            print(f"Failed to load {model_name}: {e}")
            if device.type == "mps":
                torch.mps.empty_cache()
            gc.collect()
            continue
    
    raise RuntimeError("Failed to load any teacher model")

# Only load if we need to generate something
needs_teacher = not all(cache_status.values())

if needs_teacher:
    teacher_model, teacher_tokenizer, teacher_name = load_teacher_model(config, DEVICE)
    print(f"\nTeacher model parameters: {sum(p.numel() for p in teacher_model.parameters()) / 1e9:.2f}B")
else:
    print("All caches exist, skipping teacher loading.")

Attempting to load: meta-llama/Llama-3.2-8B-Instruct
Failed to load meta-llama/Llama-3.2-8B-Instruct: meta-llama/Llama-3.2-8B-Instruct is not a local folder and is not a valid model identifier listed on 'https://huggingface.co/models'
If this is a private repository, make sure to pass a token having permission to this repo either by logging in with `hf auth login` or by passing `token=<your_token>`
Attempting to load: Qwen/Qwen2.5-3B-Instruct


Loading weights:   0%|          | 0/434 [00:00<?, ?it/s]

✓ Successfully loaded: Qwen/Qwen2.5-3B-Instruct

Teacher model parameters: 3.09B


In [5]:
# Load processed datasets
from datasets import load_from_disk

if needs_teacher:
    print("Loading processed datasets...")
    
    sst2_train = load_from_disk(str(DATA_DIR / "sst2_train"))
    squad_train = load_from_disk(str(DATA_DIR / "squad_train"))
    
    print(f"  SST-2 train: {len(sst2_train)} examples")
    print(f"  SQuAD train: {len(squad_train)} examples")
    
    # Load prompts for KD2
    with open(DATA_DIR / "squad_train_prompts.json", "r") as f:
        squad_prompts = json.load(f)

Loading processed datasets...
  SST-2 train: 2000 examples
  SQuAD train: 2000 examples


## KD1: Cache Teacher Logits

For logit-based KD, we need the teacher's output logits for each training example.
We'll store the top-k logits to save memory.

In [6]:
# Cache SST-2 teacher logits (memory-efficient approach)
from tqdm.auto import tqdm

if not cache_status["sst2_logits"]:
    print("Caching SST-2 teacher logits...")
    print("Using memory-efficient chunked approach with top-k storage")
    
    TOP_K = 50  # Only store top-50 logits per position to save memory
    CHUNK_SIZE = 50  # Save every 50 examples
    
    all_logits = []
    all_indices = []
    chunk_idx = 0
    
    teacher_model.eval()
    
    with torch.no_grad():
        for i in tqdm(range(len(sst2_train)), desc="SST-2 logits"):
            example = sst2_train[i]
            
            # Handle both tensor and list formats
            if isinstance(example["input_ids"], torch.Tensor):
                input_ids = example["input_ids"].unsqueeze(0).to(DEVICE)
                attention_mask = example["attention_mask"].unsqueeze(0).to(DEVICE)
            else:
                input_ids = torch.tensor([example["input_ids"]], device=DEVICE)
                attention_mask = torch.tensor([example["attention_mask"]], device=DEVICE)
            
            outputs = teacher_model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                output_hidden_states=False
            )
            
            # Get top-k logits only (huge memory savings)
            logits = outputs.logits[0]  # [seq_len, vocab_size]
            top_values, top_indices = torch.topk(logits, k=TOP_K, dim=-1)
            
            all_logits.append(top_values.cpu().half())
            all_indices.append(top_indices.cpu().int())
            
            # Clear intermediate tensors
            del outputs, logits, top_values, top_indices, input_ids, attention_mask
            
            # Save chunk periodically
            if len(all_logits) >= CHUNK_SIZE:
                chunk_data = {
                    "logits": [l for l in all_logits],
                    "indices": [idx for idx in all_indices]
                }
                torch.save(chunk_data, CACHE_DIR / f"sst2_logits_chunk_{chunk_idx}.pt")
                print(f"  Saved chunk {chunk_idx}")
                all_logits = []
                all_indices = []
                chunk_idx += 1
                
                if DEVICE.type == "mps":
                    torch.mps.empty_cache()
                gc.collect()
    
    # Save final chunk
    if all_logits:
        chunk_data = {
            "logits": all_logits,
            "indices": all_indices
        }
        torch.save(chunk_data, CACHE_DIR / f"sst2_logits_chunk_{chunk_idx}.pt")
        print(f"  Saved final chunk {chunk_idx}")
        chunk_idx += 1
    
    # Save metadata
    meta = {"num_chunks": chunk_idx, "top_k": TOP_K, "task": "sst2"}
    torch.save(meta, CACHE_DIR / "sst2_logits.pt")  # Marker file with metadata
    
    print(f"✓ Saved SST-2 logits in {chunk_idx} chunks")
    
    if DEVICE.type == "mps":
        torch.mps.empty_cache()
    gc.collect()
else:
    print("✓ SST-2 logits already cached")

Caching SST-2 teacher logits...
Using memory-efficient chunked approach with top-k storage


SST-2 logits:   0%|          | 0/2000 [00:00<?, ?it/s]

  Saved chunk 0
  Saved chunk 1
  Saved chunk 2
  Saved chunk 3
  Saved chunk 4
  Saved chunk 5
  Saved chunk 6
  Saved chunk 7
  Saved chunk 8
  Saved chunk 9
  Saved chunk 10
  Saved chunk 11
  Saved chunk 12
  Saved chunk 13
  Saved chunk 14
  Saved chunk 15
  Saved chunk 16
  Saved chunk 17
  Saved chunk 18
  Saved chunk 19
  Saved chunk 20
  Saved chunk 21
  Saved chunk 22
  Saved chunk 23
  Saved chunk 24
  Saved chunk 25
  Saved chunk 26
  Saved chunk 27
  Saved chunk 28
  Saved chunk 29
  Saved chunk 30
  Saved chunk 31
  Saved chunk 32
  Saved chunk 33
  Saved chunk 34
  Saved chunk 35
  Saved chunk 36
  Saved chunk 37
  Saved chunk 38
  Saved chunk 39
✓ Saved SST-2 logits in 40 chunks


In [7]:
# Cache SQuAD teacher logits (memory-efficient approach)
if not cache_status["squad_logits"]:
    print("Caching SQuAD teacher logits...")
    print("Using memory-efficient chunked approach with top-k storage")
    
    TOP_K = 50
    CHUNK_SIZE = 25  # Smaller chunks for longer sequences
    
    all_logits = []
    all_indices = []
    chunk_idx = 0
    
    with torch.no_grad():
        for i in tqdm(range(len(squad_train)), desc="SQuAD logits"):
            example = squad_train[i]
            
            if isinstance(example["input_ids"], torch.Tensor):
                input_ids = example["input_ids"].unsqueeze(0).to(DEVICE)
                attention_mask = example["attention_mask"].unsqueeze(0).to(DEVICE)
            else:
                input_ids = torch.tensor([example["input_ids"]], device=DEVICE)
                attention_mask = torch.tensor([example["attention_mask"]], device=DEVICE)
            
            outputs = teacher_model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                output_hidden_states=False
            )
            
            logits = outputs.logits[0]
            top_values, top_indices = torch.topk(logits, k=TOP_K, dim=-1)
            
            all_logits.append(top_values.cpu().half())
            all_indices.append(top_indices.cpu().int())
            
            del outputs, logits, top_values, top_indices, input_ids, attention_mask
            
            if len(all_logits) >= CHUNK_SIZE:
                chunk_data = {
                    "logits": all_logits,
                    "indices": all_indices
                }
                torch.save(chunk_data, CACHE_DIR / f"squad_logits_chunk_{chunk_idx}.pt")
                print(f"  Saved chunk {chunk_idx}")
                all_logits = []
                all_indices = []
                chunk_idx += 1
                
                if DEVICE.type == "mps":
                    torch.mps.empty_cache()
                gc.collect()
    
    if all_logits:
        chunk_data = {
            "logits": all_logits,
            "indices": all_indices
        }
        torch.save(chunk_data, CACHE_DIR / f"squad_logits_chunk_{chunk_idx}.pt")
        print(f"  Saved final chunk {chunk_idx}")
        chunk_idx += 1
    
    meta = {"num_chunks": chunk_idx, "top_k": TOP_K, "task": "squad"}
    torch.save(meta, CACHE_DIR / "squad_logits.pt")
    
    print(f"✓ Saved SQuAD logits in {chunk_idx} chunks")
    
    if DEVICE.type == "mps":
        torch.mps.empty_cache()
    gc.collect()
else:
    print("✓ SQuAD logits already cached")

Caching SQuAD teacher logits...
Using memory-efficient chunked approach with top-k storage


SQuAD logits:   0%|          | 0/2000 [00:00<?, ?it/s]

  Saved chunk 0
  Saved chunk 1
  Saved chunk 2
  Saved chunk 3
  Saved chunk 4
  Saved chunk 5
  Saved chunk 6
  Saved chunk 7
  Saved chunk 8
  Saved chunk 9
  Saved chunk 10
  Saved chunk 11
  Saved chunk 12
  Saved chunk 13
  Saved chunk 14
  Saved chunk 15
  Saved chunk 16
  Saved chunk 17
  Saved chunk 18
  Saved chunk 19
  Saved chunk 20
  Saved chunk 21
  Saved chunk 22
  Saved chunk 23
  Saved chunk 24
  Saved chunk 25
  Saved chunk 26
  Saved chunk 27
  Saved chunk 28
  Saved chunk 29
  Saved chunk 30
  Saved chunk 31
  Saved chunk 32
  Saved chunk 33
  Saved chunk 34
  Saved chunk 35
  Saved chunk 36
  Saved chunk 37
  Saved chunk 38
  Saved chunk 39
  Saved chunk 40
  Saved chunk 41
  Saved chunk 42
  Saved chunk 43
  Saved chunk 44
  Saved chunk 45
  Saved chunk 46
  Saved chunk 47
  Saved chunk 48
  Saved chunk 49
  Saved chunk 50
  Saved chunk 51
  Saved chunk 52
  Saved chunk 53
  Saved chunk 54
  Saved chunk 55
  Saved chunk 56
  Saved chunk 57
  Saved chunk 58
  Saved

## KD2: Generate Teacher Answers (Sequence-level KD)

For sequence-level KD on SQuAD, we generate teacher's predicted answers.
The student learns to mimic the teacher's generated sequences.

In [8]:
# Generate teacher answers for SQuAD
from tqdm.auto import tqdm

if not cache_status["squad_answers"]:
    print("Generating teacher answers for SQuAD...")
    
    teacher_answers = []
    
    # Generation config
    gen_config = {
        "max_new_tokens": 64,
        "do_sample": False,  # Greedy for reproducibility
        "pad_token_id": teacher_tokenizer.pad_token_id,
        "eos_token_id": teacher_tokenizer.eos_token_id,
    }
    
    batch_size = 1  # One at a time for MPS stability
    
    with torch.no_grad():
        for i in tqdm(range(0, len(squad_prompts), batch_size), desc="Generating"):
            batch = squad_prompts[i:i+batch_size]
            
            for item in batch:
                inputs = teacher_tokenizer(
                    item["prompt"],
                    return_tensors="pt",
                    max_length=config.get_max_length("squad") - 64,
                    truncation=True
                ).to(DEVICE)
                
                outputs = teacher_model.generate(
                    **inputs,
                    **gen_config
                )
                
                # Decode only new tokens
                generated = teacher_tokenizer.decode(
                    outputs[0][inputs["input_ids"].shape[1]:],
                    skip_special_tokens=True
                ).strip()
                
                teacher_answers.append({
                    "id": item["id"],
                    "prompt": item["prompt"],
                    "teacher_answer": generated,
                    "gold_answers": item["gold_answers"]
                })
            
            # Periodic cleanup
            if i % 50 == 0 and DEVICE.type == "mps":
                torch.mps.empty_cache()
    
    # Save
    with open(CACHE_DIR / "squad_teacher_answers.json", "w") as f:
        json.dump(teacher_answers, f, indent=2)
    
    print(f"\n✓ Generated {len(teacher_answers)} teacher answers")
    
    # Show samples
    print("\nSample teacher answers:")
    for sample in teacher_answers[:3]:
        print(f"  Q: {sample['prompt'].split('Question:')[1].split('Context:')[0].strip()[:50]}...")
        print(f"  Teacher: {sample['teacher_answer'][:50]}...")
        print(f"  Gold: {sample['gold_answers'][0]}")
        print()
else:
    print("✓ SQuAD teacher answers already cached")

Generating teacher answers for SQuAD...


Generating:   0%|          | 0/2000 [00:00<?, ?it/s]

The following generation flags are not valid and may be ignored: ['temperature', 'top_p', 'top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.



✓ Generated 2000 teacher answers

Sample teacher answers:
  Q: What percentage of Egyptians polled support death ...
  Teacher: According to the context provided, 84% of Egyptian...
  Gold: 84%

  Q: Ann Arbor ranks 1st among what goods sold?

Answer...
  Teacher: Ann Arbor ranks 1st among U.S. cities in the numbe...
  Gold: books

  Q: In developing countries, who makes most of the spe...
  Teacher: In developing countries, the executive branch make...
  Gold: the executive



## KD3: Cache Hidden States (Feature-based KD)

For feature-based KD, we cache teacher's hidden states from selected layers.
This is memory-intensive, so we:
1. Store only selected layers (e.g., every 4th layer)
2. Use float16 for storage
3. Save in chunks

In [9]:
# Cache hidden states (memory-safe)
if not cache_status["hidden_states"]:
    print("Caching hidden states for KD3...")
    print("Note: This is memory-intensive. Using chunked saving.")
    
    # Configuration
    LAYER_STRIDE = 4  # Cache every 4th layer
    CHUNK_SIZE = 100  # Save every 100 examples
    USE_SST2_ONLY = config.fast_mode  # In fast mode, only cache SST-2
    
    dataset = sst2_train
    task_name = "sst2"
    
    num_layers = teacher_model.config.num_hidden_layers
    selected_layers = list(range(0, num_layers, LAYER_STRIDE))
    print(f"  Caching layers: {selected_layers} (of {num_layers} total)")
    
    all_hidden_states = []
    chunk_idx = 0
    
    with torch.no_grad():
        for i in tqdm(range(len(dataset)), desc="Extracting hidden states"):
            example = dataset[i]
            
            # Prepare input
            input_ids = torch.tensor([example["input_ids"]], device=DEVICE)
            attention_mask = torch.tensor([example["attention_mask"]], device=DEVICE)
            
            # Get hidden states
            outputs = teacher_model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                output_hidden_states=True
            )
            
            # Extract selected layers and convert to float16
            layer_states = []
            for layer_idx in selected_layers:
                # Mean pool over sequence length to reduce memory
                layer_output = outputs.hidden_states[layer_idx + 1]  # +1 for embeddings
                pooled = layer_output.mean(dim=1)  # [1, hidden_size]
                layer_states.append(pooled.cpu().half())
            
            # Stack layers: [num_selected_layers, hidden_size]
            stacked = torch.cat(layer_states, dim=0)
            all_hidden_states.append(stacked)
            
            # Save chunk
            if len(all_hidden_states) >= CHUNK_SIZE:
                chunk_tensor = torch.stack(all_hidden_states)
                torch.save(chunk_tensor, CACHE_DIR / f"hidden_states_{task_name}_{chunk_idx}.pt")
                print(f"  Saved chunk {chunk_idx} ({len(all_hidden_states)} examples)")
                all_hidden_states = []
                chunk_idx += 1
                
                if DEVICE.type == "mps":
                    torch.mps.empty_cache()
    
    # Save final chunk
    if all_hidden_states:
        chunk_tensor = torch.stack(all_hidden_states)
        torch.save(chunk_tensor, CACHE_DIR / f"hidden_states_{task_name}_{chunk_idx}.pt")
        print(f"  Saved chunk {chunk_idx} ({len(all_hidden_states)} examples)")
    
    # Save metadata
    hidden_state_meta = {
        "task": task_name,
        "num_chunks": chunk_idx + 1,
        "selected_layers": selected_layers,
        "total_layers": num_layers,
        "hidden_size": teacher_model.config.hidden_size,
        "pooling": "mean",
        "dtype": "float16"
    }
    
    with open(CACHE_DIR / f"hidden_states_{task_name}_meta.json", "w") as f:
        json.dump(hidden_state_meta, f, indent=2)
    
    print(f"\n✓ Saved hidden states in {chunk_idx + 1} chunks")
else:
    print("✓ Hidden states already cached")

Caching hidden states for KD3...
Note: This is memory-intensive. Using chunked saving.
  Caching layers: [0, 4, 8, 12, 16, 20, 24, 28, 32] (of 36 total)


Extracting hidden states:   0%|          | 0/2000 [00:00<?, ?it/s]

  Saved chunk 0 (100 examples)
  Saved chunk 1 (100 examples)
  Saved chunk 2 (100 examples)
  Saved chunk 3 (100 examples)
  Saved chunk 4 (100 examples)
  Saved chunk 5 (100 examples)
  Saved chunk 6 (100 examples)
  Saved chunk 7 (100 examples)
  Saved chunk 8 (100 examples)
  Saved chunk 9 (100 examples)
  Saved chunk 10 (100 examples)
  Saved chunk 11 (100 examples)
  Saved chunk 12 (100 examples)
  Saved chunk 13 (100 examples)
  Saved chunk 14 (100 examples)
  Saved chunk 15 (100 examples)
  Saved chunk 16 (100 examples)
  Saved chunk 17 (100 examples)
  Saved chunk 18 (100 examples)
  Saved chunk 19 (100 examples)

✓ Saved hidden states in 21 chunks


In [10]:
# Cleanup
if needs_teacher:
    print("Cleaning up...")
    del teacher_model
    if 'cache' in dir():
        del cache
    if DEVICE.type == "mps":
        torch.mps.empty_cache()
    gc.collect()
    print("✓ Memory freed")

Cleaning up...
✓ Memory freed


In [11]:
# Verify all caches
print("Verifying cached files...")
print()

cache_files = list(CACHE_DIR.iterdir())
total_size = sum(f.stat().st_size for f in cache_files if f.is_file())

print(f"Cache directory: {CACHE_DIR}")
print(f"Total size: {total_size / 1024 / 1024:.2f} MB")
print()

for f in sorted(cache_files):
    if f.is_file():
        size_mb = f.stat().st_size / 1024 / 1024
        print(f"  {f.name}: {size_mb:.2f} MB")

Verifying cached files...

Cache directory: /Users/pjere/Workshop/thesis-exp/results/teacher_cache
Total size: 514.70 MB

  .gitkeep: 0.00 MB
  hidden_states_sst2_0.pt: 3.52 MB
  hidden_states_sst2_1.pt: 3.52 MB
  hidden_states_sst2_10.pt: 3.52 MB
  hidden_states_sst2_11.pt: 3.52 MB
  hidden_states_sst2_12.pt: 3.52 MB
  hidden_states_sst2_13.pt: 3.52 MB
  hidden_states_sst2_14.pt: 3.52 MB
  hidden_states_sst2_15.pt: 3.52 MB
  hidden_states_sst2_16.pt: 3.52 MB
  hidden_states_sst2_17.pt: 3.52 MB
  hidden_states_sst2_18.pt: 3.52 MB
  hidden_states_sst2_19.pt: 3.52 MB
  hidden_states_sst2_2.pt: 3.52 MB
  hidden_states_sst2_3.pt: 3.52 MB
  hidden_states_sst2_4.pt: 3.52 MB
  hidden_states_sst2_5.pt: 3.52 MB
  hidden_states_sst2_6.pt: 3.52 MB
  hidden_states_sst2_7.pt: 3.52 MB
  hidden_states_sst2_8.pt: 3.52 MB
  hidden_states_sst2_9.pt: 3.52 MB
  hidden_states_sst2_meta.json: 0.00 MB
  squad_logits.pt: 0.00 MB
  squad_logits_chunk_0.pt: 3.68 MB
  squad_logits_chunk_1.pt: 3.68 MB
  squad_log

In [12]:
# Summary
print("=" * 60)
print("TEACHER OUTPUT CACHING COMPLETE")
print("=" * 60)
print(f"""
Teacher Model: {teacher_name if 'teacher_name' in dir() else 'N/A (used cache)'}
Mode: {'FAST' if config.fast_mode else 'FULL'}

Cached Outputs:
  KD1 (Logit-based):
    - sst2_logits.pt
    - squad_logits.pt
  
  KD2 (Sequence-level):
    - squad_teacher_answers.json
  
  KD3 (Feature-based):
    - hidden_states_*.pt (chunked)

Next Steps:
  1. Run 05_train_baseline_and_kd1.ipynb for baseline and logit KD
  2. Run 06_train_kd2_and_kd3.ipynb for sequence and feature KD
""")

TEACHER OUTPUT CACHING COMPLETE

Teacher Model: Qwen/Qwen2.5-3B-Instruct
Mode: FAST

Cached Outputs:
  KD1 (Logit-based):
    - sst2_logits.pt
    - squad_logits.pt

  KD2 (Sequence-level):
    - squad_teacher_answers.json

  KD3 (Feature-based):
    - hidden_states_*.pt (chunked)

Next Steps:
  1. Run 05_train_baseline_and_kd1.ipynb for baseline and logit KD
  2. Run 06_train_kd2_and_kd3.ipynb for sequence and feature KD

