In [1]:
import os

import torch
from dotenv import load_dotenv
import pandas as pd
import torchmetrics

from dataset.dataset import LandmarkDataset
from dataset.augmentations import aug_version_1
from models.networks import LandmarkResidual
from src.utils import read_artifacts_s3, set_seed




load_dotenv()
args = read_artifacts_s3(object_key=os.environ.get("CONFIG_VERSION_0"))
set_seed(args["seed"])
df = pd.read_csv(args["df_path"])
train_dataset = LandmarkDataset(dataframe=df, transform=aug_version_1(args))
trainloader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=32,
    shuffle=True,
    num_workers=0,
    pin_memory=False,
)

In [2]:
net = LandmarkResidual(model='resnet50')
net.to('cpu')
optim = torch.optim.Adam(net.parameters(), lr=0.0003)
criterion = torch.nn.CrossEntropyLoss()
acc = torchmetrics.Accuracy().to('cpu')
f1 = torchmetrics.F1Score(num_classes=495, average='weighted').to('cpu')

Loading weights from S3: 100%|####################################| 102M/102M [00:09<00:00, 11.2MB/s]


In [3]:
overfit_batch = next(iter(trainloader))

In [4]:
for i in range(50):
    images, labels = overfit_batch['images'], overfit_batch['labels']
    images, labels = images.to('cpu'), labels.to('cpu')
    net.zero_grad(set_to_none=True)
    outputs = net(images)
    _, preds = torch.max(outputs, dim=1)
    batch_acc = acc(preds, labels)
    batch_f1 = f1(preds, labels)
    loss = criterion(outputs, labels)
    loss.backward()
    optim.step()
    print(f"[{i+1}/50] loss:{loss:.6f} acc:{batch_acc:.6f} f1:{batch_f1:.6f}")

[1/50] loss:6.199275 acc:0.000000 f1:0.000000
[2/50] loss:6.031512 acc:0.156250 f1:0.094792
[3/50] loss:5.862623 acc:0.656250 f1:0.594345
[4/50] loss:5.693649 acc:0.937500 f1:0.916667
[5/50] loss:5.514441 acc:1.000000 f1:1.000000
[6/50] loss:5.338186 acc:1.000000 f1:1.000000
[7/50] loss:5.152004 acc:1.000000 f1:1.000000
[8/50] loss:4.965442 acc:1.000000 f1:1.000000
[9/50] loss:4.772569 acc:1.000000 f1:1.000000
[10/50] loss:4.576386 acc:1.000000 f1:1.000000
[11/50] loss:4.377149 acc:1.000000 f1:1.000000
[12/50] loss:4.177557 acc:1.000000 f1:1.000000
[13/50] loss:3.979196 acc:1.000000 f1:1.000000
[14/50] loss:3.784426 acc:1.000000 f1:1.000000
[15/50] loss:3.593911 acc:1.000000 f1:1.000000
[16/50] loss:3.407115 acc:1.000000 f1:1.000000
[17/50] loss:3.224083 acc:1.000000 f1:1.000000
[18/50] loss:3.045442 acc:1.000000 f1:1.000000
[19/50] loss:2.871886 acc:1.000000 f1:1.000000
[20/50] loss:2.704116 acc:1.000000 f1:1.000000
[21/50] loss:2.541639 acc:1.000000 f1:1.000000
[22/50] loss:2.384249 