In [1]:
import os
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from models.resnet import resnet50
from my_dataset import MyDataset
from my_lossfunc import JointLoss, MultilabelLoss, DiscriminativeLoss
from my_transform import data_transforms
from scipy.spatial.distance import pdist
from torch.utils.data import DataLoader
from tqdm import tqdm
from utils import *

os.environ["CUDA_VISIBLE_DEVICES"] = "1,2,3"
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print('Device:', DEVICE)


BASE = '/home/zengrui/datasets'
DUKE_DIR_TRAIN = f'{BASE}/ReID_Duke/bounding_box_train'
DUKE_DIR_TEST = f'{BASE}/ReID_Duke/bounding_box_test'
DUKE_IMG_AMOUNT = 16522
DUKE_ID_AMOUNT = 702
MARKET_DIR_TRAIN = f'{BASE}/ReID_Market/bounding_box_train'
MARKET_DIR_TEST = f'{BASE}/ReID_Market/bounding_box_test'
MARKET_IMG_AMOUNT = 12936
MARKET_ID_AMOUNT = 751
ML_PATH = 'data/ml_Market.dat'
PRETRAIN_PATH = 'data/pretrained_weight.pkl'
PRETRAIN_OUT_PATH = 'data/pretrained_weight_{}.pkl'

BATCH_SIZE = 96
EPOCH = 30
LR = 0.1

BETA = 0.2
LAMB1 = 2e-4
LAMB2 = 50
MARGIN = 1
SCALA_CE = 30

Device: cuda


## Prepare Work

In [2]:
# data loader
data_loader = {
    'source': DataLoader(
        dataset=MyDataset(DUKE_DIR_TRAIN, 
                          transform=data_transforms('train')),
        batch_size=BATCH_SIZE,
        shuffle=True,
    ),
    'target': DataLoader(
        dataset=MyDataset(MARKET_DIR_TRAIN,
                          transform=data_transforms('train'),
                          require_view=True),
        batch_size=BATCH_SIZE,
        shuffle=True,
    ),
}
print('data_loader: ok.')

data_loader: ok.


## Train

In [3]:
class Trainer(object):
    def __init__(self):
        
        # 网络
        self.net = resnet50(pretrained=False, 
                            num_classes=DUKE_ID_AMOUNT)
        self.net = nn.DataParallel(self.net).to(DEVICE)
        if PRETRAIN_PATH is not None and os.path.exists(PRETRAIN_PATH):
            self.net.load_state_dict(torch.load(PRETRAIN_PATH))
            print('Pre-trained model loaded.')
        else:
            print('Pre-trained model not found. Train from scratch.')
            
        # 损失
        self.mdl_loss = DiscriminativeLoss(0.001).to(DEVICE)
        self.al_loss = nn.CrossEntropyLoss().to(DEVICE)
        self.rj_loss = JointLoss(MARGIN).to(DEVICE)  # lack 1 param
        self.cml_loss = MultilabelLoss(BATCH_SIZE).to(DEVICE)
        
        # 优化器
        self.optimizer = torch.optim.SGD(
            self.net.parameters(), lr=LR, momentum=0.9)
        self.lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
            self.optimizer, milestones=[int(EPOCH / 8 * 5), int(EPOCH / 8 * 7)])
        
        # 存储器
        self.ml_mem = torch.zeros(MARKET_IMG_AMOUNT, DUKE_ID_AMOUNT)
        self.inited = self.ml_mem.sum(dim=1) != 0
    
    def train(self):
        for epoch in range(EPOCH):
            self.train_epoch(epoch)
    
    def train_epoch(self, epoch):
        count = 0
        running_loss = {'total': 0,
                        'src': 0,
                        'st': 0,
                        'ml': 0,
                        'tgt': 0,
                       }
        if not self.mdl_loss.initialized:  # ...
            self.init_losses(data_loader['target'])
        
        with tqdm(total=len(data_loader['source'])) as pbar:
            tgt_iter = iter(data_loader['target'])
            for step, (ax, ay) in enumerate(data_loader['source']):
                # a - source, b - target
                ax = ax.to(DEVICE)
                ay = ay.to(DEVICE)
                try:
                    b = next(tgt_iter)
                except StopIteration:
                    tgt_iter = iter(data_loader['target'])
                    b = next(tgt_iter)
                (bx, by, b_view, b_idx) = b
                bx, by, b_view = bx.to(DEVICE), by.to(DEVICE), b_view.to(DEVICE)

                a_f, a_sim, _ = self.net(ax)
                b_f, b_sim, _ = self.net(bx)

                loss_src = self.al_loss(a_sim * SCALA_CE, ay)  # 有监督 交叉熵
                
                agents = self.net.module.fc.weight.renorm(2, 0, 1e-5).mul(1e5)  # 归一化 shape=(702, 2048)
                
                loss_st = self.rj_loss(agents.detach(), a_f, a_sim.detach(), ay, 
                                       b_f, b_sim.detach())
                
                with torch.no_grad():
                    ml = F.softmax(b_f.mm(agents.t_() * SCALA_CE), dim=1)  # t_(): 转置并inplace
#                 loss_ml = self.cml_loss(torch.log(ml), b_view)

                loss_ml = torch.Tensor([0]).to(DEVICE)
    
                if epoch < 1:  # 为什么第一轮不算 mdl_loss 呢
                    loss_tgt = torch.Tensor([0]).to(DEVICE)
                else:
                    ml_cpu = ml.detach().cpu()
                    is_inited_batch = self.inited[b_idx]
                    inited_idx = b_idx[is_inited_batch]
                    uninited_idx = b_idx[~is_inited_batch]
                    self.ml_mem[uninited_idx] = ml_cpu[~is_inited_batch]  # 0标签满更新
                    self.inited[uninited_idx] = True
                    self.ml_mem[inited_idx] = 0.9 * self.ml_mem[inited_idx] \
                                            + 0.1 * ml_cpu[is_inited_batch]  # 非空标签小更新
                    loss_tgt = self.mdl_loss(b_f, self.ml_mem[b_idx], by)

                self.optimizer.zero_grad()
                loss_total = loss_tgt + LAMB1 * loss_ml + LAMB2 * (loss_src + BETA * loss_st)
                loss_total.backward()
                self.optimizer.step()

                count += 1
                loss_cpu = float(loss_total.data.cpu().numpy())
                loss_src_cpu = float(loss_src.data.cpu().numpy())
                loss_st_cpu = float(loss_st.data.cpu().numpy())
                loss_ml_cpu = float(loss_ml.data.cpu().numpy())
                loss_tgt_cpu = float(loss_tgt.data.cpu().numpy())
                running_loss['total'] += loss_cpu
                running_loss['src'] += loss_src_cpu
                running_loss['st'] += loss_st_cpu
                running_loss['ml'] += loss_ml_cpu
                running_loss['tgt'] += loss_tgt_cpu
                pbar.set_description('Loss: %.4f (%.4f + %.4f + %.4f + %.4f)' 
                                     % (loss_cpu, loss_src_cpu, loss_st_cpu, loss_ml_cpu, loss_tgt_cpu))
                pbar.update()

            self.lr_scheduler.step()
            for k in running_loss.keys():
                running_loss[k] /= count
            print('Epoch: %d, Loss: %.4f (%.4f + %.4f + %.4f + %.4f)' 
                  % (epoch, running_loss['total'], running_loss['src'], 
                     running_loss['st'], running_loss['ml'], running_loss['tgt']))
            
    def init_losses(self, tgt_loader):
        print('Initializing losses ....')
        if os.path.isfile(ML_PATH):
            (ml, view, pairwise_agreements) = torch.load(ML_PATH)
            print('ml loaded.')
        else:
            print('ml not found, computing ....')
            sim, _, views = extract_features(
                data_loader['target'], self.net, index_feature=1, return_numpy=False)
            ml = F.softmax(sim * SCALA_CE, dim=1)
            ml_np = ml.cpu().numpy()
            pairwise_agreements = 1 - pdist(ml_np, 'minkowski', p=1) / 2  # 相似比较特征
            print('ml saving to %s...' % ML_PATH)
            torch.save((ml, views, pairwise_agreements), ML_PATH)
        log_ml = torch.log(ml)
        # ...
        self.mdl_loss.init_threshold(pairwise_agreements)
        print('mdl_loss threshold inited.')
        
    def save_model(self, cover=False):
        if cover:
            torch.save(self.net.state_dict(), PRETRAIN_PATH)
        else:
            torch.save(self.net.state_dict(), PRETRAIN_OUT_PATH.format(time.time()))
        print('Model weight saved.')

In [4]:
trainer = Trainer()
trainer.train()

Pre-trained model not found. Train from scratch.
Initializing losses ....
ml loaded.


  0%|          | 0/173 [00:00<?, ?it/s]

mdl_loss threshold inited.


Loss: 393.0692 (7.5515 + 1.5497 + 0.0000 + 0.0000): 100%|██████████| 173/173 [06:00<00:00,  2.08s/it] 
  0%|          | 0/173 [00:00<?, ?it/s]

Epoch: 0, Loss: 572.2653 (11.2167 + 1.1430 + 0.0000 + 0.0000)


Loss: 353.1386 (6.7364 + 1.5626 + 0.0000 + 0.6931): 100%|██████████| 173/173 [06:42<00:00,  2.32s/it]
  0%|          | 0/173 [00:00<?, ?it/s]

Epoch: 1, Loss: 377.3244 (7.2226 + 1.5622 + 0.0000 + 0.5701)


Loss: 352.1070 (6.7072 + 1.6053 + 0.0000 + 0.6931): 100%|██████████| 173/173 [06:41<00:00,  2.32s/it]
  0%|          | 0/173 [00:00<?, ?it/s]

Epoch: 2, Loss: 365.2999 (6.9722 + 1.5994 + 0.0000 + 0.6931)


Loss: 353.6849 (6.7324 + 1.6371 + 0.0000 + 0.6931): 100%|██████████| 173/173 [06:44<00:00,  2.34s/it]
  0%|          | 0/173 [00:00<?, ?it/s]

Epoch: 3, Loss: 358.2038 (6.8254 + 1.6243 + 0.0000 + 0.6931)


Loss: 329.3775 (6.2483 + 1.6272 + 0.0000 + 0.6931): 100%|██████████| 173/173 [06:43<00:00,  2.34s/it]
  0%|          | 0/173 [00:00<?, ?it/s]

Epoch: 4, Loss: 351.8907 (6.6958 + 1.6407 + 0.0000 + 0.6931)


Loss: 333.5637 (6.3270 + 1.6519 + 0.0000 + 0.6931): 100%|██████████| 173/173 [06:43<00:00,  2.33s/it]
  0%|          | 0/173 [00:00<?, ?it/s]

Epoch: 5, Loss: 344.3794 (6.5429 + 1.6541 + 0.0000 + 0.6931)


Loss: 314.8884 (5.9568 + 1.6353 + 0.0000 + 0.6931): 100%|██████████| 173/173 [06:53<00:00,  2.39s/it]
  0%|          | 0/173 [00:00<?, ?it/s]

Epoch: 6, Loss: 335.6200 (6.3676 + 1.6549 + 0.0000 + 0.6931)


Loss: 344.4575 (6.5375 + 1.6888 + 0.0000 + 0.6928): 100%|██████████| 173/173 [06:48<00:00,  2.36s/it]
  0%|          | 0/173 [00:00<?, ?it/s]

Epoch: 7, Loss: 324.9527 (6.1550 + 1.6508 + 0.0000 + 0.6930)


Loss: 308.0649 (5.8183 + 1.6456 + 0.0000 + 0.6936): 100%|██████████| 173/173 [06:43<00:00,  2.33s/it]
  0%|          | 0/173 [00:00<?, ?it/s]

Epoch: 8, Loss: 310.0275 (5.8597 + 1.6348 + 0.0000 + 0.6925)


Loss: 333.0367 (6.3128 + 1.6703 + 0.0000 + 0.6921): 100%|██████████| 173/173 [06:42<00:00,  2.32s/it]
  0%|          | 0/173 [00:00<?, ?it/s]

Epoch: 9, Loss: 294.1223 (5.5463 + 1.6116 + 0.0000 + 0.6909)


Loss: 297.3013 (5.6075 + 1.6238 + 0.0000 + 0.6872): 100%|██████████| 173/173 [06:41<00:00,  2.32s/it]
  0%|          | 0/173 [00:00<?, ?it/s]

Epoch: 10, Loss: 275.1643 (5.1735 + 1.5803 + 0.0000 + 0.6846)


Loss: 365.8995 (6.9595 + 1.7269 + 0.0000 + 0.6552): 100%|██████████| 173/173 [06:43<00:00,  2.33s/it]
  0%|          | 0/173 [00:00<?, ?it/s]

Epoch: 11, Loss: 250.2498 (4.6837 + 1.5393 + 0.0000 + 0.6712)


Loss: 310.4696 (5.8718 + 1.6218 + 0.0000 + 0.6606): 100%|██████████| 173/173 [06:43<00:00,  2.33s/it]
  0%|          | 0/173 [00:00<?, ?it/s]

Epoch: 12, Loss: 223.2540 (4.1515 + 1.5020 + 0.0000 + 0.6565)


Loss: 273.7122 (5.1466 + 1.5664 + 0.0000 + 0.7189): 100%|██████████| 173/173 [06:44<00:00,  2.34s/it]
  0%|          | 0/173 [00:00<?, ?it/s]

Epoch: 13, Loss: 197.0531 (3.6360 + 1.4612 + 0.0000 + 0.6397)


Loss: 190.8086 (3.5112 + 1.4677 + 0.0000 + 0.5710): 100%|██████████| 173/173 [06:49<00:00,  2.37s/it]
  0%|          | 0/173 [00:00<?, ?it/s]

Epoch: 14, Loss: 172.5562 (3.1541 + 1.4224 + 0.0000 + 0.6278)


Loss: 254.7523 (4.7718 + 1.5464 + 0.0000 + 0.6979): 100%|██████████| 173/173 [06:50<00:00,  2.37s/it]
  0%|          | 0/173 [00:00<?, ?it/s]

Epoch: 15, Loss: 150.4413 (2.7217 + 1.3746 + 0.0000 + 0.6087)


Loss: 161.9988 (2.9558 + 1.3559 + 0.0000 + 0.6500): 100%|██████████| 173/173 [06:49<00:00,  2.37s/it]
  0%|          | 0/173 [00:00<?, ?it/s]

Epoch: 16, Loss: 131.0259 (2.3415 + 1.3333 + 0.0000 + 0.6189)


Loss: 116.5644 (2.0560 + 1.3253 + 0.0000 + 0.5130): 100%|██████████| 173/173 [06:48<00:00,  2.36s/it]
  0%|          | 0/173 [00:00<?, ?it/s]

Epoch: 17, Loss: 114.4985 (2.0195 + 1.2915 + 0.0000 + 0.6086)


Loss: 153.7216 (2.8047 + 1.2767 + 0.0000 + 0.7174): 100%|██████████| 173/173 [06:46<00:00,  2.35s/it]
  0%|          | 0/173 [00:00<?, ?it/s]

Epoch: 18, Loss: 90.0850 (1.5404 + 1.2454 + 0.0000 + 0.6095)


Loss: 122.4461 (2.1628 + 1.3635 + 0.0000 + 0.6701): 100%|██████████| 173/173 [06:45<00:00,  2.34s/it]
  0%|          | 0/173 [00:00<?, ?it/s]

Epoch: 19, Loss: 84.1443 (1.4233 + 1.2363 + 0.0000 + 0.6150)


Loss: 166.5662 (3.0424 + 1.3780 + 0.0000 + 0.6655): 100%|██████████| 173/173 [06:47<00:00,  2.36s/it]
  0%|          | 0/173 [00:00<?, ?it/s]

Epoch: 20, Loss: 82.2728 (1.3870 + 1.2321 + 0.0000 + 0.5995)


Loss: 200.6777 (3.7156 + 1.4307 + 0.0000 + 0.5913): 100%|██████████| 173/173 [06:43<00:00,  2.33s/it]
  0%|          | 0/173 [00:00<?, ?it/s]

Epoch: 21, Loss: 80.0774 (1.3437 + 1.2281 + 0.0000 + 0.6133)


Loss: 128.2402 (2.2912 + 1.3083 + 0.0000 + 0.5949): 100%|██████████| 173/173 [06:42<00:00,  2.33s/it]
  0%|          | 0/173 [00:00<?, ?it/s]

Epoch: 22, Loss: 77.7314 (1.2977 + 1.2234 + 0.0000 + 0.6123)


Loss: 123.6079 (2.1921 + 1.3506 + 0.0000 + 0.4989): 100%|██████████| 173/173 [06:45<00:00,  2.34s/it]
  0%|          | 0/173 [00:00<?, ?it/s]

Epoch: 23, Loss: 76.8484 (1.2813 + 1.2180 + 0.0000 + 0.6042)


Loss: 193.6456 (3.5804 + 1.3972 + 0.0000 + 0.6520): 100%|██████████| 173/173 [06:46<00:00,  2.35s/it]
  0%|          | 0/173 [00:00<?, ?it/s]

Epoch: 24, Loss: 74.9093 (1.2428 + 1.2157 + 0.0000 + 0.6143)


Loss: 142.7090 (2.5681 + 1.3807 + 0.0000 + 0.4979): 100%|██████████| 173/173 [06:44<00:00,  2.34s/it]
  0%|          | 0/173 [00:00<?, ?it/s]

Epoch: 25, Loss: 73.0781 (1.2073 + 1.2100 + 0.0000 + 0.6123)


Loss: 126.6697 (2.2507 + 1.3652 + 0.0000 + 0.4815): 100%|██████████| 173/173 [06:42<00:00,  2.33s/it]
  0%|          | 0/173 [00:00<?, ?it/s]

Epoch: 26, Loss: 70.8979 (1.1648 + 1.2051 + 0.0000 + 0.6088)


Loss: 160.2090 (2.9131 + 1.3996 + 0.0000 + 0.5577): 100%|██████████| 173/173 [06:48<00:00,  2.36s/it]
  0%|          | 0/173 [00:00<?, ?it/s]

Epoch: 27, Loss: 69.8352 (1.1436 + 1.2032 + 0.0000 + 0.6222)


Loss: 120.7345 (2.1270 + 1.3754 + 0.0000 + 0.6305): 100%|██████████| 173/173 [06:50<00:00,  2.37s/it]
  0%|          | 0/173 [00:00<?, ?it/s]

Epoch: 28, Loss: 69.8719 (1.1446 + 1.2031 + 0.0000 + 0.6090)


Loss: 195.6564 (3.6103 + 1.4575 + 0.0000 + 0.5658): 100%|██████████| 173/173 [06:47<00:00,  2.36s/it]

Epoch: 29, Loss: 70.3622 (1.1542 + 1.2042 + 0.0000 + 0.6111)





In [None]:
trainer.save_model()
trainer.save_model(cover=True)