In [1]:
!pip install easydict
!pip install mlflow

import mlflow

exp_name = "without XBM"
mlflow.start_run(run_name=exp_name)

Collecting easydict
  Downloading easydict-1.9.tar.gz (6.4 kB)
Building wheels for collected packages: easydict
  Building wheel for easydict (setup.py) ... [?25ldone
[?25h  Created wheel for easydict: filename=easydict-1.9-py3-none-any.whl size=6350 sha256=7918bcfa25dd3e7b4bd752cb024b3fbfe6f67c383869fc42c78175d045d20d0d
  Stored in directory: /root/.cache/pip/wheels/88/96/68/c2be18e7406804be2e593e1c37845f2dd20ac2ce1381ce40b0
Successfully built easydict
Installing collected packages: easydict
Successfully installed easydict-1.9
Collecting mlflow
  Downloading mlflow-1.15.0-py3-none-any.whl (14.2 MB)
[K     |████████████████████████████████| 14.2 MB 6.8 MB/s eta 0:00:01     |████████████████████████▌       | 10.9 MB 6.3 MB/s eta 0:00:01     |████████████████████████████▎   | 12.5 MB 6.8 MB/s eta 0:00:01     |█████████████████████████████▌  | 13.0 MB 6.8 MB/s eta 0:00:01
Collecting prometheus-flask-exporter
  Downloading prometheus_flask_exporter-0.18.1.tar.gz (21 kB)
Collecting gunic

<ActiveRun: >

In [2]:
""" dataset config """

from easydict import EasyDict

dataset_cfg = EasyDict()

dataset_cfg.dataset_new_folder = '/kaggle/input/train-test-folders'  # 'E:/datasets/SOP_retrieval'
# dataset_cfg.valid_query_retrieval_sets_path = 'E:/datasets/valid_dataset.pickle'
dataset_cfg.test_query_retrieval_sets_path = '/kaggle/input/train-test-folders/test_dataset.pickle'  # 'E:/datasets/test_dataset.pickle'

dataset_cfg.nb_categories = 12
dataset_cfg.sz_dataset = 120053
dataset_cfg.nb_elems_needed_for_product = 4

# augmentation
dataset_cfg.sz_crop = 224
dataset_cfg.sz_resize = 256
dataset_cfg.mean = [0.485, 0.456, 0.406]
dataset_cfg.std = [0.229, 0.224, 0.225]

dataset_cfg.load_image_folder = False #True

""" eval config """

from easydict import EasyDict

eval_cfg = EasyDict()

eval_cfg.compute_metrics_before_training = False #True
eval_cfg.evaluate_on_train_data = False
eval_cfg.visualize_embeddings = False

""" model config """

from easydict import EasyDict

model_cfg = EasyDict()

model_cfg.pretrained_model = True
model_cfg.embedding_dim = 128
model_cfg.random_seed = 0

""" train config """

from easydict import EasyDict

train_cfg = EasyDict()

train_cfg.checkpoints_dir = '/kaggle/working/'  # 'D:/Users/Admin/PycharmProjects/imageretrievalvaliullina/'
train_cfg.tensorboard_dir = '/kaggle/working/'
train_cfg.device = 'cuda:0'
train_cfg.nb_epochs = 100
train_cfg.continue_training_from_epoch = False
train_cfg.checkpoint_from_epoch = 0
train_cfg.batch_size = 64
train_cfg.lr = 1e-5
train_cfg.weight_decay = 1e-4
train_cfg.save_model = True
train_cfg.log_to_mlflow = True
train_cfg.use_gpu = True
train_cfg.margin = 0.25

train_cfg.train_on_kaggle = False
train_cfg.path_to_image_folders = '/kaggle/input/train-test-folders'

train_cfg.use_memory_bank = False
train_cfg.memory_bank_iter = 0  # 1000

train_cfg.overfit_on_batch = False
train_cfg.overfit_on_batch_iters = 10000

In [3]:
""" Batch Sampler """

import torch
import numpy as np


class BatchSampler(torch.utils.data.sampler.BatchSampler):
    def __init__(self, dataset, batch_size, n=4, m=4, l=4):
        self.n = n  # категорий
        self.m = m  # продуктов
        self.l = l  # изображений
        self.dataset = dataset
        self.batch_size = batch_size

        self.category_labels = np.array(dataset.category_labels)
        self.unique_category_labels = list(set(self.category_labels))

        self.prod_labels = np.array(dataset.targets)
        self.unique_prod_labels = list(set(self.prod_labels))

        self.current_product_label_indices = self.get_all_labels_indices_with_current_label(
            unique_labels=self.unique_prod_labels,
            labels=self.prod_labels)

    @staticmethod
    def get_all_labels_indices_with_current_label(unique_labels, labels):
        out = []
        # для каждого уникального продукта/категории запоминаем все индексы
        for c in unique_labels:
            prod_label_indices = np.where(labels == c)[0]
            np.random.shuffle(prod_label_indices)
            out.append(prod_label_indices)
        return out

    def __len__(self):
        return len(self.dataset) // self.batch_size

    def __iter__(self):
        for _ in range(len(self.dataset) // self.batch_size):
            assert len(self.unique_category_labels) >= self.n, 'not enough categories'
            chosen_categories = np.random.choice(len(self.unique_category_labels), self.n, replace=False)
            chosen_categories_products = [self.dataset.categories_to_products[category] for category in chosen_categories]

            chosen_products = []
            for category_products in chosen_categories_products:
                chosen_products.extend(np.random.choice(category_products, self.m, replace=False))

            chosen_ids = []
            for product in chosen_products:
                cur_product_ids = self.current_product_label_indices[product]
                assert len(cur_product_ids) >= self.l, f'not enough images, {cur_product_ids}, {chosen_categories}'
                chosen_ids.extend(np.random.choice(cur_product_ids, self.l, replace=False))

            assert len(chosen_ids) > 0
            yield chosen_ids

In [4]:
""" dataset """

from torchvision import transforms, datasets
from torch.utils.data import DataLoader
import numpy as np
import pickle
from copy import deepcopy
from collections import Counter

# from data.BatchSampler import BatchSampler
# from configs.dataset_config import cfg as dataset_cfg
# from configs.train_config import cfg as train_config


def make_image_folder(dataset_type):
    """
    Создание Image Folder с соответствующей аугментацией
    """
    if dataset_type == 'train':
        transforms_ = transforms.Compose([
            transforms.RandomResizedCrop(dataset_cfg.sz_crop),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean=dataset_cfg.mean, std=dataset_cfg.std)
        ])
    else:
        transforms_ = transforms.Compose([
            transforms.Resize(dataset_cfg.sz_resize),
            transforms.CenterCrop(dataset_cfg.sz_crop),
            transforms.ToTensor(),
            transforms.Normalize(mean=dataset_cfg.mean, std=dataset_cfg.std)
        ])
    image_folder = datasets.ImageFolder(root=f'/kaggle/input/train-test-folders/{dataset_type}/{dataset_type}', transform=transforms_)

#     if dataset_type == 'train':
#         # get category labels and dict {category_1: list of category_1 products, ...}
#         categories_to_products = {k: [] for k in range(dataset_cfg.nb_categories)}
#         image_folder.category_labels = []
#         for im in image_folder.imgs:
# #             print(im)
#             category_label = int(im[0].split('/')[-3].split('_')[0])
#             image_folder.category_labels.append(category_label)
#             product_label = im[1]
#             print(category_label, product_label)
#             categories_to_products[category_label].append(product_label)
#         image_folder.categories_to_products = {k: np.unique(v) for k, v in categories_to_products.items()}
# #         image_folder.category_labels = [int(im[0].split('/')[-3].split('_')[0]) for im in image_folder.imgs]
# #         image_folder.categories_to_products = categories_to_products

    if dataset_type == 'train':
        # get category labels and dict {category_1: list of category_1 products, ...}
        categories_to_products = {k: [] for k in range(dataset_cfg.nb_categories)}
        product_labels = [im[1] for im in image_folder.imgs]
        product_labels_counter = Counter(product_labels)
        product_labels_to_keep = [k for k, v in product_labels_counter.items() if v >= dataset_cfg.nb_elems_needed_for_product]
        ids_to_keep = np.asarray([i for i, p in enumerate(product_labels) if p in product_labels_to_keep])
        product_labels = list(np.asarray(product_labels)[ids_to_keep])
        image_folder.imgs = list(np.asarray(image_folder.imgs)[ids_to_keep])

        for i, im in enumerate(image_folder.imgs):
            category_label = int(im[0].split('/')[-3].split('_')[0])
            assert int(im[1]) == product_labels[i]
            categories_to_products[category_label].append(product_labels[i])
        categories_to_products = {k: np.unique(v) for k, v in categories_to_products.items()}
        image_folder.category_labels = [int(im[0].split('/')[-3].split('_')[0]) for im in image_folder.imgs]
        image_folder.categories_to_products = categories_to_products

    # для обучения на Kaggle
    with open(f'image_folder_{dataset_type}', 'wb') as f:
        pickle.dump(image_folder, f)
    return image_folder


def get_dataloader(dataset_type):
    print(f'Getting {dataset_type} dataloader..')
    with open(train_cfg.path_to_image_folders + f'/image_folder_{dataset_type}', 'rb') as f:
        image_folder = pickle.load(f)
    
    if dataset_type == 'train':
        batch_sampler = BatchSampler(image_folder, train_cfg.batch_size) if dataset_type == 'train' else None
        dataloader = DataLoader(image_folder, batch_sampler=batch_sampler)
        return dataloader

    query_image_folder = deepcopy(image_folder)
    retrieval_image_folder = deepcopy(image_folder)
    imgs = get_query_and_retrieval_sets(image_folder, dataset_type, type='query')
    query_image_folder.imgs = imgs
    query_image_folder.samples = imgs
    query_dataloader = DataLoader(query_image_folder, batch_size=train_cfg.batch_size)
    print(f'got query dataloader, {len(query_dataloader)}')

    imgs = get_query_and_retrieval_sets(image_folder, dataset_type, type='retrieval')
    retrieval_image_folder.imgs = imgs
    retrieval_image_folder.samples = imgs
    retrieval_dataloader = DataLoader(retrieval_image_folder, batch_size=train_cfg.batch_size)
    print(f'got retrieval dataloader, {len(retrieval_dataloader)}')
    return (query_dataloader, retrieval_dataloader)

In [5]:
""" triplet loss """


import torch
import torch.nn as nn
import numpy as np
# from configs.train_config import cfg as train_cfg


class TripletLoss(nn.Module):
    def __init__(self, margin):
        super(TripletLoss, self).__init__()
        self.margin = margin

    @staticmethod
    def get_distance(x1, x2):
        mm = torch.mm(x1, x2.t())
        dist = mm.diag().view((mm.diag().size()[0], 1))
        dist = dist.expand_as(mm)
        dist_ = dist + dist.t()
        dist_ = (dist_ - 2 * mm).clamp(min=0)
        return dist_.clamp(min=1e-4).sqrt()

    def forward(self, inputs_col, targets_col, inputs_row, targets_row):
        targets_col, targets_row = torch.tensor(targets_col), torch.tensor(targets_row)
        b_size = inputs_col.size(0)
        dists = self.get_distance(inputs_col, inputs_row)

        p0 = targets_col.clone().view(1, targets_col.size()[0]).expand_as(dists)
        p1 = targets_row.view(targets_row.size()[0], 1).expand_as(dists)

        positives_ids = torch.eq(p0, p1).to(dtype=torch.uint8) - (torch.eye(len(dists)))#.to(self.device)
        negatives_ids = (positives_ids == 0).to(dtype=torch.uint8) - (torch.eye(len(dists)))

        losses_ = []
        for i in range(b_size):
            pos_ids_ = np.atleast_1d(positives_ids[i].nonzero().squeeze().cpu().numpy())
            neg_ids_ = np.atleast_1d(negatives_ids[i].nonzero().squeeze().cpu().numpy())

            pos_dists = dists[i, pos_ids_]
            neg_dists = dists[i, neg_ids_]

            pos_pair_expanded = pos_dists.expand(len(neg_ids_), len(pos_ids_)).T
            neg_pair_expanded = neg_dists.expand(len(pos_ids_), len(neg_ids_))
            all_possible_ids = (pos_pair_expanded + self.margin > neg_pair_expanded).to(dtype=torch.uint8).nonzero().squeeze().cpu().numpy()
            if len(all_possible_ids) > 0:
                pos_idxs, neg_idxs = (all_possible_ids[:, 0], all_possible_ids[:, 1]) if len(all_possible_ids.shape) > 1 \
                    else (all_possible_ids[0], all_possible_ids[1])

                pos_dists_final = pos_dists[pos_idxs]
                neg_dists_final = neg_dists[neg_idxs]

                if isinstance(pos_idxs, np.int64) or isinstance(neg_idxs, np.int64):# or len(pos_idxs) < 1 or len(neg_idxs) < 1:
                    pos_dists_final = pos_dists
                    neg_ids = np.random.choice(len(neg_dists), len(pos_dists))
                    neg_dists_final = neg_dists[neg_ids]
            else:
                pos_dists_final = pos_dists
                neg_ids = np.random.choice(len(neg_dists), len(pos_dists))
                neg_dists_final = neg_dists[neg_ids]

            loss = torch.relu(pos_dists_final - neg_dists_final + self.margin)
            if len(loss) > 0:
                losses_.extend(loss)

        loss = torch.stack(losses_).mean()
        return loss

In [6]:
""" memory bank """


import torch
import numpy as np


class MemoryBank(object):
    def __init__(self, embeddings_size, size):
        self.embeddings = torch.tensor(np.zeros((size, embeddings_size))).float()#.cuda()
        self.product_labels = torch.tensor(np.zeros(size)).float()#.cuda()
        self.size = size
        self.cur_size = 0

    def update(self, embeddings, product_labels):
        q_size = len(product_labels)
        if self.cur_size + q_size > self.size:
            self.embeddings[-q_size:] = embeddings
            self.product_labels[-q_size:] = product_labels
            self.cur_size = 0
        else:
            self.embeddings[self.cur_size: self.cur_size + q_size] = embeddings
            self.product_labels[self.cur_size: self.cur_size + q_size] = product_labels
            self.cur_size += q_size

    def get_embeddings(self):
        if self.product_labels[-1].item() != 0:
            return self.embeddings, self.product_labels
        else:
            return self.embeddings[:self.cur_size], self.product_labels[:self.cur_size]

In [7]:
""" resnet """


from torch.nn import AvgPool2d, Dropout, Linear
import torch.nn.functional as f
import numpy as np
import torchvision
import torch


def get_resnet50():
    resnet50 = torchvision.models.resnet50(pretrained=True)
    resnet50.features = torch.nn.Sequential(resnet50.conv1, resnet50.bn1, resnet50.relu, resnet50.maxpool,
                                            resnet50.layer1,
                                            resnet50.layer2, resnet50.layer3, resnet50.layer4)
    resnet50.sz_features_output = 2048
    for module in filter(lambda m: type(m) == torch.nn.BatchNorm2d, resnet50.modules()):
        module.eval()
        module.train = lambda _: None
    return resnet50


def get_params_dict(model, emb_module_name):
    dict_ = {k: [] for k in ['backbone', *emb_module_name]}
    for name, param in model.named_parameters():
        name = name.split('.')[0]
        if name not in emb_module_name:
            dict_['backbone'] += [param]
        else:
            dict_[name] += [param]
    nb_total = len(list(model.parameters()))
    nb_dict_params = sum([len(dict_[d]) for d in dict_])
    assert nb_total == nb_dict_params
    return dict_


def get_embedding(model, cfg):
    model.features_pooling = AvgPool2d(7, stride=1, padding=0, ceil_mode=True, count_include_pad=True)
    model.features_dropout = Dropout(0.01)
    torch.random.manual_seed(cfg.random_seed)
    np.random.seed(cfg.random_seed)

    model.embedding = Linear(model.sz_features_output, cfg.embedding_dim).to(list(model.parameters())[0].device)
    model.parameters_dict = get_params_dict(model=model, emb_module_name=['embedding'])

    def forward(x):
        x = model.features(x)
        x = model.features_pooling(x)
        bs = x.size(0)
        x = x.view(bs, -1)
        x = model.embedding(x)
        x = f.normalize(x, p=2, dim=1)
        return x

    model.forward = forward


def get_model(cfg):
    resnet50 = get_resnet50()
    get_embedding(resnet50, cfg)
    return resnet50


In [8]:
""" data utils """


import pickle

# from configs.dataset_config import cfg as dataset_cfg


def get_query_and_retrieval_sets(image_folder, dataset_type, type):

    query_retrieval_sets_path = dataset_cfg.valid_query_retrieval_sets_path if dataset_type == 'valid' \
        else dataset_cfg.test_query_retrieval_sets_path

    with open(query_retrieval_sets_path, 'rb') as f:
        query_retrieval_sets = pickle.load(f)

    set_paths = query_retrieval_sets[type]
    set_paths_and_labels = []

    for i, im in enumerate(image_folder.imgs):
        split = im[0].split("/")
        img_path = '/'.join([im_ for im_ in split[-3:]])
        if img_path in set_paths:
            set_paths_and_labels.append(('/'.join([im_ for im_ in split[:-3]]) + "/" + img_path, int(im[1])))

    return set_paths_and_labels


In [9]:
""" debug utils """


import torch
import numpy as np

# from configs.train_config import cfg as train_cfg
# from utils.eval_utils import get_nearest_neighbors


def overfit_on_batch(dl, criterion, optimizer, model):
    batch = next(dl)
    images, labels = batch[0], batch[1]
    for iter_ in range(train_cfg.overfit_on_batch_iters):
        optimizer.zero_grad()
        model, images = model.cuda(), images.cuda()
        embeddings = model(images)
        embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
        loss = criterion(embeddings, labels, embeddings, labels)

        assert not torch.isnan(loss).any(), 'loss is nan'
        loss.backward()
        optimizer.step()
        top_1s = []
        for i, emb in enumerate(embeddings):
            neighbors = get_nearest_neighbors(torch.stack([em for em_i, em in enumerate(embeddings) if em_i != i]).detach().cpu().numpy().reshape(-1, 128),
                                              [l.item() for j, l in enumerate(labels) if j != i],
                                              emb.detach().cpu().numpy().reshape(-1, 128), k=1)
            top_1 = 1 if labels[i].item() == neighbors else 0
            top_1s.append(top_1)

        print(f'iter: {iter_}, loss: {loss.item()}, top 1: {np.mean(top_1s)}')

In [10]:
import pandas as pd
import pickle

In [11]:
""" eval utils """

import numpy as np
import torch
from sklearn.neighbors import NearestNeighbors


def get_top_k(product_labels, neighbours_ids):
    return np.mean(product_labels == neighbours_ids) * 100


def get_nearest_neighbors(retrieval_embeddings, query_embeddings, k=1):
    index = NearestNeighbors(n_neighbors=k)
    index.fit(retrieval_embeddings)
    neighbors = index.kneighbors(query_embeddings)[1]
    return neighbors


def evaluate(model, query_dl, retrieval_dl):
    query_embeddings, query_labels = compute_embeddings(model, query_dl)
    retrieval_embeddings, retrieval_labels = compute_embeddings(model, retrieval_dl)

    neighbors = get_nearest_neighbors(retrieval_embeddings, query_embeddings)
    n_neighbors = np.array([retrieval_labels[n[0]] for n in neighbors])
    top_1 = get_top_k(query_labels, n_neighbors)
    return top_1


def compute_embeddings(model, dl):
    model = model.cuda()
    print('Computing embeddings..')
    dl_len = len(dl)
    print(f'len: {dl_len}')
    dl = iter(dl)
    all_embeddings, all_labels = [], []
    for i, batch in enumerate(dl):
        if i % 50 == 0:
            print(f'iter: {i}/{dl_len}')
        x, y = batch[0], batch[1]
        embeddings = model(x.cuda())
#         embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
        all_labels.extend(y.data.cpu().numpy())
        all_embeddings.extend(embeddings.data.cpu().numpy())
    return all_embeddings, all_labels


def visualize_embeddings(embeddings_writer, model, dataloader, num_batches=5, tag=''):
    all_embeddings, all_labels = [], []
    dl = iter(dataloader)
    for i in range(num_batches):
        batch = next(dl)
        images, labels = batch[0], batch[1]
        embeddings = model(images)
        embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
        all_embeddings.extend(embeddings)
        all_labels.extend(labels)

    embeddings_writer.add_embedding(np.asarray([e.detach().numpy() for e in all_embeddings]),
                                    metadata=np.asarray([a.item() for a in all_labels]), tag=tag)

In [12]:
""" train utils """

import torch

# from configs.train_config import cfg as train_cfg
# from losses.triplet_loss import TripletLoss


def get_optimizer(model):
    opt = torch.optim.Adam([
        {'params': model.parameters(), 'lr': train_cfg.lr, 'weight_decay': train_cfg.weight_decay}])
    return opt


def get_criterion():
    criterion = TripletLoss(train_cfg.margin).cuda()
    return criterion


def make_training_step(batch, criterion, optimizer, model, global_step, memory_bank):
    optimizer.zero_grad()
    images, labels = batch[0], batch[1]
    model, images = model.cuda(), images.cuda()
    embeddings = model(images)
#     embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
    loss = criterion(embeddings, labels, embeddings, labels)

    if train_cfg.use_memory_bank and global_step > train_cfg.memory_bank_iter:
        print('sampling from bank')
        mb_enbeddings, mb_features = memory_bank.get_embeddings()
        mb_loss = criterion(embeddings, labels, mb_enbeddings, mb_features)
        loss = loss + mb_loss

    # loss = loss + mb_loss  # , embeddings, labels
    assert not torch.isnan(loss).any(), 'loss is nan'
    loss.backward()
    optimizer.step()
    return loss.item()


In [None]:
""" main """

import torch
import time
import numpy as np
from tensorboardX import SummaryWriter
import tarfile
import warnings

# from data.dataset import get_dataloader
# from models.resnet import get_model
# from utils.train_utils import get_optimizer, get_criterion, make_training_step
# from utils.eval_utils import visualize_embeddings
# from utils.eval_utils import evaluate
# from utils.debug_utils import overfit_on_batch
# from configs.train_config import cfg as train_cfg
# from configs.eval_config import cfg as eval_cfg
# from configs.dataset_config import cfg as dataset_cfg
# from configs.model_config import cfg as model_cfg
# from models.memory_bank import MemoryBank
warnings.filterwarnings('ignore', category=UserWarning)


def train():
    train_dataloader = get_dataloader(dataset_type='train')
    query_test_dataloader, retrieval_test_dataloader = get_dataloader(dataset_type='test')

    model = get_model(model_cfg)
    optimizer = get_optimizer(model)
    criterion = get_criterion()

    if train_cfg.use_memory_bank:
        memory_bank = MemoryBank(embeddings_size=128, size=len(train_dataloader.dataset))
    else:
        memory_bank = None

    start_epoch, global_step = 0, -1

    # loading saved checkpoints if needed
    if train_cfg.continue_training_from_epoch:
        try:
            checkpoint = torch.load(train_cfg.checkpoints_dir + f'checkpoint_{train_cfg.checkpoint_from_epoch}.pth')
            model.load_state_dict(checkpoint['model'])
            start_epoch = checkpoint['epoch'] + 1
            global_step = checkpoint['global_step'] + 1
            optimizer.load_state_dict(checkpoint['opt'])
        except FileNotFoundError:
            print('Checkpoint not found')

    # evaluate before training if needed
    if eval_cfg.compute_metrics_before_training:
        model.eval()

        print(f'Evaluating on test data..')
        top_1_test_accuracy = evaluate(model, query_test_dataloader, retrieval_test_dataloader)
        print(f'Top-1 test accuracy: {top_1_test_accuracy}')

        # visualize embeddings with tensorboard
        if eval_cfg.visualize_embeddings:
            embeddings_writer = SummaryWriter(log_dir=train_cfg.tensorboard_dir + f'/embeddings_vis/epoch_-1')
            visualize_embeddings(embeddings_writer, model, train_dataloader, num_batches=5, tag='training_batch_before_training')
        model.train()

    if train_cfg.overfit_on_batch:
        dl = iter(train_dataloader)
        overfit_on_batch(dl, criterion, optimizer, model)

    # main loop
    for e in range(start_epoch, train_cfg.nb_epochs):
        print(f'Epoch: {e}/{train_cfg.nb_epochs}')
        epoch_start_time = time.time()
        embeddings_writer = SummaryWriter(log_dir=train_cfg.tensorboard_dir + f'/embeddings_vis/epoch_{e}')

        model.train()
        print('Starting training..')
        loss_list = []
        len_ = len(train_dataloader)
        
        for i, batch in enumerate(train_dataloader):
            loss = make_training_step(batch, criterion, optimizer, model, global_step, memory_bank)
            loss_list.append(loss)
            
            mlflow.log_metric('loss', loss, global_step)
            global_step += 1

            if global_step % 50 == 0:
                if global_step != 0:
                    loss_mean = np.mean(loss_list[-50:])
                else:
                    loss_mean = loss
                print(f'global step: {global_step}/{len_}, loss: {loss_mean}')

        # save checkpoints
        if train_cfg.save_model:
            print('Saving current model...')
            state = {
                'model': model.state_dict(),
                'epoch': e,
                'global_step': global_step,
                'opt': optimizer.state_dict()
                }
            torch.save(state, (train_cfg.checkpoints_dir + f'checkpoint_{e}.pth'))

        # evaluate model
        model.eval()
        print(f'Evaluating on test data..')
        top_1_test_accuracy = evaluate(model, query_test_dataloader, retrieval_test_dataloader)
        print(f'Top-1 test accuracy: {top_1_test_accuracy}')
        mlflow.log_metric('top-1 accuracy', top_1_test_accuracy, e)
        model.train()

        # visualize embeddings with tensorboard
        if eval_cfg.visualize_embeddings:
            visualize_embeddings(embeddings_writer, model, train_dataloader, num_batches=5,
                                 tag='training_batch_during_training')

        print(f'epoch training time: {round((time.time() - epoch_start_time) / 60, 3)} min')


if __name__ == '__main__':
    # if train_cfg.use_gpu:
    #     torch.cuda.set_device(0)
    np.random.seed(0)
    torch.manual_seed(0)

    total_training_start_time = time.time()
    train()
    print(f'Training time: {round((time.time() - total_training_start_time) / 60, 3)} min')

Getting train dataloader..
Getting test dataloader..
got query dataloader, 72
got retrieval dataloader, 305


Downloading: "https://download.pytorch.org/models/resnet50-19c8e357.pth" to /root/.cache/torch/hub/checkpoints/resnet50-19c8e357.pth


  0%|          | 0.00/97.8M [00:00<?, ?B/s]

Epoch: 0/100
Starting training..
global step: 0/1124, loss: 0.15906189382076263
global step: 50/1124, loss: 0.16462813913822175
global step: 100/1124, loss: 0.16671728640794753
global step: 150/1124, loss: 0.16550217539072037
global step: 200/1124, loss: 0.16345396831631662
global step: 250/1124, loss: 0.1599628384411335
global step: 300/1124, loss: 0.16209630742669107
global step: 350/1124, loss: 0.16195931315422057
global step: 400/1124, loss: 0.16366865307092668
global step: 450/1124, loss: 0.16001058861613274
global step: 500/1124, loss: 0.1614264416694641
global step: 550/1124, loss: 0.16144887536764144


In [None]:
# mlflow.log_metric('top-1 accuracy', 1, 1)