In [1]:
import torch
from torch.optim import lr_scheduler
import torch.optim as optim
from torch.autograd import Variable

from torchvision import transforms

from trainer import fit
import numpy as np

cuda = torch.cuda.is_available()

# TRIPLET

In [2]:
# Set up data loaders
from datasets import TripletDataset

root_dir = '/home/cuong/AIC20-Track2/AIC20_track2/AIC20_ReID/image_train'
train_csv = 'cls_train.csv'
val_csv = 'cls_val.csv'
label_json = 'train_image_metadata.json'

size = (224, 224)

triplet_train_dataset = TripletDataset(root_dir, train_csv, label_json,
                                       transform = transforms.Compose([
                                        transforms.Resize(size),  
                                        transforms.ToTensor()
                                      ]))
triplet_val_dataset = TripletDataset(root_dir, val_csv, label_json,
                                     transform = transforms.Compose([
                                        transforms.Resize(size),
                                        transforms.ToTensor()
                                      ]))

batch_size = 8
kwargs = {'num_workers': 1, 'pin_memory': True} if cuda else {}
triplet_train_loader = torch.utils.data.DataLoader(triplet_train_dataset, batch_size=batch_size, shuffle=True, **kwargs)
triplet_val_loader = torch.utils.data.DataLoader(triplet_val_dataset, batch_size=batch_size, shuffle=False, **kwargs)

In [3]:
# Set up the network and training parameters
from networks import EfficientNetExtractor, TripletNet
from losses import TripletLoss

margin = 1.
embedding_net = EfficientNetExtractor('b4')
model = TripletNet(embedding_net)

if cuda:
    model.cuda()
loss_fn = TripletLoss(margin)
lr = 1e-3
optimizer = optim.Adam(model.parameters(), lr=lr)
scheduler = lr_scheduler.StepLR(optimizer, 8, gamma=0.1, last_epoch=-1)
n_epochs = 20
log_interval = 100

Loaded pretrained weights for efficientnet-b4


In [None]:
fit(triplet_train_loader, triplet_val_loader, model, loss_fn, optimizer, scheduler, n_epochs, cuda, log_interval)





In [None]:
torch.save(model, 'triplet-b4-200404.pth')