# FocusNet: Low-Light Road Hazard Detection

## Step-by-Step Guide to Run FocusNet in Google Colab

### Prerequisites:
1. Google account with access to Google Drive and Google Colab
2. Your dataset in COCO format
3. The 5 essential FocusNet Python files

### Step 1: Prepare Your Files
Before opening Colab, make sure you have:
- **Dataset**: Upload your COCO format dataset to Google Drive
- **Python Files**: Download these 5 files to your computer:
  - `backbone_cbam_mnv3.py`
  - `cbam.py` 
  - `detector.py`
  - `ssd_head.py`
  - `transforms_lowlight.py`

### Step 2: Open Google Colab
1. Go to [colab.research.google.com](https://colab.research.google.com)
2. Upload this notebook file or create a new notebook
3. Go to **Runtime** → **Change runtime type** 
4. Set **Hardware accelerator** to **GPU** (T4 recommended)
5. Click **Save**

### Step 3: Install Dependencies
Run the cell below to install required packages.

In [None]:
# Step 3: Install Dependencies
print("Installing required packages...")

# Install PyTorch with CUDA support
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121

# Install other required packages
!pip install opencv-python-headless
!pip install pillow
!pip install matplotlib
!pip install scikit-learn

print("All packages installed successfully!")

## Step 4: Load Python Files from Google Drive

Upload your 5 essential FocusNet Python files to Google Drive, then load them safely to prevent losing them when Colab crashes.

In [None]:
import sys
import os

# Define Google Drive path for your Python files
GDRIVE_CODE_PATH = "/content/drive/MyDrive/focusnet_code"  # UPDATE THIS PATH!

# Add the code directory to Python path
sys.path.insert(0, GDRIVE_CODE_PATH)

# Check if files exist in Google Drive
required_files = ['backbone_cbam_mnv3.py', 'cbam.py', 'detector.py', 'ssd_head.py', 'transforms_lowlight.py']

print("Checking for Python files in Google Drive...")
missing_files = []

for filename in required_files:
    filepath = os.path.join(GDRIVE_CODE_PATH, filename)
    if os.path.exists(filepath):
        print(f"✓ Found: {filename}")
    else:
        missing_files.append(filename)
        print(f"✗ Missing: {filename}")

if missing_files:
    print(f"\n⚠️ Missing files: {missing_files}")
    print(f"Please upload these files to Google Drive at: {GDRIVE_CODE_PATH}")
    
    # Fallback: Upload files directly to Colab
    print("\n📁 Fallback: Upload files directly to Colab session")
    from google.colab import files
    uploaded = files.upload()
    
    print(f"📤 Uploaded {len(uploaded)} files to current session:")
    for filename in uploaded.keys():
        print(f"  - {filename}")
else:
    print("✅ All Python files found in Google Drive!")
    print("🔒 Files are safe from session crashes and will persist across sessions.")

print(f"\n📂 Code directory: {GDRIVE_CODE_PATH}")
print("🚀 Python files ready to import!")

## Step 5: Initialize FocusNet Model and Load Dataset

Now let's set up the FocusNet model and prepare your dataset for training.

In [None]:
# FocusNet: SSD + MobileNetV3 + CBAM for Low-Light Road Hazard Detection
# Complete training and evaluation pipeline for thesis validation

import torch
from torch.utils.data import DataLoader
import json
import os
from PIL import Image
import matplotlib.pyplot as plt

# === MOUNT GOOGLE DRIVE FOR DATASET ACCESS ===
from google.colab import drive
drive.mount('/content/drive')
print("📂 Google Drive mounted successfully!")

# === CORE FOCUSNET ARCHITECTURE IMPORTS ===
# Import the 5 essential Python files we uploaded:
try:
    from backbone_cbam_mnv3 import MNV3BackboneWithCBAM
    from cbam import CBAM, ChannelAttention, SpatialAttention  
    from detector import SSD_CBAM_MNV3
    from ssd_head import SSDHead
    from transforms_lowlight import FocusNetTransforms
    print("✅ All FocusNet modules imported successfully!")
except ImportError as e:
    print(f"❌ Import Error: {e}")
    print("Please make sure all Python files are uploaded to Google Drive or current session")

# === DATASET CONFIGURATION ===
# UPDATE THESE PATHS for your dataset location
DATASET_BASE_PATH = "/content/drive/MyDrive/your_dataset"  # UPDATE THIS!
TRAIN_IMG_DIR = f"{DATASET_BASE_PATH}/images/train"
VAL_IMG_DIR = f"{DATASET_BASE_PATH}/images/val"  
TRAIN_ANN_FILE = f"{DATASET_BASE_PATH}/annotations/train.json"
VAL_ANN_FILE = f"{DATASET_BASE_PATH}/annotations/val.json"

# === TRAINING CONFIGURATION ===
BATCH_SIZE = 8  # Adjust based on your GPU memory
LEARNING_RATE = 1e-4
NUM_EPOCHS = 50
NUM_CLASSES = 2  # Background + your object class (UPDATE if different)

# === DEVICE CONFIGURATION ===
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"🖥️ Using device: {device}")

if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"CUDA Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    print("⚠️ CUDA not available - using CPU (training will be slow)")

print(f"🎯 FocusNet Configuration:")
print(f"  - Batch Size: {BATCH_SIZE}")
print(f"  - Learning Rate: {LEARNING_RATE}")
print(f"  - Number of Classes: {NUM_CLASSES}")
print(f"  - Training Epochs: {NUM_EPOCHS}")

### Step 6: Training Function
Now we'll define the training function for FocusNet.

In [None]:
# Step 6: Define Training and Evaluation Functions

def train_one_epoch(model, loss_fn, loader, optimizer, device, epoch):
    model.train()
    total_loss = 0.0
    num_batches = len(loader)
    
    print(f"Training Epoch {epoch}...")
    for i, (images, targets) in enumerate(loader):
        images = images.to(device)
        batch_targets = []
        for t in targets:
            bt = {'boxes': t['boxes'].to(device), 'labels': t['labels'].to(device)}
            batch_targets.append(bt)
        
        cls_logits, box_deltas, anchors = model(images)
        loss = loss_fn(cls_logits, box_deltas, anchors, batch_targets)
        
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
        optimizer.step()
        
        total_loss += loss.item()
        
        # Print progress every 10 batches
        if (i + 1) % 10 == 0:
            avg_loss = total_loss / (i + 1)
            print(f"  Batch {i+1}/{num_batches}, Loss: {avg_loss:.4f}")
    
    return total_loss / num_batches

@torch.no_grad()
def evaluate(model, loss_fn, loader, device):
    model.eval()
    total_loss = 0.0
    
    print("Evaluating...")
    for images, targets in loader:
        images = images.to(device)
        batch_targets = [{'boxes': t['boxes'].to(device), 'labels': t['labels'].to(device)} for t in targets]
        
        cls_logits, box_deltas, anchors = model(images)
        loss = loss_fn(cls_logits, box_deltas, anchors, batch_targets)
        total_loss += loss.item()
    
    return total_loss / len(loader)

print("Training functions defined successfully!")

### Step 7: Start Training
Now we'll train the FocusNet model. You can adjust the number of epochs based on your needs.

In [None]:
# Step 7: Train FocusNet Model (with Crash Protection)

import shutil
import time

# Training configuration
num_epochs = 10  # You can increase this for longer training
best_val_loss = float('inf')
train_losses = []
val_losses = []

# Define backup paths
GDRIVE_BACKUP_DIR = f"{GDRIVE_DATASET_BASE}/training_backups"
os.makedirs(GDRIVE_BACKUP_DIR, exist_ok=True)

print(f"Starting FocusNet training for {num_epochs} epochs...")
print(f"Device: {device}")
print(f"Batch size: {train_loader.batch_size}")
print(f"Automatic backups will be saved to: {GDRIVE_BACKUP_DIR}")
print("=" * 50)

for epoch in range(num_epochs):
    print(f"\nEpoch {epoch+1}/{num_epochs}")
    
    # Training
    train_loss = train_one_epoch(model, loss_fn, train_loader, optimizer, device, epoch+1)
    train_losses.append(train_loss)
    
    # Validation
    val_loss = evaluate(model, loss_fn, val_loader, device)
    val_losses.append(val_loss)
    
    print(f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")
    
    # Save best model (local)
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        checkpoint = {
            'model': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'epoch': epoch,
            'train_loss': train_loss,
            'val_loss': val_loss,
            'train_losses': train_losses,
            'val_losses': val_losses
        }
        
        # Save locally
        torch.save(checkpoint, 'best_focusnet_model.pt')
        print(f"Saved best model (val_loss: {val_loss:.4f})")
        
        # Automatic backup to Google Drive (crash protection)
        backup_path = f"{GDRIVE_BACKUP_DIR}/best_model_epoch_{epoch+1}.pt"
        shutil.copy('best_focusnet_model.pt', backup_path)
        print(f"Backup saved to Drive: best_model_epoch_{epoch+1}.pt")
    
    # Save checkpoint every 5 epochs (additional protection)
    if (epoch + 1) % 5 == 0:
        checkpoint_path = f"{GDRIVE_BACKUP_DIR}/checkpoint_epoch_{epoch+1}.pt"
        torch.save({
            'model': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'epoch': epoch,
            'train_losses': train_losses,
            'val_losses': val_losses
        }, checkpoint_path)
        print(f"Checkpoint saved: checkpoint_epoch_{epoch+1}.pt")
        
        # Plot training progress
        plt.figure(figsize=(12, 4))
        
        plt.subplot(1, 2, 1)
        plt.plot(train_losses, label='Train Loss', color='blue')
        plt.plot(val_losses, label='Val Loss', color='red')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.legend()
        plt.title('FocusNet Training Progress')
        plt.grid(True)
        
        plt.subplot(1, 2, 2)
        plt.plot(val_losses, label='Validation Loss', color='orange')
        plt.xlabel('Epoch')
        plt.ylabel('Validation Loss')
        plt.title('Validation Loss Trend')
        plt.grid(True)
        
        plt.tight_layout()
        plt.savefig(f'{GDRIVE_BACKUP_DIR}/training_progress_epoch_{epoch+1}.png')
        plt.show()

print("\nTraining completed!")
print(f"Best validation loss: {best_val_loss:.4f}")
print("Files saved:")
print("- best_focusnet_model.pt (local session)")
print(f"- Multiple backups in: {GDRIVE_BACKUP_DIR}")
print("Your training is protected against crashes!")

### Step 8: Test the Trained Model
Let's test our trained FocusNet model on some validation images.

In [None]:
# Step 8: Test FocusNet Model

import torch.nn.functional as F

def decode_predictions(cls_logits, box_deltas, anchors, score_thresh=0.5, nms_thresh=0.45):
    """Decode model predictions into bounding boxes and scores"""
    cls_scores = F.softmax(cls_logits, dim=-1)  # [B, A, C]
    
    # Get max scores and predicted classes
    max_scores, pred_labels = cls_scores.max(dim=-1)  # [B, A]
    
    # Filter by score threshold and exclude background (class 0)
    valid_mask = (max_scores > score_thresh) & (pred_labels > 0)
    
    results = []
    for b in range(cls_logits.size(0)):
        valid_b = valid_mask[b]
        if not valid_b.any():
            results.append(([], [], []))
            continue
            
        scores_b = max_scores[b][valid_b]
        labels_b = pred_labels[b][valid_b]
        deltas_b = box_deltas[b][valid_b]
        anchors_b = anchors[valid_b]
        
        # Decode boxes
        pred_boxes = decode_boxes(anchors_b, deltas_b)
        
        results.append((pred_boxes.cpu(), labels_b.cpu(), scores_b.cpu()))
    
    return results

def decode_boxes(anchors, deltas, center_variance=0.1, size_variance=0.2):
    """Decode box deltas to actual coordinates"""
    cxcy = deltas[..., :2] * center_variance * anchors[..., :2] + anchors[..., :2]
    wh = torch.exp(deltas[..., 2:] * size_variance) * anchors[..., 2:]
    
    # Convert to x1y1x2y2
    x1y1 = cxcy - wh / 2
    x2y2 = cxcy + wh / 2
    
    return torch.cat([x1y1, x2y2], dim=-1)

def visualize_predictions(image_tensor, boxes, labels, scores, class_names=None):
    """Visualize predictions on image"""
    # Convert tensor to numpy
    if isinstance(image_tensor, torch.Tensor):
        img = image_tensor.permute(1, 2, 0).cpu().numpy()
        # Denormalize (assuming ImageNet normalization)
        mean = [0.485, 0.456, 0.406]
        std = [0.229, 0.224, 0.225]
        img = img * std + mean
        img = np.clip(img, 0, 1)
    
    plt.figure(figsize=(12, 8))
    plt.imshow(img)
    
    # Draw bounding boxes
    ax = plt.gca()
    for box, label, score in zip(boxes, labels, scores):
        x1, y1, x2, y2 = box
        # Convert normalized coords to pixel coords
        h, w = img.shape[:2]
        x1, y1, x2, y2 = x1*w, y1*h, x2*w, y2*h
        
        # Draw rectangle
        rect = plt.Rectangle((x1, y1), x2-x1, y2-y1, 
                           fill=False, color='red', linewidth=2)
        ax.add_patch(rect)
        
        # Add label
        label_text = f"Class {label}: {score:.2f}"
        if class_names and label < len(class_names):
            label_text = f"{class_names[label]}: {score:.2f}"
        
        plt.text(x1, y1-5, label_text, color='red', fontsize=10,
                bbox=dict(boxstyle="round,pad=0.3", facecolor='white', alpha=0.7))
    
    plt.axis('off')
    plt.title('FocusNet Predictions')
    plt.tight_layout()
    plt.show()

# Load best model for testing
print("Loading best trained model...")
checkpoint = torch.load('best_focusnet_model.pt', map_location=device)
model.load_state_dict(checkpoint['model'])
model.eval()
print("Model loaded successfully!")

# Test on a few validation samples
print("Testing FocusNet on validation samples...")

with torch.no_grad():
    # Get a batch from validation loader
    for images, targets in val_loader:
        images = images.to(device)
        
        # Make predictions
        cls_logits, box_deltas, anchors = model(images)
        
        # Decode predictions
        predictions = decode_predictions(cls_logits, box_deltas, anchors, 
                                       score_thresh=0.3, nms_thresh=0.45)
        
        # Visualize first 3 images from the batch
        for i in range(min(3, len(images))):
            boxes, labels, scores = predictions[i]
            
            print(f"\nSample {i+1}:")
            print(f"Detected {len(boxes)} objects")
            
            # Get class names from dataset
            class_names = ['background'] + [cat['name'] for cat in val_ds.categories.values()]
            
            # Visualize
            visualize_predictions(images[i], boxes, labels, scores, class_names)
            
            # Print detection details
            if len(boxes) > 0:
                for j, (box, label, score) in enumerate(zip(boxes, labels, scores)):
                    class_name = class_names[label] if label < len(class_names) else f"Class {label}"
                    print(f"  Detection {j+1}: {class_name} (confidence: {score:.3f})")
            else:
                print("  No objects detected")
        
        break  # Only test first batch

print("Testing completed!")

### Step 9: Download Your Trained Model
Save your trained FocusNet model to your computer and optionally to Google Drive.

In [None]:
# Step 9: Download and Save Your Trained Model

# Download the model file to your computer
from google.colab import files
files.download('best_focusnet_model.pt')
print("Model downloaded to your computer!")

# Optional: Also save to Google Drive for backup
import shutil

save_to_drive = input("Save model to Google Drive as backup? (y/n): ").lower().strip()
if save_to_drive == 'y':
    drive_backup_path = f"{GDRIVE_DATASET_BASE}/best_focusnet_model.pt"
    shutil.copy('best_focusnet_model.pt', drive_backup_path)
    print(f"Model also saved to Google Drive: {drive_backup_path}")

# Display final training summary
print("\n" + "="*60)
print("FOCUSNET TRAINING COMPLETED SUCCESSFULLY!")
print("="*60)
print(f"Architecture: SSD + MobileNetV3 + CBAM")
print(f"Total epochs trained: {num_epochs}")
print(f"Best validation loss: {best_val_loss:.4f}")
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Dataset: {len(train_ds)} training + {len(val_ds)} validation samples")
print(f"Classes: {len(train_ds.categories)} hazard categories")
print("\nFiles created:")
print("- best_focusnet_model.pt (downloaded to your computer)")
if save_to_drive == 'y':
    print("- best_focusnet_model.pt (backed up to Google Drive)")
print("\nYour FocusNet model is ready for low-light road hazard detection!")
print("="*60)