In [1]:
import torch
import torch.nn as nn
from torchvision.models import vit_b_16, ViT_B_16_Weights
from tqdm import tqdm

model = vit_b_16(weights=ViT_B_16_Weights.IMAGENET1K_V1)
model.heads = nn.Identity()

In [2]:
embedding_dim = 768
num_classes = 250  # the number of identity in the fine-tuning dataset

class FaceIDModel(nn.Module):
    def __init__(self, base, num_classes):
        super().__init__()
        self.base = base
        self.classifier = nn.Sequential(
            nn.Linear(embedding_dim, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Linear(512, num_classes)
        )
    
    def forward(self, x):
        features = self.base(x)
        return self.classifier(features)


In [3]:
import torchvision.transforms as T
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader

transform = T.Compose([
    T.Resize((224, 224)),
    T.RandomHorizontalFlip(),
    T.ColorJitter(0.2, 0.2, 0.2),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

train_ds = ImageFolder("data/vggface2_subset/train", transform=transform)
train_dl = DataLoader(train_ds, batch_size=32, shuffle=True, num_workers=4)

model = FaceIDModel(model, num_classes=len(train_ds.classes)).cuda()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)

print("start training")
for epoch in tqdm(range(10), desc="fine-tuning:"):
    model.train()
    total, correct = 0, 0
    for x, y in train_dl:
        x, y = x.cuda(), y.cuda()
        logits = model(x)
        loss = criterion(logits, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        preds = logits.argmax(1)
        correct += (preds == y).sum().item()
        total += y.size(0)
    
    acc = 100 * correct / total
    print(f"Epoch {epoch+1} | Train Acc: {acc:.2f}% | Loss: {loss.item():.4f}")

start training


fine-tuning::  10%|█         | 1/10 [02:01<18:12, 121.43s/it]

Epoch 1 | Train Acc: 0.78% | Loss: 5.4293


fine-tuning::  20%|██        | 2/10 [04:15<17:12, 129.11s/it]

Epoch 2 | Train Acc: 1.66% | Loss: 4.9742


fine-tuning::  30%|███       | 3/10 [07:03<17:06, 146.60s/it]

Epoch 3 | Train Acc: 3.94% | Loss: 4.8587


fine-tuning::  40%|████      | 4/10 [10:27<16:55, 169.25s/it]

Epoch 4 | Train Acc: 6.64% | Loss: 4.4627


fine-tuning::  50%|█████     | 5/10 [14:18<15:57, 191.54s/it]

Epoch 5 | Train Acc: 10.38% | Loss: 4.6953


fine-tuning::  60%|██████    | 6/10 [18:10<13:41, 205.26s/it]

Epoch 6 | Train Acc: 14.77% | Loss: 4.1910


fine-tuning::  70%|███████   | 7/10 [21:57<10:37, 212.51s/it]

Epoch 7 | Train Acc: 20.56% | Loss: 3.4828


fine-tuning::  80%|████████  | 8/10 [25:44<07:13, 216.95s/it]

Epoch 8 | Train Acc: 27.68% | Loss: 3.6351


fine-tuning::  90%|█████████ | 9/10 [29:31<03:40, 220.08s/it]

Epoch 9 | Train Acc: 34.11% | Loss: 2.8922


fine-tuning:: 100%|██████████| 10/10 [33:17<00:00, 199.72s/it]

Epoch 10 | Train Acc: 42.08% | Loss: 3.3783





In [4]:
torch.save(model.base.state_dict(), "vit_b_face_finetuned.pth")
torch.save(model.state_dict(), "vit_b_face_finetuned_whole.pth")

In [5]:
for epoch in tqdm(range(10), desc="fine-tuning:"):
    model.train()
    total, correct = 0, 0
    for x, y in train_dl:
        x, y = x.cuda(), y.cuda()
        logits = model(x)
        loss = criterion(logits, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        preds = logits.argmax(1)
        correct += (preds == y).sum().item()
        total += y.size(0)
    
    acc = 100 * correct / total
    print(f"Epoch {epoch+1} | Train Acc: {acc:.2f}% | Loss: {loss.item():.4f}")

fine-tuning::  10%|█         | 1/10 [03:43<33:30, 223.34s/it]

Epoch 1 | Train Acc: 51.46% | Loss: 2.4806


fine-tuning::  20%|██        | 2/10 [07:27<29:49, 223.71s/it]

Epoch 2 | Train Acc: 62.94% | Loss: 1.6940


fine-tuning::  30%|███       | 3/10 [10:36<24:14, 207.83s/it]

Epoch 3 | Train Acc: 72.69% | Loss: 2.2148


fine-tuning::  40%|████      | 4/10 [13:32<19:31, 195.25s/it]

Epoch 4 | Train Acc: 82.32% | Loss: 1.5436


fine-tuning::  50%|█████     | 5/10 [16:32<15:49, 189.94s/it]

Epoch 5 | Train Acc: 91.12% | Loss: 0.8930


fine-tuning::  60%|██████    | 6/10 [19:33<12:27, 186.85s/it]

Epoch 6 | Train Acc: 97.06% | Loss: 0.6144


fine-tuning::  70%|███████   | 7/10 [22:30<09:11, 183.72s/it]

Epoch 7 | Train Acc: 99.52% | Loss: 0.1627


fine-tuning::  80%|████████  | 8/10 [25:26<06:02, 181.11s/it]

Epoch 8 | Train Acc: 99.95% | Loss: 0.1931


fine-tuning::  90%|█████████ | 9/10 [28:10<02:55, 175.81s/it]

Epoch 9 | Train Acc: 99.97% | Loss: 0.0921


fine-tuning:: 100%|██████████| 10/10 [30:53<00:00, 185.30s/it]

Epoch 10 | Train Acc: 100.00% | Loss: 0.0664





In [6]:
torch.save(model.base.state_dict(), "vit_b_face_finetuned20.pth")
torch.save(model.state_dict(), "vit_b_face_finetuned20_whole.pth")

In [7]:
import torchvision.transforms as T
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader

transform = T.Compose([
    T.Resize((224, 224)),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
val_ds = ImageFolder("data/vggface2_subset/train_val", transform=transform)
val_dl = DataLoader(val_ds, batch_size=32, shuffle=False, num_workers=4)

In [8]:
# 10 epoch fine-tune
model = vit_b_16(weights=ViT_B_16_Weights.IMAGENET1K_V1)
model.heads = nn.Identity()
model.load_state_dict(torch.load("vit_b_face_finetuned.pth"))
model = FaceIDModel(model, num_classes=len(val_ds.classes))
model = model.cuda()
model.eval()
correct = 0
total = 0

with torch.no_grad():
    for x, y in tqdm(val_dl, desc="Evaluating"):
        x, y = x.cuda(), y.cuda()
        logits = model(x)
        preds = torch.argmax(logits, dim=1)
        correct += (preds == y).sum().item()
        total += y.size(0)

acc = correct / total
print(f"\nTop-1 Accuracy on VGGFace2 Val: {acc * 100:.2f}%")

Evaluating: 100%|██████████| 79/79 [00:30<00:00,  2.62it/s]


Top-1 Accuracy on VGGFace2 Val: 0.20%





In [9]:
# 20 epoch fine-tune
model = vit_b_16(weights=ViT_B_16_Weights.IMAGENET1K_V1)
model.heads = nn.Identity()
model.load_state_dict(torch.load("vit_b_face_finetuned20.pth"))
model = FaceIDModel(model, num_classes=len(val_ds.classes))
model = model.cuda()
model.eval()
correct = 0
total = 0

with torch.no_grad():
    for x, y in tqdm(val_dl, desc="Evaluating"):
        x, y = x.cuda(), y.cuda()
        logits = model(x)
        preds = torch.argmax(logits, dim=1)
        correct += (preds == y).sum().item()
        total += y.size(0)

acc = correct / total
print(f"\nTop-1 Accuracy on VGGFace2 Val: {acc * 100:.2f}%")

Evaluating: 100%|██████████| 79/79 [00:30<00:00,  2.57it/s]


Top-1 Accuracy on VGGFace2 Val: 0.40%





In [10]:
# 10 epoch fine-tune
model = vit_b_16(weights=ViT_B_16_Weights.IMAGENET1K_V1)
model.heads = nn.Identity()
model = FaceIDModel(model, num_classes=len(val_ds.classes))
model.load_state_dict(torch.load("vit_b_face_finetuned_whole.pth"))
model = model.cuda()
model.eval()
correct = 0
total = 0

with torch.no_grad():
    for x, y in tqdm(val_dl, desc="Evaluating"):
        x, y = x.cuda(), y.cuda()
        logits = model(x)
        preds = torch.argmax(logits, dim=1)
        correct += (preds == y).sum().item()
        total += y.size(0)

acc = correct / total
print(f"\nTop-1 Accuracy on VGGFace2 Val: {acc * 100:.2f}%")

Evaluating: 100%|██████████| 79/79 [00:26<00:00,  2.99it/s]


Top-1 Accuracy on VGGFace2 Val: 20.92%





In [11]:
# 10 epoch fine-tune
model = vit_b_16(weights=ViT_B_16_Weights.IMAGENET1K_V1)
model.heads = nn.Identity()
model = FaceIDModel(model, num_classes=len(val_ds.classes))
model.load_state_dict(torch.load("vit_b_face_finetuned20_whole.pth"))
model = model.cuda()
model.eval()
correct = 0
total = 0

with torch.no_grad():
    for x, y in tqdm(val_dl, desc="Evaluating"):
        x, y = x.cuda(), y.cuda()
        logits = model(x)
        preds = torch.argmax(logits, dim=1)
        correct += (preds == y).sum().item()
        total += y.size(0)

acc = correct / total
print(f"\nTop-1 Accuracy on VGGFace2 Val: {acc * 100:.2f}%")

Evaluating: 100%|██████████| 79/79 [00:25<00:00,  3.09it/s]


Top-1 Accuracy on VGGFace2 Val: 33.16%



