In [1]:
from face2embeddings.data.setup import create_dataloaders
from pathlib import Path
from face2embeddings.model import FaceSwin
from face2embeddings.utils import create_writer
from face2embeddings.engine import train
from torchvision.transforms import v2
import torch
torch.backends.cuda.matmul.allow_tf32 = True

In [3]:
transform = v2.Compose([
    v2.RandomHorizontalFlip(p=0.3),
    v2.RandomRotation(degrees=(-25, 25), interpolation=v2.InterpolationMode.BILINEAR),
    v2.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3),
    v2.RandomPosterize(bits=4, p=0.25),
    v2.RandomEqualize(p=0.2),
])

In [4]:
train_dataloader, val_dataloader = create_dataloaders(
    train_dir=Path(r"C:\Users\emely\OneDrive\Desktop\face-auth-dataset\train"),
    val_dir=Path(r"C:\Users\emely\OneDrive\Desktop\face-auth-dataset\val"),
    transforms=transform,
    batch_size=5,
    num_workers=6,
)

In [5]:
writer = create_writer(experiment_name="train-8")

[INFO] Created SummaryWriter, saving to: runs\2024-07-06\train-8...


In [6]:
TRAIN_FROM_DEFAULT = False

In [7]:
if TRAIN_FROM_DEFAULT: 
    model = FaceSwin(train_from_default=True).to(torch.device("cuda"))
else:
    model = FaceSwin()
    model.load_encoder(Path("./model/face-swin-encoder-v2.pt"))
    model.to(torch.device("cuda"))
loss_fn = torch.nn.TripletMarginWithDistanceLoss(distance_function=torch.nn.PairwiseDistance(), margin=3)
optimizer = torch.optim.AdamW(model.parameters(), lr=0.0001, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=0.001, end_factor=1, total_iters=30)

In [8]:
model(torch.rand(5, 3, 224, 224).to(torch.device("cuda"))).shape

torch.Size([5, 256])

In [9]:
results = train(
    model=model,
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader,
    loss_fn=loss_fn,
    optimizer=optimizer,
    scheduler=scheduler,
    epochs=3,
    writer=writer,
    device=torch.device("cuda"),
    checkpoint_step=5000,
    checkpoint_path=Path("./model/train/train-8"),
)

Epochs Loop:   0%|          | 0/3 [00:00<?, ?it/s]

Train Step:   0%|          | 0/35939 [00:00<?, ?it/s]

Validation Step:   0%|          | 0/5989 [00:00<?, ?it/s]

Epoch: 1 | train_loss: 0.5149 | val_loss: 0.4544 | 


Train Step:   0%|          | 0/35939 [00:00<?, ?it/s]

Validation Step:   0%|          | 0/5989 [00:00<?, ?it/s]

Epoch: 2 | train_loss: 0.4292 | val_loss: 0.4078 | 


Train Step:   0%|          | 0/35939 [00:00<?, ?it/s]

Validation Step:   0%|          | 0/5989 [00:00<?, ?it/s]

Epoch: 3 | train_loss: 0.3984 | val_loss: 0.4056 | 


In [10]:
torch.save(model.state_dict(), Path("./model/swin-face-v2.pt"))