In [1]:
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

import os.path as op
import torch
import numpy as np
import random
import time

from datasets import build_dataloader
from processor.processor import do_train
from utils.checkpoint import Checkpointer
from utils.iotools import save_train_configs
from utils.logger import setup_logger
from solver import build_optimizer, build_lr_scheduler
from model import build_model
from utils.metrics import Evaluator
from utils.options import get_args
from utils.comm import get_rank, synchronize
import tqdm

In [2]:
def set_seed(seed=0):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True


In [3]:
import argparse

def get_temp_args():
    parser = argparse.ArgumentParser(description="IRRA Args")
    ######################## general settings ########################
    parser.add_argument("--local_rank", default=0, type=int)
    parser.add_argument("--name", default="baseline", help="experiment name to save")
    parser.add_argument("--output_dir", default="logs")
    parser.add_argument("--log_period", default=100)
    parser.add_argument("--eval_period", default=1)
    parser.add_argument("--val_dataset", default="test") # use val set when evaluate, if test use test set
    parser.add_argument("--resume", default=False, action='store_true')
    parser.add_argument("--resume_ckpt_file", default="", help='resume from ...')

    ######################## model general settings ########################
    parser.add_argument("--pretrain_choice", default='ViT-B/16') # whether use pretrained model
    parser.add_argument("--temperature", type=float, default=0.02, help="initial temperature value, if 0, don't use temperature")
    parser.add_argument("--img_aug", default=False, action='store_true')

    ## cross modal transfomer setting
    parser.add_argument("--cmt_depth", type=int, default=4, help="cross modal transformer self attn layers")
    parser.add_argument("--masked_token_rate", type=float, default=0.8, help="masked token rate for mlm task")
    parser.add_argument("--masked_token_unchanged_rate", type=float, default=0.1, help="masked token unchanged rate")
    parser.add_argument("--lr_factor", type=float, default=5.0, help="lr factor for random init self implement module")
    parser.add_argument("--MLM", default=False, action='store_true', help="whether to use Mask Language Modeling dataset")

    ######################## loss settings ########################
    parser.add_argument("--loss_names", default='sdm+id+mlm', help="which loss to use ['mlm', 'cmpm', 'id', 'itc', 'sdm']")
    parser.add_argument("--mlm_loss_weight", type=float, default=1.0, help="mlm loss weight")
    parser.add_argument("--id_loss_weight", type=float, default=1.0, help="id loss weight")
    
    ######################## vison trainsformer settings ########################
    parser.add_argument("--img_size", type=tuple, default=(384, 128))
    parser.add_argument("--stride_size", type=int, default=16)

    ######################## text transformer settings ########################
    parser.add_argument("--text_length", type=int, default=77)
    parser.add_argument("--vocab_size", type=int, default=49408)

    ######################## solver ########################
    parser.add_argument("--optimizer", type=str, default="Adam", help="[SGD, Adam, Adamw]")
    parser.add_argument("--lr", type=float, default=1e-5)
    parser.add_argument("--bias_lr_factor", type=float, default=2.)
    parser.add_argument("--momentum", type=float, default=0.9)
    parser.add_argument("--weight_decay", type=float, default=4e-5)
    parser.add_argument("--weight_decay_bias", type=float, default=0.)
    parser.add_argument("--alpha", type=float, default=0.9)
    parser.add_argument("--beta", type=float, default=0.999)
    
    ######################## scheduler ########################
    parser.add_argument("--num_epoch", type=int, default=60)
    parser.add_argument("--milestones", type=int, nargs='+', default=(20, 50))
    parser.add_argument("--gamma", type=float, default=0.1)
    parser.add_argument("--warmup_factor", type=float, default=0.1)
    parser.add_argument("--warmup_epochs", type=int, default=5)
    parser.add_argument("--warmup_method", type=str, default="linear")
    parser.add_argument("--lrscheduler", type=str, default="cosine")
    parser.add_argument("--target_lr", type=float, default=0)
    parser.add_argument("--power", type=float, default=0.9)

    ######################## dataset ########################
    parser.add_argument("--dataset_name", default="CUHK-PEDES", help="[CUHK-PEDES, ICFG-PEDES, RSTPReid]")
    parser.add_argument("--sampler", default="random", help="choose sampler from [idtentity, random]")
    parser.add_argument("--num_instance", type=int, default=4)
    parser.add_argument("--root_dir", default="./data")
    parser.add_argument("--batch_size", type=int, default=128)
    parser.add_argument("--test_batch_size", type=int, default=512)
    parser.add_argument("--num_workers", type=int, default=8)
    parser.add_argument("--test", dest='training', default=True, action='store_false')

    return parser

In [4]:
parser = get_temp_args()

In [5]:
args = parser.parse_args(args=[
    "--name","irra",
    "--img_aug","--MLM",
    "--batch_size","6",
    "--loss_names","itc+proj",
    "--dataset_name","AGTBPR",
    "--root_dir",r"F:\Datasets\AG-ReID.v1",
    "--num_epoch","60",
    "--cmt_depth","1"
])

In [6]:
args

Namespace(local_rank=0, name='irra', output_dir='logs', log_period=100, eval_period=1, val_dataset='test', resume=False, resume_ckpt_file='', pretrain_choice='ViT-B/16', temperature=0.02, img_aug=True, cmt_depth=1, masked_token_rate=0.8, masked_token_unchanged_rate=0.1, lr_factor=5.0, MLM=True, loss_names='itc+proj', mlm_loss_weight=1.0, id_loss_weight=1.0, img_size=(384, 128), stride_size=16, text_length=77, vocab_size=49408, optimizer='Adam', lr=1e-05, bias_lr_factor=2.0, momentum=0.9, weight_decay=4e-05, weight_decay_bias=0.0, alpha=0.9, beta=0.999, num_epoch=60, milestones=(20, 50), gamma=0.1, warmup_factor=0.1, warmup_epochs=5, warmup_method='linear', lrscheduler='cosine', target_lr=0, power=0.9, dataset_name='AGTBPR', sampler='random', num_instance=4, root_dir='F:\\Datasets\\AG-ReID.v1', batch_size=6, test_batch_size=512, num_workers=8, training=True)

In [7]:
train_loader, val_img_loader, val_txt_loader, num_classes = build_dataloader(args)

+--------+-----+--------+----------+
| subset | ids | images | captions |
+--------+-----+--------+----------+
| train  | 199 |  8154  |   8154   |
|  test  | 189 |  1149  |   1149   |
|  val   | 189 |  7204  |   7204   |
+--------+-----+--------+----------+


In [8]:
for batch in train_loader:
    batch = {k: v.to(torch.device("cuda:0")) for k, v in batch.items()}
    break

# test model

In [9]:
args.img_size = (args.img_size[0]+args.stride_size*2,args.img_size[1])

In [16]:
args.token_num = 16

In [10]:
model = build_model(args, 1000)

Training Model with ['itc', 'proj'] tasks
Resized position embedding from size:torch.Size([1, 197, 768]) to size: torch.Size([1, 209, 768]) with height:26 width: 8


In [11]:
model = model.cuda()

In [12]:
'Total params: %2.fM' % (sum(p.numel() for p in model.parameters()) / 1000000.0)

'Total params: 179M'

In [13]:
ret = model(batch)

In [14]:
ret

{'temperature': tensor(0.0200),
 'loss_proj': tensor(1.9277, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>),
 'itc_loss': tensor(1.6424, device='cuda:0', grad_fn=<DivBackward0>)}

In [15]:
for batch in tqdm.tqdm(train_loader):
    batch = {k: v.to(torch.device("cuda:0")) for k, v in batch.items()}
    ret = model(batch)
    #break

100%|██████████████████████████████████████████████████████████████████████████████| 1359/1359 [03:04<00:00,  7.36it/s]


# Debug model

In [13]:
model.proj_prefix.dtype

torch.float32

In [14]:
model.cross_modal_transformer

Transformer(
  (resblocks): Sequential(
    (0): ResidualAttentionBlock(
      (attn): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
      )
      (ln_1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (mlp): Sequential(
        (c_fc): Linear(in_features=512, out_features=2048, bias=True)
        (gelu): QuickGELU()
        (c_proj): Linear(in_features=2048, out_features=512, bias=True)
      )
      (ln_2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    )
  )
)

In [26]:
images = batch['images']
caption_ids = batch['caption_ids']
image_feats, text_feats = model.base_model(images.to(torch.float16), caption_ids)

In [27]:
image_feats

tensor([[[ 0.4673, -0.2153,  0.1763,  ..., -0.1285,  0.2556,  0.0911],
         [-0.2076, -0.4395, -0.1622,  ..., -0.1357, -0.4082, -0.0923],
         [-0.1691, -0.2532,  0.1851,  ...,  0.1747,  0.0487, -0.1838],
         ...,
         [-0.2209, -0.1714,  0.1213,  ...,  0.0792, -0.1569,  0.0273],
         [ 0.0170, -0.1510, -0.0058,  ...,  0.0731, -0.0576, -0.2205],
         [-0.2681, -0.4751,  0.0470,  ..., -0.0914, -0.0970, -0.1152]],

        [[ 0.3862, -0.2991, -0.2688,  ..., -0.0126,  0.3752,  0.1027],
         [-0.0489,  0.1192, -0.0728,  ...,  0.3779,  0.2159, -0.2103],
         [-0.1732,  0.0795,  0.0630,  ...,  0.0494,  0.0748, -0.2546],
         ...,
         [-0.3430,  0.0327, -0.3711,  ...,  0.1998, -0.3416,  0.0274],
         [-0.2734, -0.0792, -0.1304,  ...,  0.1578, -0.2462, -0.0635],
         [-0.2852, -0.2664, -0.1409,  ...,  0.1362, -0.1469,  0.0829]],

        [[ 0.3259,  0.0110, -0.0614,  ..., -0.1625,  0.3955,  0.2661],
         [-0.6494, -0.2656, -0.1009,  ...,  0

In [28]:
logit_scale = model.logit_scale

In [31]:
bs = image_feats.shape[0]

In [33]:
model.proj_prefix.unsqueeze(0).repeat(bs,1,1).shape

torch.Size([6, 16, 512])

In [35]:
x = model.cross_former(model.proj_prefix.unsqueeze(0).repeat(bs,1,1).to(torch.float16), image_feats, image_feats)

In [39]:
x_casual = model.proj_dec(
    inputs_embeds = x,
    is_casual=True
)

In [42]:
x_attn = model.proj_dec(
    inputs_embeds = x,
    is_casual=False
)

In [44]:
y_pair = batch['pair_img']

In [46]:
y_pair_feats = model.base_model.encode_image(y_pair)

In [53]:
y_pair_feats

tensor([[[ 1.7297e-01,  4.2297e-02,  2.8534e-02,  ...,  1.3220e-01,
          -2.5520e-03,  4.9561e-01],
         [-7.7979e-01,  5.8105e-02, -2.1229e-03,  ...,  1.4075e-01,
          -1.6650e-01,  4.7943e-02],
         [-4.1553e-01,  1.3245e-01, -3.6316e-02,  ..., -5.8899e-02,
           2.2614e-02,  8.0750e-02],
         ...,
         [-1.2769e-01,  3.2251e-01,  4.8950e-01,  ...,  2.3340e-01,
          -2.2598e-02,  3.8379e-01],
         [-3.2666e-01,  2.5928e-01,  1.9690e-01,  ...,  2.0166e-01,
           1.2421e-01,  3.1445e-01],
         [-5.1465e-01, -1.6382e-01,  3.3752e-02,  ...,  7.4341e-02,
          -1.2671e-01,  4.0710e-02]],

        [[ 6.8555e-01, -4.9927e-01, -1.7200e-01,  ...,  2.3633e-01,
           2.9858e-01,  2.8516e-01],
         [ 1.2427e-01,  6.7406e-03, -6.8909e-02,  ...,  3.9697e-01,
          -3.3691e-01, -5.1178e-02],
         [ 1.8713e-01,  9.4971e-02, -4.7925e-01,  ...,  3.2422e-01,
           7.3730e-02, -6.9397e-02],
         ...,
         [ 7.0251e-02, -2

In [48]:
i_y_feats = y_pair_feats[:, 0, :]

In [49]:
i_y_feats.shape

torch.Size([6, 512])

In [56]:
y_pair_feats[:,1:,:].shape

torch.Size([6, 192, 512])

In [55]:
i_y_feats.unsqueeze(1).shape

torch.Size([6, 1, 512])

In [61]:
sim_tkn = torch.nn.functional.cosine_similarity(i_y_feats.unsqueeze(1),y_pair_feats[:,1:,:],dim=-1)

In [62]:
vlu, idx = torch.sort(sim_tkn,dim=1)

In [63]:
vlu

tensor([[0.4136, 0.4229, 0.4329,  ..., 0.6919, 0.7075, 0.7241],
        [0.3496, 0.3809, 0.3845,  ..., 0.6587, 0.6899, 0.7080],
        [0.3452, 0.3728, 0.3779,  ..., 0.6069, 0.6235, 0.6641],
        [0.3972, 0.4104, 0.4126,  ..., 0.6729, 0.6787, 0.7007],
        [0.3535, 0.3582, 0.3669,  ..., 0.5757, 0.5781, 0.5820],
        [0.3777, 0.3828, 0.4294,  ..., 0.7197, 0.7251, 0.7334]],
       device='cuda:0', dtype=torch.float16, grad_fn=<SortBackward0>)

In [64]:
idx

tensor([[164, 147, 122,  ...,   3, 176,  79],
        [ 33,  30,  80,  ..., 126, 169, 109],
        [116,  45,  43,  ..., 185, 159,  94],
        [ 75,  77,  78,  ..., 144, 116, 105],
        [ 34, 124,  75,  ..., 154, 155, 152],
        [126, 155, 157,  ...,  30,  86,  78]], device='cuda:0')

In [68]:
y_idx = idx[:,-model.proj_token_num:]

In [82]:
x = torch.arange(bs).unsqueeze(0).reshape(-1,1).repeat(1,model.proj_token_num)

In [84]:
x.shape

torch.Size([6, 16])

In [85]:
x.reshape(-1)

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4,
        4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5])

In [87]:
y_idx.reshape(-1)+1

tensor([184, 136, 145, 151, 117, 128, 192, 168, 160,  16, 144, 178, 146,   4,
        177,  80,  19, 122, 151, 129, 128, 143, 104, 119,  92,  96, 105, 184,
        112, 127, 170, 110, 176, 152, 185,  80,  33, 144, 103, 188,  87,  77,
         88,  71, 120, 186, 160,  95, 136,  31, 100,  74, 155,   8, 107,  99,
        108, 110, 119, 132,  15, 145, 117, 106, 187, 110, 186, 188, 138, 146,
        190, 130, 164, 106, 145,  22, 128, 155, 156, 153,  46,  80,  77,  83,
         88,  89,  11,  30,  82, 154,  90,  84, 169,  31,  87,  79],
       device='cuda:0')

In [92]:
y_gt = y_pair_feats[x.reshape(-1),y_idx.reshape(-1)+1]

In [98]:
dec_x = ((x_casual[0]+x_attn[0])*0.5).reshape(bs*model.proj_token_num,-1)

In [101]:
torch.nn.functional.l1_loss(dec_x,y_gt) + torch.nn.functional.mse_loss(dec_x,y_gt)

tensor(1.8516, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)

In [19]:
torch.cat([torch.randn(4,64) for i in range(10)],dim=0).shape

torch.Size([40, 64])

In [20]:
torch.cat([torch.randn(4,64) for i in range(10)],dim=1).shape

torch.Size([4, 640])