# Ada-FracBNN Testing Notebook (Google Colab Version)

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/YOUR_USERNAME/endingengineering/blob/main/test_ada_fracbnn_colab.ipynb)

This notebook is optimized for Google Colab with GPU support.

## ‚ö†Ô∏è Important First Steps:
1. **Enable GPU**: `Runtime` ‚Üí `Change runtime type` ‚Üí `Hardware accelerator: GPU`
2. **Run Cell 1**: Sets up the environment (clone repo, install packages)
3. **Run remaining cells** sequentially

## Features:
- Automatic environment setup for Colab
- GPU detection and verification
- Google Drive integration for saving models
- Optimized batch sizes for Colab GPUs


## üöÄ Step 0: Google Colab Setup (RUN THIS FIRST!)


In [None]:
# ============================================
# GOOGLE COLAB ENVIRONMENT SETUP
# ============================================

import sys
import os

# Detect if running in Colab
try:
    import google.colab
    IN_COLAB = True
    print("‚úì Running in Google Colab")
except:
    IN_COLAB = False
    print("‚úì Running locally")

if IN_COLAB:
    print("\n" + "="*60)
    print("SETTING UP GOOGLE COLAB ENVIRONMENT")
    print("="*60)
    
    # Check GPU
    import torch
    if torch.cuda.is_available():
        print(f"\n‚úì GPU DETECTED: {torch.cuda.get_device_name(0)}")
        print(f"  CUDA Version: {torch.version.cuda}")
        print(f"  PyTorch Version: {torch.__version__}")
    else:
        print("\n‚ö†Ô∏è  WARNING: GPU NOT DETECTED!")
        print("   Go to: Runtime ‚Üí Change runtime type ‚Üí Hardware accelerator: GPU")
        print("   Then restart this notebook.")
    
    # Clone repository
    print("\nüì¶ Cloning repository...")
    # Replace 'YOUR_USERNAME' with your GitHub username
    repo_url = "https://github.com/YOUR_USERNAME/endingengineering.git"
    
    # Alternative: Upload files manually
    print("   Option 1: Clone from GitHub (recommended)")
    print(f"   !git clone {repo_url}")
    print("\n   Option 2: Upload files manually")
    print("   Uncomment the lines below to upload a zip file:")
    print("""
    # from google.colab import files
    # uploaded = files.upload()  # Upload your project.zip
    # !unzip -q project.zip
    """)
    
    # For now, clone (you can modify this)
    # !git clone {repo_url}
    
    # Alternative: Upload files
    print("\nüìÅ Please choose setup method:")
    print("   A) Clone from GitHub - modify repo_url above and uncomment git clone")
    print("   B) Upload files - uncomment the upload code above")
    print("\n‚ö†Ô∏è  After setup, uncomment the appropriate section and rerun this cell")
    
    # Uncomment ONE of these:
    # Method A: GitHub clone
    # !git clone https://github.com/YOUR_USERNAME/endingengineering.git
    # %cd endingengineering
    
    # Method B: Manual upload
    # from google.colab import files
    # import zipfile
    # uploaded = files.upload()
    # for f in uploaded.keys():
    #     if f.endswith('.zip'):
    #         !unzip -q {f}
    #         dir_name = f.replace('.zip', '')
    #         %cd {dir_name}
    
    # Install dependencies
    print("\nüìö Installing dependencies...")
    !pip install -q tqdm seaborn
    
    # Mount Google Drive for saving models
    print("\nüíæ Mounting Google Drive (optional - for saving models)...")
    print("   This allows you to save trained models permanently.")
    from google.colab import drive
    drive.mount('/content/drive', force_remount=False)
    
    print("\n" + "="*60)
    print("‚úÖ SETUP COMPLETE!")
    print("="*60)
    print("\nYou can now run the remaining cells.")
    print("Note: Make sure to uncomment the setup method above on first run!")
    
else:
    print("Running locally - no Colab setup needed.")


## 1. Setup and Imports


In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.notebook import tqdm
import os
import sys

# Add project root to path (works for both Colab and local)
try:
    import google.colab
    # Colab paths
    if os.path.exists('/content/endingengineering'):
        sys.path.insert(0, '/content/endingengineering')
    project_root = '/content/endingengineering'
except:
    # Local paths
    project_root = os.path.abspath('.')
    if project_root not in sys.path:
        sys.path.insert(0, project_root)

# Import project modules
import utils.utils as util
import utils.quantization as q
import model.fracbnn_cifar10 as m

# Set style
sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (12, 6)

print("="*60)
print("ENVIRONMENT CHECK")
print("="*60)
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")
print(f"Working directory: {os.getcwd()}")
print("="*60)


## 2. Configuration (Optimized for Colab)


In [None]:
# Configuration (optimized for Google Colab)
try:
    import google.colab
    IN_COLAB = True
    data_dir = '/content/data/cifar10'
    batch_size = 256  # Larger batch for Colab GPU
    save_dir = '/content/drive/MyDrive/ada_fracbnn_models/'
except:
    IN_COLAB = False
    data_dir = './data/cifar10'
    batch_size = 128
    save_dir = './saved_models/'

config = {
    'device': 'cuda' if torch.cuda.is_available() else 'cpu',
    'batch_size': batch_size,
    'num_workers': 2,  # Optimal for Colab
    'data_dir': data_dir,
    'save_dir': save_dir,
    
    # Adaptive PG parameters
    'target_sparsity': 0.15,  # Target 15% sparsity (85% of channels use 1-bit)
    'sparsity_weight': 0.01,   # Weight for sparsity regularization
    
    # Knowledge Distillation parameters
    'kd_temperature': 4.0,
    'kd_alpha': 0.7,
    
    # Training parameters
    'learning_rate': 1e-3,
    'num_epochs': 5,  # Small number for quick testing
}

# Create save directory
os.makedirs(config['save_dir'], exist_ok=True)

print("Configuration:")
print(f"  Environment: {'Google Colab' if IN_COLAB else 'Local'}")
for key, value in config.items():
    print(f"  {key}: {value}")


## 3. Load CIFAR-10 Dataset


In [None]:
def load_cifar10(config, normalize=False):
    """Load CIFAR-10 dataset"""
    transform_list = [transforms.ToTensor()]
    
    if normalize:
        normalize_transform = transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        )
        transform_list.append(normalize_transform)
    
    transform_train = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.RandomCrop(32, 4),
    ] + transform_list)
    
    transform_test = transforms.Compose(transform_list)
    
    trainset = torchvision.datasets.CIFAR10(
        root=config['data_dir'],
        train=True,
        download=True,
        transform=transform_train
    )
    
    testset = torchvision.datasets.CIFAR10(
        root=config['data_dir'],
        train=False,
        download=True,
        transform=transform_test
    )
    
    trainloader = torch.utils.data.DataLoader(
        trainset,
        batch_size=config['batch_size'],
        shuffle=True,
        num_workers=config['num_workers'],
        pin_memory=True,
        drop_last=True
    )
    
    testloader = torch.utils.data.DataLoader(
        testset,
        batch_size=config['batch_size'],
        shuffle=False,
        num_workers=config['num_workers'],
        pin_memory=True,
        drop_last=True
    )
    
    classes = ('plane', 'car', 'bird', 'cat', 'deer',
               'dog', 'frog', 'horse', 'ship', 'truck')
    
    return trainloader, testloader, classes

print("Loading CIFAR-10 dataset...")
trainloader, testloader, classes = load_cifar10(config)
print(f"‚úì Training batches: {len(trainloader)}")
print(f"‚úì Testing batches: {len(testloader)}")
print(f"‚úì Classes: {classes}")


In [None]:
def create_baseline_model(config):
    """Create baseline FracBNN model"""
    model = m.resnet20(
        batch_size=config['batch_size'],
        num_gpus=torch.cuda.device_count()
    )
    return model

def create_adaptive_pg_model(config):
    """Create Adaptive PG model"""
    model = m.resnet20(
        batch_size=config['batch_size'],
        num_gpus=torch.cuda.device_count(),
        adaptive_pg=True,
        target_sparsity=config['target_sparsity']
    )
    return model

print("Creating models...")
print("="*60)

# Baseline model
print("\n1Ô∏è‚É£ Baseline FracBNN")
baseline_model = create_baseline_model(config).to(config['device'])
print(f"   ‚úì Created | Parameters: {sum(p.numel() for p in baseline_model.parameters()):,}")

# Adaptive PG model
print("\n2Ô∏è‚É£ Adaptive PG (Ada-FracBNN)")
adaptive_model = create_adaptive_pg_model(config).to(config['device'])
print(f"   ‚úì Created | Parameters: {sum(p.numel() for p in adaptive_model.parameters()):,}")

# Test forward pass
print("\n" + "="*60)
print("Testing forward pass...")
test_images, test_labels = next(iter(testloader))
test_images = test_images.to(config['device'])

baseline_model.eval()
with torch.no_grad():
    out1 = baseline_model(test_images)
print(f"‚úì Baseline output: {out1.shape}")

adaptive_model.eval()
with torch.no_grad():
    out2 = adaptive_model(test_images)
print(f"‚úì Adaptive output: {out2.shape}")

print("\n‚úÖ All models working correctly!")


## 5. Analyze Adaptive PG Gates üìä

Visualize the learnable gates that control 2-bit upgrades


In [None]:
def analyze_gates(model, config):
    """Analyze and visualize gates"""
    if not hasattr(model, 'get_gate_statistics'):
        print("‚ö†Ô∏è Model doesn't have adaptive gates")
        return
    
    gate_stats = model.get_gate_statistics()
    if not gate_stats:
        print("‚ö†Ô∏è No gate statistics available")
        return
    
    # Extract data
    names = [s['layer_name'] for s in gate_stats]
    active = [s['active_fraction'] for s in gate_stats]
    means = [s['gate_mean'] for s in gate_stats]
    stds = [s['gate_std'] for s in gate_stats]
    
    # Print stats
    print("="*60)
    print("ADAPTIVE PG GATE ANALYSIS")
    print("="*60)
    for i, name in enumerate(names):
        print(f"{name}: 2-bit={active[i]:.3f}, mean={means[i]:.3f}, std={stds[i]:.3f}")
    
    avg = np.mean(active)
    print(f"\nüìä Average 2-bit fraction: {avg:.3f}")
    print(f"üéØ Target sparsity: {config['target_sparsity']:.3f}")
    print(f"üìâ Actual sparsity: {1.0-avg:.3f}")
    
    # Visualize
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    # Plot 1: Active fractions
    axes[0].bar(range(len(names)), active, alpha=0.7, color='steelblue')
    axes[0].axhline(config['target_sparsity'], color='r', linestyle='--', 
                    label=f"Target: {config['target_sparsity']:.2f}")
    axes[0].axhline(avg, color='g', linestyle='--', label=f"Avg: {avg:.2f}")
    axes[0].set_xlabel('Layer')
    axes[0].set_ylabel('2-bit Fraction')
    axes[0].set_title('2-bit Fraction per Layer')
    axes[0].set_xticks(range(len(names)))
    axes[0].set_xticklabels([n.split('.')[-1] for n in names], rotation=45, ha='right')
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)
    
    # Plot 2: Mean/Std
    x = np.arange(len(names))
    width = 0.35
    axes[1].bar(x - width/2, means, width, label='Mean', alpha=0.7, color='orange')
    axes[1].bar(x + width/2, stds, width, label='Std', alpha=0.7, color='purple')
    axes[1].set_xlabel('Layer')
    axes[1].set_ylabel('Value')
    axes[1].set_title('Gate Statistics')
    axes[1].set_xticks(x)
    axes[1].set_xticklabels([n.split('.')[-1] for n in names], rotation=45, ha='right')
    axes[1].legend()
    axes[1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()

# Analyze adaptive model
analyze_gates(adaptive_model, config)


In [None]:
def quick_eval(model, loader, device, max_batches=10):
    """Quick evaluation on subset"""
    model.eval()
    correct = total = 0
    with torch.no_grad():
        for i, (imgs, labels) in enumerate(loader):
            if i >= max_batches:
                break
            imgs, labels = imgs.to(device), labels.to(device)
            outputs = model(imgs)
            _, pred = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (pred == labels).sum().item()
    return 100.0 * correct / total

print("Quick evaluation (10 batches, random weights)...")
print("="*60)
acc1 = quick_eval(baseline_model, testloader, config['device'])
print(f"Baseline FracBNN: {acc1:.2f}%")

acc2 = quick_eval(adaptive_model, testloader, config['device'])
print(f"Adaptive PG: {acc2:.2f}%")

print("\n(Note: ~10% is random chance for 10 classes)")
print("After training 250 epochs: ~90-92% accuracy expected")


## 7. Full Training (Optional)

Run full training for 250 epochs. **Warning**: Takes 8-12 hours!

```python
# Run this in a new cell if you want full training
!python cifar10.py -id 1 -e 250 -b 256 -ts 0.15 -sw 0.01 -s
```

Or run shorter training for testing (10 epochs):
```python
!python cifar10.py -id 1 -e 10 -b 256 -ts 0.15 -sw 0.01 -s
```


## ‚úÖ Summary

### What we tested:
- ‚úÖ Environment setup (GPU, dependencies)
- ‚úÖ CIFAR-10 data loading
- ‚úÖ Baseline FracBNN model
- ‚úÖ Adaptive PG model with learnable gates
- ‚úÖ Forward pass verification
- ‚úÖ Gate statistics and visualization
- ‚úÖ Quick accuracy evaluation

### Next steps:
1. üöÄ Run full training (250 epochs)
2. üìä Compare baseline vs adaptive performance
3. üî¨ Analyze learned gate patterns
4. ‚ö° Measure compute savings
5. üìù Train with knowledge distillation

### Full Training Commands:
```bash
# Baseline FracBNN
!python cifar10.py -id 0 -e 250 -b 256 -g 0.0 -s

# Adaptive PG
!python cifar10.py -id 1 -e 250 -b 256 -ts 0.15 -sw 0.01 -s

# Adaptive PG + KD
!python cifar10.py -id 2 -e 250 -b 256 -ts 0.15 -sw 0.01 -temp 4.0 -alpha 0.7 -tp teacher.pth -s
```

### Save results to Drive:
```python
# Mount drive (if not already)
from google.colab import drive
drive.mount('/content/drive')

# Copy models
!cp -r save_CIFAR10_model /content/drive/MyDrive/ada_fracbnn_results/
```

---
**Ready for production training!** üéâ
