In [1]:
import sys
sys.path.append("..")

In [9]:
import utils
from argparse import ArgumentParser
from DatasetLoader import VGG_dataset
import torchvision.transforms as transforms
from models.EmbedNet import EmbedNet
import torch.optim as optim
from tqdm import tqdm
from loguru import logger
import timm
import os

from pytorch_metric_learning import distances, losses, miners, reducers, testers
from pytorch_metric_learning.utils.accuracy_calculator import AccuracyCalculator
import torch

In [3]:
def train(model, loss_func, mining_func, loader, optimizer):
    loss = 0
    counter = 0
    with tqdm(loader, unit="batch") as tepoch:
        for data, labels in tepoch:
            data, labels = data.to("cuda:1"), labels.to("cuda:1")
            optimizer.zero_grad()
            embeddings = model(data)
            indices_tuple = mining_func(embeddings, labels)
            loss = loss_func(embeddings, labels, indices_tuple)
            loss.backward()
            optimizer.step()
            loss    += loss.detach().cpu().item()
            counter += 1
            tepoch.set_postfix(loss=loss/counter)

    return loss/counter


### convenient function from pytorch-metric-learning ###
def get_all_embeddings(dataset, model):
    tester = testers.BaseTester()
    return tester.get_all_embeddings(dataset, model)

### compute accuracy using AccuracyCalculator from pytorch-metric-learning ###
def test(train_set, test_set, model, accuracy_calculator):
    train_embeddings, train_labels = get_all_embeddings(train_set, model)
    test_embeddings, test_labels = get_all_embeddings(test_set, model)
    train_labels = train_labels.squeeze(1)
    test_labels = test_labels.squeeze(1)
    logger.info("Computing accuracy")
    accuracies = accuracy_calculator.get_accuracy(
        test_embeddings, train_embeddings, test_labels, train_labels, False
    )
    logger.info("Test set accuracy (Precision@1) = {}".format(accuracies["precision_at_1"]))
    return accuracies["precision_at_1"]


In [4]:

config = utils.load_yaml("../experiments/baseline.yaml")
utils.seed_everything(config["seed"])

# ------ DATASET -------
train_transform = transforms.Compose(
        [transforms.ToTensor(),
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]
)

test_transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Resize((224, 224)),
     transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])

train_dataset, valid_dataset = VGG_dataset("../vgg_data/train.csv", train_transform),\
    VGG_dataset("../vgg_data/valid.csv", test_transform)

train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=100, num_workers=5
)
test_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=100)

In [5]:
model_name = "visformer_tiny"
model = timm.create_model(model_name, num_classes = 768) #EmbedNet(config)
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [6]:
### pytorch-metric-learning stuff ###
distance = distances.CosineSimilarity()
reducer = reducers.ThresholdReducer(low=0)
loss_func = losses.TripletMarginLoss(margin=0.3, distance=distance, reducer=reducer)
mining_func = miners.TripletMarginMiner(
    margin=0.3, distance=distance, type_of_triplets="semihard"
)
accuracy_calculator = AccuracyCalculator(include=("precision_at_1",), k=1)

In [10]:
num_epochs = 1
for epoch in range(1, num_epochs + 1):
    model.to("cuda:1")
    logger.info(f'Start of {epoch}/{num_epochs}')
    train_loss = train(model, loss_func, mining_func, train_loader, optimizer)
    logger.info(f'Train loss {train_loss}')
    logger.info(f'End of {epoch}/{num_epochs}')
    
model.cpu()
torch.save({
    "epoch": epoch,
    "state_dict": model.state_dict(),
}, os.path.join("../logs/", "best_pretrained_model.pth"))
logger.info('Training is finished')

2022-11-25 20:44:45.094 | INFO     | __main__:<module>:4 - Start of 1/1
  2%|▉                                             | 99/5028 [00:38<32:12,  2.55batch/s, loss=tensor(0.0025, device='cuda:1', grad_fn=<DivBackward0>)]
2022-11-25 20:45:23.911 | INFO     | __main__:<module>:6 - Train loss 0.0025270201731473207
2022-11-25 20:45:23.912 | INFO     | __main__:<module>:7 - End of 1/1
2022-11-25 20:45:24.001 | INFO     | __main__:<module>:14 - Training is finished
