In [1]:
import import_ipynb
import warnings
warnings.filterwarnings("ignore", category=UserWarning) 

In [None]:
import os
import sys
from threading import local
import time
import datetime
import argparse
import logging
import os.path as osp
import numpy as np
import gc

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch import distributed as dist
from apex import amp

from configs.default_img import get_img_config
from data import build_dataloader
from models import build_model
from losses import build_losses
from tools.utils import save_checkpoint, set_seed, get_logger
from train import fit_train
from test import test, test_prcc  
from infer import infer, infer_prcc

In [3]:
import sys

sys.argv = [
    'main.ipynb',
    '--cfg', 'configs/res50_cels_cal.yaml',
    '--root', 'C:/Users/USER/OneDrive/Documents/My Nural Net/Person Re-Identification/dataset',
    '--dataset', 'prcc',
    '--output', 'outputs',
    #'--resume', 'C:/Users/USER/OneDrive/Documents/My Nural Net/Person Re-Identification/outputs/prcc/eval/best_model.pth.tar', 
    #'--amp', False,
    #'--eval', False, 
    #'--infer', False, 
    '--tag', 'eval',  
    '--gpu', '0',
    '--seed', '1',
    '--k_cal', '0.1',
    '--k_kl', '0.01'
]

In [4]:
def parse_option():
    parser = argparse.ArgumentParser(description='Train clothes-changing re-id model with clothes-based adversarial loss')
    parser.add_argument('--cfg', type=str, required=True, metavar="FILE", help='path to config file')
    # Datasets
    parser.add_argument('--root', type=str, help="your root path to data directory")
    parser.add_argument('--dataset', type=str, default='ltcc', help="ltcc, prcc, vcclothes, last, deepchange")
    # Miscs
    parser.add_argument('--output', type=str, help="your output path to save model and logs")
    parser.add_argument('--resume', type=str, metavar='PATH')
    parser.add_argument('--amp', action='store_true', help="automatic mixed precision")
    parser.add_argument('--eval', action='store_true', help="evaluation only")
    parser.add_argument('--infer', action='store_true', help="inference only")
    parser.add_argument('--tag', type=str, help='tag for log file')
    parser.add_argument('--name', type=str, help='your model name for record')
    parser.add_argument('--gpu', default='0', type=str, help='gpu device ids for CUDA_VISIBLE_DEVICES')
    # Options and Hyper-parameters
    parser.add_argument('--seed', type=str, help='seed for single-shot')
    parser.add_argument('--single_shot', action='store_true', help='single-shot option')
    parser.add_argument('--k_cal', type=str)
    parser.add_argument('--k_kl', type=str)

    args, unparsed = parser.parse_known_args()
    config = get_img_config(args)

    return config

In [None]:
import torch

if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

print(f"Using device: {device}")

In [6]:
def main(config):
    # Build dataloader
    if config.DATA.DATASET == 'prcc':
        trainloader, queryloader_same, queryloader_diff, galleryloader, dataset, train_sampler = build_dataloader(config)
    else: 
        trainloader, queryloader, galleryloader, dataset, train_sampler = build_dataloader(config) 

    # Define a matrix pid2clothes with shape (num_pids, num_clothes).
    # pid2clothes[i, j] = 1 when j-th clothes belongs to i-th identity. Otherwise, pid2clothes[i, j] = 0.
    pid2clothes = torch.from_numpy(dataset.pid2clothes) 
    # Build model
    model, model2, fuse, classifier, clothes_classifier, clothes_classifier2 = build_model(config, dataset.num_train_pids, dataset.num_train_clothes)
    print("model loaded")
    # Build identity classification loss, pairwise loss, clothes classificaiton loss, and adversarial loss.
    criterion_cla, criterion_pair, criterion_clothes, criterion_adv, kl = build_losses(config, dataset.num_train_clothes)
    print("loss built")
    # Build optimizer
    parameters = list(model.parameters()) + list(fuse.parameters()) + list(classifier.parameters()) 
    parameters2 = list(model2.parameters()) + list(clothes_classifier2.parameters()) 


    if config.TRAIN.OPTIMIZER.NAME == 'adam': 
        optimizer = optim.Adam(parameters, lr=config.TRAIN.OPTIMIZER.LR,         
                               weight_decay=config.TRAIN.OPTIMIZER.WEIGHT_DECAY) 
        optimizer2 = optim.Adam(parameters2, lr=config.TRAIN.OPTIMIZER.LR,        
                               weight_decay=config.TRAIN.OPTIMIZER.WEIGHT_DECAY)
        optimizer_cc = optim.Adam(clothes_classifier.parameters(), lr=config.TRAIN.OPTIMIZER.LR, 
                               weight_decay=config.TRAIN.OPTIMIZER.WEIGHT_DECAY)
    elif config.TRAIN.OPTIMIZER.NAME == 'adamw':
        optimizer = optim.AdamW(parameters, lr=config.TRAIN.OPTIMIZER.LR, 
                               weight_decay=config.TRAIN.OPTIMIZER.WEIGHT_DECAY)
        optimizer2 = optim.AdamW(parameters2, lr=config.TRAIN.OPTIMIZER.LR,
                               weight_decay=config.TRAIN.OPTIMIZER.WEIGHT_DECAY)
        optimizer_cc = optim.AdamW(clothes_classifier.parameters(), lr=config.TRAIN.OPTIMIZER.LR, 
                                  weight_decay=config.TRAIN.OPTIMIZER.WEIGHT_DECAY)
    elif config.TRAIN.OPTIMIZER.NAME == 'sgd':
        optimizer = optim.SGD(parameters, lr=config.TRAIN.OPTIMIZER.LR, momentum=0.9, 
                              weight_decay=config.TRAIN.OPTIMIZER.WEIGHT_DECAY, nesterov=True)
        optimizer2 = optim.SGD(parameters2, lr=config.TRAIN.OPTIMIZER.LR, momentum=0.9,
                              weight_decay=config.TRAIN.OPTIMIZER.WEIGHT_DECAY, nesterov=True)
        optimizer_cc = optim.SGD(clothes_classifier.parameters(), lr=config.TRAIN.OPTIMIZER.LR, momentum=0.9, 
                              weight_decay=config.TRAIN.OPTIMIZER.WEIGHT_DECAY, nesterov=True)
    else:
        raise KeyError("Unknown optimizer: {}".format(config.TRAIN.OPTIMIZER.NAME))

    # Build lr_scheduler
    scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=config.TRAIN.LR_SCHEDULER.STEPSIZE, 
                                         gamma=config.TRAIN.LR_SCHEDULER.DECAY_RATE) 
    scheduler2 = lr_scheduler.MultiStepLR(optimizer2, milestones=config.TRAIN.LR_SCHEDULER.STEPSIZE,     
                                         gamma=config.TRAIN.LR_SCHEDULER.DECAY_RATE)

    start_epoch = config.TRAIN.START_EPOCH

    if config.MODEL.RESUME:
        logger.info("Loading checkpoint from '{}'".format(config.MODEL.RESUME))
        checkpoint = torch.load(config.MODEL.RESUME)
        model.load_state_dict(checkpoint['model_state_dict'])
        classifier.load_state_dict(checkpoint['classifier_state_dict'])
        fuse.load_state_dict(checkpoint['fuse_state_dict'])
        model2.load_state_dict(checkpoint['model2_state_dict'])
        clothes_classifier.load_state_dict(checkpoint['clothes_classifier_state_dict'])
        clothes_classifier2.load_state_dict(checkpoint['clothes_classifier2_state_dict'])
        start_epoch = checkpoint['epoch']+1 


    model = model.to(device)
    model2 = model2.to(device)
    classifier = classifier.to(device)
    clothes_classifier2 = clothes_classifier2.to(device)
    fuse = fuse.to(device)
    clothes_classifier = clothes_classifier.to(device)

    if config.EVAL_MODE:
        logger.info("Evaluate only")
        with torch.no_grad():
            if config.DATA.DATASET == 'prcc':
                test_prcc(model, queryloader_same, queryloader_diff, galleryloader, dataset)
            else:
                test(config, model, queryloader, galleryloader, dataset)
        return
    
    if config.INFER_MODE:
        logger.info("Infer only")
        with torch.no_grad():
            if config.DATA.DATASET == 'prcc':
                infer_prcc(config, model, queryloader_same, queryloader_diff, galleryloader, dataset)
            else:
                infer(config, model, queryloader, galleryloader, dataset)
        return

    start_time = time.time()
    train_time = 0
    best_rank1 = -np.inf 
    best_epoch = 0
    logger.info("==> Start training")
    for epoch in range(start_epoch, config.TRAIN.MAX_EPOCH):
        train_sampler.set_epoch(epoch) 
        start_train_time = time.time()

        fit_train(config, epoch, model, model2, classifier, clothes_classifier, clothes_classifier2, fuse, criterion_cla, criterion_pair,
            criterion_clothes, criterion_adv, optimizer, optimizer2, optimizer_cc, trainloader, pid2clothes, kl)

        train_time += round(time.time() - start_train_time)

        if (epoch) > config.TEST.START_EVAL and config.TEST.EVAL_STEP > 0 and \
            (epoch) % config.TEST.EVAL_STEP == 0 or (epoch) == config.TRAIN.MAX_EPOCH:
            logger.info("==> Test")
            if config.DATA.DATASET == 'prcc':
                rank1 = test_prcc(model, queryloader_same, queryloader_diff, galleryloader, dataset)
            else:
                rank1 = test(config, model, queryloader, galleryloader, dataset)
            torch.cuda.empty_cache()
            is_best = rank1 > best_rank1
            if is_best:
                best_rank1 = rank1
                best_epoch = epoch 

            model_state_dict = model.state_dict()
            model2_state_dict = model2.state_dict()
            fuse_state_dict = fuse.state_dict()
            classifier_state_dict = classifier.state_dict()
            clothes_classifier_state_dict = clothes_classifier.state_dict()
            clothes_classifier2_state_dict = clothes_classifier2.state_dict()

            save_checkpoint({     
                    'model_state_dict': model_state_dict,
                    'model2_state_dict': model2_state_dict,
                    'fuse_state_dict': fuse_state_dict,
                    'classifier_state_dict': classifier_state_dict,
                    'clothes_classifier_state_dict': clothes_classifier_state_dict,
                    'clothes_classifier2_state_dict': clothes_classifier2_state_dict,
                    'rank1': rank1,
                    'epoch': epoch,
                }, is_best, osp.join(config.OUTPUT, 'checkpoint_ep' + str(epoch) + '.pth.tar'))
        scheduler.step()
        scheduler2.step()

    logger.info("==> Best Rank-1 {:.1%}, achieved at epoch {}".format(best_rank1, best_epoch))

    elapsed = round(time.time() - start_time)
    elapsed = str(datetime.timedelta(seconds=elapsed)) 
    train_time = str(datetime.timedelta(seconds=train_time)) 
    logger.info("Finished. Total elapsed time (h:m:s): {}. Training time (h:m:s): {}.".format(elapsed, train_time))
    
    del model, model2, classifier, clothes_classifier2, fuse, clothes_classifier, optimizer, optimizer2, optimizer_cc 
    gc.collect() 

In [None]:
if __name__ == '__main__':
    gc.collect()

    config = parse_option()
    set_seed(config.SEED )
    # Initialize logger
    if (not config.EVAL_MODE) and (not config.INFER_MODE):
        output_file = os.path.join(config.OUTPUT, 'log_train_.log')
    elif config.EVAL_MODE:
        output_file = os.path.join(config.OUTPUT, 'log_test.log')
    elif config.INFER_MODE:
        output_file = osp.join(config.OUTPUT, 'log_infer.log')  
        
    logger = get_logger(output_file, 'reid') 
    logger.info("Config:\n-----------------------------------------")
    logger.info(config)
    logger.info("-----------------------------------------")

    main(config)