In [1]:
import matplotlib.pyplot as plt
import os
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from models.resnet import resnet50
from my_dataset import MyDataset
from my_lossfunc import JointLoss, MultilabelLoss, DiscriminativeLoss
from my_transform import data_transforms
from PIL import Image
from scipy.spatial.distance import pdist, cdist  # 一集合点距, 两集合点距
from torch.utils.data import DataLoader
from tqdm import tnrange
from tqdm import tqdm_notebook as tqdm
from utils import *

os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print('Device:', DEVICE)


BASE = '/home/zengrui/datasets'
DUKE_DIR_TRAIN = f'{BASE}/ReID_Duke/bounding_box_train'
DUKE_DIR_TEST = f'{BASE}/ReID_Duke/bounding_box_test'
DUKE_DIR_QUERY = f'{BASE}/ReID_Duke/query'
DUKE_IMG_AMOUNT = 16522
DUKE_ID_AMOUNT = 702
MARKET_DIR_TRAIN = f'{BASE}/ReID_Market/bounding_box_train'
MARKET_DIR_TEST = f'{BASE}/ReID_Market/bounding_box_test'
MARKET_DIR_QUERY = f'{BASE}/ReID_Market/query'
MARKET_IMG_AMOUNT = 12936
MARKET_ID_AMOUNT = 751

SOURCE_DIR_TRAIN = DUKE_DIR_TRAIN
TARGET_DIR_TRAIN = MARKET_DIR_TRAIN
TARGET_DIR_GALLERY = MARKET_DIR_TEST
TARGET_DIR_PROBE = MARKET_DIR_QUERY
SOURCE_ID_AMOUNT = DUKE_ID_AMOUNT
TARGET_IMG_AMOUNT = MARKET_IMG_AMOUNT
ML_PATH = 'data/ml_Market.dat'
PRETRAIN_PATH = 'data/pretrained_weight.pkl'
PRETRAIN_OUT_PATH = 'data/pretrained_weight_{}.pkl'

BATCH_SIZE = 168
EPOCH = 80
LR = 0.1

BETA = 0.2
LAMB1 = 2e-4
LAMB2 = 50
MARGIN = 1
SCALA_CE = 30

Device: cuda


# Prepare Work

In [2]:
# data loader
data_loader = {
    'source': DataLoader(
        dataset=MyDataset(SOURCE_DIR_TRAIN, 
                          transform=data_transforms('train'),
                          require_view=False,
                          encode_label=True),
        batch_size=BATCH_SIZE,
        shuffle=True,
    ),
    'target': DataLoader(
        dataset=MyDataset(TARGET_DIR_TRAIN,
                          transform=data_transforms('train'),
                          require_view=True,
                          encode_label=True),
        batch_size=BATCH_SIZE,
        shuffle=True,
    ),
    'gallery': DataLoader(
        dataset=MyDataset(TARGET_DIR_GALLERY,
                          transform=data_transforms('val'),
                          require_view=True),
        batch_size=BATCH_SIZE,
        shuffle=False,
    ),
    'probe': DataLoader(
        dataset=MyDataset(TARGET_DIR_PROBE,
                          transform=data_transforms('val'),
                          require_view=True),
        batch_size=BATCH_SIZE,
        shuffle=False,
    ),
}
print('data_loader: ok.')

data_loader: ok.


# Trainer

In [3]:
%matplotlib inline
class Trainer(object):
    def __init__(self):
        
        # 网络
        self.net = resnet50(pretrained=False, 
                            num_classes=SOURCE_ID_AMOUNT)
        self.net = nn.DataParallel(self.net).to(DEVICE)
        if PRETRAIN_PATH is not None and os.path.exists(PRETRAIN_PATH):
            self.net.load_state_dict(torch.load(PRETRAIN_PATH))
            print('Pretrained model loaded.')
        else:
            print('Pretrained model not found. Train from scratch.')
            
        # 损失
        self.mdl_loss = DiscriminativeLoss(0.001).to(DEVICE)
        self.al_loss = nn.CrossEntropyLoss().to(DEVICE)
        self.rj_loss = JointLoss(MARGIN).to(DEVICE)  # lack 1 param
        self.cml_loss = MultilabelLoss(BATCH_SIZE).to(DEVICE)
        
        # 优化器
        self.optimizer = torch.optim.SGD(
            self.net.parameters(), lr=LR, momentum=0.9)
        self.lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
            self.optimizer, milestones=[int(EPOCH / 8 * 5), int(EPOCH / 8 * 7)])
        
        # 存储器
        self.ml_mem = torch.zeros(TARGET_IMG_AMOUNT, SOURCE_ID_AMOUNT)
        self.inited = self.ml_mem.sum(dim=1) != 0
    
    def train(self):
        '''进行一次完整训练.'''
        print('Training start. Epochs: %d' % EPOCH)
        self.net.train()
        for epoch in tnrange(EPOCH):
            self.train_epoch(epoch)
    
    def train_epoch(self, epoch):
        '''训练一个epoch.'''
        stats = ('total', 'src', 'st', 'ml', 'tgt')
        running_loss = {stat: AverageMeter() for stat in stats}
        
        if not self.mdl_loss.initialized:
            self.init_losses(data_loader['target'])
            self.net.train()
        
        with tqdm(total=len(data_loader['source'])) as pbar:
            tgt_iter = iter(data_loader['target'])
            for step, (ax, ay) in enumerate(data_loader['source']):
                # a - source, b - target
                ax = ax.to(DEVICE)
                ay = ay.to(DEVICE)
                try:
                    b = next(tgt_iter)
                except StopIteration:
                    tgt_iter = iter(data_loader['target'])
                    b = next(tgt_iter)
                (bx, by, b_view, b_idx) = b
                bx, by, b_view = bx.to(DEVICE), by.to(DEVICE), b_view.to(DEVICE)

                a_f, a_sim, _ = self.net(ax)
                b_f, b_sim, _ = self.net(bx)

                loss = {stat: torch.Tensor([0]).to(DEVICE) 
                        for stat in stats}
                
                loss['src'] = self.al_loss(a_sim * SCALA_CE, ay)  # 有监督 交叉熵
                
                agents = self.net.module.fc.weight.renorm(2, 0, 1e-5).mul(1e5)  # 归一化 shape=(702, 2048)
                loss['st'] = self.rj_loss(agents.detach(), a_f, a_sim.detach(), ay, 
                                          b_f, b_sim.detach())
                
                with torch.no_grad():
                    ml = F.softmax(b_f.mm(agents.t_() * SCALA_CE), dim=1)  # t_(): 转置并inplace
                loss['ml'] = self.cml_loss(torch.log(ml), b_view)
    
                if epoch > 0:  # 为什么第一轮不算 mdl_loss 呢
                    ml_cpu = ml.detach().cpu()
                    is_inited_batch = self.inited[b_idx]
                    inited_idx = b_idx[is_inited_batch]
                    uninited_idx = b_idx[~is_inited_batch]
                    self.ml_mem[uninited_idx] = ml_cpu[~is_inited_batch]  # 0标签满更新
                    self.inited[uninited_idx] = True
                    self.ml_mem[inited_idx] = 0.9 * self.ml_mem[inited_idx] \
                                            + 0.1 * ml_cpu[is_inited_batch]  # 非空标签小更新
                    loss['tgt'] = self.mdl_loss(b_f, self.ml_mem[b_idx], by)

                self.optimizer.zero_grad()
                loss['total'] = loss['tgt'] + LAMB1 * loss['ml'] \
                              + LAMB2 * (loss['src'] + BETA * loss['st'])
                loss['total'].backward()
                self.optimizer.step()

                for stat in stats:
                    loss_cpu = float(loss[stat].data.cpu().numpy())
                    running_loss[stat].update(loss_cpu)
                pbar.set_description('Loss: %.4f' % running_loss['total'].avg)
                pbar.update()

            self.lr_scheduler.step()
            pbar.set_description('Progress:')
            print('Epoch: %d, Loss: %.4f (%.4f + %.4f + %.4f + %.4f)' 
                  % (epoch, 
                     running_loss['total'].avg, 
                     running_loss['src'].avg * LAMB2, 
                     running_loss['st'].avg * LAMB2 * BETA, 
                     running_loss['ml'].avg * LAMB1, 
                     running_loss['tgt'].avg))
            
    def eval_performance(self, gallery_loader, probe_loader):
        stats = ('r1', 'r5', 'r10', 'MAP')
        val = {stat: AverageMeter() for stat in stats}
        self.net.eval()
        
        gallery_f, gallery_y, gallery_views = extract_features(
            gallery_loader, self.net, index_feature=0)
        probe_f, probe_y, probe_views = extract_features(
            probe_loader, self.net, index_feature=0)
        dist = cdist(gallery_f, probe_f, metric='cosine')  # 实际是 1-cos ∈ [0, 2], 越小越相似
        CMC, MAP, example = eval_cmc_map(
            dist, gallery_y, probe_y, gallery_views, probe_views, 
            ignore_MAP=False, show_example=True)
#         CMC, MAP, example = eval_cmc_map(
#             dist, gallery_y, probe_y, 
#             ignore_MAP=False, show_example=True)
        r1, r5, r10 = CMC[0], CMC[4], CMC[9]
        self.r1, self.r5, self.r10, self.MAP = r1, r5, r10, MAP
        
        for stat in stats:
            val[stat].update(locals()[stat].item(), BATCH_SIZE)
            
        # 显示rank多图
#         if show_img_result:
#             plt.subplot(1, 11, 1)
#             plt.title('Query')
#             plt.imshow(Image.open(file.path, 'r'))

#             for i in range(10):
#                 plt.subplot(1, 11, i + 2)
#                 plt.imshow(Image.open(DATA_DIR_TEST +
#                                       '\\' + sort_list[i][0], 'r'))

#             plt.show()
            
        return val
            
    def init_losses(self, tgt_loader):
        '''训练前初始化loss参数.'''
        print('#' * 8, 'Initializing losses', '#' * 8)
        if os.path.isfile(ML_PATH):
            (ml, views, pairwise_agreements) = torch.load(ML_PATH)
            print('Ml loaded.')
        else:
            print('Ml not found, computing...')
            sim, _, views = extract_features(
                data_loader['target'], self.net, index_feature=1, return_numpy=False)
            ml = F.softmax(sim * SCALA_CE, dim=1)
            ml_np = ml.cpu().numpy()
            pairwise_agreements = 1 - pdist(ml_np, 'minkowski', p=1) / 2  # 相似比较特征
            print('Ml saving to %s...' % ML_PATH)
            torch.save((ml, views, pairwise_agreements), ML_PATH)

        self.cml_loss.init_centers(torch.log(ml), views)
        print('Cml_loss centers inited.')
        self.mdl_loss.init_threshold(pairwise_agreements)
        print('Mdl_loss threshold inited.')
        print('#' * 8, 'OK', '#' * 8)
        
    def save_model(self, cover=False):
        '''
        保存当前模型net的参数.
        
        :param cover: True覆盖默认文件, False新增带时间戳文件
        '''
        if cover:
            torch.save(self.net.state_dict(), PRETRAIN_PATH)
        else:
            torch.save(self.net.state_dict(), PRETRAIN_OUT_PATH.format(time.time()))
        print('Model weight saved.')

# Train

In [4]:
trainer = Trainer()
trainer.train()

Pretrained model loaded.
Training start. Epochs: 80


HBox(children=(IntProgress(value=0, max=80), HTML(value='')))

######## Initializing losses ########
Ml loaded.
Cml_loss centers inited.
Mdl_loss threshold inited.
######## OK ########


HBox(children=(IntProgress(value=0, max=99), HTML(value='')))

Epoch: 0, Loss: 10.6842 (3.3365 + 6.2431 + 1.1046 + 0.0000)



HBox(children=(IntProgress(value=0, max=99), HTML(value='')))

Epoch: 1, Loss: 10.6556 (3.6345 + 6.2713 + 0.1417 + 0.6080)



HBox(children=(IntProgress(value=0, max=99), HTML(value='')))

Epoch: 2, Loss: 10.3961 (3.4397 + 6.2342 + 0.1068 + 0.6154)



HBox(children=(IntProgress(value=0, max=99), HTML(value='')))

Epoch: 3, Loss: 10.5488 (3.6624 + 6.1746 + 0.1060 + 0.6058)



HBox(children=(IntProgress(value=0, max=99), HTML(value='')))

Epoch: 4, Loss: 10.2710 (3.3356 + 6.2030 + 0.1071 + 0.6252)



HBox(children=(IntProgress(value=0, max=99), HTML(value='')))

Epoch: 5, Loss: 10.1738 (3.3412 + 6.0904 + 0.1057 + 0.6365)



HBox(children=(IntProgress(value=0, max=99), HTML(value='')))

Epoch: 6, Loss: 10.0553 (3.2058 + 6.1078 + 0.1047 + 0.6371)



HBox(children=(IntProgress(value=0, max=99), HTML(value='')))

Epoch: 7, Loss: 10.0538 (3.2937 + 6.0298 + 0.1046 + 0.6257)



HBox(children=(IntProgress(value=0, max=99), HTML(value='')))

Epoch: 8, Loss: 10.1289 (3.3386 + 6.0359 + 0.1063 + 0.6481)



HBox(children=(IntProgress(value=0, max=99), HTML(value='')))

Epoch: 9, Loss: 9.7106 (2.9927 + 5.9714 + 0.1039 + 0.6426)



HBox(children=(IntProgress(value=0, max=99), HTML(value='')))

Epoch: 10, Loss: 9.7775 (3.0708 + 5.9710 + 0.1036 + 0.6321)



HBox(children=(IntProgress(value=0, max=99), HTML(value='')))

Epoch: 11, Loss: 9.8388 (3.2023 + 5.8830 + 0.1056 + 0.6479)



HBox(children=(IntProgress(value=0, max=99), HTML(value='')))

Epoch: 12, Loss: 9.8276 (3.1011 + 5.9872 + 0.1043 + 0.6350)



HBox(children=(IntProgress(value=0, max=99), HTML(value='')))

Epoch: 13, Loss: 9.7024 (3.0172 + 5.9475 + 0.1058 + 0.6320)



HBox(children=(IntProgress(value=0, max=99), HTML(value='')))

Epoch: 14, Loss: 9.9612 (3.2571 + 5.9621 + 0.1033 + 0.6387)



HBox(children=(IntProgress(value=0, max=99), HTML(value='')))

Epoch: 15, Loss: 9.8938 (3.2324 + 5.9165 + 0.1017 + 0.6432)



HBox(children=(IntProgress(value=0, max=99), HTML(value='')))

Epoch: 16, Loss: 9.7829 (3.1485 + 5.8871 + 0.1036 + 0.6437)



HBox(children=(IntProgress(value=0, max=99), HTML(value='')))

Epoch: 17, Loss: 9.6673 (3.0295 + 5.8913 + 0.1029 + 0.6436)



HBox(children=(IntProgress(value=0, max=99), HTML(value='')))

Epoch: 18, Loss: 9.4638 (2.9967 + 5.7319 + 0.1020 + 0.6332)



HBox(children=(IntProgress(value=0, max=99), HTML(value='')))

Epoch: 19, Loss: 9.4939 (2.9018 + 5.8379 + 0.1050 + 0.6492)



HBox(children=(IntProgress(value=0, max=99), HTML(value='')))

Epoch: 20, Loss: 9.5467 (3.0098 + 5.7937 + 0.1042 + 0.6390)



HBox(children=(IntProgress(value=0, max=99), HTML(value='')))

Epoch: 21, Loss: 9.0816 (2.6123 + 5.7358 + 0.1025 + 0.6309)



HBox(children=(IntProgress(value=0, max=99), HTML(value='')))

Epoch: 22, Loss: 9.1971 (2.7024 + 5.7545 + 0.1018 + 0.6384)



HBox(children=(IntProgress(value=0, max=99), HTML(value='')))

Epoch: 23, Loss: 9.0665 (2.6658 + 5.6596 + 0.1035 + 0.6376)



HBox(children=(IntProgress(value=0, max=99), HTML(value='')))

Epoch: 24, Loss: 9.4058 (2.9340 + 5.7254 + 0.1034 + 0.6430)



HBox(children=(IntProgress(value=0, max=99), HTML(value='')))

Epoch: 25, Loss: 9.3238 (2.8530 + 5.7286 + 0.1020 + 0.6402)



HBox(children=(IntProgress(value=0, max=99), HTML(value='')))

Epoch: 26, Loss: 8.9747 (2.5523 + 5.6737 + 0.1034 + 0.6455)



HBox(children=(IntProgress(value=0, max=99), HTML(value='')))

Epoch: 27, Loss: 8.9781 (2.5918 + 5.6329 + 0.1040 + 0.6494)



HBox(children=(IntProgress(value=0, max=99), HTML(value='')))

Epoch: 28, Loss: 8.6662 (2.3333 + 5.5996 + 0.1015 + 0.6317)



HBox(children=(IntProgress(value=0, max=99), HTML(value='')))

Epoch: 29, Loss: 9.1265 (2.7620 + 5.6297 + 0.1025 + 0.6322)



HBox(children=(IntProgress(value=0, max=99), HTML(value='')))

Epoch: 30, Loss: 8.9457 (2.6579 + 5.5435 + 0.1028 + 0.6416)



HBox(children=(IntProgress(value=0, max=99), HTML(value='')))

Epoch: 31, Loss: 9.0799 (2.6917 + 5.6459 + 0.1030 + 0.6394)



HBox(children=(IntProgress(value=0, max=99), HTML(value='')))

Epoch: 32, Loss: 8.7476 (2.4257 + 5.5796 + 0.1014 + 0.6409)



HBox(children=(IntProgress(value=0, max=99), HTML(value='')))

Epoch: 33, Loss: 8.7122 (2.4248 + 5.5585 + 0.0998 + 0.6291)



HBox(children=(IntProgress(value=0, max=99), HTML(value='')))

Epoch: 34, Loss: 8.4906 (2.2539 + 5.4978 + 0.1027 + 0.6362)



HBox(children=(IntProgress(value=0, max=99), HTML(value='')))

Epoch: 35, Loss: 8.7863 (2.4880 + 5.5473 + 0.1021 + 0.6488)



HBox(children=(IntProgress(value=0, max=99), HTML(value='')))

Epoch: 36, Loss: 8.7549 (2.5321 + 5.4850 + 0.1017 + 0.6361)



HBox(children=(IntProgress(value=0, max=99), HTML(value='')))

Epoch: 37, Loss: 8.5545 (2.3434 + 5.4728 + 0.1057 + 0.6326)



HBox(children=(IntProgress(value=0, max=99), HTML(value='')))

Epoch: 38, Loss: 8.3167 (2.1930 + 5.3916 + 0.1016 + 0.6305)



HBox(children=(IntProgress(value=0, max=99), HTML(value='')))

Epoch: 39, Loss: 8.6157 (2.4032 + 5.4677 + 0.1029 + 0.6419)



HBox(children=(IntProgress(value=0, max=99), HTML(value='')))

Epoch: 40, Loss: 8.6744 (2.4469 + 5.4944 + 0.1043 + 0.6288)



HBox(children=(IntProgress(value=0, max=99), HTML(value='')))

Epoch: 41, Loss: 8.5044 (2.3597 + 5.4139 + 0.1018 + 0.6290)



HBox(children=(IntProgress(value=0, max=99), HTML(value='')))

Epoch: 42, Loss: 8.4537 (2.2984 + 5.4282 + 0.1008 + 0.6263)



HBox(children=(IntProgress(value=0, max=99), HTML(value='')))

Epoch: 43, Loss: 8.7389 (2.5921 + 5.4148 + 0.1041 + 0.6279)



HBox(children=(IntProgress(value=0, max=99), HTML(value='')))

Epoch: 44, Loss: 8.4775 (2.2764 + 5.4697 + 0.1030 + 0.6284)



HBox(children=(IntProgress(value=0, max=99), HTML(value='')))

Epoch: 45, Loss: 8.7049 (2.4917 + 5.4857 + 0.0991 + 0.6285)



HBox(children=(IntProgress(value=0, max=99), HTML(value='')))

Epoch: 46, Loss: 8.4370 (2.3326 + 5.3756 + 0.1037 + 0.6251)



HBox(children=(IntProgress(value=0, max=99), HTML(value='')))

Epoch: 47, Loss: 8.5326 (2.3809 + 5.4263 + 0.1032 + 0.6221)



HBox(children=(IntProgress(value=0, max=99), HTML(value='')))

Epoch: 48, Loss: 8.4311 (2.3069 + 5.4021 + 0.1020 + 0.6200)



HBox(children=(IntProgress(value=0, max=99), HTML(value='')))

Epoch: 49, Loss: 8.1466 (2.1018 + 5.3180 + 0.1009 + 0.6260)



HBox(children=(IntProgress(value=0, max=99), HTML(value='')))

Epoch: 50, Loss: 7.8367 (1.8498 + 5.2612 + 0.1003 + 0.6254)



HBox(children=(IntProgress(value=0, max=99), HTML(value='')))

Epoch: 51, Loss: 7.7018 (1.7740 + 5.2122 + 0.0968 + 0.6188)



HBox(children=(IntProgress(value=0, max=99), HTML(value='')))

Epoch: 52, Loss: 7.7439 (1.8009 + 5.2245 + 0.0973 + 0.6213)



HBox(children=(IntProgress(value=0, max=99), HTML(value='')))

Epoch: 53, Loss: 7.4958 (1.6131 + 5.1638 + 0.0970 + 0.6219)



HBox(children=(IntProgress(value=0, max=99), HTML(value='')))

Epoch: 54, Loss: 7.5231 (1.6315 + 5.1754 + 0.0986 + 0.6176)



HBox(children=(IntProgress(value=0, max=99), HTML(value='')))

Epoch: 55, Loss: 7.5334 (1.6550 + 5.1646 + 0.0987 + 0.6150)



HBox(children=(IntProgress(value=0, max=99), HTML(value='')))

Epoch: 56, Loss: 7.4486 (1.5906 + 5.1465 + 0.0973 + 0.6143)



HBox(children=(IntProgress(value=0, max=99), HTML(value='')))

Epoch: 57, Loss: 7.5424 (1.7079 + 5.1182 + 0.0978 + 0.6185)



HBox(children=(IntProgress(value=0, max=99), HTML(value='')))

Epoch: 58, Loss: 7.5773 (1.7154 + 5.1379 + 0.0974 + 0.6265)



HBox(children=(IntProgress(value=0, max=99), HTML(value='')))

Epoch: 59, Loss: 7.4506 (1.6252 + 5.1176 + 0.0984 + 0.6094)



HBox(children=(IntProgress(value=0, max=99), HTML(value='')))

Epoch: 60, Loss: 7.4396 (1.6281 + 5.1061 + 0.0964 + 0.6089)



HBox(children=(IntProgress(value=0, max=99), HTML(value='')))

Epoch: 61, Loss: 7.2857 (1.4774 + 5.0949 + 0.0967 + 0.6168)



HBox(children=(IntProgress(value=0, max=99), HTML(value='')))

Epoch: 62, Loss: 7.1426 (1.3783 + 5.0501 + 0.0983 + 0.6159)



HBox(children=(IntProgress(value=0, max=99), HTML(value='')))

Epoch: 63, Loss: 7.2481 (1.4738 + 5.0594 + 0.0970 + 0.6178)



HBox(children=(IntProgress(value=0, max=99), HTML(value='')))

Epoch: 64, Loss: 7.2642 (1.4801 + 5.0723 + 0.0970 + 0.6149)



HBox(children=(IntProgress(value=0, max=99), HTML(value='')))

Epoch: 65, Loss: 7.2702 (1.5173 + 5.0438 + 0.0973 + 0.6118)



HBox(children=(IntProgress(value=0, max=99), HTML(value='')))

Epoch: 66, Loss: 7.2731 (1.5275 + 5.0307 + 0.0966 + 0.6184)



HBox(children=(IntProgress(value=0, max=99), HTML(value='')))

Epoch: 67, Loss: 7.2233 (1.4917 + 5.0239 + 0.0957 + 0.6120)



HBox(children=(IntProgress(value=0, max=99), HTML(value='')))

Epoch: 68, Loss: 7.2326 (1.5301 + 5.0003 + 0.0969 + 0.6054)



HBox(children=(IntProgress(value=0, max=99), HTML(value='')))

Epoch: 69, Loss: 7.1834 (1.4678 + 5.0081 + 0.0953 + 0.6122)



HBox(children=(IntProgress(value=0, max=99), HTML(value='')))

Epoch: 70, Loss: 7.1518 (1.4461 + 5.0063 + 0.0951 + 0.6043)



HBox(children=(IntProgress(value=0, max=99), HTML(value='')))

Epoch: 71, Loss: 7.1381 (1.4315 + 5.0070 + 0.0947 + 0.6048)



HBox(children=(IntProgress(value=0, max=99), HTML(value='')))

Epoch: 72, Loss: 7.1263 (1.4124 + 5.0039 + 0.0966 + 0.6134)



HBox(children=(IntProgress(value=0, max=99), HTML(value='')))

Epoch: 73, Loss: 7.1922 (1.4680 + 5.0154 + 0.0961 + 0.6126)



HBox(children=(IntProgress(value=0, max=99), HTML(value='')))

Epoch: 74, Loss: 7.1926 (1.5026 + 4.9914 + 0.0950 + 0.6036)



HBox(children=(IntProgress(value=0, max=99), HTML(value='')))

Epoch: 75, Loss: 7.1768 (1.4702 + 5.0173 + 0.0965 + 0.5928)



HBox(children=(IntProgress(value=0, max=99), HTML(value='')))

Epoch: 76, Loss: 7.2185 (1.5183 + 4.9993 + 0.0952 + 0.6057)



HBox(children=(IntProgress(value=0, max=99), HTML(value='')))

Epoch: 77, Loss: 7.0047 (1.3046 + 4.9862 + 0.0958 + 0.6181)



HBox(children=(IntProgress(value=0, max=99), HTML(value='')))

Epoch: 78, Loss: 6.9338 (1.2128 + 5.0057 + 0.0960 + 0.6193)



HBox(children=(IntProgress(value=0, max=99), HTML(value='')))

Epoch: 79, Loss: 7.3788 (1.6714 + 4.9951 + 0.0959 + 0.6164)




In [5]:
trainer.save_model()
trainer.save_model(cover=True)

Model weight saved.
Model weight saved.


# Eval

In [6]:
trainer = Trainer()

Pretrained model loaded.


In [7]:
trainer.eval_performance(data_loader['gallery'], data_loader['probe'])

{'r1': <utils.AverageMeter at 0x7f41febdd350>,
 'r5': <utils.AverageMeter at 0x7f41febddad0>,
 'r10': <utils.AverageMeter at 0x7f41febdda50>,
 'MAP': <utils.AverageMeter at 0x7f41febdda90>}

In [8]:
trainer.r1, trainer.r5, trainer.r10, trainer.MAP

(29.453681710213775, 46.080760095011875, 54.89904988123515, 14.61836043303487)

In [9]:
exit()