In [None]:
# Colab: install deps
!pip -q install torch torchvision facenet-pytorch opencv-python tqdm

import os, re, glob, math, random, shutil
from pathlib import Path
import numpy as np
import cv2
from PIL import Image
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms, models
from facenet_pytorch import MTCNN


[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.0/61.0 kB[0m [31m3.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.9/1.9 MB[0m [31m61.8 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
cd /content/drive/MyDrive

/content/drive/MyDrive


In [None]:
ls

[0m[01;34m'Colab Notebooks'[0m/   [01;34mFaceQuadrantNet[0m/   Test1.ipynb


In [None]:
# ====== USER CONFIG ======
DATA_ROOT = "/content/drive/MyDrive/FaceQuadrantNet/Dataset"  # <-- change if needed
SAVE_DIR = "/content/drive/MyDrive/FaceQuadrantNet"
os.makedirs(SAVE_DIR, exist_ok=True)

# If you don't have class folders, but a single folder with names in filenames (e.g., alexandra_01.jpg),
# set this True and point DATA_ROOT to that folder.
INFER_FROM_FILENAME = False

# Target classes (unique, case-insensitive)
TARGET_CLASSES = ["alexandra", "courtney", "elizabeth", "henry", "zac"]
TARGET_CLASSES = [c.lower() for c in list(dict.fromkeys(TARGET_CLASSES))]

IMG_SIZE = 224             # model input (square)
BATCH_SIZE = 16
EPOCHS = 15
LR = 1e-4
VAL_SPLIT = 0.15
USE_ALIGNMENT = True       # MTCNN face crop+align
SEED = 42
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)


<torch._C.Generator at 0x7eb3c4bfdab0>

In [None]:
# MTCNN for alignment (fast, robust)
mtcnn = MTCNN(image_size=IMG_SIZE, margin=20, post_process=True, device=DEVICE if DEVICE=="cuda" else None)

# Basic augmentation & normalization (ImageNet stats as we use a ResNet backbone)
train_tfms = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ColorJitter(brightness=0.15, contrast=0.15, saturation=0.1, hue=0.02),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]),
])

val_tfms = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]),
])


In [None]:
class FacesQuadrantDataset(Dataset):
    def __init__(self, root, classes, transform, use_alignment=True, infer_from_filename=False):
        self.root = Path(root)
        self.classes = [c.lower() for c in classes]
        self.class_to_idx = {c:i for i,c in enumerate(self.classes)}
        self.transform = transform
        self.use_alignment = use_alignment
        self.infer_from_filename = infer_from_filename

        self.samples = []
        if not infer_from_filename:
            # Expect subfolders /class/*.jpg
            for cls in self.classes:
                folder = self.root/cls
                if not folder.exists():
                    print(f"[WARN] Missing folder for class: {cls} -> {folder}")
                    continue
                for p in folder.rglob("*"):
                    if p.suffix.lower() in [".jpg",".jpeg",".png",".bmp",".webp"]:
                        self.samples.append((str(p), self.class_to_idx[cls]))
        else:
            # Single folder; infer class from filename prefix
            for p in self.root.rglob("*"):
                if p.suffix.lower() in [".jpg",".jpeg",".png",".bmp",".webp"]:
                    name = p.stem.lower()
                    # take leading alphabetic chunk as candidate label
                    m = re.match(r"([a-z]+)", name)
                    if m:
                        label = m.group(1)
                        if label in self.class_to_idx:
                            self.samples.append((str(p), self.class_to_idx[label]))
                        else:
                            # skip unknown prefix
                            pass

        if len(self.samples) == 0:
            raise RuntimeError("No images found. Check DATA_ROOT and folder/filename setup.")

    def __len__(self):
        return len(self.samples)

    @staticmethod
    def _read_image(path):
        img = cv2.imread(path)
        assert img is not None, f"Failed to read image: {path}"
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        return img

    def _align_face(self, img):
        # Use MTCNN to get aligned face; fallback to raw if fails
        pil_img = Image.fromarray(img)
        try:
            aligned = mtcnn(pil_img)
            if aligned is not None:
                # Convert Tensor (C,H,W) normalized to 0..1 -> uint8 image for downstream transforms
                arr = (aligned.permute(1,2,0).cpu().numpy()*255).astype(np.uint8)
                return arr
        except Exception:
            pass
        # fallback: center resize without alignment
        return cv2.resize(img, (IMG_SIZE, IMG_SIZE), interpolation=cv2.INTER_AREA)

    def __getitem__(self, idx):
        path, label = self.samples[idx]
        img = self._read_image(path)
        img = self._align_face(img) if self.use_alignment else cv2.resize(img, (IMG_SIZE, IMG_SIZE))

        img_t = self.transform(img)   # Tensor (3, H, W)
        return img_t, label, path


In [None]:
class QuadrantFusionNet(nn.Module):
    def __init__(self, num_classes=5, backbone_name="resnet18", emb_dim=512):
        super().__init__()
        # Backbone
        if backbone_name == "resnet18":
            net = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
            feat_dim = net.fc.in_features
            net.fc = nn.Identity()
        else:
            raise NotImplementedError("Only resnet18 implemented here.")
        self.backbone = net
        self.feat_dim = feat_dim

        # Project to embedding (optional)
        self.proj = nn.Sequential(
            nn.Linear(feat_dim*2, emb_dim),
            nn.BatchNorm1d(emb_dim),
            nn.ReLU(inplace=True),
        )

        # Attention over 4 locals: take concat of locals -> weights over 4
        self.local_att = nn.Sequential(
            nn.Linear(feat_dim*4, 256),
            nn.ReLU(inplace=True),
            nn.Linear(256, 4)
        )

        # Classifier head for 5 identities
        self.classifier = nn.Linear(emb_dim, num_classes)

    @staticmethod
    def _split_quadrants(x):
        # x: (B,3,H,W) -> 4 quads (B,3,H/2,W/2)
        _, _, H, W = x.shape
        h2, w2 = H//2, W//2
        TL = x[:, :, 0:h2,   0:w2]
        TR = x[:, :, 0:h2,   w2:W]
        BL = x[:, :, h2:H,   0:w2]
        BR = x[:, :, h2:H,   w2:W]
        return TL, TR, BL, BR

    def _embed_single(self, x):
        # x: (B,3,H,W) -> (B,feat_dim)
        return self.backbone(x)

    def _fuse_global_local(self, Eg, locals_cat, locals_list):
        # locals_cat: concat[E_tl, E_tr, E_bl, E_br] -> (B, 4*feat_dim)
        # locals_list: list of 4 tensors [(B,feat_dim),...]
        # Attention weights over locals
        B = Eg.shape[0]
        weights = self.local_att(locals_cat)                    # (B,4)
        weights = F.softmax(weights, dim=1).unsqueeze(-1)       # (B,4,1)
        locals_stack = torch.stack(locals_list, dim=1)          # (B,4,feat_dim)
        El = torch.sum(weights * locals_stack, dim=1)           # (B,feat_dim)

        fused = torch.cat([Eg, El], dim=1)                      # (B, 2*feat_dim)
        fused = self.proj(fused)                                # (B, emb_dim)
        return fused, weights.squeeze(-1)                       # return weights for logging

    def forward(self, x):
        # Full image
        Eg = self._embed_single(x)                              # (B,feat_dim)

        # Quadrants
        TL, TR, BL, BR = self._split_quadrants(x)
        Etl = self._embed_single(TL)
        Etr = self._embed_single(TR)
        Ebl = self._embed_single(BL)
        Ebr = self._embed_single(BR)

        locals_list = [Etl, Etr, Ebl, Ebr]
        locals_cat  = torch.cat(locals_list, dim=1)             # (B,4*feat_dim)

        fused, att_w = self._fuse_global_local(Eg, locals_cat, locals_list)
        logits = self.classifier(fused)
        return logits, fused, att_w   # att_w in order [TL,TR,BL,BR]

    @torch.no_grad()
    def forward_embeddings(self, x):
        self.eval()
        logits, emb, att_w = self.forward(x)
        return emb, att_w


In [None]:
full_ds = FacesQuadrantDataset(
    root=DATA_ROOT,
    classes=TARGET_CLASSES,
    transform=train_tfms,            # we’ll override for val later
    use_alignment=USE_ALIGNMENT,
    infer_from_filename=INFER_FROM_FILENAME
)

# Train/Val split
val_size = max(1, int(len(full_ds)*VAL_SPLIT))
train_size = len(full_ds) - val_size
train_ds, val_ds = random_split(full_ds, [train_size, val_size],
                                generator=torch.Generator().manual_seed(SEED))

# Fix val transforms (no heavy augs)
val_ds.dataset.transform = val_tfms

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)
val_loader   = DataLoader(val_ds,   batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)

len(full_ds), len(train_ds), len(val_ds)


[WARN] Missing folder for class: zac -> /content/drive/MyDrive/FaceQuadrantNet/Dataset/zac


(349, 297, 52)

In [None]:
model = QuadrantFusionNet(num_classes=len(TARGET_CLASSES)).to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=1e-4)
criterion = nn.CrossEntropyLoss()
scaler = torch.cuda.amp.GradScaler(enabled=(DEVICE=="cuda"))

best_val_acc = 0.0
ckpt_path = os.path.join(SAVE_DIR, "quadrant_fusion_faces.pth")

for epoch in range(1, EPOCHS+1):
    # ---- Train ----
    model.train()
    train_loss, correct, total = 0.0, 0, 0
    for imgs, labels, _ in tqdm(train_loader, desc=f"Epoch {epoch}/{EPOCHS} [train]"):
        imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)
        optimizer.zero_grad(set_to_none=True)
        with torch.cuda.amp.autocast(enabled=(DEVICE=="cuda")):
            logits, fused, att_w = model(imgs)
            loss = criterion(logits, labels)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        train_loss += loss.item() * imgs.size(0)
        preds = logits.argmax(1)
        correct += (preds == labels).sum().item()
        total += imgs.size(0)
    train_acc = correct / total
    train_loss /= total

    # ---- Validate ----
    model.eval()
    val_loss, v_correct, v_total = 0.0, 0, 0
    with torch.no_grad():
        for imgs, labels, _ in tqdm(val_loader, desc=f"Epoch {epoch}/{EPOCHS} [val]"):
            imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)
            logits, fused, att_w = model(imgs)
            loss = criterion(logits, labels)
            val_loss += loss.item() * imgs.size(0)
            v_correct += (logits.argmax(1) == labels).sum().item()
            v_total += imgs.size(0)
    val_acc = v_correct / v_total
    val_loss /= v_total

    print(f"Epoch {epoch:02d}: train_loss={train_loss:.4f} acc={train_acc:.3f} | val_loss={val_loss:.4f} acc={val_acc:.3f}")

    # Save best
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save({
            "model_state": model.state_dict(),
            "classes": TARGET_CLASSES,
            "img_size": IMG_SIZE
        }, ckpt_path)
        print(f"✅ Saved best checkpoint @ {ckpt_path} (val_acc={val_acc:.3f})")




Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth


100%|██████████| 44.7M/44.7M [00:00<00:00, 223MB/s]
  scaler = torch.cuda.amp.GradScaler(enabled=(DEVICE=="cuda"))
  with torch.cuda.amp.autocast(enabled=(DEVICE=="cuda")):
Epoch 1/15 [train]: 100%|██████████| 19/19 [01:06<00:00,  3.51s/it]
Epoch 1/15 [val]: 100%|██████████| 4/4 [00:11<00:00,  2.88s/it]


Epoch 01: train_loss=0.9102 acc=0.741 | val_loss=0.8156 acc=0.692
✅ Saved best checkpoint @ /content/drive/MyDrive/FaceQuadrantNet/quadrant_fusion_faces.pth (val_acc=0.692)


Epoch 2/15 [train]: 100%|██████████| 19/19 [00:01<00:00, 15.05it/s]
Epoch 2/15 [val]: 100%|██████████| 4/4 [00:00<00:00, 12.45it/s]


Epoch 02: train_loss=0.1582 acc=0.990 | val_loss=0.3556 acc=0.904
✅ Saved best checkpoint @ /content/drive/MyDrive/FaceQuadrantNet/quadrant_fusion_faces.pth (val_acc=0.904)


Epoch 3/15 [train]: 100%|██████████| 19/19 [00:01<00:00, 13.28it/s]
Epoch 3/15 [val]: 100%|██████████| 4/4 [00:00<00:00, 12.61it/s]


Epoch 03: train_loss=0.0549 acc=0.997 | val_loss=0.5231 acc=0.865


Epoch 4/15 [train]: 100%|██████████| 19/19 [00:01<00:00, 14.71it/s]
Epoch 4/15 [val]: 100%|██████████| 4/4 [00:00<00:00, 13.13it/s]


Epoch 04: train_loss=0.0459 acc=1.000 | val_loss=0.3153 acc=0.885


Epoch 5/15 [train]: 100%|██████████| 19/19 [00:01<00:00, 14.72it/s]
Epoch 5/15 [val]: 100%|██████████| 4/4 [00:00<00:00, 11.47it/s]


Epoch 05: train_loss=0.0375 acc=0.997 | val_loss=0.2102 acc=0.942
✅ Saved best checkpoint @ /content/drive/MyDrive/FaceQuadrantNet/quadrant_fusion_faces.pth (val_acc=0.942)


Epoch 6/15 [train]: 100%|██████████| 19/19 [00:01<00:00, 12.92it/s]
Epoch 6/15 [val]: 100%|██████████| 4/4 [00:00<00:00, 12.58it/s]


Epoch 06: train_loss=0.0261 acc=1.000 | val_loss=0.1706 acc=0.962
✅ Saved best checkpoint @ /content/drive/MyDrive/FaceQuadrantNet/quadrant_fusion_faces.pth (val_acc=0.962)


Epoch 7/15 [train]: 100%|██████████| 19/19 [00:01<00:00, 12.98it/s]
Epoch 7/15 [val]: 100%|██████████| 4/4 [00:00<00:00, 12.35it/s]


Epoch 07: train_loss=0.0134 acc=1.000 | val_loss=0.1523 acc=0.981
✅ Saved best checkpoint @ /content/drive/MyDrive/FaceQuadrantNet/quadrant_fusion_faces.pth (val_acc=0.981)


Epoch 8/15 [train]: 100%|██████████| 19/19 [00:01<00:00, 13.33it/s]
Epoch 8/15 [val]: 100%|██████████| 4/4 [00:00<00:00, 13.31it/s]


Epoch 08: train_loss=0.0213 acc=1.000 | val_loss=0.2354 acc=0.904


Epoch 9/15 [train]: 100%|██████████| 19/19 [00:01<00:00, 14.65it/s]
Epoch 9/15 [val]: 100%|██████████| 4/4 [00:00<00:00, 13.59it/s]


Epoch 09: train_loss=0.0136 acc=1.000 | val_loss=0.2129 acc=0.942


Epoch 10/15 [train]: 100%|██████████| 19/19 [00:01<00:00, 13.60it/s]
Epoch 10/15 [val]: 100%|██████████| 4/4 [00:00<00:00, 13.55it/s]


Epoch 10: train_loss=0.0160 acc=1.000 | val_loss=0.1173 acc=0.981


Epoch 11/15 [train]: 100%|██████████| 19/19 [00:01<00:00, 14.55it/s]
Epoch 11/15 [val]: 100%|██████████| 4/4 [00:00<00:00, 13.34it/s]


Epoch 11: train_loss=0.0117 acc=1.000 | val_loss=0.1189 acc=1.000
✅ Saved best checkpoint @ /content/drive/MyDrive/FaceQuadrantNet/quadrant_fusion_faces.pth (val_acc=1.000)


Epoch 12/15 [train]: 100%|██████████| 19/19 [00:01<00:00, 13.37it/s]
Epoch 12/15 [val]: 100%|██████████| 4/4 [00:00<00:00, 12.70it/s]


Epoch 12: train_loss=0.0143 acc=1.000 | val_loss=0.1874 acc=0.962


Epoch 13/15 [train]: 100%|██████████| 19/19 [00:01<00:00, 14.26it/s]
Epoch 13/15 [val]: 100%|██████████| 4/4 [00:00<00:00, 11.96it/s]


Epoch 13: train_loss=0.0133 acc=1.000 | val_loss=0.1426 acc=0.962


Epoch 14/15 [train]: 100%|██████████| 19/19 [00:01<00:00, 13.71it/s]
Epoch 14/15 [val]: 100%|██████████| 4/4 [00:00<00:00, 12.40it/s]


Epoch 14: train_loss=0.0098 acc=1.000 | val_loss=0.1051 acc=1.000


Epoch 15/15 [train]: 100%|██████████| 19/19 [00:01<00:00, 14.81it/s]
Epoch 15/15 [val]: 100%|██████████| 4/4 [00:00<00:00, 13.66it/s]

Epoch 15: train_loss=0.0071 acc=1.000 | val_loss=0.1177 acc=1.000





In [None]:
ckpt_path = os.path.join(SAVE_DIR, "quadrant_fusion_faces.pth")

In [None]:
# Load best ckpt (if needed later)
ckpt = torch.load(ckpt_path, map_location=DEVICE)
model.load_state_dict(ckpt["model_state"])
model.eval()

@torch.no_grad()
def preprocess_image(img_path, use_alignment=USE_ALIGNMENT):
    img = cv2.cvtColor(cv2.imread(img_path), cv2.COLOR_BGR2RGB)
    if use_alignment:
        pil_img = Image.fromarray(img)
        aligned = mtcnn(pil_img)
        if aligned is not None:
            arr = (aligned.permute(1,2,0).cpu().numpy()*255).astype(np.uint8)
            img = arr
        else:
            img = cv2.resize(img, (IMG_SIZE, IMG_SIZE))
    else:
        img = cv2.resize(img, (IMG_SIZE, IMG_SIZE))
    ten = val_tfms(img).unsqueeze(0).to(DEVICE)
    return ten

@torch.no_grad()
def predict_class(img_path):
    x = preprocess_image(img_path)
    logits, fused, att_w = model(x)
    prob = F.softmax(logits, dim=1)[0].cpu().numpy()
    pred_idx = int(np.argmax(prob))
    pred_cls = TARGET_CLASSES[pred_idx]
    return pred_cls, prob, att_w[0].cpu().numpy()  # att_w order: [TL, TR, BL, BR]

@torch.no_grad()
def face_embedding(img_path):
    x = preprocess_image(img_path)
    emb, att_w = model.forward_embeddings(x)
    emb = F.normalize(emb, dim=1)  # L2-normalize for cosine
    return emb[0].cpu().numpy(), att_w[0].cpu().numpy()
