In [1]:
import os
from tqdm import tqdm

import torch
import torch.nn as nn

from backbone import ResNet50
from loss import ArcFaceLoss

from utils.dataloaders import create_dataloader
from utils.device import select_device

In [3]:
class Args():
    def __init__(self):
        self.epochs = 3
        self.learning_rate = 1e-3
        self.data = 'data'
        self.batch_size = 32
        self. image_size = 224
        self.num_workers = -1
        self.embedding_size = 512
        self.margin_loss = 0.3
        self.scale_loss = 30
        self.device = None
        
args = Args()

In [5]:
args.num_workers = os.cpu_count() if args.num_workers == -1 else args.num_workers
    
args.image_size = (args.image_size, args.image_size) if isinstance(args.image_size, int) else args.image_size

train_dir = os.path.join(args.data, 'train')
train_dataloader, train_datasets = create_dataloader(train_dir, args.image_size, args.batch_size, args.num_workers)

valid_dir = os.path.join(args.data, 'valid')
valid_dataloader, _ = create_dataloader(valid_dir, args.image_size, args.batch_size, args.num_workers)

device = select_device(args.device)

num_classes = len(train_datasets.classes)

In [26]:
feature_extraction = ResNet50(args.embedding_size).to(device)
criterion = ArcFaceLoss(num_classes, args.embedding_size).to(device)

feature_extraction = torch.compile(feature_extraction) # torch 2.0

optimizer = torch.optim.AdamW(
    params=[{'params': feature_extraction.parameters(), 
                'params': criterion.parameters()}],
    lr=args.learning_rate
)

In [13]:
X = torch.randn(32, 3, 224, 224).to(device)

embeddings = feature_extraction(X)




In [15]:
embeddings.shape

torch.Size([32, 512])

In [38]:
y = torch.ones(32, dtype=torch.int64)

In [39]:
y = y.to(device)

In [40]:
y

tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1], device='cuda:0')

In [41]:
logits, loss = criterion(embeddings, y)

In [3]:
import torch

x = torch.tensor([True, False, True])

x.mean(dtype=torch.bool)

RuntimeError: mean(): could not infer output dtype. Optional dtype must be either a floating point or complex dtype. Got: Bool