<a href="https://colab.research.google.com/github/vinitvshah/SLM/blob/feature/Train_SLM_Distillation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Knowledge Distillation: Training a Small Fraud Detection Model
## Using Mistral-7B as Teacher to Train a 125M-1B Parameter Student Model

This notebook demonstrates **knowledge distillation** to create a small, fast, production-ready fraud detection model:

### Architecture
- **Teacher Model**: Mistral-7B (7B parameters) - generates soft labels
- **Student Model**: Custom transformer (125M-1B parameters) - learns from teacher
- **Result**: 10-50x faster inference with 85-95% of teacher's accuracy

### Why Knowledge Distillation?
| Approach | Training Cost | Inference Speed | Quality | Deployment |
|----------|--------------|-----------------|---------|------------|
| Fine-tune Mistral-7B | Low | Slow (~500ms) | Best | Hard (28GB GPU) |
| Train from Scratch | Very High | Fast | Poor | Easy |
| **Knowledge Distillation** | Medium | **Fast (~20ms)** | Good (85-95%) | **Easy** |

## 1. Setup and Configuration

In [1]:
# ============================================================================
# INSTALL REQUIRED PACKAGES
# ============================================================================
# !pip install -U transformers accelerate bitsandbytes
# !pip install -U torch torchvision torchaudio
# !pip install pandas scikit-learn tqdm tensorrt vllm

import os
import json
import math
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
from datetime import datetime, timedelta
from dataclasses import dataclass, field
from typing import Optional, Tuple, List, Dict
from sklearn.model_selection import train_test_split
import random
import uuid
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

# ============================================================================
# GPU CHECK
# ============================================================================
print("=" * 70)
print("GPU Configuration for Knowledge Distillation")
print("=" * 70)

if torch.cuda.is_available():
    device = torch.device("cuda")
    gpu_name = torch.cuda.get_device_name(0)
    gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1e9
    print(f"‚úÖ GPU Available: {gpu_name}")
    print(f"   Total Memory: {gpu_memory:.2f} GB")

    if gpu_memory >= 24:
        print("   ‚úÖ Sufficient for teacher + student training")
    elif gpu_memory >= 16:
        print("   ‚ö†Ô∏è  Will use aggressive quantization for teacher")
    else:
        print("   ‚ö†Ô∏è  Limited memory - consider smaller student model")
else:
    device = torch.device("cpu")
    print("‚ùå No GPU available. Training will be very slow.")

print(f"\nUsing device: {device}")

GPU Configuration for Knowledge Distillation
‚úÖ GPU Available: NVIDIA A100-SXM4-40GB
   Total Memory: 42.47 GB
   ‚úÖ Sufficient for teacher + student training

Using device: cuda


In [2]:
# ============================================================================
# CONFIGURATION
# ============================================================================

@dataclass
class DistillationConfig:
    """Configuration for knowledge distillation training"""

    # Student model size: "tiny" (25M), "small" (125M), "medium" (350M), "large" (1B)
    student_size: str = "small"

    # Model configurations by size
    MODEL_CONFIGS: Dict = field(default_factory=lambda: {
        "tiny":   {"n_layers": 4,  "n_heads": 4,  "d_model": 256,  "d_ff": 1024,  "params": "~25M"},
        "small":  {"n_layers": 6,  "n_heads": 8,  "d_model": 512,  "d_ff": 2048,  "params": "~125M"},
        "medium": {"n_layers": 12, "n_heads": 12, "d_model": 768,  "d_ff": 3072,  "params": "~350M"},
        "large":  {"n_layers": 24, "n_heads": 16, "d_model": 1024, "d_ff": 4096,  "params": "~1B"},
    })

    # Training hyperparameters
    max_length: int = 512
    batch_size: int = 4
    gradient_accumulation_steps: int = 4  # Effective batch = 16
    learning_rate: float = 1e-4
    num_epochs: int = 10
    warmup_ratio: float = 0.1
    weight_decay: float = 0.01
    max_grad_norm: float = 1.0

    # Distillation hyperparameters
    temperature: float = 2.0      # Softens teacher probabilities
    alpha: float = 0.7            # Weight: distillation vs hard label loss

    # Teacher model
    teacher_model: str = "mistralai/Mistral-7B-v0.1"

    # Output paths
    output_dir: str = "./fraud_slm_distilled"
    checkpoint_dir: str = "./fraud_slm_checkpoints"

config = DistillationConfig()
student_cfg = config.MODEL_CONFIGS[config.student_size]

# Create output directories
os.makedirs(config.output_dir, exist_ok=True)
os.makedirs(config.checkpoint_dir, exist_ok=True)

print("=" * 70)
print("Distillation Configuration")
print("=" * 70)
print(f"\nüìä Student Model:")
print(f"   Size: {config.student_size}")
print(f"   Parameters: {student_cfg['params']}")
print(f"   Layers: {student_cfg['n_layers']}, Heads: {student_cfg['n_heads']}")
print(f"   Hidden: {student_cfg['d_model']}, FFN: {student_cfg['d_ff']}")

print(f"\nüéì Teacher Model: {config.teacher_model}")

print(f"\n‚öôÔ∏è  Training:")
print(f"   Epochs: {config.num_epochs}")
print(f"   Batch size: {config.batch_size} x {config.gradient_accumulation_steps} = {config.batch_size * config.gradient_accumulation_steps}")
print(f"   Learning rate: {config.learning_rate}")

print(f"\nüî• Distillation:")
print(f"   Temperature: {config.temperature}")
print(f"   Alpha (distill weight): {config.alpha}")

Distillation Configuration

üìä Student Model:
   Size: small
   Parameters: ~125M
   Layers: 6, Heads: 8
   Hidden: 512, FFN: 2048

üéì Teacher Model: mistralai/Mistral-7B-v0.1

‚öôÔ∏è  Training:
   Epochs: 10
   Batch size: 4 x 4 = 16
   Learning rate: 0.0001

üî• Distillation:
   Temperature: 2.0
   Alpha (distill weight): 0.7


## 2. Generate Fraud Training Data

Generate 10,000 synthetic records with 1-year historical context for each customer.

In [3]:
# ============================================================================
# GENERATE FRAUD TRAINING DATA - 10,000 RECORDS WITH 1-YEAR HISTORY
# ============================================================================

import pandas as pd
import numpy as np
from datetime import datetime, timedelta
import uuid
import os

def generate_historical_fraud_data(num_customers=2000, events_per_customer=5):
    """
    Generate synthetic fraud dataset with 1-year historical context.

    Creates denormalized records with:
    - Customer and account identifiers
    - 1-year transaction history
    - Current transaction details
    - Fraud/Legitimate labels

    Args:
        num_customers: Number of unique customers
        events_per_customer: Training samples per customer

    Returns:
        DataFrame with 10,000 training records
    """
    np.random.seed(42)  # Reproducibility

    # Merchant categories
    legit_merchants = {
        'retail': ['Walmart', 'Target', 'Costco', 'Amazon', 'Best Buy'],
        'food': ['McDonalds', 'Starbucks', 'Chipotle', 'Subway', 'Dominos'],
        'gas': ['Shell', 'Chevron', 'BP', 'ExxonMobil', 'Valero'],
        'utilities': ['Electric Co', 'Gas Utility', 'Water Works', 'Internet Provider']
    }

    fraud_merchants = [
        'CryptoExchange', 'OnlineCasino', 'WireTransfer', 'UnknownMerchant',
        'OverseasATM', 'GiftCardStore', 'SuspiciousVendor', 'HighRiskATM'
    ]

    event_types = ['purchase', 'refund', 'transfer', 'withdrawal', 'payment']

    records = []

    for i in range(num_customers):
        if i % 500 == 0:
            print(f"   Processing customer {i+1}/{num_customers}...")

        customer_id = f"CUST_{uuid.uuid4().hex[:10].upper()}"
        account_id = f"ACC_{uuid.uuid4().hex[:8].upper()}"

        # 30% fraud customers
        is_fraud_customer = np.random.random() < 0.3

        # Generate 1-year history
        base_date = datetime.now()
        history_events = []

        for month_offset in range(12, 0, -1):
            num_events = np.random.randint(5, 20)

            for _ in range(num_events):
                event_date = base_date - timedelta(days=month_offset * 30 + np.random.randint(0, 30))
                event_type = np.random.choice(event_types)

                # Fraud patterns appear in recent 2 months
                if is_fraud_customer and month_offset <= 2 and np.random.random() < 0.4:
                    merchant = np.random.choice(fraud_merchants)
                    amount = np.random.randint(1000, 15000)
                    status = 'suspicious'
                else:
                    category = np.random.choice(list(legit_merchants.keys()))
                    merchant = np.random.choice(legit_merchants[category])
                    amount = np.random.randint(10, 500)
                    status = 'approved'

                history_events.append({
                    'date': event_date.strftime('%Y-%m-%d'),
                    'type': event_type,
                    'merchant': merchant,
                    'amount': amount,
                    'status': status
                })

        history_events.sort(key=lambda x: x['date'])

        # Create training samples
        for _ in range(events_per_customer):
            event_id = f"EVT_{uuid.uuid4().hex[:8].upper()}"
            context_size = min(10, len(history_events))
            recent_history = history_events[-context_size:]

            # Build history text
            history_text = "\n".join([
                f"[{evt['date']}] {evt['type'].upper()}: {evt['merchant']}, ${evt['amount']}, {evt['status']}"
                for evt in recent_history
            ])

            # Current transaction
            current_date = base_date.strftime('%Y-%m-%d')
            if is_fraud_customer and np.random.random() < 0.5:
                current_merchant = np.random.choice(fraud_merchants)
                current_amount = np.random.randint(2000, 20000)
                label = "FRAUD"
                label_id = 1
            else:
                category = np.random.choice(list(legit_merchants.keys()))
                current_merchant = np.random.choice(legit_merchants[category])
                current_amount = np.random.randint(10, 500)
                label = "LEGITIMATE"
                label_id = 0

            # Format prompt
            prompt = f"""<|system|>
You are a fraud detection system. Analyze the transaction history and current transaction to determine if it is FRAUD or LEGITIMATE.
<|user|>
Customer: {customer_id} | Account: {account_id}

=== TRANSACTION HISTORY ===
{history_text}

=== CURRENT TRANSACTION ===
[{current_date}] {current_merchant}, ${current_amount}

Is this transaction FRAUD or LEGITIMATE?
<|assistant|>
Based on the transaction history and current transaction, this is: """

            # Full text includes label (for training)
            full_text = prompt + label

            records.append({
                'customer_id': customer_id,
                'account_id': account_id,
                'event_id': event_id,
                'prompt': prompt,
                'text': full_text,
                'label': label,
                'label_id': label_id,
                'current_amount': current_amount,
                'current_merchant': current_merchant,
                'history_text': history_text
            })

    return pd.DataFrame(records)


# Generate dataset
print("\n" + "=" * 70)
print("Generating Fraud Detection Dataset")
print("=" * 70)

df = generate_historical_fraud_data(num_customers=2000, events_per_customer=5)

print(f"\n‚úÖ Dataset generated:")
print(f"   Total records: {len(df):,}")
print(f"   Unique customers: {df['customer_id'].nunique():,}")
print(f"   Fraud rate: {(df['label'] == 'FRAUD').mean():.2%}")

# ============================================================================
# SAVE DATASET TO CSV FILE (Current working directory)
# ============================================================================

# Use current working directory
CSV_FILE = "fraud_training_data.csv"

# Save to CSV
df.to_csv(CSV_FILE, index=False)

print(f"\nüíæ Dataset saved to CSV:")
print(f"   File path: {os.path.abspath(CSV_FILE)}")
print(f"   File size: {os.path.getsize(CSV_FILE) / (1024*1024):.2f} MB")
print(f"   Columns: {list(df.columns)}")

print(f"\nüìù Sample prompt:")
print("=" * 70)
print(df.iloc[0]['prompt'][:600] + "...")

# Display data summary
print(f"\nüìä Data Summary:")
print(f"   Fraud transactions: {(df['label'] == 'FRAUD').sum():,}")
print(f"   Legitimate transactions: {(df['label'] == 'LEGITIMATE').sum():,}")
print(f"   Avg current amount (Fraud): ${df[df['label'] == 'FRAUD']['current_amount'].mean():,.2f}")
print(f"   Avg current amount (Legit): ${df[df['label'] == 'LEGITIMATE']['current_amount'].mean():,.2f}")


Generating Fraud Detection Dataset
   Processing customer 1/2000...
   Processing customer 501/2000...
   Processing customer 1001/2000...
   Processing customer 1501/2000...

‚úÖ Dataset generated:
   Total records: 10,000
   Unique customers: 2,000
   Fraud rate: 14.41%

üíæ Dataset saved to CSV:
   File path: /content/fraud_training_data.csv
   File size: 22.66 MB
   Columns: ['customer_id', 'account_id', 'event_id', 'prompt', 'text', 'label', 'label_id', 'current_amount', 'current_merchant', 'history_text']

üìù Sample prompt:
<|system|>
You are a fraud detection system. Analyze the transaction history and current transaction to determine if it is FRAUD or LEGITIMATE.
<|user|>
Customer: CUST_B3EC7DF32A | Account: ACC_21B067C9

=== TRANSACTION HISTORY ===
[2025-11-03] REFUND: Amazon, $393, approved
[2025-11-07] REFUND: Chipotle, $451, approved
[2025-11-08] TRANSFER: Amazon, $332, approved
[2025-11-24] WITHDRAWAL: Walmart, $454, approved
[2025-11-26] PAYMENT: Amazon, $25, approved

## 3. Load Teacher Model (Mistral-7B)

Load Mistral-7B with 4-bit quantization to generate soft labels.

In [4]:
# ============================================================================
# INSTALL REQUIRED PACKAGES
# ============================================================================
# 'accelerate' is REQUIRED for device_map="auto" and 4-bit quantization
!pip install -U bitsandbytes accelerate transformers

Collecting bitsandbytes
  Downloading bitsandbytes-0.49.1-py3-none-manylinux_2_24_x86_64.whl.metadata (10 kB)
Downloading bitsandbytes-0.49.1-py3-none-manylinux_2_24_x86_64.whl (59.1 MB)
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m59.1/59.1 MB[0m [31m42.6 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: bitsandbytes
Successfully installed bitsandbytes-0.49.1


In [5]:
# ============================================================================
# LOAD TEACHER MODEL (MISTRAL-7B)
# ============================================================================

import torch
import os
import bitsandbytes as bnb  # Explicit import to check availability
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

# Check if config is defined
try:
    config
except NameError:
    raise NameError("The 'config' variable is not defined. Please run the Configuration cell (Step 1) above before running this cell.")

print("=" * 70)
print("Loading Teacher Model: Mistral-7B")
print("=" * 70)

try:
    print(f"   ‚úÖ bitsandbytes version: {bnb.__version__}")
except Exception as e:
    print(f"   ‚ö†Ô∏è Error importing bitsandbytes directly: {e}")

# 4-bit quantization config
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True,
)

# ============================================================================
# PERSISTENCE CONFIGURATION
# ============================================================================
# Define where to store the model files
if os.path.exists('/content/drive/MyDrive'):
    model_cache_dir = '/content/drive/MyDrive/models/mistral-7b'
    print(f"üìÇ using Google Drive for persistence: {model_cache_dir}")
else:
    model_cache_dir = './model_cache/mistral-7b'
    print(f"üìÇ using local cache for persistence: {model_cache_dir}")

os.makedirs(model_cache_dir, exist_ok=True)

# Load tokenizer
print("\nüì• Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(
    config.teacher_model,
    cache_dir=model_cache_dir,  # Save/Load from specific folder
    trust_remote_code=True,
    padding_side="right",
)

# Set special tokens
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.pad_token_id = tokenizer.eos_token_id

print(f"   ‚úÖ Tokenizer loaded: vocab_size={tokenizer.vocab_size:,}")

# Load teacher model
print(f"\nüì• Loading {config.teacher_model}...")
print("   (This may take a few minutes for first download)")

try:
    teacher_model = AutoModelForCausalLM.from_pretrained(
        config.teacher_model,
        quantization_config=bnb_config,
        device_map="auto",
        cache_dir=model_cache_dir,  # Save/Load from specific folder
        trust_remote_code=True,
        dtype=torch.float16,
    )
except ImportError as e:
    if "bitsandbytes" in str(e) or "accelerate" in str(e):
        print("\n‚ùå IMPORT ERROR CAUGHT")
        print("It seems required libraries (bitsandbytes or accelerate) are missing or not loaded.")
        print("1. Ensure you ran the '!pip install' cell above.")
        print("2. RESTART THE RUNTIME: Click 'Runtime' > 'Restart session' in the menu.")
        print("3. Re-run the cells starting from the imports.")
    raise e

teacher_model.eval()

# Disable gradient computation for teacher
for param in teacher_model.parameters():
    param.requires_grad = False

print(f"\n‚úÖ Teacher model loaded")
print(f"   GPU Memory used: ~{torch.cuda.memory_allocated()/1e9:.2f} GB")

Loading Teacher Model: Mistral-7B
   ‚úÖ bitsandbytes version: 0.49.1
üìÇ using local cache for persistence: ./model_cache/mistral-7b

üì• Loading tokenizer...


tokenizer_config.json:   0%|          | 0.00/996 [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/493k [00:00<?, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/414 [00:00<?, ?B/s]

   ‚úÖ Tokenizer loaded: vocab_size=32,000

üì• Loading mistralai/Mistral-7B-v0.1...
   (This may take a few minutes for first download)


config.json:   0%|          | 0.00/571 [00:00<?, ?B/s]

model.safetensors.index.json: 0.00B [00:00, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/4.54G [00:00<?, ?B/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/9.94G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/116 [00:00<?, ?B/s]


‚úÖ Teacher model loaded
   GPU Memory used: ~4.13 GB


## 4. Define Student Model Architecture

Create a custom small transformer with the same architecture style as Mistral (RoPE, SwiGLU, RMSNorm).

In [6]:
# ============================================================================
# STUDENT MODEL ARCHITECTURE
# ============================================================================

class RMSNorm(nn.Module):
    """Root Mean Square Layer Normalization (like LLaMA/Mistral)"""
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        rms = torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
        return x * rms * self.weight


class RotaryPositionalEmbedding(nn.Module):
    """Rotary Position Embedding (RoPE)"""
    def __init__(self, dim: int, max_seq_len: int = 2048, base: int = 10000):
        super().__init__()
        self.dim = dim
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer('inv_freq', inv_freq)
        self._build_cache(max_seq_len)

    def _build_cache(self, seq_len: int):
        t = torch.arange(seq_len, device=self.inv_freq.device)
        freqs = torch.einsum('i,j->ij', t, self.inv_freq)
        emb = torch.cat([freqs, freqs], dim=-1)
        self.register_buffer('cos_cached', emb.cos())
        self.register_buffer('sin_cached', emb.sin())

    def forward(self, seq_len: int) -> Tuple[torch.Tensor, torch.Tensor]:
        return self.cos_cached[:seq_len], self.sin_cached[:seq_len]


def rotate_half(x: torch.Tensor) -> torch.Tensor:
    x1, x2 = x[..., :x.shape[-1]//2], x[..., x.shape[-1]//2:]
    return torch.cat([-x2, x1], dim=-1)


def apply_rotary_pos_emb(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
    cos = cos.unsqueeze(0).unsqueeze(0)
    sin = sin.unsqueeze(0).unsqueeze(0)
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed


class MultiHeadAttention(nn.Module):
    """Multi-head self-attention with RoPE"""
    def __init__(self, d_model: int, n_heads: int, dropout: float = 0.1):
        super().__init__()
        self.n_heads = n_heads
        self.head_dim = d_model // n_heads
        self.scale = self.head_dim ** -0.5

        self.q_proj = nn.Linear(d_model, d_model, bias=False)
        self.k_proj = nn.Linear(d_model, d_model, bias=False)
        self.v_proj = nn.Linear(d_model, d_model, bias=False)
        self.o_proj = nn.Linear(d_model, d_model, bias=False)

        self.dropout = nn.Dropout(dropout)
        self.rope = RotaryPositionalEmbedding(self.head_dim)

    def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        B, L, _ = x.shape

        # Project to Q, K, V
        q = self.q_proj(x).view(B, L, self.n_heads, self.head_dim).transpose(1, 2)
        k = self.k_proj(x).view(B, L, self.n_heads, self.head_dim).transpose(1, 2)
        v = self.v_proj(x).view(B, L, self.n_heads, self.head_dim).transpose(1, 2)

        # Apply RoPE
        cos, sin = self.rope(L)
        cos, sin = cos.to(x.device), sin.to(x.device)
        q, k = apply_rotary_pos_emb(q, k, cos, sin)

        # Attention scores
        scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale

        # Causal mask
        causal_mask = torch.triu(torch.ones(L, L, device=x.device, dtype=torch.bool), diagonal=1)
        scores = scores.masked_fill(causal_mask, float('-inf'))

        # Padding mask
        if attention_mask is not None:
            padding_mask = (attention_mask == 0).unsqueeze(1).unsqueeze(2)
            scores = scores.masked_fill(padding_mask, float('-inf'))

        # Softmax and apply to values
        attn_weights = F.softmax(scores, dim=-1)
        attn_weights = self.dropout(attn_weights)

        out = torch.matmul(attn_weights, v)
        out = out.transpose(1, 2).contiguous().view(B, L, -1)

        return self.o_proj(out)


class SwiGLUFeedForward(nn.Module):
    """SwiGLU Feed-Forward Network (like LLaMA/Mistral)"""
    def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
        super().__init__()
        self.gate_proj = nn.Linear(d_model, d_ff, bias=False)
        self.up_proj = nn.Linear(d_model, d_ff, bias=False)
        self.down_proj = nn.Linear(d_ff, d_model, bias=False)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.dropout(self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)))


class TransformerBlock(nn.Module):
    """Transformer block with pre-norm architecture"""
    def __init__(self, d_model: int, n_heads: int, d_ff: int, dropout: float = 0.1):
        super().__init__()
        self.attention = MultiHeadAttention(d_model, n_heads, dropout)
        self.feed_forward = SwiGLUFeedForward(d_model, d_ff, dropout)
        self.norm1 = RMSNorm(d_model)
        self.norm2 = RMSNorm(d_model)

    def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        x = x + self.attention(self.norm1(x), attention_mask)
        x = x + self.feed_forward(self.norm2(x))
        return x


class FraudDetectionSLM(nn.Module):
    """
    Small Language Model for Fraud Detection.
    Architecture mirrors Mistral/LLaMA but with fewer parameters.
    """
    def __init__(
        self,
        vocab_size: int,
        d_model: int = 512,
        n_layers: int = 6,
        n_heads: int = 8,
        d_ff: int = 2048,
        max_seq_len: int = 512,
        dropout: float = 0.1,
    ):
        super().__init__()
        self.d_model = d_model
        self.vocab_size = vocab_size

        # Token embeddings
        self.embed_tokens = nn.Embedding(vocab_size, d_model)

        # Transformer layers
        self.layers = nn.ModuleList([
            TransformerBlock(d_model, n_heads, d_ff, dropout)
            for _ in range(n_layers)
        ])

        # Output
        self.norm = RMSNorm(d_model)
        self.lm_head = nn.Linear(d_model, vocab_size, bias=False)

        # Weight tying
        self.lm_head.weight = self.embed_tokens.weight

        # Initialize weights
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(
        self,
        input_ids: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        return_hidden: bool = False,
    ) -> torch.Tensor:
        x = self.embed_tokens(input_ids)

        for layer in self.layers:
            x = layer(x, attention_mask)

        x = self.norm(x)

        if return_hidden:
            return x

        logits = self.lm_head(x)
        return logits

    def count_parameters(self) -> Tuple[int, int]:
        total = sum(p.numel() for p in self.parameters())
        trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
        return total, trainable


# Create student model
print("=" * 70)
print(f"Creating Student Model ({config.student_size.upper()})")
print("=" * 70)

student_model = FraudDetectionSLM(
    vocab_size=tokenizer.vocab_size,
    d_model=student_cfg['d_model'],
    n_layers=student_cfg['n_layers'],
    n_heads=student_cfg['n_heads'],
    d_ff=student_cfg['d_ff'],
    max_seq_len=config.max_length,
    dropout=0.1,
).to(device)

total_params, trainable_params = student_model.count_parameters()

print(f"\n‚úÖ Student model created")
print(f"   Total parameters: {total_params:,} ({total_params/1e6:.1f}M)")
print(f"   Trainable: {trainable_params:,}")
print(f"   Architecture: {student_cfg['n_layers']} layers, {student_cfg['n_heads']} heads")
print(f"   Dimensions: d_model={student_cfg['d_model']}, d_ff={student_cfg['d_ff']}")

# Compare sizes
teacher_params = 7e9  # Mistral-7B
compression_ratio = teacher_params / total_params
print(f"\nüìä Compression: {compression_ratio:.0f}x smaller than teacher")

Creating Student Model (SMALL)

‚úÖ Student model created
   Total parameters: 41,556,480 (41.6M)
   Trainable: 41,556,480
   Architecture: 6 layers, 8 heads
   Dimensions: d_model=512, d_ff=2048

üìä Compression: 168x smaller than teacher


## 5. Prepare Dataset and DataLoaders

In [7]:
# ============================================================================
# DATASET AND DATALOADERS
# ============================================================================

class DistillationDataset(Dataset):
    """Dataset for knowledge distillation training"""

    def __init__(self, df: pd.DataFrame, tokenizer, max_length: int = 512):
        self.df = df.reset_index(drop=True)
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self) -> int:
        return len(self.df)

    def __getitem__(self, idx: int) -> Dict:
        row = self.df.iloc[idx]

        # Tokenize full text (prompt + label)
        encoding = self.tokenizer(
            row['text'],
            truncation=True,
            padding='max_length',
            max_length=self.max_length,
            return_tensors='pt'
        )

        # Create labels (shift input_ids for causal LM)
        labels = encoding['input_ids'].squeeze().clone()

        # Mask padding tokens in labels
        labels[labels == self.tokenizer.pad_token_id] = -100

        return {
            'input_ids': encoding['input_ids'].squeeze(),
            'attention_mask': encoding['attention_mask'].squeeze(),
            'labels': labels,
            'label_id': torch.tensor(row['label_id']),
        }


# Split data
print("=" * 70)
print("Preparing Datasets")
print("=" * 70)

train_df, val_df = train_test_split(
    df, test_size=0.1, random_state=42, stratify=df['label_id']
)

print(f"\nüìä Data split:")
print(f"   Training: {len(train_df):,} samples")
print(f"   Validation: {len(val_df):,} samples")
print(f"   Train fraud rate: {train_df['label_id'].mean():.2%}")
print(f"   Val fraud rate: {val_df['label_id'].mean():.2%}")

# Create datasets
train_dataset = DistillationDataset(train_df, tokenizer, config.max_length)
val_dataset = DistillationDataset(val_df, tokenizer, config.max_length)

# Create dataloaders
train_loader = DataLoader(
    train_dataset,
    batch_size=config.batch_size,
    shuffle=True,
    num_workers=2,
    pin_memory=True,
)

val_loader = DataLoader(
    val_dataset,
    batch_size=config.batch_size,
    shuffle=False,
    num_workers=2,
    pin_memory=True,
)

print(f"\n‚úÖ DataLoaders created")
print(f"   Train batches: {len(train_loader)}")
print(f"   Val batches: {len(val_loader)}")

Preparing Datasets

üìä Data split:
   Training: 9,000 samples
   Validation: 1,000 samples
   Train fraud rate: 14.41%
   Val fraud rate: 14.40%

‚úÖ DataLoaders created
   Train batches: 2250
   Val batches: 250


## 6. Knowledge Distillation Training

Train the student model using:
1. **Soft labels** from teacher (KL divergence loss)
2. **Hard labels** from ground truth (cross-entropy loss)

Combined loss: `L = Œ± * KL_loss + (1-Œ±) * CE_loss`

In [8]:
# ============================================================================
# DISTILLATION LOSS
# ============================================================================

class DistillationLoss(nn.Module):
    """
    Combined loss for knowledge distillation:
    L = Œ± * KL(student || teacher) + (1-Œ±) * CE(student, hard_labels)
    """
    def __init__(self, temperature: float = 2.0, alpha: float = 0.7):
        super().__init__()
        self.temperature = temperature
        self.alpha = alpha
        self.ce_loss = nn.CrossEntropyLoss(ignore_index=-100)

    def forward(
        self,
        student_logits: torch.Tensor,  # (B, L, V)
        teacher_logits: torch.Tensor,  # (B, L, V)
        labels: torch.Tensor,          # (B, L)
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:

        # Reshape for loss computation
        B, L, V = student_logits.shape
        student_flat = student_logits.view(-1, V)
        teacher_flat = teacher_logits.view(-1, V)
        labels_flat = labels.view(-1)

        # Mask for valid positions (non-padding)
        valid_mask = labels_flat != -100

        if valid_mask.sum() == 0:
            return torch.tensor(0.0, device=student_logits.device), \
                   torch.tensor(0.0, device=student_logits.device), \
                   torch.tensor(0.0, device=student_logits.device)

        # Apply mask
        student_valid = student_flat[valid_mask]
        teacher_valid = teacher_flat[valid_mask]
        labels_valid = labels_flat[valid_mask]

        # Soft label loss (KL divergence with temperature)
        soft_student = F.log_softmax(student_valid / self.temperature, dim=-1)
        soft_teacher = F.softmax(teacher_valid / self.temperature, dim=-1)

        kl_loss = F.kl_div(
            soft_student,
            soft_teacher,
            reduction='batchmean'
        ) * (self.temperature ** 2)

        # Hard label loss (cross-entropy)
        ce_loss = self.ce_loss(student_flat, labels_flat)

        # Combined loss
        total_loss = self.alpha * kl_loss + (1 - self.alpha) * ce_loss

        return total_loss, kl_loss, ce_loss


# Create loss function
distill_loss_fn = DistillationLoss(
    temperature=config.temperature,
    alpha=config.alpha
)

print("‚úÖ Distillation loss function created")
print(f"   Temperature: {config.temperature}")
print(f"   Alpha (distill weight): {config.alpha}")

‚úÖ Distillation loss function created
   Temperature: 2.0
   Alpha (distill weight): 0.7


In [9]:
# ============================================================================
# TRAINING LOOP
# ============================================================================

def train_distillation(
    student_model: nn.Module,
    teacher_model: nn.Module,
    train_loader: DataLoader,
    val_loader: DataLoader,
    config: DistillationConfig,
    device: torch.device,
):
    """
    Main training loop for knowledge distillation.
    """
    # Optimizer
    optimizer = AdamW(
        student_model.parameters(),
        lr=config.learning_rate,
        weight_decay=config.weight_decay,
    )

    # Learning rate scheduler
    total_steps = len(train_loader) * config.num_epochs
    warmup_steps = int(total_steps * config.warmup_ratio)

    scheduler = torch.optim.lr_scheduler.OneCycleLR(
        optimizer,
        max_lr=config.learning_rate,
        total_steps=total_steps,
        pct_start=config.warmup_ratio,
        anneal_strategy='cos',
    )

    # Loss function
    loss_fn = DistillationLoss(config.temperature, config.alpha)

    # Training history
    history = {
        'train_loss': [], 'train_kl': [], 'train_ce': [],
        'val_loss': [], 'val_kl': [], 'val_ce': [],
    }

    best_val_loss = float('inf')

    print("\n" + "=" * 70)
    print("Starting Knowledge Distillation Training")
    print("=" * 70)
    print(f"Total steps: {total_steps}")
    print(f"Warmup steps: {warmup_steps}")

    for epoch in range(config.num_epochs):
        # ========== Training ==========
        student_model.train()
        train_losses = {'total': 0, 'kl': 0, 'ce': 0}
        num_batches = 0

        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{config.num_epochs} [Train]")

        optimizer.zero_grad()

        for batch_idx, batch in enumerate(pbar):
            # Move to device
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            # Get teacher logits (no grad)
            with torch.no_grad():
                teacher_outputs = teacher_model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                )
                teacher_logits = teacher_outputs.logits.detach()

            # Get student logits
            student_logits = student_model(input_ids, attention_mask)

            # Compute loss
            total_loss, kl_loss, ce_loss = loss_fn(
                student_logits, teacher_logits, labels
            )

            # Scale loss for gradient accumulation
            scaled_loss = total_loss / config.gradient_accumulation_steps
            scaled_loss.backward()

            # Gradient accumulation step
            if (batch_idx + 1) % config.gradient_accumulation_steps == 0:
                # Gradient clipping
                torch.nn.utils.clip_grad_norm_(
                    student_model.parameters(), config.max_grad_norm
                )

                optimizer.step()
                scheduler.step()
                optimizer.zero_grad()

            # Track losses
            train_losses['total'] += total_loss.item()
            train_losses['kl'] += kl_loss.item()
            train_losses['ce'] += ce_loss.item()
            num_batches += 1

            # Update progress bar
            pbar.set_postfix({
                'loss': f"{total_loss.item():.4f}",
                'kl': f"{kl_loss.item():.4f}",
                'ce': f"{ce_loss.item():.4f}",
                'lr': f"{scheduler.get_last_lr()[0]:.2e}"
            })

        # Average training losses
        for key in train_losses:
            train_losses[key] /= num_batches

        # ========== Validation ==========
        student_model.eval()
        val_losses = {'total': 0, 'kl': 0, 'ce': 0}
        num_val_batches = 0

        with torch.no_grad():
            for batch in tqdm(val_loader, desc=f"Epoch {epoch+1}/{config.num_epochs} [Val]"):
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                labels = batch['labels'].to(device)

                # Teacher logits
                teacher_outputs = teacher_model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                )
                teacher_logits = teacher_outputs.logits

                # Student logits
                student_logits = student_model(input_ids, attention_mask)

                # Compute loss
                total_loss, kl_loss, ce_loss = loss_fn(
                    student_logits, teacher_logits, labels
                )

                val_losses['total'] += total_loss.item()
                val_losses['kl'] += kl_loss.item()
                val_losses['ce'] += ce_loss.item()
                num_val_batches += 1

        # Average validation losses
        for key in val_losses:
            val_losses[key] /= num_val_batches

        # Update history
        history['train_loss'].append(train_losses['total'])
        history['train_kl'].append(train_losses['kl'])
        history['train_ce'].append(train_losses['ce'])
        history['val_loss'].append(val_losses['total'])
        history['val_kl'].append(val_losses['kl'])
        history['val_ce'].append(val_losses['ce'])

        # Print epoch summary
        print(f"\nüìä Epoch {epoch+1}/{config.num_epochs}:")
        print(f"   Train - Loss: {train_losses['total']:.4f}, KL: {train_losses['kl']:.4f}, CE: {train_losses['ce']:.4f}")
        print(f"   Val   - Loss: {val_losses['total']:.4f}, KL: {val_losses['kl']:.4f}, CE: {val_losses['ce']:.4f}")

        # Save best model
        if val_losses['total'] < best_val_loss:
            best_val_loss = val_losses['total']
            torch.save({
                'epoch': epoch,
                'model_state_dict': student_model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_loss': val_losses['total'],
                'config': student_cfg,
            }, os.path.join(config.checkpoint_dir, 'best_model.pt'))
            print(f"   ‚úÖ New best model saved!")

        # Save checkpoint
        if (epoch + 1) % 2 == 0:
            torch.save({
                'epoch': epoch,
                'model_state_dict': student_model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'history': history,
            }, os.path.join(config.checkpoint_dir, f'checkpoint_epoch_{epoch+1}.pt'))

    return history


# Run training
history = train_distillation(
    student_model=student_model,
    teacher_model=teacher_model,
    train_loader=train_loader,
    val_loader=val_loader,
    config=config,
    device=device,
)

print("\n" + "=" * 70)
print("‚úÖ Training Complete!")
print("=" * 70)


Starting Knowledge Distillation Training
Total steps: 22500
Warmup steps: 2250


Epoch 1/10 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2250/2250 [09:26<00:00,  3.97it/s, loss=7.7506, kl=9.6487, ce=3.3216, lr=1.80e-05]
Epoch 1/10 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 250/250 [00:53<00:00,  4.66it/s]



üìä Epoch 1/10:
   Train - Loss: 11.1372, KL: 13.4810, CE: 5.6682
   Val   - Loss: 7.6239, KL: 9.4983, CE: 3.2501
   ‚úÖ New best model saved!


Epoch 2/10 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2250/2250 [09:25<00:00,  3.98it/s, loss=1.3664, kl=1.4213, ce=1.2385, lr=5.20e-05]
Epoch 2/10 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 250/250 [00:53<00:00,  4.66it/s]



üìä Epoch 2/10:
   Train - Loss: 3.9852, KL: 4.8586, CE: 1.9471
   Val   - Loss: 1.3428, KL: 1.3732, CE: 1.2717
   ‚úÖ New best model saved!


Epoch 3/10 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2250/2250 [09:25<00:00,  3.98it/s, loss=0.8552, kl=0.7572, ce=1.0837, lr=8.59e-05]
Epoch 3/10 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 250/250 [00:53<00:00,  4.65it/s]



üìä Epoch 3/10:
   Train - Loss: 0.9980, KL: 0.9505, CE: 1.1088
   Val   - Loss: 0.8401, KL: 0.7297, CE: 1.0978
   ‚úÖ New best model saved!


Epoch 4/10 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2250/2250 [09:26<00:00,  3.97it/s, loss=0.7779, kl=0.6471, ce=1.0831, lr=1.00e-04]
Epoch 4/10 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 250/250 [00:53<00:00,  4.66it/s]



üìä Epoch 4/10:
   Train - Loss: 0.8099, KL: 0.6991, CE: 1.0685
   Val   - Loss: 0.7733, KL: 0.6357, CE: 1.0942
   ‚úÖ New best model saved!


Epoch 5/10 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2250/2250 [09:25<00:00,  3.98it/s, loss=0.7475, kl=0.6150, ce=1.0565, lr=9.98e-05]
Epoch 5/10 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 250/250 [00:53<00:00,  4.65it/s]



üìä Epoch 5/10:
   Train - Loss: 0.7667, KL: 0.6398, CE: 1.0627
   Val   - Loss: 0.7477, KL: 0.6083, CE: 1.0731
   ‚úÖ New best model saved!


Epoch 6/10 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2250/2250 [09:25<00:00,  3.98it/s, loss=0.7365, kl=0.5976, ce=1.0608, lr=9.92e-05]
Epoch 6/10 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 250/250 [00:53<00:00,  4.66it/s]



üìä Epoch 6/10:
   Train - Loss: 0.7469, KL: 0.6125, CE: 1.0605
   Val   - Loss: 0.7331, KL: 0.5902, CE: 1.0665
   ‚úÖ New best model saved!


Epoch 7/10 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2250/2250 [09:25<00:00,  3.98it/s, loss=0.7350, kl=0.5893, ce=1.0751, lr=9.83e-05]
Epoch 7/10 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 250/250 [00:53<00:00,  4.64it/s]



üìä Epoch 7/10:
   Train - Loss: 0.7351, KL: 0.5962, CE: 1.0593
   Val   - Loss: 0.7250, KL: 0.5786, CE: 1.0667
   ‚úÖ New best model saved!


Epoch 8/10 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2250/2250 [09:25<00:00,  3.98it/s, loss=0.7303, kl=0.5901, ce=1.0574, lr=9.70e-05]
Epoch 8/10 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 250/250 [00:53<00:00,  4.65it/s]



üìä Epoch 8/10:
   Train - Loss: 0.7272, KL: 0.5852, CE: 1.0585
   Val   - Loss: 0.7191, KL: 0.5749, CE: 1.0555
   ‚úÖ New best model saved!


Epoch 9/10 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2250/2250 [09:25<00:00,  3.98it/s, loss=0.7130, kl=0.5679, ce=1.0518, lr=9.53e-05]
Epoch 9/10 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 250/250 [00:53<00:00,  4.66it/s]



üìä Epoch 9/10:
   Train - Loss: 0.7215, KL: 0.5773, CE: 1.0579
   Val   - Loss: 0.7148, KL: 0.5636, CE: 1.0676
   ‚úÖ New best model saved!


Epoch 10/10 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2250/2250 [09:25<00:00,  3.98it/s, loss=0.7112, kl=0.5635, ce=1.0560, lr=9.33e-05]
Epoch 10/10 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 250/250 [00:53<00:00,  4.66it/s]



üìä Epoch 10/10:
   Train - Loss: 0.7172, KL: 0.5714, CE: 1.0576
   Val   - Loss: 0.7110, KL: 0.5617, CE: 1.0594
   ‚úÖ New best model saved!

‚úÖ Training Complete!


## 7. Save Final Model

In [10]:
# ============================================================================
# SAVE FINAL MODEL
# ============================================================================

# Load best checkpoint
best_checkpoint = torch.load(os.path.join(config.checkpoint_dir, 'best_model.pt'))
student_model.load_state_dict(best_checkpoint['model_state_dict'])

# Save model and config
torch.save({
    'model_state_dict': student_model.state_dict(),
    'config': {
        'vocab_size': tokenizer.vocab_size,
        'd_model': student_cfg['d_model'],
        'n_layers': student_cfg['n_layers'],
        'n_heads': student_cfg['n_heads'],
        'd_ff': student_cfg['d_ff'],
        'max_seq_len': config.max_length,
    },
    'training_history': history,
}, os.path.join(config.output_dir, 'fraud_slm_final.pt'))

# Save tokenizer
tokenizer.save_pretrained(config.output_dir)

print("=" * 70)
print("Model Saved")
print("=" * 70)
print(f"\nüìÅ Output directory: {config.output_dir}")
print(f"   - fraud_slm_final.pt (model weights)")
print(f"   - tokenizer files")

Model Saved

üìÅ Output directory: ./fraud_slm_distilled
   - fraud_slm_final.pt (model weights)
   - tokenizer files


In [11]:
!pip install --upgrade huggingface_hub
from huggingface_hub import login

# This will prompt you for your token (paste the 'Write' token here)
login()

Collecting huggingface_hub
  Downloading huggingface_hub-1.3.3-py3-none-any.whl.metadata (13 kB)
Downloading huggingface_hub-1.3.3-py3-none-any.whl (536 kB)
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m536.6/536.6 kB[0m [31m11.2 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: huggingface_hub
  Attempting uninstall: huggingface_hub
    Found existing installation: huggingface-hub 0.36.0
    Uninstalling huggingface-hub-0.36.0:
      Successfully uninstalled huggingface-hub-0.36.0
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
transformers 4.57.6 requires huggingface-hub<1.0,>=0.34.0, but you have huggingface-hub 1.3.3 which is incompatible.[0m[31m
[0mSuccessfully installed huggingface_hub-1.3.3


Error importing huggingface_hub._login: cannot import name 'ANSI' from 'huggingface_hub.utils' (/usr/local/lib/python3.12/dist-packages/huggingface_hub/utils/__init__.py)


ImportError: cannot import name 'ANSI' from 'huggingface_hub.utils' (/usr/local/lib/python3.12/dist-packages/huggingface_hub/utils/__init__.py)

## 8. Test Student Model

In [None]:
# ============================================================================
# TEST STUDENT MODEL
# ============================================================================

student_model.eval()

def predict_fraud(model, tokenizer, prompt: str, device) -> str:
    """Generate fraud prediction from student model"""
    inputs = tokenizer(
        prompt,
        return_tensors='pt',
        truncation=True,
        max_length=config.max_length,
        padding=True,
    ).to(device)

    with torch.no_grad():
        logits = model(inputs['input_ids'], inputs['attention_mask'])
        # Get next token prediction
        next_token_logits = logits[0, -1, :]
        next_token_id = next_token_logits.argmax().item()
        predicted_token = tokenizer.decode([next_token_id])

    # Determine prediction
    if 'FRAUD' in predicted_token.upper() or 'fraud' in predicted_token.lower():
        return 'FRAUD'
    elif 'LEGIT' in predicted_token.upper() or 'legit' in predicted_token.lower():
        return 'LEGITIMATE'
    else:
        # Use logits for FRAUD/LEGITIMATE tokens
        fraud_tokens = tokenizer.encode('FRAUD', add_special_tokens=False)
        legit_tokens = tokenizer.encode('LEGITIMATE', add_special_tokens=False)

        fraud_score = next_token_logits[fraud_tokens[0]].item() if fraud_tokens else 0
        legit_score = next_token_logits[legit_tokens[0]].item() if legit_tokens else 0

        return 'FRAUD' if fraud_score > legit_score else 'LEGITIMATE'


# Test cases
test_cases = [
    {'merchant': 'Starbucks', 'amount': 12.50, 'expected': 'LEGITIMATE'},
    {'merchant': 'Wire Transfer Intl', 'amount': 9500, 'expected': 'FRAUD'},
    {'merchant': 'Amazon', 'amount': 150, 'expected': 'LEGITIMATE'},
    {'merchant': 'Crypto Exchange', 'amount': 15000, 'expected': 'FRAUD'},
    {'merchant': 'Whole Foods', 'amount': 85, 'expected': 'LEGITIMATE'},
]

print("=" * 70)
print("Student Model Predictions")
print("=" * 70)

correct = 0
for i, test in enumerate(test_cases):
    # Create test prompt
    prompt = f"""<|system|>
You are a fraud detection system.
<|user|>
Transaction: {test['merchant']}, ${test['amount']}
Is this FRAUD or LEGITIMATE?
<|assistant|>
Based on the transaction, this is: """

    prediction = predict_fraud(student_model, tokenizer, prompt, device)
    is_correct = prediction == test['expected']
    correct += int(is_correct)

    icon = "‚úÖ" if is_correct else "‚ùå"
    print(f"\n{i+1}. {test['merchant']} - ${test['amount']}")
    print(f"   Expected: {test['expected']}")
    print(f"   Predicted: {prediction} {icon}")

print(f"\n" + "=" * 70)
print(f"Accuracy: {correct}/{len(test_cases)} ({100*correct/len(test_cases):.0f}%)")

## 9. Performance Benchmark

In [None]:
# ============================================================================
# PERFORMANCE BENCHMARK
# ============================================================================

import time

print("=" * 70)
print("Performance Benchmark: Student vs Teacher")
print("=" * 70)

# Test prompt
test_prompt = """<|system|>
You are a fraud detection system.
<|user|>
Transaction: Amazon, $150
Is this FRAUD or LEGITIMATE?
<|assistant|>
Based on the transaction, this is: """

inputs = tokenizer(
    test_prompt,
    return_tensors='pt',
    truncation=True,
    max_length=config.max_length,
    padding=True,
).to(device)

NUM_RUNS = 100

# Benchmark Student Model
student_model.eval()
torch.cuda.synchronize() if torch.cuda.is_available() else None

start = time.time()
for _ in range(NUM_RUNS):
    with torch.no_grad():
        _ = student_model(inputs['input_ids'], inputs['attention_mask'])
    torch.cuda.synchronize() if torch.cuda.is_available() else None
student_time = (time.time() - start) / NUM_RUNS * 1000  # ms

# Benchmark Teacher Model
start = time.time()
for _ in range(NUM_RUNS):
    with torch.no_grad():
        _ = teacher_model(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'])
    torch.cuda.synchronize() if torch.cuda.is_available() else None
teacher_time = (time.time() - start) / NUM_RUNS * 1000  # ms

# Results
speedup = teacher_time / student_time

print(f"\n‚è±Ô∏è  Inference Latency (avg over {NUM_RUNS} runs):")
print(f"   Student ({student_cfg['params']}): {student_time:.2f} ms")
print(f"   Teacher (7B): {teacher_time:.2f} ms")
print(f"\nüöÄ Speedup: {speedup:.1f}x faster")

# Memory comparison
student_memory = sum(p.numel() * p.element_size() for p in student_model.parameters()) / 1e6
teacher_memory = 7000  # ~7GB for Mistral-7B (estimate)

print(f"\nüíæ Memory Footprint:")
print(f"   Student: ~{student_memory:.0f} MB")
print(f"   Teacher: ~{teacher_memory} MB")
print(f"   Reduction: {teacher_memory/student_memory:.0f}x smaller")

## 10. Export for Production (TensorRT / ONNX)

In [None]:
# ============================================================================
# EXPORT TO ONNX FOR TENSORRT
# ============================================================================

print("=" * 70)
print("Exporting Student Model to ONNX")
print("=" * 70)

# Prepare dummy input
dummy_input_ids = torch.randint(0, tokenizer.vocab_size, (1, config.max_length)).to(device)
dummy_attention_mask = torch.ones(1, config.max_length).to(device)

# Export to ONNX
onnx_path = os.path.join(config.output_dir, 'fraud_slm.onnx')

student_model.eval()
with torch.no_grad():
    torch.onnx.export(
        student_model,
        (dummy_input_ids, dummy_attention_mask),
        onnx_path,
        input_names=['input_ids', 'attention_mask'],
        output_names=['logits'],
        dynamic_axes={
            'input_ids': {0: 'batch_size', 1: 'sequence'},
            'attention_mask': {0: 'batch_size', 1: 'sequence'},
            'logits': {0: 'batch_size', 1: 'sequence'},
        },
        opset_version=14,
        do_constant_folding=True,
    )

onnx_size = os.path.getsize(onnx_path) / 1e6

print(f"\n‚úÖ ONNX model exported:")
print(f"   Path: {onnx_path}")
print(f"   Size: {onnx_size:.1f} MB")

print(f"\nüìù TensorRT Conversion (run in terminal):")
print(f"   trtexec --onnx={onnx_path} --saveEngine=fraud_slm.trt --fp16")

## Summary

In [None]:
# ============================================================================
# FINAL SUMMARY
# ============================================================================

print("\n" + "=" * 70)
print("‚úÖ KNOWLEDGE DISTILLATION COMPLETE")
print("=" * 70)

print(f"""
‚îå‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îê
‚îÇ  SOLUTION SUMMARY                                                   ‚îÇ
‚îú‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚î§
‚îÇ                                                                     ‚îÇ
‚îÇ  üéì TEACHER MODEL                                                   ‚îÇ
‚îÇ     ‚Ä¢ Mistral-7B (7B parameters)                                   ‚îÇ
‚îÇ     ‚Ä¢ Used to generate soft labels                                 ‚îÇ
‚îÇ                                                                     ‚îÇ
‚îÇ  üéí STUDENT MODEL                                                   ‚îÇ
‚îÇ     ‚Ä¢ Size: {config.student_size.upper()} ({student_cfg['params']})                                   ‚îÇ
‚îÇ     ‚Ä¢ Architecture: {student_cfg['n_layers']} layers, {student_cfg['n_heads']} heads, d={student_cfg['d_model']}                 ‚îÇ
‚îÇ     ‚Ä¢ Compression: {7e9/total_params:.0f}x smaller than teacher                            ‚îÇ
‚îÇ                                                                     ‚îÇ
‚îÇ  üìä TRAINING                                                        ‚îÇ
‚îÇ     ‚Ä¢ Dataset: 10,000 fraud detection samples                      ‚îÇ
‚îÇ     ‚Ä¢ Method: Knowledge Distillation (Œ±={config.alpha}, T={config.temperature})             ‚îÇ
‚îÇ     ‚Ä¢ Epochs: {config.num_epochs}                                                       ‚îÇ
‚îÇ                                                                     ‚îÇ
‚îÇ  üöÄ PERFORMANCE                                                     ‚îÇ
‚îÇ     ‚Ä¢ Inference: {student_time:.1f}ms ({speedup:.0f}x faster than teacher)                ‚îÇ
‚îÇ     ‚Ä¢ Memory: {student_memory:.0f}MB ({teacher_memory/student_memory:.0f}x smaller)                                 ‚îÇ
‚îÇ                                                                     ‚îÇ
‚îÇ  üìÅ OUTPUT FILES                                                    ‚îÇ
‚îÇ     ‚Ä¢ {config.output_dir}/fraud_slm_final.pt                      ‚îÇ
‚îÇ     ‚Ä¢ {config.output_dir}/fraud_slm.onnx                          ‚îÇ
‚îÇ     ‚Ä¢ {config.output_dir}/tokenizer files                         ‚îÇ
‚îÇ                                                                     ‚îÇ
‚îî‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îò

üéØ NEXT STEPS:
   1. Convert ONNX to TensorRT for even faster inference
   2. Deploy with vLLM or TensorRT-LLM
   3. Fine-tune on real production data
   4. Add monitoring and A/B testing
""")