<a href="https://colab.research.google.com/github/raki-rankawat/stm32/blob/main/VWW_KD_Model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [17]:
import os
import time
import tarfile
import random
import shutil
from pathlib import Path
from urllib.request import urlretrieve

import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import DataLoader
from torchvision import datasets, transforms

In [18]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [19]:

# ----------------------------
# Auto Download + Prepare VWW (10k subset)
# ----------------------------
vww_url = "https://www.silabs.com/public/files/github/machine_learning/benchmarks/datasets/vw_coco2014_96.tar.gz"

base_dir = Path("/content/vww_work")
archive_path = base_dir / "vw_coco2014_96.tar.gz"
extract_dir = base_dir / "extracted"
subset_dir = base_dir / "vww_10k"

n_per_class = 5000
val_ratio = 0.20

random.seed(41)
torch.manual_seed(41)

def download_vww():
    base_dir.mkdir(parents=True, exist_ok=True)

    if archive_path.exists() and archive_path.stat().st_size > 0:
        print("‚úÖ VWW archive already downloaded")
        return

    print("‚¨áÔ∏è Downloading VWW archive...")
    urlretrieve(vww_url, archive_path)
    print("‚úÖ Download complete:", archive_path)

def extract_vww():
    extract_dir.mkdir(parents=True, exist_ok=True)

    if any(extract_dir.iterdir()):
        print("‚úÖ VWW already extracted")
        return

    print("üì¶ Extracting VWW archive...")
    with tarfile.open(archive_path, "r:gz") as tar:
        tar.extractall(extract_dir)
    print("‚úÖ Extraction complete:", extract_dir)

def find_vww_root():
    for p in extract_dir.rglob("person"):
        if p.is_dir() and (p.parent / "non_person").is_dir():
            return p.parent
    raise RuntimeError("‚ùå Could not find 'person' and 'non_person' directories under extracted dataset")

def list_images(folder):
    exts = {".jpg", ".jpeg", ".png"}
    return [p for p in folder.rglob("*") if p.is_file() and p.suffix.lower() in exts]

def make_vww_subset(src_root):
    if (subset_dir / "train" / "person").is_dir() and (subset_dir / "val" / "non_person").is_dir():
        print("‚úÖ VWW 10k subset already exists:", subset_dir)
        return

    for split in ["train", "val"]:
        for c in ["person", "non_person"]:
            (subset_dir / split / c).mkdir(parents=True, exist_ok=True)

    person_imgs = list_images(src_root / "person")
    nonperson_imgs = list_images(src_root / "non_person")

    if len(person_imgs) < n_per_class or len(nonperson_imgs) < n_per_class:
        raise ValueError(
            f"‚ùå Not enough images:\n"
            f"person: {len(person_imgs)} (need {n_per_class})\n"
            f"non_person: {len(nonperson_imgs)} (need {n_per_class})"
        )

    random.shuffle(person_imgs)
    random.shuffle(nonperson_imgs)

    person_sel = person_imgs[:n_per_class]
    nonperson_sel = nonperson_imgs[:n_per_class]

    def split_list(lst, val_ratio):
        n_val = int(len(lst) * val_ratio)
        return lst[n_val:], lst[:n_val]  # train, val

    p_train, p_val = split_list(person_sel, val_ratio)
    n_train, n_val = split_list(nonperson_sel, val_ratio)

    def copy_files(files, dst_dir):
        for f in files:
            dst = dst_dir / f.name
            if dst.exists():
                dst = dst_dir / (f"{f.parent.name}_{f.name}")
            shutil.copy2(f, dst)

    print("üß© Creating VWW 10k subset...")
    copy_files(p_train, subset_dir / "train" / "person")
    copy_files(p_val,   subset_dir / "val"   / "person")
    copy_files(n_train, subset_dir / "train" / "non_person")
    copy_files(n_val,   subset_dir / "val"   / "non_person")
    print("‚úÖ VWW subset created at:", subset_dir)

download_vww()
extract_vww()
vww_root = find_vww_root()
print("‚úÖ Found VWW root:", vww_root)
make_vww_subset(vww_root)

‚úÖ VWW archive already downloaded
‚úÖ VWW already extracted
‚úÖ Found VWW root: /content/vww_work/extracted/vw_coco2014_96
‚úÖ VWW 10k subset already exists: /content/vww_work/vww_10k


In [20]:
# ----------------------------
# Data Loaders (same style)
# ----------------------------
batch_size = 64
img_size = 96

train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomResizedCrop(img_size, scale=(0.6, 1.0)),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406),
                         (0.229, 0.224, 0.225))
])

test_transform = transforms.Compose([
    transforms.Resize((img_size, img_size)),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406),
                         (0.229, 0.224, 0.225))
])

train_data = datasets.ImageFolder(root=str(subset_dir / "train"), transform=train_transform)
test_data  = datasets.ImageFolder(root=str(subset_dir / "val"),   transform=test_transform)

train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
test_loader  = DataLoader(test_data,  batch_size=batch_size, shuffle=False)

print("Class mapping:", train_data.class_to_idx)

Class mapping: {'non_person': 0, 'person': 1}


In [21]:
# ----------------------------
# Teacher Model (same style as your VWWConvNet)
# ----------------------------
class VWWConvNet(nn.Module):
    def __init__(self):
        super().__init__()

        self.conv1 = nn.Conv2d(3, 16, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(16)
        self.conv2 = nn.Conv2d(16, 32, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(32)
        self.conv3 = nn.Conv2d(32, 64, 3, padding=1)
        self.bn3 = nn.BatchNorm2d(64)
        self.conv4 = nn.Conv2d(64, 128, 3, padding=1)
        self.bn4 = nn.BatchNorm2d(128)

        self.fc1 = nn.Linear(128 * 6 * 6, 256)
        self.fc2 = nn.Linear(256, 2)

        self.dropout = nn.Dropout(0.30)

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.max_pool2d(x, 2) # 96 -> 48

        x = F.relu(self.bn2(self.conv2(x)))
        x = F.max_pool2d(x, 2) # 48 -> 24

        x = F.relu(self.bn3(self.conv3(x)))
        x = F.max_pool2d(x, 2) # 24 -> 12

        x = F.relu(self.bn4(self.conv4(x)))
        x = F.max_pool2d(x, 2) # 12 -> 6

        x = x.view(x.size(0), -1)

        x = F.relu(self.fc1(x))
        if self.training:
            x = self.dropout(x)
        x = self.fc2(x)

        return x

In [22]:
# Load Teacher weights
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
teacher = VWWConvNet().to(device)

teacher.load_state_dict(torch.load(
    "/content/drive/My Drive/Colab Notebooks/stm_vww_best.pth",
    map_location=torch.device("cpu")
))
teacher.eval()
print("‚úÖ Loaded Teacher: stm_vww_best.pth")

‚úÖ Loaded Teacher: stm_vww_best.pth


In [23]:
# Student Model (Smaller) - keep simple
# ----------------------------
class StudentVWWNet(nn.Module):
    def __init__(self):
        super().__init__()

        self.conv1 = nn.Conv2d(3, 16, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(16)
        self.conv2 = nn.Conv2d(16, 32, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(32)

        # 96 -> 48 -> 24 with 2 pools
        self.fc1 = nn.Linear(32 * 24 * 24, 256)
        self.fc2 = nn.Linear(256, 2)

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.max_pool2d(x, 2)

        x = F.relu(self.bn2(self.conv2(x)))
        x = F.max_pool2d(x, 2)

        x = x.view(x.size(0), -1)

        x = F.relu(self.fc1(x))
        x = self.fc2(x)

        return x

In [24]:
student = StudentVWWNet().to(device)

In [25]:
# ----------------------------
# Optimizer
# ----------------------------
optimizer = torch.optim.Adam(student.parameters(), lr=0.001)

In [26]:
def kd_loss(student_logits, teacher_logits, labels, T=4, alpha=0.3):
    p = F.log_softmax(student_logits / T, dim=1)
    q = F.softmax(teacher_logits / T, dim=1)
    loss_kd = F.kl_div(p, q, reduction='batchmean') * (T * T)

    loss_ce = F.cross_entropy(student_logits, labels)

    return alpha * loss_ce + (1 - alpha) * loss_kd

In [27]:
def evaluate(model, loader):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for X, y in loader:
            X, y = X.to(device), y.to(device)
            out = model(X)
            pred = out.argmax(dim=1)
            correct += (pred == y).sum().item()
            total += y.size(0)
    return 100 * correct / total

In [28]:
acc_teacher = evaluate(teacher, test_loader)
print(f"üéì Teacher Model Accuracy: {acc_teacher:.2f}%")

üéì Teacher Model Accuracy: 78.55%


In [31]:
# ----------------------------
# Student Training (KD)
# ----------------------------
epochs = 30
start_time = time.time()

best_acc = 0.0
best_epoch = 1
best_path = "/content/drive/My Drive/Colab Notebooks/stm_vww_kd_best.pth"

for epoch in range(1, epochs + 1):
    student.train()
    train_losses = 0.0

    for X, y in train_loader:
        X, y = X.to(device), y.to(device)

        student_logits = student(X)
        with torch.no_grad():
            teacher_logits = teacher(X)

        loss = kd_loss(student_logits, teacher_logits, y, T=4, alpha=0.3)

        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()

        train_losses += loss.item()

    acc = evaluate(student, test_loader)
    print(f"Epoch {epoch} / {epochs} | Loss: {train_losses / len(train_loader):.4f} | Student Test Acc: {acc:.2f}%")

    if acc > best_acc:
        best_acc = acc
        best_epoch = epoch
        torch.save(student.state_dict(), best_path)
        print("‚úÖ Student best saved as stm_vww_kd_best.pth for epoch", best_epoch)

print(f"Training time: {(time.time() - start_time)/60:.2f} minutes")
print(f"Best epoch: {best_epoch} | Best Student Test Acc: {best_acc:.2f}%")

Epoch 1 / 30 | Loss: 0.4397 | Student Test Acc: 68.85%
‚úÖ Student best saved as stm_vww_kd_best.pth for epoch 1
Epoch 2 / 30 | Loss: 0.4204 | Student Test Acc: 69.35%
‚úÖ Student best saved as stm_vww_kd_best.pth for epoch 2
Epoch 3 / 30 | Loss: 0.4136 | Student Test Acc: 70.15%
‚úÖ Student best saved as stm_vww_kd_best.pth for epoch 3
Epoch 4 / 30 | Loss: 0.4013 | Student Test Acc: 69.30%
Epoch 5 / 30 | Loss: 0.3937 | Student Test Acc: 65.70%
Epoch 6 / 30 | Loss: 0.3847 | Student Test Acc: 68.30%
Epoch 7 / 30 | Loss: 0.3734 | Student Test Acc: 69.95%
Epoch 8 / 30 | Loss: 0.3819 | Student Test Acc: 68.55%
Epoch 9 / 30 | Loss: 0.3645 | Student Test Acc: 68.70%
Epoch 10 / 30 | Loss: 0.3554 | Student Test Acc: 69.80%
Epoch 11 / 30 | Loss: 0.3517 | Student Test Acc: 69.70%
Epoch 12 / 30 | Loss: 0.3504 | Student Test Acc: 70.65%
‚úÖ Student best saved as stm_vww_kd_best.pth for epoch 12
Epoch 13 / 30 | Loss: 0.3430 | Student Test Acc: 69.70%
Epoch 14 / 30 | Loss: 0.3472 | Student Test Acc: