In [1]:
import torch
from dataset import CelebMBTIDataset

In [2]:
from utils import assign_splits, get_label_mapping

assign_splits(
    file_path="faces_yolo_metadata.csv",
    id_column="id",
    label_column="mbti",
)

mbti_to_idx = get_label_mapping(
    file_path="faces_yolo_metadata.csv",
    label_column="mbti",
)

for key in mbti_to_idx:
    if key.startswith("I"):
        mbti_to_idx[key] = 0
    else: 
        mbti_to_idx[key] = 1

print("MBTI to index mapping:", mbti_to_idx)

split
train    5609
test      715
val       694
Name: count, dtype: int64
MBTI to index mapping: {'ENFJ': 1, 'ENFP': 1, 'ENTJ': 1, 'ENTP': 1, 'ESFJ': 1, 'ESFP': 1, 'ESTJ': 1, 'ESTP': 1, 'INFJ': 0, 'INFP': 0, 'INTJ': 0, 'INTP': 0, 'ISFJ': 0, 'ISFP': 0, 'ISTJ': 0, 'ISTP': 0}


In [3]:
from torchvision import transforms

train_tf = transforms.Compose(
    [
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomApply([transforms.ColorJitter(0.25, 0.25, 0.25, 0.05)], p=0.8),
        transforms.RandomApply(
            [
                transforms.RandomAffine(
                    degrees=10,  # small rotation
                    translate=(0.05, 0.05),  # small shifts
                    scale=(0.9, 1.1),
                )
            ],
            p=0.5,
        ),
        transforms.RandomApply([transforms.GaussianBlur(kernel_size=3)], p=0.2),
        transforms.ToTensor(),
        transforms.RandomErasing(
            p=0.25, scale=(0.02, 0.1), ratio=(0.3, 3.3), value="random"
        ),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ]
)

val_tf = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ]
)

In [4]:
from torch.utils.data import DataLoader

train_ds = CelebMBTIDataset(
    root_dir="faces_yolo",
    metadata_csv="faces_yolo_metadata.csv",
    split="train",
    transform=train_tf,
    mbti_to_idx=mbti_to_idx,
)

val_ds = CelebMBTIDataset(
    root_dir="faces_yolo",
    metadata_csv="faces_yolo_metadata.csv",
    split="val",
    transform=val_tf,
    mbti_to_idx=mbti_to_idx,
)

test_ds = CelebMBTIDataset(
    root_dir="faces_yolo",
    metadata_csv="faces_yolo_metadata.csv",
    split="test",
    transform=val_tf,
    mbti_to_idx=mbti_to_idx,
)

train_loader = DataLoader(train_ds, batch_size=32, shuffle=True, num_workers=4)
val_loader = DataLoader(val_ds, batch_size=64, shuffle=False, num_workers=4)

for images, labels in train_loader:
    print(images.shape)  # [Batch, 3, 224, 224]
    print(labels.shape)  # [Batch, num_outputs]
    break

torch.Size([32, 3, 224, 224])
torch.Size([32])


In [None]:
from torchvision.models import resnet18, ResNet18_Weights
import torch.nn as nn

num_classes = 2

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
in_features = model.fc.in_features

model.fc = nn.Sequential(
    nn.Dropout(p=0.5),
    nn.Linear(in_features, num_classes),
)

model.to(device)
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)

In [7]:
from tqdm import tqdm

for epoch in range(50):
    # training phase with tqdm
    pbar = tqdm(train_loader, desc=f"Epoch {epoch:02d} [train]", leave=False)
    model.train()
    total_loss, total_correct, total = 0.0, 0, 0
    for x, y in pbar:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        logits = model(x)
        loss = criterion(logits, y)
        loss.backward()
        optimizer.step()

        preds = logits.argmax(1)
        total_correct += (preds == y).sum().item()
        total_loss += loss.item() * y.size(0)
        total += y.size(0)

        pbar.set_postfix(loss=loss.item())

    tr_loss = total_loss / total
    tr_acc = total_correct / total

    # validation phase with tqdm
    pbar = tqdm(val_loader, desc=f"Epoch {epoch:02d} [val]", leave=False)
    model.eval()
    total_loss, total_correct, total = 0.0, 0, 0
    with torch.no_grad():
        for x, y in pbar:
            x, y = x.to(device), y.to(device)
            logits = model(x)
            loss = criterion(logits, y)
            preds = logits.argmax(1)
            total_correct += (preds == y).sum().item()
            total_loss += loss.item() * y.size(0)
            total += y.size(0)

            pbar.set_postfix(loss=loss.item())

    va_loss = total_loss / total
    va_acc = total_correct / total

    print(
        f"Epoch {epoch:02d} | "
        f"train loss {tr_loss:.4f}/ train accuracy {tr_acc:.3f} | "
        f"val loss {va_loss:.4f}/ val accuracy {va_acc:.3f}"
    )

                                                                               

Epoch 00 | train loss 0.7688/ train accuracy 0.559 | val loss 0.6822/ val accuracy 0.588


                                                                               

Epoch 01 | train loss 0.7292/ train accuracy 0.579 | val loss 0.7218/ val accuracy 0.548


                                                                               

Epoch 02 | train loss 0.6986/ train accuracy 0.610 | val loss 0.6685/ val accuracy 0.617


                                                                               

Epoch 03 | train loss 0.6628/ train accuracy 0.629 | val loss 0.7096/ val accuracy 0.615


                                                                               

Epoch 04 | train loss 0.6334/ train accuracy 0.664 | val loss 0.7277/ val accuracy 0.578


                                                                               

Epoch 05 | train loss 0.5984/ train accuracy 0.703 | val loss 0.6857/ val accuracy 0.633


                                                                               

Epoch 06 | train loss 0.5594/ train accuracy 0.743 | val loss 0.7045/ val accuracy 0.622


                                                                               

Epoch 07 | train loss 0.5406/ train accuracy 0.755 | val loss 0.6536/ val accuracy 0.666


                                                                               

Epoch 08 | train loss 0.4978/ train accuracy 0.792 | val loss 0.6991/ val accuracy 0.667


                                                                               

Epoch 09 | train loss 0.4549/ train accuracy 0.832 | val loss 0.7046/ val accuracy 0.674


                                                                               

Epoch 10 | train loss 0.4400/ train accuracy 0.844 | val loss 0.7868/ val accuracy 0.634


                                                                               

Epoch 11 | train loss 0.4096/ train accuracy 0.866 | val loss 0.8094/ val accuracy 0.700


                                                                               

Epoch 12 | train loss 0.3810/ train accuracy 0.887 | val loss 0.6986/ val accuracy 0.702


                                                                               

Epoch 13 | train loss 0.3635/ train accuracy 0.903 | val loss 0.7156/ val accuracy 0.710


                                                                               

Epoch 14 | train loss 0.3616/ train accuracy 0.905 | val loss 0.7898/ val accuracy 0.697


                                                                               

Epoch 15 | train loss 0.3404/ train accuracy 0.918 | val loss 0.7195/ val accuracy 0.709


                                                                               

Epoch 16 | train loss 0.3307/ train accuracy 0.931 | val loss 0.7089/ val accuracy 0.699


                                                                               

Epoch 17 | train loss 0.3153/ train accuracy 0.937 | val loss 0.6661/ val accuracy 0.710


                                                                               

Epoch 18 | train loss 0.3042/ train accuracy 0.950 | val loss 0.6738/ val accuracy 0.715


                                                                               

Epoch 19 | train loss 0.3042/ train accuracy 0.945 | val loss 0.7253/ val accuracy 0.709


                                                                               

Epoch 20 | train loss 0.2945/ train accuracy 0.953 | val loss 0.6567/ val accuracy 0.710


                                                                               

Epoch 21 | train loss 0.2928/ train accuracy 0.956 | val loss 0.6671/ val accuracy 0.726


                                                                               

Epoch 22 | train loss 0.2877/ train accuracy 0.958 | val loss 0.6656/ val accuracy 0.723


                                                                               

Epoch 23 | train loss 0.2865/ train accuracy 0.962 | val loss 0.6375/ val accuracy 0.720


                                                                               

Epoch 24 | train loss 0.2813/ train accuracy 0.961 | val loss 0.6381/ val accuracy 0.736


                                                                               

Epoch 25 | train loss 0.2801/ train accuracy 0.965 | val loss 0.6515/ val accuracy 0.723


                                                                               

Epoch 26 | train loss 0.2732/ train accuracy 0.969 | val loss 0.6503/ val accuracy 0.728


                                                                               

Epoch 27 | train loss 0.2791/ train accuracy 0.966 | val loss 0.6125/ val accuracy 0.759


                                                                               

Epoch 28 | train loss 0.2718/ train accuracy 0.967 | val loss 0.5995/ val accuracy 0.764


                                                                               

Epoch 29 | train loss 0.2694/ train accuracy 0.971 | val loss 0.6348/ val accuracy 0.736


                                                                               

Epoch 30 | train loss 0.2587/ train accuracy 0.979 | val loss 0.6358/ val accuracy 0.752


                                                                               

Epoch 31 | train loss 0.2628/ train accuracy 0.974 | val loss 0.6413/ val accuracy 0.722


                                                                               

Epoch 32 | train loss 0.2656/ train accuracy 0.972 | val loss 0.6192/ val accuracy 0.741


                                                                               

Epoch 33 | train loss 0.2635/ train accuracy 0.975 | val loss 0.6050/ val accuracy 0.758


                                                                               

Epoch 34 | train loss 0.2551/ train accuracy 0.980 | val loss 0.6324/ val accuracy 0.745


                                                                               

Epoch 35 | train loss 0.2530/ train accuracy 0.981 | val loss 0.6357/ val accuracy 0.744


                                                                               

Epoch 36 | train loss 0.2524/ train accuracy 0.981 | val loss 0.6157/ val accuracy 0.744


                                                                               

Epoch 37 | train loss 0.2527/ train accuracy 0.982 | val loss 0.6193/ val accuracy 0.751


                                                                               

Epoch 38 | train loss 0.2604/ train accuracy 0.978 | val loss 0.6163/ val accuracy 0.744


                                                                               

Epoch 39 | train loss 0.2558/ train accuracy 0.978 | val loss 0.6201/ val accuracy 0.764


                                                                               

Epoch 40 | train loss 0.2529/ train accuracy 0.979 | val loss 0.6594/ val accuracy 0.725


                                                                               

Epoch 41 | train loss 0.2528/ train accuracy 0.979 | val loss 0.6018/ val accuracy 0.775


                                                                               

Epoch 42 | train loss 0.2488/ train accuracy 0.981 | val loss 0.6140/ val accuracy 0.754


                                                                               

Epoch 43 | train loss 0.2550/ train accuracy 0.978 | val loss 0.6470/ val accuracy 0.723


                                                                               

Epoch 44 | train loss 0.2486/ train accuracy 0.982 | val loss 0.6183/ val accuracy 0.767


                                                                               

Epoch 45 | train loss 0.2465/ train accuracy 0.984 | val loss 0.6316/ val accuracy 0.758


                                                                               

Epoch 46 | train loss 0.2460/ train accuracy 0.981 | val loss 0.6107/ val accuracy 0.755


                                                                               

Epoch 47 | train loss 0.2487/ train accuracy 0.981 | val loss 0.5946/ val accuracy 0.762


                                                                               

Epoch 48 | train loss 0.2454/ train accuracy 0.985 | val loss 0.5974/ val accuracy 0.768


                                                                               

Epoch 49 | train loss 0.2390/ train accuracy 0.988 | val loss 0.5790/ val accuracy 0.767


