# Ada-FracBNN Testing Notebook

This notebook provides an interactive environment to test and explore the Adaptive FracBNN (Ada-FracBNN) implementation.

## Features:
1. **Baseline FracBNN** - Original FracBNN with fixed gates
2. **Adaptive PG** - Learnable per-channel fractionalization
3. **Knowledge Distillation** - KD from compact FP teacher

## Quick Start:
Run cells sequentially to:
- Load and configure models
- Test on CIFAR-10 dataset
- Visualize gate statistics
- Compare model performance


## 1. Setup and Imports


In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.notebook import tqdm
import os
import sys

# Add project root to path
project_root = os.path.abspath('.')
if project_root not in sys.path:
    sys.path.insert(0, project_root)

# Import project modules
import utils.utils as util
import utils.quantization as q
import model.fracbnn_cifar10 as m

# Set style
sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (12, 6)

print("PyTorch version:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("CUDA device:", torch.cuda.get_device_name(0))


## 2. Configuration

In [None]:
# Configuration
config = {
    'device': 'cuda' if torch.cuda.is_available() else 'cpu',
    'batch_size': 128,
    'num_workers': 2,
    'data_dir': './data/cifar10',
    
    # Adaptive PG parameters
    'target_sparsity': 0.15,  # Target 15% sparsity (85% of channels use 1-bit)
    'sparsity_weight': 0.01,   # Weight for sparsity regularization
    
    # Knowledge Distillation parameters
    'kd_temperature': 4.0,
    'kd_alpha': 0.7,
    
    # Training parameters
    'learning_rate': 1e-3,
    'num_epochs': 5,  # Small number for quick testing
}

print("Configuration:")
for key, value in config.items():
    print(f"  {key}: {value}")


In [None]:
def load_cifar10(config, normalize=False):
    """Load CIFAR-10 dataset"""
    transform_list = [transforms.ToTensor()]
    
    if normalize:
        normalize_transform = transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        )
        transform_list.append(normalize_transform)
    
    transform_train = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.RandomCrop(32, 4),
    ] + transform_list)
    
    transform_test = transforms.Compose(transform_list)
    
    trainset = torchvision.datasets.CIFAR10(
        root=config['data_dir'],
        train=True,
        download=True,
        transform=transform_train
    )
    
    testset = torchvision.datasets.CIFAR10(
        root=config['data_dir'],
        train=False,
        download=True,
        transform=transform_test
    )
    
    trainloader = torch.utils.data.DataLoader(
        trainset,
        batch_size=config['batch_size'],
        shuffle=True,
        num_workers=config['num_workers'],
        pin_memory=True,
        drop_last=True  # Required for binary input encoder
    )
    
    testloader = torch.utils.data.DataLoader(
        testset,
        batch_size=config['batch_size'],
        shuffle=False,
        num_workers=config['num_workers'],
        pin_memory=True,
        drop_last=True
    )
    
    classes = ('plane', 'car', 'bird', 'cat', 'deer',
               'dog', 'frog', 'horse', 'ship', 'truck')
    
    return trainloader, testloader, classes

# Load data
print("Loading CIFAR-10 dataset...")
trainloader, testloader, classes = load_cifar10(config)
print(f"Training batches: {len(trainloader)}")
print(f"Testing batches: {len(testloader)}")
print(f"Classes: {classes}")


## 4. Model Creation and Testing


In [None]:
def create_baseline_model(config):
    """Create baseline FracBNN model (binput-pg)"""
    model = m.resnet20(
        batch_size=config['batch_size'],
        num_gpus=torch.cuda.device_count()
    )
    return model

def create_adaptive_pg_model(config):
    """Create Adaptive PG model (adaptive-pg)"""
    model = m.resnet20(
        batch_size=config['batch_size'],
        num_gpus=torch.cuda.device_count(),
        adaptive_pg=True,
        target_sparsity=config['target_sparsity']
    )
    return model

def create_teacher_model():
    """Create FP teacher model for knowledge distillation"""
    model = m.fp_resnet20(num_classes=10)
    return model

# Test baseline model
print("Creating Baseline FracBNN model...")
baseline_model = create_baseline_model(config)
baseline_model = baseline_model.to(config['device'])
print(f"âœ“ Baseline model created")
print(f"  Parameters: {sum(p.numel() for p in baseline_model.parameters()):,}")

# Test adaptive PG model
print("\nCreating Adaptive PG model...")
adaptive_model = create_adaptive_pg_model(config)
adaptive_model = adaptive_model.to(config['device'])
print(f"âœ“ Adaptive PG model created")
print(f"  Parameters: {sum(p.numel() for p in adaptive_model.parameters()):,}")

# Test forward pass
test_images, test_labels = next(iter(testloader))
test_images = test_images.to(config['device'])
test_labels = test_labels.to(config['device'])

print("\nâœ“ Testing forward pass...")
baseline_model.eval()
with torch.no_grad():
    baseline_output = baseline_model(test_images)
print(f"  Baseline output shape: {baseline_output.shape}")

adaptive_model.eval()
with torch.no_grad():
    adaptive_output = adaptive_model(test_images)
print(f"  Adaptive output shape: {adaptive_output.shape}")


## 5. Analyze Adaptive PG Gates


In [None]:
def analyze_adaptive_gates(model, config):
    """Analyze and visualize adaptive PG gates"""
    if not hasattr(model, 'get_gate_statistics'):
        print("Model does not have adaptive gates.")
        return
    
    gate_stats = model.get_gate_statistics()
    
    if not gate_stats:
        print("No gate statistics available.")
        return
    
    # Extract statistics
    layer_names = []
    active_fractions = []
    gate_means = []
    gate_stds = []
    
    for stats in gate_stats:
        layer_names.append(stats['layer_name'])
        active_fractions.append(stats['active_fraction'])
        gate_means.append(stats['gate_mean'])
        gate_stds.append(stats['gate_std'])
    
    # Print statistics
    print("\n" + "="*60)
    print("ADAPTIVE PG GATE ANALYSIS")
    print("="*60)
    
    for i, name in enumerate(layer_names):
        print(f"\n{name}:")
        print(f"  2-bit fraction: {active_fractions[i]:.3f}")
        print(f"  Gate mean: {gate_means[i]:.3f}")
        print(f"  Gate std: {gate_stds[i]:.3f}")
    
    avg_active = np.mean(active_fractions)
    print(f"\nOverall 2-bit fraction: {avg_active:.3f}")
    print(f"Target sparsity: {config['target_sparsity']:.3f}")
    print(f"Actual sparsity: {1.0 - avg_active:.3f}")
    
    # Visualization
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    # Plot 1: 2-bit fraction per layer
    axes[0].bar(range(len(layer_names)), active_fractions, alpha=0.7)
    axes[0].axhline(y=config['target_sparsity'], color='r', linestyle='--',
                    label=f"Target: {config['target_sparsity']:.2f}")
    axes[0].axhline(y=avg_active, color='g', linestyle='--',
                    label=f"Average: {avg_active:.2f}")
    axes[0].set_xlabel('Layer')
    axes[0].set_ylabel('2-bit Fraction')
    axes[0].set_title('2-bit Fraction per Layer')
    axes[0].set_xticks(range(len(layer_names)))
    axes[0].set_xticklabels([name.split('.')[-1] for name in layer_names],
                            rotation=45, ha='right')
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)
    
    # Plot 2: Gate mean and std
    x = np.arange(len(layer_names))
    width = 0.35
    axes[1].bar(x - width/2, gate_means, width, label='Mean', alpha=0.7)
    axes[1].bar(x + width/2, gate_stds, width, label='Std', alpha=0.7)
    axes[1].set_xlabel('Layer')
    axes[1].set_ylabel('Value')
    axes[1].set_title('Gate Statistics')
    axes[1].set_xticks(x)
    axes[1].set_xticklabels([name.split('.')[-1] for name in layer_names],
                            rotation=45, ha='right')
    axes[1].legend()
    axes[1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    return gate_stats

# Analyze adaptive model gates
gate_stats = analyze_adaptive_gates(adaptive_model, config)


## 6. Quick Evaluation Functions


In [None]:
def evaluate_model(model, dataloader, device, desc="Evaluating", max_batches=None):
    """Evaluate model accuracy"""
    model.eval()
    correct = 0
    total = 0
    
    with torch.no_grad():
        for i, (images, labels) in enumerate(tqdm(dataloader, desc=desc)):
            if max_batches and i >= max_batches:
                break
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    accuracy = 100.0 * correct / total
    return accuracy

# Quick evaluation on subset
print("Quick evaluation (first 10 batches)...")
baseline_acc = evaluate_model(baseline_model, testloader, config['device'], 
                               "Baseline", max_batches=10)
print(f"Baseline FracBNN: {baseline_acc:.2f}%")

adaptive_acc = evaluate_model(adaptive_model, testloader, config['device'],
                              "Adaptive", max_batches=10)
print(f"Adaptive PG: {adaptive_acc:.2f}%")

print("\n(Note: These are random initialization results)")


## 7. Summary and Full Training Commands

### âœ… Test Summary:
- Successfully loaded CIFAR-10 dataset
- Created and tested Baseline and Adaptive PG models
- Verified forward pass works correctly
- Analyzed adaptive gate statistics
- Quick evaluation completed

### ðŸš€ For Full Training, Use Command Line:

```bash
# 1. Baseline FracBNN (250 epochs)
python cifar10.py -id 0 -e 250 -b 128 -g 0.0 -s

# 2. Adaptive PG (Ada-FracBNN)
python cifar10.py -id 1 -e 250 -b 128 -ts 0.15 -sw 0.01 -s

# 3. Adaptive PG + Knowledge Distillation
python cifar10.py -id 2 -e 250 -b 128 -ts 0.15 -sw 0.01 \\
    -temp 4.0 -alpha 0.7 -tp teacher.pth -s
```

### ðŸ“Š Parameter Guide:
- `-id`: Model ID (0=baseline, 1=adaptive-pg, 2=adaptive-pg-kd)
- `-e`: Number of epochs
- `-b`: Batch size
- `-ts`: Target sparsity (0.15 = 15%)
- `-sw`: Sparsity regularization weight
- `-temp`: KD temperature
- `-alpha`: KD loss weight
- `-tp`: Teacher model path
- `-s`: Save model
