In [22]:
import torch
from dataset import CelebMBTIDataset

In [None]:
from utils import assign_splits, get_label_mapping

assign_splits(
    file_path="faces_yolo_metadata.csv",
    id_column="id",
    label_column="mbti",
)    

split
train    5609
test      715
val       694
Name: count, dtype: int64


In [None]:
mbti_to_idx = get_label_mapping(
    file_path="faces_yolo_metadata.csv",
    label_column="mbti",
)

MBTI to index mapping: {'ENFJ': 0, 'ENFP': 1, 'ENTJ': 2, 'ENTP': 3, 'ESFJ': 4, 'ESFP': 5, 'ESTJ': 6, 'ESTP': 7, 'INFJ': 8, 'INFP': 9, 'INTJ': 10, 'INTP': 11, 'ISFJ': 12, 'ISFP': 13, 'ISTJ': 14, 'ISTP': 15}


In [25]:
from torchvision import transforms

train_transform = transforms.Compose(
    [
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)

val_transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)

In [None]:
train_ds = CelebMBTIDataset(
    root_dir="faces_yolo",
    metadata_csv="faces_yolo_metadata.csv",
    split="train",
    transform=train_transform,
    mbti_to_idx=mbti_to_idx,
)

val_ds = CelebMBTIDataset(
    root_dir="faces_yolo",
    metadata_csv="faces_yolo_metadata.csv",
    split="val",
    transform=val_transform,
    mbti_to_idx=mbti_to_idx,
)

test_ds = CelebMBTIDataset(
    root_dir="faces_yolo",
    metadata_csv="faces_yolo_metadata.csv",
    split="test",
    transform=val_transform,
    mbti_to_idx=mbti_to_idx,
)

In [27]:
from torch.utils.data import DataLoader

train_loader = DataLoader(train_ds, batch_size=32, shuffle=True, num_workers=4)
val_loader = DataLoader(val_ds, batch_size=64, shuffle=False, num_workers=4)

device = "cuda" if torch.cuda.is_available() else "cpu"
num_classes = len(mbti_to_idx)

## Baseline: majority guessing

In [28]:
import collections

labels = [lbl for _, lbl in train_ds]
majority_label = collections.Counter(labels).most_common(1)[0][0]


def eval_majority_baseline():
    correct, total = 0, 0
    with torch.no_grad():
        for x, y in val_loader:
            bs = y.size(0)
            pred = torch.full_like(y, fill_value=majority_label)
            correct += (pred == y).sum().item()
            total += bs
    print(f"Majority baseline acc: {correct/total:.3f}")


eval_majority_baseline()

Majority baseline acc: 0.076


## Resnet 50, pre-trained

In [29]:
from torchvision.models import resnet50, ResNet50_Weights
import torch.nn as nn

model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
in_features = model.fc.in_features
model.fc = nn.Linear(in_features, num_classes)
model.to(device)

criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)

Downloading: "https://download.pytorch.org/models/resnet50-11ad3fa6.pth" to C:\Users\leanh/.cache\torch\hub\checkpoints\resnet50-11ad3fa6.pth


100.0%


In [30]:
def run_epoch(loader, train=True):
    if train:
        model.train()
    else:
        model.eval()
    total_loss, total_correct, total = 0.0, 0, 0
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        if train:
            optimizer.zero_grad()
        logits = model(x)
        loss = criterion(logits, y)
        if train:
            loss.backward()
            optimizer.step()
        with torch.no_grad():
            preds = logits.argmax(1)
            total_correct += (preds == y).sum().item()
            total_loss += loss.item() * y.size(0)
            total += y.size(0)
    return total_loss / total, total_correct / total

In [32]:
from tqdm import tqdm

for epoch in range(50):
    # training phase with tqdm
    pbar = tqdm(train_loader, desc=f"Epoch {epoch:02d} [train]", leave=False)
    model.train()
    total_loss, total_correct, total = 0.0, 0, 0
    for x, y in pbar:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        logits = model(x)
        loss = criterion(logits, y)
        loss.backward()
        optimizer.step()

        preds = logits.argmax(1)
        total_correct += (preds == y).sum().item()
        total_loss += loss.item() * y.size(0)
        total += y.size(0)

        pbar.set_postfix(loss=loss.item())

    tr_loss = total_loss / total
    tr_acc = total_correct / total

    # validation phase with tqdm
    pbar = tqdm(val_loader, desc=f"Epoch {epoch:02d} [val]", leave=False)
    model.eval()
    total_loss, total_correct, total = 0.0, 0, 0
    with torch.no_grad():
        for x, y in pbar:
            x, y = x.to(device), y.to(device)
            logits = model(x)
            loss = criterion(logits, y)
            preds = logits.argmax(1)
            total_correct += (preds == y).sum().item()
            total_loss += loss.item() * y.size(0)
            total += y.size(0)

            pbar.set_postfix(loss=loss.item())

    va_loss = total_loss / total
    va_acc = total_correct / total

    print(
        f"Epoch {epoch:02d} | "
        f"train loss {tr_loss:.4f}/ train accuracy {tr_acc:.3f} | "
        f"val loss {va_loss:.4f}/ val accuracy {va_acc:.3f}"
    )

                                                                              

Epoch 00 | train 2.7329/0.099 | val 2.7114/0.108


                                                                              

Epoch 01 | train 2.5730/0.204 | val 2.6550/0.150


                                                                              

Epoch 02 | train 2.3101/0.327 | val 2.6765/0.161


                                                                              

Epoch 03 | train 1.9112/0.508 | val 2.6941/0.163


                                                                              

Epoch 04 | train 1.4483/0.699 | val 2.8142/0.183


                                                                               

Epoch 05 | train 1.0665/0.842 | val 2.8884/0.199


                                                                               

Epoch 06 | train 0.8753/0.923 | val 2.7824/0.220


                                                                               

Epoch 07 | train 0.7668/0.960 | val 2.7878/0.242


                                                                               

Epoch 08 | train 0.6966/0.985 | val 2.7277/0.251


                                                                               

Epoch 09 | train 0.6681/0.993 | val 2.6936/0.256


                                                                               

Epoch 10 | train 0.6467/0.996 | val 2.6895/0.249


                                                                               

Epoch 11 | train 0.6309/0.997 | val 2.6254/0.264


                                                                               

Epoch 12 | train 0.6218/0.998 | val 2.6501/0.246


                                                                               

Epoch 13 | train 0.6151/0.998 | val 2.6110/0.268


                                                                               

Epoch 14 | train 0.6128/0.999 | val 2.6243/0.265


                                                                               

Epoch 15 | train 0.6057/0.999 | val 2.6315/0.254


                                                                               

Epoch 16 | train 0.6017/0.999 | val 2.6270/0.254


                                                                               

Epoch 17 | train 0.5991/0.999 | val 2.6077/0.282


                                                                               

Epoch 18 | train 0.5971/0.999 | val 2.6151/0.265


                                                                               

Epoch 19 | train 0.5970/0.999 | val 2.6235/0.262


                                                                               

Epoch 20 | train 0.5969/0.999 | val 2.6511/0.256


                                                                               

Epoch 21 | train 0.5931/0.999 | val 2.6181/0.265


                                                                               

Epoch 22 | train 0.5934/0.999 | val 2.6385/0.281


                                                                               

Epoch 23 | train 0.5935/0.999 | val 2.7064/0.255


                                                                               

Epoch 24 | train 0.6085/0.996 | val 2.7527/0.258


                                                                               

Epoch 25 | train 0.7391/0.949 | val 3.0588/0.225


                                                                               

Epoch 26 | train 0.7992/0.929 | val 2.9839/0.235


                                                                               

Epoch 27 | train 0.6830/0.976 | val 2.8150/0.259


                                                                               

Epoch 28 | train 0.6286/0.993 | val 2.7096/0.274


                                                         

KeyboardInterrupt: 

## Resnet with tougher regularizations

In [34]:
from torchvision import transforms

train_tf = transforms.Compose(
    [
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomApply([transforms.ColorJitter(0.25, 0.25, 0.25, 0.05)], p=0.8),
        transforms.RandomApply(
            [
                transforms.RandomAffine(
                    degrees=10,  # small rotation
                    translate=(0.05, 0.05),  # small shifts
                    scale=(0.9, 1.1),
                )
            ],
            p=0.5,
        ),
        transforms.RandomApply([transforms.GaussianBlur(kernel_size=3)], p=0.2),
        transforms.ToTensor(),
        transforms.RandomErasing(
            p=0.25, scale=(0.02, 0.1), ratio=(0.3, 3.3), value="random"
        ),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ]
)

val_tf = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ]
)

In [35]:
train_ds = CelebMBTIDataset(
    root_dir="faces_yolo",
    metadata_csv="faces_yolo_metadata.csv",
    split="train",
    transform=train_transform,
    mbti_to_idx=mbti_to_idx,
)

val_ds = CelebMBTIDataset(
    root_dir="faces_yolo",
    metadata_csv="faces_yolo_metadata.csv",
    split="val",
    transform=val_transform,
    mbti_to_idx=mbti_to_idx,
)

test_ds = CelebMBTIDataset(
    root_dir="faces_yolo",
    metadata_csv="faces_yolo_metadata.csv",
    split="test",
    transform=val_transform,
    mbti_to_idx=mbti_to_idx,
)

In [36]:
from torch.utils.data import DataLoader

train_loader = DataLoader(train_ds, batch_size=32, shuffle=True, num_workers=4)
val_loader = DataLoader(val_ds, batch_size=64, shuffle=False, num_workers=4)

device = "cuda" if torch.cuda.is_available() else "cpu"
num_classes = len(mbti_to_idx)

In [37]:
from torchvision.models import resnet50, ResNet50_Weights

model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
in_features = model.fc.in_features
model.fc = nn.Sequential(
    nn.Dropout(p=0.5),
    nn.Linear(in_features, num_classes),
)

for name, p in model.named_parameters():
    p.requires_grad = False

# unfreeze last block + head
for name, p in model.named_parameters():
    if name.startswith("layer4") or name.startswith("fc"):
        p.requires_grad = True

### Freeze all but last layer

In [38]:
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
optimizer = torch.optim.AdamW(
    filter(lambda p: p.requires_grad, model.parameters()),
    lr=3e-4,  # slightly higher while only last layers are trainable
    weight_decay=1e-4,
)

In [39]:
from tqdm import tqdm

model.to(device)

for epoch in range(5):
    # training phase with tqdm
    pbar = tqdm(train_loader, desc=f"Epoch {epoch:02d} [train]", leave=False)
    model.train()
    total_loss, total_correct, total = 0.0, 0, 0
    for x, y in pbar:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        logits = model(x)
        loss = criterion(logits, y)
        loss.backward()
        optimizer.step()

        preds = logits.argmax(1)
        total_correct += (preds == y).sum().item()
        total_loss += loss.item() * y.size(0)
        total += y.size(0)

        pbar.set_postfix(loss=loss.item())

    tr_loss = total_loss / total
    tr_acc = total_correct / total

    # validation phase with tqdm
    pbar = tqdm(val_loader, desc=f"Epoch {epoch:02d} [val]", leave=False)
    model.eval()
    total_loss, total_correct, total = 0.0, 0, 0
    with torch.no_grad():
        for x, y in pbar:
            x, y = x.to(device), y.to(device)
            logits = model(x)
            loss = criterion(logits, y)
            preds = logits.argmax(1)
            total_correct += (preds == y).sum().item()
            total_loss += loss.item() * y.size(0)
            total += y.size(0)

            pbar.set_postfix(loss=loss.item())

    va_loss = total_loss / total
    va_acc = total_correct / total

    print(
        f"Epoch {epoch:02d} | "
        f"train loss {tr_loss:.4f}/ train accuracy {tr_acc:.3f} | "
        f"val loss {va_loss:.4f}/ val accuracy {va_acc:.3f}"
    )

                                                                              

Epoch 00 | train loss 2.7296/ train accuracy 0.102 | val loss 2.6895/ val accuracy 0.127


                                                                              

Epoch 01 | train loss 2.5744/ train accuracy 0.190 | val loss 2.7063/ val accuracy 0.122


                                                                              

Epoch 02 | train loss 2.3168/ train accuracy 0.312 | val loss 2.7044/ val accuracy 0.161


                                                                              

Epoch 03 | train loss 1.8546/ train accuracy 0.522 | val loss 2.7575/ val accuracy 0.171


                                                                              

Epoch 04 | train loss 1.3431/ train accuracy 0.736 | val loss 2.9254/ val accuracy 0.169




### Unfreeze all layers and re-train with smaller weights

In [None]:
for name, p in model.named_parameters():
    p.requires_grad = True

criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
optimizer = torch.optim.AdamW(
    lr=3e-5,  # slightly higher while only last layers are trainable
    weight_decay=1e-4,
)

In [41]:
for epoch in range(15):
    # training phase with tqdm
    pbar = tqdm(train_loader, desc=f"Epoch {epoch:02d} [train]", leave=False)
    model.train()
    total_loss, total_correct, total = 0.0, 0, 0
    for x, y in pbar:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        logits = model(x)
        loss = criterion(logits, y)
        loss.backward()
        optimizer.step()

        preds = logits.argmax(1)
        total_correct += (preds == y).sum().item()
        total_loss += loss.item() * y.size(0)
        total += y.size(0)

        pbar.set_postfix(loss=loss.item())

    tr_loss = total_loss / total
    tr_acc = total_correct / total

    # validation phase with tqdm
    pbar = tqdm(val_loader, desc=f"Epoch {epoch:02d} [val]", leave=False)
    model.eval()
    total_loss, total_correct, total = 0.0, 0, 0
    with torch.no_grad():
        for x, y in pbar:
            x, y = x.to(device), y.to(device)
            logits = model(x)
            loss = criterion(logits, y)
            preds = logits.argmax(1)
            total_correct += (preds == y).sum().item()
            total_loss += loss.item() * y.size(0)
            total += y.size(0)

            pbar.set_postfix(loss=loss.item())

    va_loss = total_loss / total
    va_acc = total_correct / total

    print(
        f"Epoch {epoch:02d} | "
        f"train loss {tr_loss:.4f}/ train accuracy {tr_acc:.3f} | "
        f"val loss {va_loss:.4f}/ val accuracy {va_acc:.3f}"
    )

                                                                               

Epoch 00 | train loss 0.9566/ train accuracy 0.902 | val loss 2.7491/ val accuracy 0.197


                                                                               

Epoch 01 | train loss 0.8297/ train accuracy 0.953 | val loss 2.7742/ val accuracy 0.212


                                                                               

Epoch 02 | train loss 0.7712/ train accuracy 0.972 | val loss 2.7496/ val accuracy 0.200


                                                                               

Epoch 03 | train loss 0.7281/ train accuracy 0.983 | val loss 2.7429/ val accuracy 0.215


                                                                               

Epoch 04 | train loss 0.7001/ train accuracy 0.992 | val loss 2.7257/ val accuracy 0.206


                                                                               

Epoch 05 | train loss 0.6847/ train accuracy 0.993 | val loss 2.7461/ val accuracy 0.218


                                                                               

Epoch 06 | train loss 0.6675/ train accuracy 0.996 | val loss 2.7241/ val accuracy 0.218


                                                                               

Epoch 07 | train loss 0.6605/ train accuracy 0.997 | val loss 2.7220/ val accuracy 0.222


                                                                               

Epoch 08 | train loss 0.6524/ train accuracy 0.998 | val loss 2.7136/ val accuracy 0.223


                                                                               

Epoch 09 | train loss 0.6463/ train accuracy 0.999 | val loss 2.7115/ val accuracy 0.219


                                                                               

Epoch 10 | train loss 0.6406/ train accuracy 0.998 | val loss 2.6946/ val accuracy 0.215


                                                                               

Epoch 11 | train loss 0.6378/ train accuracy 0.999 | val loss 2.6626/ val accuracy 0.232


                                                                               

Epoch 12 | train loss 0.6326/ train accuracy 0.999 | val loss 2.6670/ val accuracy 0.215


                                                                               

Epoch 13 | train loss 0.6275/ train accuracy 0.999 | val loss 2.6760/ val accuracy 0.222


                                                                               

Epoch 14 | train loss 0.6272/ train accuracy 0.999 | val loss 2.6614/ val accuracy 0.231


