In [2]:
import torch
import torch.nn as nn
from torchvision.models import vit_b_16, ViT_B_16_Weights
from tqdm import tqdm

model = vit_b_16(weights=ViT_B_16_Weights.IMAGENET1K_V1)
model.heads = nn.Identity()

In [3]:
embedding_dim = 768
num_classes = 250  # the number of identity in the fine-tuning dataset

class FaceIDModel(nn.Module):
    def __init__(self, base, num_classes):
        super().__init__()
        self.base = base
        self.classifier = nn.Sequential(
            nn.Linear(embedding_dim, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Linear(512, num_classes)
        )
    
    def forward(self, x):
        features = self.base(x)
        return self.classifier(features)


In [4]:
import torchvision.transforms as T
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader

transform = T.Compose([
    T.Resize((224, 224)),
    T.RandomHorizontalFlip(),
    T.ColorJitter(0.2, 0.2, 0.2),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

train_ds = ImageFolder("data/vggface2_subset/train", transform=transform)
train_dl = DataLoader(train_ds, batch_size=32, shuffle=True, num_workers=4)

model = FaceIDModel(model, num_classes=len(train_ds.classes)).cuda()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)

print("start training")
for epoch in tqdm(range(10), desc="fine-tuning:"):
    model.train()
    total, correct = 0, 0
    for x, y in train_dl:
        x, y = x.cuda(), y.cuda()
        logits = model(x)
        loss = criterion(logits, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        preds = logits.argmax(1)
        correct += (preds == y).sum().item()
        total += y.size(0)
    
    acc = 100 * correct / total
    print(f"Epoch {epoch+1} | Train Acc: {acc:.2f}% | Loss: {loss.item():.4f}")

start training


fine-tuning::  10%|█         | 1/10 [02:06<18:56, 126.33s/it]

Epoch 1 | Train Acc: 0.64% | Loss: 5.4768


fine-tuning::  20%|██        | 2/10 [04:57<20:23, 152.95s/it]

Epoch 2 | Train Acc: 2.93% | Loss: 5.3049


fine-tuning::  30%|███       | 3/10 [08:07<19:47, 169.63s/it]

Epoch 3 | Train Acc: 5.42% | Loss: 4.9678


fine-tuning::  40%|████      | 4/10 [11:34<18:25, 184.26s/it]

Epoch 4 | Train Acc: 8.27% | Loss: 4.2354


fine-tuning::  50%|█████     | 5/10 [15:13<16:23, 196.79s/it]

Epoch 5 | Train Acc: 12.96% | Loss: 4.0116


fine-tuning::  60%|██████    | 6/10 [18:49<13:34, 203.63s/it]

Epoch 6 | Train Acc: 17.87% | Loss: 3.6797


fine-tuning::  70%|███████   | 7/10 [22:32<10:29, 209.88s/it]

Epoch 7 | Train Acc: 23.07% | Loss: 3.7828


fine-tuning::  80%|████████  | 8/10 [26:16<07:08, 214.18s/it]

Epoch 8 | Train Acc: 29.65% | Loss: 3.1489


fine-tuning::  90%|█████████ | 9/10 [29:52<03:34, 214.86s/it]

Epoch 9 | Train Acc: 38.50% | Loss: 3.1636


fine-tuning:: 100%|██████████| 10/10 [33:31<00:00, 201.12s/it]

Epoch 10 | Train Acc: 47.02% | Loss: 1.9352





In [5]:
torch.save(model.base.state_dict(), "vit_b_face_finetuned.pth")

In [6]:
for epoch in tqdm(range(10), desc="fine-tuning:"):
    model.train()
    total, correct = 0, 0
    for x, y in train_dl:
        x, y = x.cuda(), y.cuda()
        logits = model(x)
        loss = criterion(logits, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        preds = logits.argmax(1)
        correct += (preds == y).sum().item()
        total += y.size(0)
    
    acc = 100 * correct / total
    print(f"Epoch {epoch+1} | Train Acc: {acc:.2f}% | Loss: {loss.item():.4f}")

fine-tuning::  10%|█         | 1/10 [02:23<21:35, 144.00s/it]

Epoch 1 | Train Acc: 55.57% | Loss: 3.1372


fine-tuning::  20%|██        | 2/10 [05:46<23:47, 178.39s/it]

Epoch 2 | Train Acc: 66.27% | Loss: 1.8770


fine-tuning::  30%|███       | 3/10 [09:26<23:01, 197.38s/it]

Epoch 3 | Train Acc: 75.92% | Loss: 1.4154


fine-tuning::  40%|████      | 4/10 [13:14<20:56, 209.35s/it]

Epoch 4 | Train Acc: 86.86% | Loss: 0.8995


fine-tuning::  50%|█████     | 5/10 [17:01<17:59, 215.95s/it]

Epoch 5 | Train Acc: 93.49% | Loss: 0.3825


fine-tuning::  60%|██████    | 6/10 [20:49<14:40, 220.07s/it]

Epoch 6 | Train Acc: 98.22% | Loss: 0.7334


fine-tuning::  70%|███████   | 7/10 [24:58<11:27, 229.27s/it]

Epoch 7 | Train Acc: 99.73% | Loss: 0.3040


fine-tuning::  80%|████████  | 8/10 [28:48<07:39, 229.70s/it]

Epoch 8 | Train Acc: 99.98% | Loss: 0.2130


fine-tuning::  90%|█████████ | 9/10 [32:46<03:52, 232.09s/it]

Epoch 9 | Train Acc: 99.98% | Loss: 0.0570


fine-tuning:: 100%|██████████| 10/10 [36:33<00:00, 219.38s/it]

Epoch 10 | Train Acc: 100.00% | Loss: 0.0521





In [9]:
torch.save(model.base.state_dict(), "vit_b_face_finetuned20.pth")

In [18]:
val_ds = ImageFolder("data/vggface2/val", transform=transform)
val_dl = DataLoader(val_ds, batch_size=32, shuffle=False, num_workers=4)

In [None]:
# 10 epoch fine-tune
model = vit_b_16(weights=ViT_B_16_Weights.IMAGENET1K_V1)
model.heads = nn.Identity()
model.load_state_dict(torch.load("vit_b_face_finetuned.pth"))
model = FaceIDModel(model, num_classes=len(train_ds.classes))
model = model.cuda()
model.eval()
correct = 0
total = 0

with torch.no_grad():
    for x, y in tqdm(val_dl, desc="Evaluating"):
        x, y = x.cuda(), y.cuda()
        logits = model(x)
        preds = torch.argmax(logits, dim=1)
        correct += (preds == y).sum().item()
        total += y.size(0)

acc = correct / total
print(f"\nTop-1 Accuracy on VGGFace2 Val: {acc * 100:.2f}%")

Evaluating: 100%|██████████| 666/666 [03:06<00:00,  3.56it/s]


Top-1 Accuracy on VGGFace2 Val: 0.54%





In [19]:
# 20 epoch fine-tune
model = vit_b_16(weights=ViT_B_16_Weights.IMAGENET1K_V1)
model.heads = nn.Identity()
model.load_state_dict(torch.load("vit_b_face_finetuned20.pth"))
model = FaceIDModel(model, num_classes=len(train_ds.classes))
model = model.cuda()
model.eval()
correct = 0
total = 0

with torch.no_grad():
    for x, y in tqdm(val_dl, desc="Evaluating"):
        x, y = x.cuda(), y.cuda()
        logits = model(x)
        preds = torch.argmax(logits, dim=1)
        correct += (preds == y).sum().item()
        total += y.size(0)

acc = correct / total
print(f"\nTop-1 Accuracy on VGGFace2 Val: {acc * 100:.2f}%")

Evaluating: 100%|██████████| 666/666 [03:37<00:00,  3.06it/s]


Top-1 Accuracy on VGGFace2 Val: 0.79%



