In [1]:
import os

import os.path as op
import torch
import numpy as np
import random
import time

from datasets import build_dataloader
from processor.processor import do_train
from utils.checkpoint import Checkpointer
from utils.iotools import save_train_configs
from utils.logger import setup_logger
from solver import build_optimizer, build_lr_scheduler
from model import build_model
from utils.metrics import Evaluator
from utils.options import get_args
from utils.comm import get_rank, synchronize


In [2]:
def set_seed(seed=0):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True


In [3]:
import argparse

def get_temp_args():
    parser = argparse.ArgumentParser(description="IRRA Args")
    ######################## general settings ########################
    parser.add_argument("--local_rank", default=0, type=int)
    parser.add_argument("--name", default="baseline", help="experiment name to save")
    parser.add_argument("--output_dir", default="logs")
    parser.add_argument("--log_period", default=100)
    parser.add_argument("--eval_period", default=1)
    parser.add_argument("--val_dataset", default="test") # use val set when evaluate, if test use test set
    parser.add_argument("--resume", default=False, action='store_true')
    parser.add_argument("--resume_ckpt_file", default="", help='resume from ...')

    ######################## model general settings ########################
    parser.add_argument("--pretrain_choice", default='ViT-B/16') # whether use pretrained model
    parser.add_argument("--temperature", type=float, default=0.02, help="initial temperature value, if 0, don't use temperature")
    parser.add_argument("--img_aug", default=False, action='store_true')

    ## cross modal transfomer setting
    parser.add_argument("--cmt_depth", type=int, default=4, help="cross modal transformer self attn layers")
    parser.add_argument("--masked_token_rate", type=float, default=0.8, help="masked token rate for mlm task")
    parser.add_argument("--masked_token_unchanged_rate", type=float, default=0.1, help="masked token unchanged rate")
    parser.add_argument("--lr_factor", type=float, default=5.0, help="lr factor for random init self implement module")
    parser.add_argument("--MLM", default=False, action='store_true', help="whether to use Mask Language Modeling dataset")

    ######################## loss settings ########################
    parser.add_argument("--loss_names", default='sdm+id+mlm', help="which loss to use ['mlm', 'cmpm', 'id', 'itc', 'sdm']")
    parser.add_argument("--mlm_loss_weight", type=float, default=1.0, help="mlm loss weight")
    parser.add_argument("--id_loss_weight", type=float, default=1.0, help="id loss weight")
    
    ######################## vison trainsformer settings ########################
    parser.add_argument("--img_size", type=tuple, default=(384, 128))
    parser.add_argument("--stride_size", type=int, default=16)

    ######################## text transformer settings ########################
    parser.add_argument("--text_length", type=int, default=77)
    parser.add_argument("--vocab_size", type=int, default=49408)

    ######################## solver ########################
    parser.add_argument("--optimizer", type=str, default="Adam", help="[SGD, Adam, Adamw]")
    parser.add_argument("--lr", type=float, default=1e-5)
    parser.add_argument("--bias_lr_factor", type=float, default=2.)
    parser.add_argument("--momentum", type=float, default=0.9)
    parser.add_argument("--weight_decay", type=float, default=4e-5)
    parser.add_argument("--weight_decay_bias", type=float, default=0.)
    parser.add_argument("--alpha", type=float, default=0.9)
    parser.add_argument("--beta", type=float, default=0.999)
    
    ######################## scheduler ########################
    parser.add_argument("--num_epoch", type=int, default=60)
    parser.add_argument("--milestones", type=int, nargs='+', default=(20, 50))
    parser.add_argument("--gamma", type=float, default=0.1)
    parser.add_argument("--warmup_factor", type=float, default=0.1)
    parser.add_argument("--warmup_epochs", type=int, default=5)
    parser.add_argument("--warmup_method", type=str, default="linear")
    parser.add_argument("--lrscheduler", type=str, default="cosine")
    parser.add_argument("--target_lr", type=float, default=0)
    parser.add_argument("--power", type=float, default=0.9)

    ######################## dataset ########################
    parser.add_argument("--dataset_name", default="CUHK-PEDES", help="[CUHK-PEDES, ICFG-PEDES, RSTPReid]")
    parser.add_argument("--sampler", default="random", help="choose sampler from [idtentity, random]")
    parser.add_argument("--num_instance", type=int, default=4)
    parser.add_argument("--root_dir", default="./data")
    parser.add_argument("--batch_size", type=int, default=128)
    parser.add_argument("--test_batch_size", type=int, default=512)
    parser.add_argument("--num_workers", type=int, default=8)
    parser.add_argument("--test", dest='training', default=True, action='store_false')

    return parser

In [4]:
def set_seed(seed=0):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True


In [5]:
parser = get_temp_args()

In [6]:
args = parser.parse_args(args=[
    "--name","irra",
    "--img_aug","--MLM",
    "--batch_size","16",
    "--loss_names","sdm+mlm+id",
    "--dataset_name","CUHK-PEDES",
    "--root_dir","E:\Share\jupyterDir\DSSAM\datasets",
    "--num_epoch","60",
])

In [9]:
set_seed(1+get_rank())
name = args.name

num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
args.distributed = num_gpus > 1

if args.distributed:
    torch.cuda.set_device(args.local_rank)
    torch.distributed.init_process_group(backend="nccl", init_method="env://")
    synchronize()

device = "cuda"
cur_time = time.strftime("%Y%m%d_%H%M%S", time.localtime())
args.output_dir = op.join(args.output_dir, args.dataset_name, f'{cur_time}_{name}')
logger = setup_logger('IRRA', save_dir=args.output_dir, if_train=args.training, distributed_rank=get_rank())
logger.info("Using {} GPUs".format(num_gpus))
logger.info(str(args).replace(',', '\n'))
save_train_configs(args.output_dir, args)

# # get image-text pair datasets dataloader
# train_loader, val_img_loader, val_txt_loader, num_classes = build_dataloader(args)
model = build_model(args, 11003)
logger.info('Total params: %2.fM' % (sum(p.numel() for p in model.parameters()) / 1000000.0))

logs\CUHK-PEDES\20240130_185055_irra\CUHK-PEDES\20240130_185211_irra\CUHK-PEDES\20240130_185220_irra is not exists, create given directory
2024-01-30 18:52:20,476 IRRA INFO: Using 1 GPUs
2024-01-30 18:52:20,476 IRRA INFO: Using 1 GPUs
2024-01-30 18:52:20,476 IRRA INFO: Using 1 GPUs
2024-01-30 18:52:20,478 IRRA INFO: Namespace(local_rank=0
 name='irra'
 output_dir='logs\\CUHK-PEDES\\20240130_185055_irra\\CUHK-PEDES\\20240130_185211_irra\\CUHK-PEDES\\20240130_185220_irra'
 log_period=100
 eval_period=1
 val_dataset='test'
 resume=False
 resume_ckpt_file=''
 pretrain_choice='ViT-B/16'
 temperature=0.02
 img_aug=True
 cmt_depth=4
 masked_token_rate=0.8
 masked_token_unchanged_rate=0.1
 lr_factor=5.0
 MLM=True
 loss_names='sdm+mlm+id'
 mlm_loss_weight=1.0
 id_loss_weight=1.0
 img_size=(384
 128)
 stride_size=16
 text_length=77
 vocab_size=49408
 optimizer='Adam'
 lr=1e-05
 bias_lr_factor=2.0
 momentum=0.9
 weight_decay=4e-05
 weight_decay_bias=0.0
 alpha=0.9
 beta=0.999
 num_epoch=60
 miles

In [12]:
model.mlm_head

Sequential(
  (dense): Linear(in_features=512, out_features=512, bias=True)
  (gelu): QuickGELU()
  (ln): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  (fc): Linear(in_features=512, out_features=49408, bias=True)
)

In [16]:
sum(p.numel() for p in model.parameters()) / 1000000.0 + (sum(p.numel() for p in model.mlm_head.dense.parameters()) / 1000000.0)*2

195.060731

In [8]:
for n_iter, batch in enumerate(train_loader):
    # batch = {k: v for k, v in batch.items()}
    break

In [9]:
batch

{'caption_ids': tensor([[49406,  2308,  3309,  ...,     0,     0,     0],
         [49406,   320,   786,  ...,     0,     0,     0],
         [49406,   518,   786,  ...,     0,     0,     0],
         ...,
         [49406,   530,   320,  ...,     0,     0,     0],
         [49406,   518,  2533,  ...,     0,     0,     0],
         [49406,   320,   786,  ...,     0,     0,     0]]),
 'image_ids': tensor([28537, 29320, 29703, 14420, 15668, 33059, 22634,  2130, 29151, 16195,
            84,  6956,  9964,  8251, 19266,  6896]),
 'mlm_labels': tensor([[0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         ...,
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0]]),
 'mlm_ids': tensor([[49406,  2308,  3309,  ...,     0,     0,     0],
         [49406,   320,   786,  ...,     0,     0,     0],
         [49406,   518,   786,  ...,     0,     0,     0],
         ...,
         [49406,   530,   32

In [10]:
batch = {k: v for k, v in batch.items()}

In [11]:
batch

{'caption_ids': tensor([[49406,  2308,  3309,  ...,     0,     0,     0],
         [49406,   320,   786,  ...,     0,     0,     0],
         [49406,   518,   786,  ...,     0,     0,     0],
         ...,
         [49406,   530,   320,  ...,     0,     0,     0],
         [49406,   518,  2533,  ...,     0,     0,     0],
         [49406,   320,   786,  ...,     0,     0,     0]]),
 'image_ids': tensor([28537, 29320, 29703, 14420, 15668, 33059, 22634,  2130, 29151, 16195,
            84,  6956,  9964,  8251, 19266,  6896]),
 'mlm_labels': tensor([[0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         ...,
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0]]),
 'mlm_ids': tensor([[49406,  2308,  3309,  ...,     0,     0,     0],
         [49406,   320,   786,  ...,     0,     0,     0],
         [49406,   518,   786,  ...,     0,     0,     0],
         ...,
         [49406,   530,   32

In [13]:
caption_ids = batch['caption_ids']

In [22]:
caption_ids[0]

tensor([49406,  2308,  3309,   320,  1538,  1709, 19820,  1253, 49405, 49405,
         1538,  6400,   268,  3140, 12386,   593,  6313, 49405, 11344,  4079,
          269, 49407,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0])

In [21]:
caption_ids.argmax(dim=-1)

tensor([21, 44, 22, 21, 21, 23, 32, 52, 23, 20, 22, 33, 19, 18, 23, 22])

In [23]:
from transformers import CLIPTokenizer

In [24]:
tokenizer = CLIPTokenizer.from_pretrained(r"F:\preTrainedModels\clip-vit-base-patch16")

In [36]:
tokenizer.decode([49407])

'<|endoftext|>'

In [32]:
tokenizer

CLIPTokenizer(name_or_path='F:\preTrainedModels\clip-vit-base-patch16', vocab_size=49408, model_max_length=77, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'bos_token': AddedToken("<|startoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=True), 'eos_token': AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=True), 'unk_token': AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=True), 'pad_token': '<|endoftext|>'}, clean_up_tokenization_spaces=True)

In [39]:
tokenizer.decode(caption_ids[0])

'<|startoftext|>woman wearing a long sleeved top jekyll jekyll long multi - color skirt with flat jekyll toe shoes. <|endoftext|>!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!'

In [42]:
1/args.temperature

50.0

In [44]:
batch['pids']

tensor([ 9222,  9482,  9603,  4642,  5044, 10692,  7301,   700,  9427,  5220,
           28,  2248,  3195,  2661,  6198,  2229])