<a href="https://colab.research.google.com/github/taweener11/darkSideUnmasked/blob/main/toy_training_example.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
data_dir = '/content/drive/MyDrive/LFW'  # or just '/content/data' for temporary Colab storage


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class ArcMarginProduct(nn.Module):
    def __init__(self, in_features, out_features, s=30.0, m=0.5, easy_margin=False):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.s = s
        self.m = m
        self.weight = nn.Parameter(torch.FloatTensor(out_features, in_features))
        nn.init.xavier_uniform_(self.weight)
        self.easy_margin = easy_margin
        self.cos_m = torch.cos(torch.tensor(self.m))
        self.sin_m = torch.sin(torch.tensor(self.m))
        self.th = torch.cos(torch.tensor(3.14159265 - self.m))
        self.mm = torch.sin(torch.tensor(3.14159265 - self.m)) * self.m

    def forward(self, input, label):
        cosine = F.linear(F.normalize(input), F.normalize(self.weight))
        sine = torch.sqrt(1.0 - torch.clamp(cosine ** 2, 0, 1))
        phi = cosine * self.cos_m - sine * self.sin_m
        one_hot = torch.zeros_like(cosine)
        one_hot.scatter_(1, label.view(-1, 1), 1)
        if self.easy_margin:
            phi = torch.where(cosine > 0, phi, cosine)
        else:
            phi = torch.where(cosine > self.th, phi, cosine - self.mm)
        output = (one_hot * phi) + ((1.0 - one_hot) * cosine)
        output *= self.s
        return output

In [None]:
!apt-get install iputils-ping

Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
The following NEW packages will be installed:
  iputils-ping
0 upgraded, 1 newly installed, 0 to remove and 34 not upgraded.
Need to get 42.9 kB of archives.
After this operation, 116 kB of additional disk space will be used.
Get:1 http://archive.ubuntu.com/ubuntu jammy/main amd64 iputils-ping amd64 3:20211215-1 [42.9 kB]
Fetched 42.9 kB in 0s (220 kB/s)
Selecting previously unselected package iputils-ping.
(Reading database ... 126102 files and directories currently installed.)
Preparing to unpack .../iputils-ping_3%3a20211215-1_amd64.deb ...
Unpacking iputils-ping (3:20211215-1) ...
Setting up iputils-ping (3:20211215-1) ...
Processing triggers for man-db (2.10.2-1) ...


In [None]:
!ping www.google.com # Check network connectivity. Should see replies if connected


PING www.google.com (173.194.212.147) 56(84) bytes of data.
64 bytes from vq-in-f147.1e100.net (173.194.212.147): icmp_seq=1 ttl=114 time=3.66 ms
64 bytes from vq-in-f147.1e100.net (173.194.212.147): icmp_seq=2 ttl=114 time=0.548 ms
64 bytes from vq-in-f147.1e100.net (173.194.212.147): icmp_seq=3 ttl=114 time=0.677 ms
64 bytes from vq-in-f147.1e100.net (173.194.212.147): icmp_seq=4 ttl=114 time=0.469 ms
64 bytes from vq-in-f147.1e100.net (173.194.212.147): icmp_seq=5 ttl=114 time=0.457 ms

--- www.google.com ping statistics ---
5 packets transmitted, 5 received, 0% packet loss, time 4005ms
rtt min/avg/max/mdev = 0.457/1.161/3.657/1.250 ms


In [None]:
from torchvision.datasets import LFWPeople, Flowers102
import torchvision.transforms as T
from torch.utils.data import DataLoader, Subset
import numpy as np
from collections import Counter

# Define image transform (resize, tensor)
transform = T.Compose([
    T.Resize((112, 112)),
    T.ToTensor(),
])

# Download LFW
# lfw_all = LFWPeople(root=data_dir, split='train', download=True, transform=transform)



# # Get class distribution and select classes with at least 10 images
# counts = Counter(lfw_all.targets)
# min_images_per_class = 10
# top_classes = [cls for cls, cnt in counts.items() if cnt >= min_images_per_class]

# # Filter dataset to only those classes
# idxs = [i for i, t in enumerate(lfw_all.targets) if t in top_classes]
# filtered_targets = [lfw_all.targets[i] for i in idxs]
# # Relabel for contiguous integer labels
# class_map = {old: new for new, old in enumerate(sorted(set(filtered_targets)))}
# filtered_targets = [class_map[t] for t in filtered_targets]

# # Build subset
# lfw_subset = Subset(lfw_all, idxs)
# lfw_subset.targets = filtered_targets # this is a trick to get DataLoader labels right

# n_classes = len(set(filtered_targets))
# print(f"Using {n_classes} classes, {len(filtered_targets)} images.")

In [None]:
from torchvision.datasets import Flowers102

flowers = Flowers102(root='./data', split='train', download=True, transform=transform)
train_loader = DataLoader(flowers, batch_size=64, shuffle=True)
n_classes = 102

100%|██████████| 345M/345M [00:10<00:00, 32.9MB/s]
100%|██████████| 502/502 [00:00<00:00, 548kB/s]
100%|██████████| 15.0k/15.0k [00:00<00:00, 13.9MB/s]


In [None]:

import torchvision.models as models
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
backbone = models.resnet18(weights=None)
feature_dim = backbone.fc.in_features
backbone.fc = nn.Identity()
arc_head = ArcMarginProduct(feature_dim, n_classes).to(device)


In [None]:
model = nn.Sequential(backbone, nn.Flatten(start_dim=1)).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(list(model.parameters()) + list(arc_head.parameters()), lr=1e-3)

for epoch in range(2):
    model.train()
    arc_head.train()
    total, correct, running_loss = 0, 0, 0.0
    for i, (images, labels) in enumerate(train_loader):
        images, labels = images.to(device), labels.to(device)
        features = model(images)
        logits = arc_head(features, labels)
        loss = criterion(logits, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * images.size(0)
        preds = logits.argmax(1)
        correct += (preds == labels).sum().item()
        total += images.size(0)
        if (i+1) % 50 == 0: print(f"Batch {i+1}/{len(train_loader)} - Loss {loss.item():.4f}")
    print(f"Epoch {epoch+1}: Loss={running_loss/total:.4f}  Accuracy={correct/total*100:.2f}%")

Epoch 1: Loss=18.9964  Accuracy=0.00%
Epoch 2: Loss=17.8274  Accuracy=0.00%
