<a href="https://colab.research.google.com/github/raki-rankawat/stm32/blob/main/VWW_Pruned_Model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [48]:
import os
import time
import tarfile
import random
import shutil
from pathlib import Path
from urllib.request import urlretrieve

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.utils.prune as prune

from torch.utils.data import DataLoader
from torchvision import datasets, transforms

In [49]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [50]:
# ----------------------------
# Auto Download + Prepare VWW (10k subset)
# ----------------------------
vww_url = "https://www.silabs.com/public/files/github/machine_learning/benchmarks/datasets/vw_coco2014_96.tar.gz"

base_dir = Path("/content/vww_work")
archive_path = base_dir / "vw_coco2014_96.tar.gz"
extract_dir = base_dir / "extracted"
subset_dir = base_dir / "vww_10k"

n_per_class = 5000
val_ratio = 0.20

random.seed(41)
torch.manual_seed(41)

def download_vww():
    base_dir.mkdir(parents=True, exist_ok=True)

    if archive_path.exists() and archive_path.stat().st_size > 0:
        print("‚úÖ VWW archive already downloaded")
        return

    print("‚¨áÔ∏è Downloading VWW archive...")
    urlretrieve(vww_url, archive_path)
    print("‚úÖ Download complete:", archive_path)

def extract_vww():
    extract_dir.mkdir(parents=True, exist_ok=True)

    if any(extract_dir.iterdir()):
        print("‚úÖ VWW already extracted")
        return

    print("üì¶ Extracting VWW archive...")
    with tarfile.open(archive_path, "r:gz") as tar:
        tar.extractall(extract_dir)
    print("‚úÖ Extraction complete:", extract_dir)

def find_vww_root():
    for p in extract_dir.rglob("person"):
        if p.is_dir() and (p.parent / "non_person").is_dir():
            return p.parent
    raise RuntimeError("‚ùå Could not find 'person' and 'non_person' directories under extracted dataset")

def list_images(folder):
    exts = {".jpg", ".jpeg", ".png"}
    return [p for p in folder.rglob("*") if p.is_file() and p.suffix.lower() in exts]

def make_vww_subset(src_root):
    if (subset_dir / "train" / "person").is_dir() and (subset_dir / "val" / "non_person").is_dir():
        print("‚úÖ VWW 10k subset already exists:", subset_dir)
        return

    for split in ["train", "val"]:
        for c in ["person", "non_person"]:
            (subset_dir / split / c).mkdir(parents=True, exist_ok=True)

    person_imgs = list_images(src_root / "person")
    nonperson_imgs = list_images(src_root / "non_person")

    if len(person_imgs) < n_per_class or len(nonperson_imgs) < n_per_class:
        raise ValueError(
            f"‚ùå Not enough images:\n"
            f"person: {len(person_imgs)} (need {n_per_class})\n"
            f"non_person: {len(nonperson_imgs)} (need {n_per_class})"
        )

    random.shuffle(person_imgs)
    random.shuffle(nonperson_imgs)

    person_sel = person_imgs[:n_per_class]
    nonperson_sel = nonperson_imgs[:n_per_class]

    def split_list(lst, val_ratio):
        n_val = int(len(lst) * val_ratio)
        return lst[n_val:], lst[:n_val]  # train, val

    p_train, p_val = split_list(person_sel, val_ratio)
    n_train, n_val = split_list(nonperson_sel, val_ratio)

    def copy_files(files, dst_dir):
        for f in files:
            dst = dst_dir / f.name
            if dst.exists():
                dst = dst_dir / (f"{f.parent.name}_{f.name}")
            shutil.copy2(f, dst)

    print("üß© Creating VWW 10k subset...")
    copy_files(p_train, subset_dir / "train" / "person")
    copy_files(p_val,   subset_dir / "val"   / "person")
    copy_files(n_train, subset_dir / "train" / "non_person")
    copy_files(n_val,   subset_dir / "val"   / "non_person")
    print("‚úÖ VWW subset created at:", subset_dir)

download_vww()
extract_vww()
vww_root = find_vww_root()
print("‚úÖ Found VWW root:", vww_root)
make_vww_subset(vww_root)

‚úÖ VWW archive already downloaded
‚úÖ VWW already extracted
‚úÖ Found VWW root: /content/vww_work/extracted/vw_coco2014_96
‚úÖ VWW 10k subset already exists: /content/vww_work/vww_10k


In [51]:

# ----------------------------
# Data Loaders (same style)
# ----------------------------
batch_size = 64
img_size = 96

train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomResizedCrop(img_size, scale=(0.6, 1.0)),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406),
                         (0.229, 0.224, 0.225))
])

test_transform = transforms.Compose([
    transforms.Resize((img_size, img_size)),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406),
                         (0.229, 0.224, 0.225))
])

train_data = datasets.ImageFolder(root=str(subset_dir / "train"), transform=train_transform)
test_data  = datasets.ImageFolder(root=str(subset_dir / "val"),   transform=test_transform)

train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
test_loader  = DataLoader(test_data,  batch_size=batch_size, shuffle=False)

print("Class mapping:", train_data.class_to_idx)

Class mapping: {'non_person': 0, 'person': 1}


In [52]:
# ----------------------------
# CNN Model (same style) - must match your trained VWW best
# ----------------------------
class VWWConvNet(nn.Module):
    def __init__(self):
        super().__init__()

        self.conv1 = nn.Conv2d(3, 16, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(16)
        self.conv2 = nn.Conv2d(16, 32, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(32)
        self.conv3 = nn.Conv2d(32, 64, 3, padding=1)
        self.bn3 = nn.BatchNorm2d(64)
        self.conv4 = nn.Conv2d(64, 128, 3, padding=1)
        self.bn4 = nn.BatchNorm2d(128)

        self.fc1 = nn.Linear(128 * 6 * 6, 256)
        self.fc2 = nn.Linear(256, 2)

        self.dropout = nn.Dropout(0.30)

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.max_pool2d(x, 2) # 96 -> 48

        x = F.relu(self.bn2(self.conv2(x)))
        x = F.max_pool2d(x, 2) # 48 -> 24

        x = F.relu(self.bn3(self.conv3(x)))
        x = F.max_pool2d(x, 2) # 24 -> 12

        x = F.relu(self.bn4(self.conv4(x)))
        x = F.max_pool2d(x, 2) # 12 -> 6

        x = x.view(x.size(0), -1)

        x = F.relu(self.fc1(x))
        if self.training:
            x = self.dropout(x)
        x = self.fc2(x)

        return x

In [53]:
# ----------------------------
# Load weights (BEST VWW checkpoint)
# ----------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = VWWConvNet().to(device)

model.load_state_dict(torch.load("/content/drive/My Drive/Colab Notebooks/stm_vww_best.pth", map_location=torch.device("cpu")))

<All keys matched successfully>

In [54]:
# ----------------------------
# Accuracy Before Pruning
# ----------------------------
def test_accuracy_full(model, loader, device):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            y = y.to(device)
            out = model(x)
            pred = out.argmax(dim=1)
            correct += (pred == y).sum().item()
            total += y.size(0)
    return 100.0 * correct / total


acc_base = test_accuracy_full(model, test_loader, device)
print(f"‚úÖ PyTorch FULL test accuracy (BASE): {acc_base:.2f}%")

‚úÖ PyTorch FULL test accuracy (BASE): 76.70%


In [55]:
# ----------------------------
# Pruning
# ----------------------------
PRUNE_AMOUNT = 0.20          # 10% | 20% | 30%
PRUNE_TYPE = "structured"    # "structured" | "unstructured"

layers_to_prune = [
    (model.conv2, "weight"),
    (model.conv3, "weight"),
    (model.conv4, "weight"),
]

if PRUNE_TYPE == "structured":
    for layer, param in layers_to_prune:
        prune.ln_structured(layer, name=param, amount=PRUNE_AMOUNT, n=2, dim=0)
    print(f"‚úÖ Structured pruning: {PRUNE_AMOUNT*100:.0f}% filters on conv2/conv3/conv4")
else:
    for layer, param in layers_to_prune:
        prune.l1_unstructured(layer, name=param, amount=PRUNE_AMOUNT)
    print(f"‚úÖ Unstructured pruning: {PRUNE_AMOUNT*100:.0f}% weights on conv2/conv3/conv4")

‚úÖ Structured pruning: 20% filters on conv2/conv3/conv4


In [56]:
# Accuracy After Pruning (before FT)
acc_after_prune = test_accuracy_full(model, test_loader, device)
print(f"‚úÖ PyTorch FULL test accuracy (AFTER PRUNE, before FT): {acc_after_prune:.2f}%")

‚úÖ PyTorch FULL test accuracy (AFTER PRUNE, before FT): 55.55%


In [57]:
# ----------------------------
# Fine-Tune
# ----------------------------
FT_EPOCHS = 3
FT_LR = 1e-4

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=FT_LR)

start_time = time.time()

for epoch in range(1, FT_EPOCHS + 1):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for x, y in train_loader:
        x = x.to(device)
        y = y.to(device)

        out = model(x)
        loss = criterion(out, y)

        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * y.size(0)
        pred = out.argmax(dim=1)
        correct += (pred == y).sum().item()
        total += y.size(0)

    print(f"FT Epoch {epoch}/{FT_EPOCHS} | Train Loss: {running_loss/total:.4f} | Train Acc: {100*correct/total:.2f}%")

print(f"Fine-tune time: {(time.time() - start_time)/60:.2f} minutes")

FT Epoch 1/3 | Train Loss: 0.4374 | Train Acc: 79.56%
FT Epoch 2/3 | Train Loss: 0.4230 | Train Acc: 80.90%
FT Epoch 3/3 | Train Loss: 0.4176 | Train Acc: 80.59%
Fine-tune time: 0.80 minutes


In [58]:
# Accuracy After Fine-Tune
acc_after_ft = test_accuracy_full(model, test_loader, device)
print(f"‚úÖ PyTorch FULL test accuracy (AFTER FT): {acc_after_ft:.2f}%")

‚úÖ PyTorch FULL test accuracy (AFTER FT): 76.15%


In [59]:
# Make pruning permanent before saving/exporting
for layer, param in layers_to_prune:
    prune.remove(layer, param)

PRUNED_FT_PTH = "/content/drive/My Drive/Colab Notebooks/stm_vww_pruned_model.pth"
torch.save(model.state_dict(), PRUNED_FT_PTH)
print("‚úÖ Saved pruned+finetuned weights:", PRUNED_FT_PTH)

‚úÖ Saved pruned+finetuned weights: /content/drive/My Drive/Colab Notebooks/stm_vww_pruned_model.pth
