In [1]:
# !git clone https://github.com/hwjiang1510/VQLoC.git

In [2]:
import os
# os.chdir('/kaggle/working/VQLoC')
os.getcwd()

'/kaggle/working'

In [4]:
import os
import pprint
import random
import numpy as np
import torch
import torch.nn.parallel
import torch.optim
import itertools
import argparse
import kornia
import kornia.augmentation as K
from kornia.constants import DataKey
from einops import rearrange


NORMALIZE_MEAN = [0.485, 0.456, 0.406]
NORMALIZE_STD = [0.229, 0.224, 0.225]

def process_data(config, sample, iter=0, split='train', device='cuda'):
    '''
    sample: 
        'clip': clip,                           # [B,T,3,H,W]
        'clip_with_bbox': clip_with_bbox,       # [B,T], binary value 0 / 1
        'clip_bbox': clip_bbox,                 # [B,T,4]
        'query': query                          # [B,3,H2,W2]
    '''    
    B, T, _, H, W = sample['clip'].shape
    B, _, H2, W2 = sample['query'].shape
    normalization = kornia.enhance.Normalize(mean=NORMALIZE_MEAN, std=NORMALIZE_STD)

    brightness = config.train.aug_brightness
    contrast = config.train.aug_contrast
    saturation = config.train.aug_saturation
    query_size = config.dataset.query_size
    crop_sacle = config.train.aug_crop_scale
    crop_ratio_min = config.train.aug_crop_ratio_min
    crop_ratio_max = config.train.aug_crop_ratio_max
    affine_degree = config.train.aug_affine_degree
    affine_translate = config.train.aug_affine_translate
    affine_scale_min = config.train.aug_affine_scale_min
    affine_scale_max = config.train.aug_affine_scale_max
    affine_shear_min = config.train.aug_affine_shear_min
    affine_shear_max = config.train.aug_affine_shear_max
    prob_color = config.train.aug_prob_color
    prob_flip = config.train.aug_prob_flip
    prob_crop = config.train.aug_prob_crop
    prob_affine = config.train.aug_prob_affine

    transform_clip = K.AugmentationSequential(
                K.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.3, hue=0, p=1.0),
                K.RandomHorizontalFlip(p=0.5),
                K.RandomResizedCrop((H, W), scale=(0.66, 1.0), ratio=(crop_ratio_min, crop_ratio_max), p=1.0),
                # K.RandomAffine(affine_degree, [affine_translate, affine_translate], [affine_scale_min, affine_scale_max], 
                #                 [affine_shear_min, affine_shear_max], p=prob_affine),
                data_keys=[DataKey.INPUT, DataKey.BBOX_XYXY],  # Just to define the future input here.
                same_on_batch=True,
                )
    transform_query = K.AugmentationSequential(
                K.ColorJitter(brightness, contrast, saturation, hue=0, p=prob_color),
                K.RandomHorizontalFlip(p=prob_flip),
                K.RandomResizedCrop((query_size, query_size), scale=(crop_sacle, 1.0), ratio=(crop_ratio_min, crop_ratio_max), p=prob_crop),
                # K.RandomAffine(affine_degree, [affine_translate, affine_translate], [affine_scale_min, affine_scale_max], 
                #                 [affine_shear_min, affine_shear_max], p=prob_affine
                # K.RandomAffine(affine_degree, [0, 0], [1.0, 1.0], 
                #                 [1.0, 1.0], p=prob_affine),
                data_keys=["input"],  # Just to define the future input here.
                same_on_batch=False,
                )
    
    transform_query_frame = K.AugmentationSequential(
                K.ColorJitter(brightness, contrast, saturation, hue=0, p=prob_color),
                K.RandomHorizontalFlip(p=prob_flip),
                # K.RandomAffine(affine_degree, [affine_translate, affine_translate], [affine_scale_min, affine_scale_max], 
                #                 [affine_shear_min, affine_shear_max], p=prob_affine
                # K.RandomAffine(affine_degree, [0, 0], [1.0, 1.0], 
                #                 [1.0, 1.0], p=prob_affine),
                data_keys=[DataKey.INPUT, DataKey.BBOX_XYXY],  # Just to define the future input here.
                same_on_batch=False,
                )
    
    clip = sample['clip']                           # [B,T,C,H,W]
    query = sample['query']                         # [B,C,H',W']
    clip_with_bbox = sample['clip_with_bbox']       # [B,T]
    clip_bbox = sample['clip_bbox']                 # [B,T,4], with value range [0,1], torch axis
    clip_bbox = recover_bbox(clip_bbox, H, W)       # [B,T,4], with range in image pixels, torch axis
    clip_bbox = bbox_torchTocv2(clip_bbox)          # [B,T,4], with range in image pixels, cv2 axis
    if config.train.use_query_roi and 'query_frame' in sample.keys():
        query_frame = sample['query_frame']                         # [B,C,H,W]
        query_frame_bbox = sample['query_frame_bbox']   
        query_frame_bbox = recover_bbox(query_frame_bbox, H, W)
        query_frame_bbox = bbox_torchTocv2(query_frame_bbox)        # [B,4]

    # augment clips
    if split == 'train' and config.train.aug_clip and (iter > config.train.aug_clip_iter):        
        clip_aug, clip_bbox_aug  = [], []
        for clip_cur, clip_bbox_cur in zip(clip, clip_bbox):    # [T,C,H,W], [T,4,2]
            clip_cur_aug, clip_bbox_cur_aug = transform_clip(clip_cur.to(device), clip_bbox_cur.to(device).unsqueeze(1))
            clip_aug.append(clip_cur_aug)
            clip_bbox_aug.append(clip_bbox_cur_aug.squeeze())

        clip_aug = torch.stack(clip_aug)                     # [B,T,C,H,W]
        clip_bbox_aug = torch.stack(clip_bbox_aug)           # [B,T,4]
        clip_bbox_aug = bbox_cv2Totorch(clip_bbox_aug)
        clip_bbox_aug, with_bbox_update = check_bbox(clip_bbox_aug, H, W)
        clip_bbox_aug = normalize_bbox(clip_bbox_aug, H, W)                 # back in range [0,1]
        clip_with_bbox_aug = torch.logical_and(with_bbox_update.to(clip_with_bbox.device), clip_with_bbox)
        sample['clip'] = clip_aug.to(device)
        sample['clip_with_bbox'] = clip_with_bbox_aug.to(device).float()
        sample['clip_bbox'] = clip_bbox_aug.to(device)
    
    # augment the query
    if split == 'train' and config.train.aug_query:
        query = transform_query(query)
        sample['query'] = query.to(device)
    
    # augment the query frame
    if split == 'train' and config.train.aug_query and 'query_frame' in sample.keys():
        query_frame, query_frame_bbox = transform_query_frame(query)
        sample['query_frame'] = query_frame.to(device)
        query_frame_bbox = bbox_cv2Totorch(query_frame_bbox)
        query_frame_bbox = normalize_bbox(query_frame_bbox, H, W).clamp(min=0.0, max=1.0)
        sample['query_frame_bbox'] = query_frame_bbox.to(device).float()

    # normalize the input clips
    sample['clip_origin'] = sample['clip'].clone()
    clip = rearrange(sample['clip'], 'b t c h w -> (b t) c h w').to(device)
    clip = normalization(clip)
    sample['clip'] = rearrange(clip, '(b t) c h w -> b t c h w', b=B, t=T)

    # normalize input query
    sample['query_origin'] = sample['query'].clone()
    sample['query'] = normalization(sample['query'])

    # normalize input query frame
    if 'query_frame' in sample.keys():
        sample['query_frame_origin'] = sample['query_frame'].clone()
        sample['query_frame'] = normalization(sample['query_frame'])

    return sample


def replicate_sample_for_hnm(gts):
    '''
        gts = {
            'clip':                 in [b,t,c,h,w]
            'clip_rigin':           in [b,t,c,h,w]
            'clip_with_bbox':       in [b,t]
            'before_query':         in [b,t]
            'clip_bbox':            in [b,t,4]
            'query':                in [b,c,h,w]
            'query_origin':         in [b,c,h,w]
            'clip_h':               in [b]
            'clip_w':               in [b]
        }
    '''
    clip = gts['clip']
    clip_origin = gts['clip_origin']
    clip_with_bbox = gts['clip_with_bbox']
    before_query = gts['before_query']
    clip_bbox = gts['clip_bbox']
    query = gts['query']
    query_origin = gts['query_origin']
    clip_h, clip_w = gts['clip_h'], gts['clip_w']

    b, t = clip.shape[:2]
    device = clip.device

    new_clip = []
    new_clip_origin = []
    new_clip_with_bbox = []
    new_before_query = []
    new_clip_bbox = []
    new_query = []
    new_query_origin = []
    new_clip_h, new_clip_w = [], []

    for i in range(b):
        for j in range(b):
            new_clip.append(clip[i])
            new_clip_origin.append(clip_origin[i])
            new_query.append(query[j])
            new_query_origin.append(query_origin[j])
            if i == j:
                new_clip_with_bbox.append(clip_with_bbox[i])
                new_before_query.append(before_query[i])
                new_clip_bbox.append(clip_bbox[i])
            else:
                new_clip_with_bbox.append(torch.zeros(t).float().to(device))
                new_before_query.append(torch.ones(t).bool().to(device))
                new_clip_bbox.append(torch.tensor([[0.0, 0.0, 0.0001, 0.0001]]).repeat(t,1).float().to(device))
            new_clip_h.append(clip_h[i])
            new_clip_w.append(clip_w[i])
    
    new_clip = torch.stack(new_clip)
    new_clip_origin = torch.stack(new_clip_origin)
    new_clip_with_bbox = torch.stack(new_clip_with_bbox)
    new_before_query = torch.stack(new_before_query)
    new_clip_bbox = torch.stack(new_clip_bbox)
    new_clip_h = torch.stack(new_clip_h)
    new_clip_w = torch.stack(new_clip_w)
    new_query = torch.stack(new_query)
    new_query_origin = torch.stack(new_query_origin)

    new_gts = {
            'clip': new_clip,                       # in [b^2,t,c,h,w]
            'clip_origin': new_clip_origin,         # in [b^2,t,c,h,w]
            'clip_with_bbox': new_clip_with_bbox,   # in [b^2,t]
            'before_query': new_before_query,       # in [b^2,t]
            'clip_bbox': new_clip_bbox,             # in [b^2,t,4]
            'query': new_query,                     # in [b^2,c,h,w]
            'query_origin': new_query_origin,       # in [b^2,c,h,w]
            'clip_h': new_clip_h,                   # in [b^2]
            'clip_w': new_clip_w,                   # in [b^2]
        }
    return new_gts


def normalize_bbox(bbox, h, w):
    '''
    bbox torch tensor in shape [4] or [...,4], under torch axis
    '''
    bbox_cp = bbox.clone()
    if len(bbox.shape) > 1: # [N,4]
        bbox_cp[..., 0] /= h
        bbox_cp[..., 1] /= w
        bbox_cp[..., 2] /= h
        bbox_cp[..., 3] /= w
        return bbox_cp
    else:
        return torch.tensor([bbox_cp[0]/h, bbox_cp[1]/w, bbox_cp[2]/h, bbox_cp[3]/w])


def recover_bbox(bbox, h, w):
    '''
    bbox torch tensor in shape [4] or [...,4], under torch axis
    '''
    bbox_cp = bbox.clone()
    if len(bbox.shape) > 1: # [N,4]
        bbox_cp[..., 0] *= h
        bbox_cp[..., 1] *= w
        bbox_cp[..., 2] *= h
        bbox_cp[..., 3] *= w
        return bbox_cp
    else:
        return torch.tensor([bbox_cp[0]*h, bbox_cp[1]*w, bbox_cp[2]*h, bbox_cp[3]*w])
    

def bbox_torchTocv2(bbox):
    '''
    torch, idx 0/2 for height, 1/3 for width (x,y,x,y)
    cv2: idx 0/2 for width, 1/3 for height (y,x,y,x)
    bbox torch tensor in shape [4] or [...,4], under torch axis
    '''
    bbox_cp = bbox.clone()
    if len(bbox.shape) > 1:
        bbox_x1 = bbox_cp[...,0].unsqueeze(-1)
        bbox_y1 = bbox_cp[...,1].unsqueeze(-1)
        bbox_x2 = bbox_cp[...,2].unsqueeze(-1)
        bbox_y2 = bbox_cp[...,3].unsqueeze(-1)
        return torch.cat([bbox_y1, bbox_x1, bbox_y2, bbox_x2], dim=-1)
    else:
        return torch.tensor([bbox_cp[1], bbox_cp[0], bbox_cp[3], bbox_cp[2]])
    

def bbox_cv2Totorch(bbox):
    '''
    torch, idx 0/2 for height, 1/3 for width (x,y,x,y)
    cv2: idx 0/2 for width, 1/3 for height (y,x,y,x)
    bbox torch tensor in shape [4] or [...,4], under cv2 axis
    '''
    bbox_cp = bbox.clone()
    if len(bbox.shape) > 1:
        bbox_x1 = bbox_cp[...,1].unsqueeze(-1)
        bbox_y1 = bbox_cp[...,0].unsqueeze(-1)
        bbox_x2 = bbox_cp[...,3].unsqueeze(-1)
        bbox_y2 = bbox_cp[...,2].unsqueeze(-1)
        return torch.cat([bbox_x1, bbox_y1, bbox_x2, bbox_y2], dim=-1)
    else:
        return torch.tensor([bbox_cp[1], bbox_cp[0], bbox_cp[3], bbox_cp[2]])


def check_bbox(bbox, h, w):
    B, T, _ = bbox.shape
    bbox = bbox.reshape(-1,4)

    x1, y1, x2, y2 = bbox[...,0], bbox[...,1], bbox[...,2], bbox[...,3]
    left_invalid = y2 <= 0.0
    right_invalid = y1 >= w - 1
    top_invalid = x2 <= 0.0
    bottom_invalid = x1 >= h - 1

    x_invalid = torch.logical_or(top_invalid, bottom_invalid)
    y_invalid = torch.logical_or(left_invalid, right_invalid)
    invalid = torch.logical_or(x_invalid, y_invalid)
    valid = ~invalid

    x1_clip = x1.clip(min=0.0, max=h).unsqueeze(-1)
    x2_clip = x2.clip(min=0.0, max=h).unsqueeze(-1)
    y1_clip = y1.clip(min=0.0, max=w).unsqueeze(-1)
    y2_clip = y2.clip(min=0.0, max=w).unsqueeze(-1)
    bbox_clip = torch.cat([x1_clip, y1_clip, x2_clip, y2_clip], dim=-1)

    return bbox_clip.reshape(B,T,4), valid.reshape(B,T)


def check_bbox_permute(bbox_p):
    '''
    bbox_p: [N,4], (x1,y1,x2,y2)
    '''
    x1p = torch.minimum(bbox_p[:, 0], bbox_p[:, 2]).reshape(-1,1)
    x2p = torch.maximum(bbox_p[:, 0], bbox_p[:, 2]).reshape(-1,1)
    y1p = torch.minimum(bbox_p[:, 1], bbox_p[:, 3]).reshape(-1,1)
    y2p = torch.maximum(bbox_p[:, 1], bbox_p[:, 3]).reshape(-1,1)
    bbox_p = torch.cat([x1p, y1p, x2p, y2p], axis=1)
    return bbox_p


def bbox_xyxyTopoints(bbox):
    '''
    bbox: torch.Tensor, in shape [..., 4]
    return: bbox in shape [...,4,2] with 4 points location
    p1---p2
    |     |
    p4---p3
    '''
    bbox_x1 = bbox[...,0].unsqueeze(-1)     # [...,1]
    bbox_y1 = bbox[...,1].unsqueeze(-1)
    bbox_x2 = bbox[...,2].unsqueeze(-1)
    bbox_y2 = bbox[...,3].unsqueeze(-1)

    pt1 = torch.cat([bbox_x1, bbox_y1], dim=-1).unsqueeze(-2)     # [...,1,2]
    pt2 = torch.cat([bbox_x2, bbox_y1], dim=-1).unsqueeze(-2)     # [...,1,2]
    pt3 = torch.cat([bbox_x2, bbox_y2], dim=-1).unsqueeze(-2)     # [...,1,2]
    pt4 = torch.cat([bbox_x1, bbox_y2], dim=-1).unsqueeze(-2)     # [...,1,2]

    pts = torch.cat([pt1, pt2, pt3, pt4], dim=-2)                 # [...,4,2]
    return pts


def bbox_pointsToxyxy(pts):
    '''
    pts: torch.Tensor, in shape [...,4,2]
    return: bbox in shape [...,4] for x1y1x2y2
    '''
    shape_in = list(pts.shape[:-2])
    pts = pts.reshape(-1,4,2)

    pt1 = pts[:,0,:]           # [N,2]
    pt3 = pts[:,3,:]

    x1 = pt1[:, 0].unsqueeze(-1)  # [N,1]
    y1 = pt1[:, 1].unsqueeze(-1)  
    x2 = pt3[:, 0].unsqueeze(-1)  
    y2 = pt3[:, 1].unsqueeze(-1)

    bbox = torch.cat([x1,y1,x2,y2], dim=-1)     # [N,4]
    bbox = bbox.reshape(shape_in + [4])
    return bbox


def create_square_bbox(bbox, img_h, img_w):
    '''
    bbox in [4], in torch coordinate
    '''
    x1, y1, x2, y2 = bbox
    center_x = (x1 + x2) / 2
    center_y = (y1 + y2) / 2
    h = center_x - x1
    w = center_y - y1
    r = max(h, w)

    new_x1 = max(center_x - r, 0)
    new_x2 = min(center_x + r, img_h-1)
    new_y1 = max(center_y - r, 0)
    new_y2 = min(center_y + r, img_w-1)

    new_bbox = torch.tensor([new_x1, new_y1, new_x2, new_y2])
    return new_bbox


def bbox_xyhwToxyxy(bbox_xyhw):
    '''
    bbox_xyhw in shape [..., 4]
    height and width of bbox is the full height and width
    '''
    bbox_center = bbox_xyhw[..., :2]
    bbox_hw = bbox_xyhw[..., 2:]
    bbox_hw_half = 0.5 * bbox_hw

    bbox_xyxy = torch.cat([bbox_center - bbox_hw_half, bbox_center + bbox_hw_half], dim=-1)
    return bbox_xyxy

In [5]:
import numpy as np
import torch
import torch.nn as nn
import math
from einops import rearrange


class PositionalEncoding1D(nn.Module):
    def __init__(self, channels):
        """
        :param channels: The last dimension of the tensor you want to apply pos emb to.
        """
        super(PositionalEncoding1D, self).__init__()
        self.org_channels = channels
        channels = int(np.ceil(channels / 2) * 2)
        self.channels = channels
        inv_freq = 1.0 / (10000 ** (torch.arange(0, channels, 2).float() / channels))
        self.register_buffer("inv_freq", inv_freq)
        self.cached_penc = None

    def forward(self, tensor):
        """
        :param tensor: A 3d tensor of size (batch_size, x, ch)
        :return: Positional Encoding Matrix of size (batch_size, x, ch)
        """
        if len(tensor.shape) != 3:
            raise RuntimeError("The input tensor has to be 3d!")

        if self.cached_penc is not None and self.cached_penc.shape == tensor.shape:
            return self.cached_penc

        self.cached_penc = None
        batch_size, x, orig_ch = tensor.shape
        pos_x = torch.arange(x, device=tensor.device).type(self.inv_freq.type())
        sin_inp_x = torch.einsum("i,j->ij", pos_x, self.inv_freq)
        emb_x = get_emb(sin_inp_x)
        emb = torch.zeros((x, self.channels), device=tensor.device).type(tensor.type())
        emb[:, : self.channels] = emb_x

        self.cached_penc = emb[None, :, :orig_ch].repeat(batch_size, 1, 1)
        return self.cached_penc
    

def positionalencoding1d(d_model, length):
    """
    positional encoding for 1-d sequence
    :param d_model: dimension of the model (C)
    :param length: length of positions (N)
    :return: length*d_model position matrix, shape [N, C]
    """
    if d_model % 2 != 0:
        raise ValueError("Cannot use sin/cos positional encoding with "
                         "odd dim (got dim={:d})".format(d_model))
    pe = torch.zeros(length, d_model)
    position = torch.arange(0, length).unsqueeze(1)
    div_term = torch.exp((torch.arange(0, d_model, 2, dtype=torch.float) *
                         -(math.log(10000.0) / d_model)))       # [N,C//2]
    pe[:, 0::2] = torch.sin(position.float() * div_term)
    pe[:, 1::2] = torch.cos(position.float() * div_term)
    return pe


def positionalencoding2d(d_model, height, width, type='sinusoidal'):
    """
    :param d_model: dimension of the model
    :param height: height of the positions
    :param width: width of the positions
    :return: d_model*height*width position matrix, shape [H*W, C]
    """
    if type == 'sinusoidal':
        pe = torch.zeros(d_model, height, width)
        # Each dimension use half of d_model
        d_model_origin = d_model
        d_model = int(d_model / 2)
        div_term = torch.exp(torch.arange(0., d_model, 2) *
                            -(math.log(10000.0) / d_model))
        pos_w = torch.arange(0., width).unsqueeze(1)
        pos_h = torch.arange(0., height).unsqueeze(1)
        pe[0:d_model:2, :, :] = torch.sin(pos_h * div_term).transpose(0, 1).unsqueeze(1).repeat(1, width, 1)
        pe[1:d_model:2, :, :] = torch.cos(pos_h * div_term).transpose(0, 1).unsqueeze(1).repeat(1, width, 1)
        pe[d_model::2, :, :] = torch.sin(pos_w * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, height)
        pe[d_model + 1::2, :, :] = torch.cos(pos_w * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, height)
        pe = rearrange(pe, 'c h w -> (h w) c')[:,:d_model_origin]
    elif type == 'zero':
        pe = torch.zeros(height * width, d_model)
    return pe


def positionalencoding3d(d_model, height, width, depth, type='sinusoidal'):
    """
    :param d_model: dimension of the model
    :param height: height of the positions
    :param width: width of the positions
    :param depth: depth of the positions
    :return: d_model*height*width position matrix, shape [H*W, C]
    """
    if type == 'sinusoidal':
        d_model_interv = int(np.ceil(d_model / 6) * 2)
        if d_model_interv % 2:
            d_model_interv += 1
        inv_freq = 1.0 / (10000 ** (torch.arange(0, d_model_interv, 2).float() / d_model_interv))
        pos_x = torch.arange(height).type(inv_freq.type())
        pos_y = torch.arange(width).type(inv_freq.type())
        pos_z = torch.arange(depth).type(inv_freq.type())
        sin_inp_x = torch.einsum("i,j->ij", pos_x, inv_freq)
        sin_inp_y = torch.einsum("i,j->ij", pos_y, inv_freq)
        sin_inp_z = torch.einsum("i,j->ij", pos_z, inv_freq)
        emb_x = get_emb(sin_inp_x).unsqueeze(1).unsqueeze(1)
        emb_y = get_emb(sin_inp_y).unsqueeze(1)
        emb_z = get_emb(sin_inp_z)
        emb = torch.zeros(height, width, depth, d_model_interv * 3)
        emb[:, :, :, : d_model_interv] = emb_x
        emb[:, :, :, d_model_interv : 2 * d_model_interv] = emb_y
        emb[:, :, :, 2 * d_model_interv :] = emb_z
        emb = rearrange(emb, 'h w d c -> (h w d) c')[:,:d_model]
    elif type == 'zero':
        emb = torch.zeros(height * width * depth, d_model)
    return emb


def get_emb(sin_inp):
    """
    Gets a base embedding for one dimension with sin and cos intertwined
    """
    emb = torch.stack((sin_inp.sin(), sin_inp.cos()), dim=-1)
    return torch.flatten(emb, -2, -1)


def BasicBlock_Conv2D(in_dim, out_dim):
    module = nn.Sequential(
                    nn.Conv2d(in_dim, out_dim, 3, padding=1),
                    nn.BatchNorm2d(out_dim),
                    nn.LeakyReLU(inplace=True)
                    )
    return module

def BasicBlock_MLP(dims):
    dims_ = dims[:-1]
    dims1, dims2 = dims_[:-1], dims_[1:]
    mlp = []
    for (dim1, dim2) in zip(dims1, dims2):
        mlp.append(
            nn.Sequential(
                nn.Linear(dim1, dim2),
                nn.BatchNorm1d(dim1),
                nn.LeakyReLU(inplace=True),
        ))
    mlp.append(
        nn.Sequential(
            nn.Linear(dims[-2], dims[-1]),
        ))
    mlp = nn.Sequential(*mlp)
    return mlp

In [6]:
import torch
from einops import rearrange


def generate_anchor_boxes_on_regions(image_size, 
                                     num_regions, 
                                     base_sizes=torch.tensor([[16, 16], [32, 32], [64, 64], [128, 128]], dtype=torch.float32),
                                     aspect_ratios=torch.tensor([0.5, 1, 2], dtype=torch.float32),
                                     dtype=torch.float32, 
                                     device='cpu'):
    """
    Generate a set of anchor boxes with different sizes and aspect ratios for each region of a split image.

    Arguments:
    image_size -- tuple of two integers, the height and width of the original image
    num_regions -- tuple of two integers, the number of regions in the height and width directions
    aspect_ratios -- torch.Tensor of shape [M], containing M aspect ratios for each base size
    dtype -- the data type of the output tensor
    device -- the device of the output tensor

    Returns:
    anchor_boxes -- torch.Tensor of shape [R^2*N*M,4], containing R^2*N*M anchor boxes represented as (center_h, center_w, box_h, box_w)
    """

    # Calculate the base sizes for each region
    region_size = (image_size[0] / num_regions[0], image_size[1] / num_regions[1])

    # Calculate the anchor boxes for each region
    anchor_boxes = torch.empty((0, 4), dtype=dtype, device=device)
    for i in range(num_regions[0]):
        for j in range(num_regions[1]):
            center_h = (i + 0.5) * region_size[0]
            center_w = (j + 0.5) * region_size[1]
            base_boxes = generate_anchor_boxes(base_sizes, aspect_ratios, dtype=dtype, device=device)
            base_boxes[:, 0] += center_h
            base_boxes[:, 1] += center_w
            anchor_boxes = torch.cat([anchor_boxes, base_boxes], dim=0)

    return anchor_boxes


def generate_anchor_boxes(base_sizes, aspect_ratios, dtype=torch.float32, device='cpu'):
    """
    Generate a set of anchor boxes with different sizes and aspect ratios.

    Arguments:
    base_sizes -- torch.Tensor of shape [N,2], containing N base sizes for the anchor boxes
    aspect_ratios -- torch.Tensor of shape [M], containing M aspect ratios for each base size
    dtype -- the data type of the output tensor
    device -- the device of the output tensor

    Returns:
    anchor_boxes -- torch.Tensor of shape [N*M,4], containing N*M anchor boxes represented as (center_h, center_w, box_h, box_w)
    """

    num_base_sizes = base_sizes.shape[0]
    num_aspect_ratios = aspect_ratios.shape[0]

    # Generate base anchor boxes
    base_boxes = torch.zeros((num_base_sizes * num_aspect_ratios, 4), dtype=dtype, device=device)
    for i in range(num_base_sizes):
        for j in range(num_aspect_ratios):
            w = torch.sqrt(base_sizes[i, 0] * base_sizes[i, 1] / aspect_ratios[j])
            h = aspect_ratios[j] * w
            idx = i * num_aspect_ratios + j
            base_boxes[idx] = torch.tensor([0, 0, h, w], dtype=dtype, device=device)

    return base_boxes


# def assign_labels(proposals, gt_boxes, iou_threshold=0.5):
#     """
#     Assign labels to a set of bounding box proposals based on their IoU with ground truth boxes.

#     Arguments:
#     proposals -- torch.Tensor of shape [B,T,N,4], representing the bounding box proposals for each frame in each clip
#     gt_boxes -- torch.Tensor of shape [B,T,4], representing the ground truth boxes for each frame in each clip
#     iou_threshold -- float, the IoU threshold for a proposal to be considered a positive match with a ground truth box

#     Returns:
#     labels -- torch.Tensor of shape [B,T,N], containing the assigned labels for each proposal (0 for background, 1 for object)
#     """

#     # Initialize the labels tensor with background labels
#     labels = torch.zeros_like(proposals[:, :, :, 0], dtype=torch.long, device=proposals.device)

#     # Loop over the batches and frames
#     for b in range(proposals.shape[0]):
#         for t in range(proposals.shape[1]):
#             # Calculate the IoU between each proposal and the ground truth box
#             iou = calculate_iou(proposals[b, t], gt_boxes[b, t])    # [N]

#             # Assign labels to the proposals based on their IoU with the ground truth box
#             labels[b, t] = iou > iou_threshold

#     return labels


def assign_labels(anchors, gt_boxes, iou_threshold=0.5, topk=5):
    """
    Assign labels to a set of bounding box proposals based on their IoU with ground truth boxes.

    Arguments:
    anchors -- torch.Tensor of shape [B,T,N,4], representing the bounding box proposals for each frame in each clip
    gt_boxes -- torch.Tensor of shape [B,T,4], representing the ground truth boxes for each frame in each clip
    iou_threshold -- float, the IoU threshold for a proposal to be considered a positive match with a ground truth box

    Returns:
    labels -- torch.Tensor of shape [B,T,N], containing the assigned labels for each proposal (0 for background, 1 for object)
    """
    anchors = anchors.detach()
    gt_boxes = gt_boxes.detach()

    b,t = gt_boxes.shape[:2]    #[B,T,N,4]

    # Calculate the IoU between each proposal and the ground truth box
    iou = calculate_iou(anchors.view(-1, anchors.shape[-2], anchors.shape[-1]),   # [B*T,N,4]
                        gt_boxes.view(-1, gt_boxes.shape[-1]))                    # [B*T,4] -> [B*T,N]
    iou = iou.view(anchors.shape[:-1])    # [B,T,N]

    # Assign labels to the proposals based on their IoU with the ground truth box
    labels = iou > iou_threshold

    if not labels.any():
        labels = process_labels(labels, iou, topk)

    return labels


def calculate_iou(boxes1, boxes2):
    """
    Calculate the IoU between two sets of bounding boxes.

    Arguments:
    boxes1 -- torch.Tensor of shape [...,N,4], containing N bounding boxes represented as [x1, y1, x2, y2]
    boxes2 -- torch.Tensor of shape [...,4], containing a single ground truth box represented as [x1, y1, x2, y2]

    Returns:
    iou -- torch.Tensor of shape [...,N], containing the IoU between each box and the ground truth box
    """

    # Add a new dimension to boxes2 for broadcasting
    boxes2 = boxes2.unsqueeze(-2)    # shape: [...,1,4]

    # Compute the coordinates of the top-left and bottom-right corners of the boxes
    boxes1_tl = boxes1[..., :2]
    boxes1_br = boxes1[..., 2:]
    boxes2_tl = boxes2[..., :2]
    boxes2_br = boxes2[..., 2:]

    # Compute the coordinates of the intersection rectangle
    tl = torch.max(boxes1_tl, boxes2_tl)
    br = torch.min(boxes1_br, boxes2_br)

    # Compute the width and height of the intersection rectangle
    wh = br - tl
    wh[wh < 0] = 0

    # Compute the area of the intersection and union rectangles
    intersection_area = wh[..., 0] * wh[..., 1]
    area1 = (boxes1[..., 2] - boxes1[..., 0]) * (boxes1[..., 3] - boxes1[..., 1])
    area2 = (boxes2[..., 2] - boxes2[..., 0]) * (boxes2[..., 3] - boxes2[..., 1])
    union_area = area1 + area2 - intersection_area

    # Compute the IoU between each box and the ground truth box
    iou = intersection_area / union_area

    return iou


def process_labels(labels, iou, topk=10):
    '''
    labels: in shape [B,T,N], bool
    iou: in shape [B,T,N]
    '''
    B,T,N = labels.shape

    labels = rearrange(labels, 'b t n -> (b t n)')
    iou = rearrange(iou, 'b t n -> (b t n)')

    if not labels.any():
        # no pos assigned, choose topk anchors with largest iou as positives
        _, topk_indices = torch.topk(iou, k=topk)
        labels[topk_indices] = True
    
    labels = rearrange(labels, '(b t n) -> b t n', b=B, t=T, n=N)
    return labels

In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as f
import numpy as np


class Block(nn.Module):
    def __init__(self, dim, num_heads=1, mlp_ratio=4., act_layer=nn.GELU, norm_layer=nn.LayerNorm, return_attn=False):
        super().__init__()

        self.channels = dim

        self.encode_query = nn.Conv1d(in_channels=dim, out_channels=dim, kernel_size=1, stride=1, padding=0)
        self.encode_key = nn.Conv1d(in_channels=dim, out_channels=dim, kernel_size=1, stride=1, padding=0)
        self.encode_value = nn.Conv1d(in_channels=dim, out_channels=dim, kernel_size=1, stride=1, padding=0)
        self.norm = norm_layer(dim)

        self.attn = Attention(dim, num_heads=num_heads)
        
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer)

    def with_pos_embed(self, tensor, pos):
        return tensor if pos is None else tensor + pos.to(tensor)
    
    def get_attn(self, query, key, query_embed=None, key_embed=None):
        b, c, n = query.shape

        q = self.with_pos_embed(query, query_embed)
        k = self.with_pos_embed(key, key_embed)

        q = self.norm(q.permute(0, 2, 1)).permute(0, 2, 1)
        k = self.norm(k.permute(0, 2, 1)).permute(0, 2, 1)

        q = self.encode_query(q).view(b, self.channels, -1)
        q = q.permute(0, 2, 1)
        k = self.encode_key(k).view(b, self.channels, -1)
        k = k.permute(0, 2, 1)
        return self.attn.get_attn(query=q, key=k)   # [b,n,n]
    
    def forward(self, query, key, query_embed=None, key_embed=None):
        b, c, n = query.shape

        q = self.with_pos_embed(query, query_embed)
        k = self.with_pos_embed(key, key_embed)

        q = self.norm(q.permute(0, 2, 1)).permute(0, 2, 1)
        k = self.norm(k.permute(0, 2, 1)).permute(0, 2, 1)

        v = self.encode_value(key).view(b, self.channels, -1)
        v = v.permute(0, 2, 1)

        q = self.encode_query(q).view(b, self.channels, -1)
        q = q.permute(0, 2, 1)

        k = self.encode_key(k).view(b, self.channels, -1)
        k = k.permute(0, 2, 1)

        query = query.view(b, self.channels, -1).permute(0, 2, 1)
        query = query + self.attn(query=q, key=k, value=v)

        query = query + self.mlp(self.norm2(query))
        query = query.permute(0, 2, 1).contiguous().view(b, self.channels, -1)

        return query


class Attention(nn.Module):
    def __init__(self, dim, num_heads=1):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5

    def get_attn(self, query, key):
        B, N, C = query.shape
        attn = torch.matmul(query, key.transpose(-2, -1)) #* self.scale
        attn = attn.softmax(dim=-1)
        #__import__('pdb').set_trace()
        return attn
    
    def forward(self, query, key, value):
        B, N, C = query.shape
        query = query.reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
        key = key.reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
        value = value.reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
        attn = torch.matmul(query, key.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        x = torch.matmul(attn, value).transpose(1, 2).reshape(B, N, C)
        return x


class Mlp(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)
        self._init_weights()

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

    def _init_weights(self):
        # nn.init.xavier_uniform_(self.fc1.weight)
        # nn.init.xavier_uniform_(self.fc2.weight)
        nn.init.normal_(self.fc1.weight, mean=0.0, std=1e-3)
        nn.init.normal_(self.fc2.weight, mean=0.0, std=1e-3)
        nn.init.normal_(self.fc1.bias, std=1e-4)
        nn.init.normal_(self.fc2.bias, std=1e-4)

In [8]:
import torch
import torch.nn as nn
from functools import partial
import timm.models.vision_transformer


class VisionTransformer(timm.models.vision_transformer.VisionTransformer):
    """ Vision Transformer with support for global average pooling
    """
    def __init__(self, global_pool=False, **kwargs):
        super(VisionTransformer, self).__init__(**kwargs)

        self.global_pool = global_pool
        if self.global_pool:
            norm_layer = kwargs['norm_layer']
            embed_dim = kwargs['embed_dim']
            self.fc_norm = norm_layer(embed_dim)

            del self.norm  # remove the original norm

    def forward_features(self, x):
        B = x.shape[0]
        x = self.patch_embed(x)

        cls_tokens = self.cls_token.expand(B, -1, -1)  # stole cls_tokens impl from Phil Wang, thanks
        x = torch.cat((cls_tokens, x), dim=1)
        x = x + self.pos_embed
        x = self.pos_drop(x)

        for blk in self.blocks:
            x = blk(x)
        
        # [b,h*w+1,c]
        return x


def vit_base_patch16(**kwargs):
    model = VisionTransformer(
        patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model

In [9]:
import torch.nn as nn
import torch
import torch.nn.functional as F
from einops import rearrange
import math
import torchvision

base_sizes=torch.tensor([[16, 16], [32, 32], [64, 64], [128, 128]], dtype=torch.float32)    # 4 types of size
aspect_ratios=torch.tensor([0.5, 1, 2], dtype=torch.float32)                                # 3 types of aspect ratio
n_base_sizes = base_sizes.shape[0]
n_aspect_ratios = aspect_ratios.shape[0]


def build_backbone(config):
    name, type = config.model.backbone_name, config.model.backbone_type
    if name == 'dino':
        assert type in ['vitb8', 'vitb16', 'vits8', 'vits16']
        backbone = torch.hub.load('facebookresearch/dino:main', 'dino_{}'.format(type))
        down_rate = int(type.replace('vitb', '').replace('vits', ''))
        backbone_dim = 768
        if type == 'vitb16' and config.model.bakcbone_use_mae_weight:
            mae_weight = torch.load('/vision/hwjiang/episodic-memory/VQ2D/checkpoint/mae_pretrain_vit_base.pth')['model']
            backbone.load_state_dict(mae_weight)
    elif name == 'dinov2':
        assert type in ['vits14', 'vitb14', 'vitl14', 'vitg14']
        backbone = torch.hub.load('facebookresearch/dinov2', 'dinov2_{}'.format(type))
        down_rate = 14
        if type == 'vitb14':
            backbone_dim = 768
        elif type == 'vits14':
            backbone_dim = 384
    elif name == 'mae':
        backbone = vit_base_patch16()
        cpt = torch.load('/vision/hwjiang/download/model_weight/mae_pretrain_vit_base.pth')['model']
        backbone.load_state_dict(cpt, strict=False)
        down_rate = 16
        backbone_dim = 768
    return backbone, down_rate, backbone_dim


class ClipMatcher(nn.Module):
    def __init__(self, config) -> None:
        super().__init__()
        self.config = config

        self.backbone, self.down_rate, self.backbone_dim = build_backbone(config)
        self.backbone_name = config.model.backbone_name

        self.query_size = config.dataset.query_size
        self.clip_size_fine = config.dataset.clip_size_fine
        self.clip_size_coarse = config.dataset.clip_size_coarse

        self.query_feat_size = self.query_size // self.down_rate
        self.clip_feat_size_fine = self.clip_size_fine // self.down_rate
        self.clip_feat_size_coarse = self.clip_size_coarse // self.down_rate

        self.type_transformer = config.model.type_transformer
        assert self.type_transformer in ['local', 'global']
        self.window_transformer = config.model.window_transformer
        self.resolution_transformer = config.model.resolution_transformer
        self.resolution_anchor_feat = config.model.resolution_anchor_feat

        self.anchors_xyhw = generate_anchor_boxes_on_regions(image_size=[self.clip_size_coarse, self.clip_size_coarse],
                                                        num_regions=[self.resolution_anchor_feat, self.resolution_anchor_feat])
        self.anchors_xyhw = self.anchors_xyhw / self.clip_size_coarse   # [R^2*N*M,4], value range [0,1], represented by [c_x,c_y,h,w] in torch axis
        self.anchors_xyxy = bbox_xyhwToxyxy(self.anchors_xyhw)

        # query down heads
        self.query_down_heads = []
        for _ in range(int(math.log2(self.query_feat_size))):
            self.query_down_heads.append(
                nn.Sequential(
                    nn.Conv2d(self.backbone_dim, self.backbone_dim, 3, stride=2, padding=1),
                    nn.BatchNorm2d(self.backbone_dim),
                    nn.LeakyReLU(inplace=True),
                )
            )
        self.query_down_heads = nn.ModuleList(self.query_down_heads)

        # feature reduce layer
        self.reduce = nn.Sequential(
            nn.Conv2d(self.backbone_dim, 256, 3, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(inplace=True),
            nn.Conv2d(256, 256, 3, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(inplace=True),
        )
        
        # clip-query correspondence
        self.CQ_corr_transformer = []
        for _ in range(1):
            self.CQ_corr_transformer.append(
                torch.nn.TransformerDecoderLayer(
                    d_model=256,
                    nhead=4,
                    dim_feedforward=1024,
                    dropout=0.0,
                    activation='gelu',
                    batch_first=True
                )
            )
        self.CQ_corr_transformer = nn.ModuleList(self.CQ_corr_transformer)

        # feature downsample layers
        self.num_head_layers, self.down_heads = int(math.log2(self.clip_feat_size_coarse)), []
        for i in range(self.num_head_layers-1):
            self.in_channel = 256 if i != 0 else self.backbone_dim
            self.down_heads.append(
                nn.Sequential(
                nn.Conv2d(256, 256, 3, stride=2, padding=1),
                nn.BatchNorm2d(256),
                nn.LeakyReLU(inplace=True),
            ))
        self.down_heads = nn.ModuleList(self.down_heads)

        # spatial-temporal PE
        self.pe_3d = positionalencoding3d(d_model=256, 
                                          height=self.resolution_transformer, 
                                          width=self.resolution_transformer, 
                                          depth=config.dataset.clip_num_frames,
                                          type=config.model.pe_transformer).unsqueeze(0)
        self.pe_3d = nn.parameter.Parameter(self.pe_3d)

        # spatial-temporal transformer layer
        self.feat_corr_transformer = []
        self.num_transformer = config.model.num_transformer
        for _ in range(self.num_transformer):
            self.feat_corr_transformer.append(
                    torch.nn.TransformerEncoderLayer(
                        d_model=256, 
                        nhead=8,
                        dim_feedforward=2048,
                        dropout=0.0,
                        activation='gelu',
                        batch_first=True
                ))
        self.feat_corr_transformer = nn.ModuleList(self.feat_corr_transformer)
        self.temporal_mask = None

        # output head
        self.head = Head(in_dim=256, in_res=self.resolution_transformer, out_res=self.resolution_anchor_feat)

    def init_weights_linear(self, m):
        if type(m) == nn.Linear:
            #nn.init.xavier_uniform_(m.weight)
            nn.init.normal_(m.weight, mean=0.0, std=1e-6)
            nn.init.normal_(m.bias, mean=0.0, std=1e-6)

    def extract_feature(self, x, return_h_w=False):
        if self.backbone_name == 'dino':
            b, _, h_origin, w_origin = x.shape
            out = self.backbone.get_intermediate_layers(x, n=1)[0]
            out = out[:, 1:, :]  # we discard the [CLS] token   # [b, h*w, c]
            h, w = int(h_origin / self.backbone.patch_embed.patch_size), int(w_origin / self.backbone.patch_embed.patch_size)
            dim = out.shape[-1]
            out = out.reshape(b, h, w, dim).permute(0,3,1,2)
            if return_h_w:
                return out, h, w
            return out
        elif self.backbone_name == 'dinov2':
            b, _, h_origin, w_origin = x.shape
            out = self.backbone.get_intermediate_layers(x, n=1)[0]
            h, w = int(h_origin / self.backbone.patch_embed.patch_size[0]), int(w_origin / self.backbone.patch_embed.patch_size[1])
            dim = out.shape[-1]
            out = out.reshape(b, h, w, dim).permute(0,3,1,2)
            if return_h_w:
                return out, h, w
            return out
        elif self.backbone_name == 'mae':
            b, _, h_origin, w_origin = x.shape
            out = self.backbone.forward_features(x) # [b,1+h*w,c]
            h, w = int(h_origin / self.backbone.patch_embed.patch_size[0]), int(w_origin / self.backbone.patch_embed.patch_size[1])
            dim = out.shape[-1]
            out = out[:,1:].reshape(b, h, w, dim).permute(0,3,1,2)  # [b,c,h,w]
            out = F.interpolate(out, size=(16,16), mode='bilinear')
            if return_h_w:
                return out, h, w
            return out
        
        
    def replicate_for_hnm(self, query_feat, clip_feat):
        '''
        query_feat in shape [b,c,h,w]
        clip_feat in shape [b*t,c,h,w]
        '''
        b = query_feat.shape[0]
        bt = clip_feat.shape[0]
        t = bt // b
        
        clip_feat = rearrange(clip_feat, '(b t) c h w -> b t c h w', b=b, t=t)

        new_clip_feat, new_query_feat = [], []
        for i in range(b):
            for j in range(b):
                new_clip_feat.append(clip_feat[i])
                new_query_feat.append(query_feat[j])

        new_clip_feat = torch.stack(new_clip_feat)      # [b^2,t,c,h,w]
        new_query_feat = torch.stack(new_query_feat)    # [b^2,c,h,w]

        new_clip_feat = rearrange(new_clip_feat, 'b t c h w -> (b t) c h w')
        return new_clip_feat, new_query_feat


    def forward(self, clip, query, query_frame_bbox=None, training=False, fix_backbone=True):
        '''
        clip: in shape [b,t,c,h,w]
        query: in shape [b,c,h2,w2]
        '''
        b, t = clip.shape[:2]
        clip = rearrange(clip, 'b t c h w -> (b t) c h w')

        # get backbone features
        if fix_backbone:
            with torch.no_grad():
                query_feat = self.extract_feature(query)
                clip_feat = self.extract_feature(clip)
        else:
            query_feat = self.extract_feature(query)        # [b c h w]
            clip_feat = self.extract_feature(clip)          # (b t) c h w
        h, w = clip_feat.shape[-2:]

        if torch.is_tensor(query_frame_bbox) and self.config.train.use_query_roi:
            idx_tensor = torch.arange(b, device=clip.device).float().view(-1, 1)
            query_frame_bbox = recover_bbox(query_frame_bbox, h, w)
            roi_bbox = torch.cat([idx_tensor, query_frame_bbox], dim=1)
            query_feat = torchvision.ops.roi_align(query_feat, roi_bbox, (h,w))

        # reduce channel size
        all_feat = torch.cat([query_feat, clip_feat], dim=0)
        all_feat = self.reduce(all_feat)
        query_feat, clip_feat = all_feat.split([b, b*t], dim=0)

        if self.config.train.use_hnm and training:
            clip_feat, query_feat = self.replicate_for_hnm(query_feat, clip_feat)   # b -> b^2
            b = b**2
        
        # find spatial correspondence between query-frame
        query_feat = rearrange(query_feat.unsqueeze(1).repeat(1,t,1,1,1), 'b t c h w -> (b t) (h w) c')  # [b*t,n,c]
        clip_feat = rearrange(clip_feat, 'b c h w -> b (h w) c')                                         # [b*t,n,c]
        for layer in self.CQ_corr_transformer:
            clip_feat = layer(clip_feat, query_feat)                                                     # [b*t,n,c]
        clip_feat = rearrange(clip_feat, 'b (h w) c -> b c h w', h=h, w=w)                               # [b*t,c,h,w]

        # down-size features and find spatial-temporal correspondence
        for head in self.down_heads:
            clip_feat = head(clip_feat)
            if list(clip_feat.shape[-2:]) == [self.resolution_transformer]*2:
                clip_feat = rearrange(clip_feat, '(b t) c h w -> b (t h w) c', b=b) + self.pe_3d
                mask = self.get_mask(clip_feat, t)
                for layer in self.feat_corr_transformer:
                    clip_feat = layer(clip_feat, src_mask=mask)
                clip_feat = rearrange(clip_feat, 'b (t h w) c -> (b t) c h w', b=b, t=t, h=self.resolution_transformer, w=self.resolution_transformer)
                break
        
        # refine anchors
        anchors_xyhw = self.anchors_xyhw.to(clip_feat.device)                   # [N,4]
        anchors_xyxy = self.anchors_xyxy.to(clip_feat.device)                   # [N,4]
        anchors_xyhw = anchors_xyhw.reshape(1,1,-1,4)                           # [1,1,N,4]
        anchors_xyxy = anchors_xyxy.reshape(1,1,-1,4)                           # [1,1,N,4]
        
        bbox_refine, prob = self.head(clip_feat)                                # [b*t,N=h*w*n*m,c]
        bbox_refine = rearrange(bbox_refine, '(b t) N c -> b t N c', b=b, t=t)  # [b,t,N,4], in xyhw frormulation
        prob = rearrange(prob, '(b t) N c -> b t N c', b=b, t=t)                # [b,t,N,1]
        bbox_refine += anchors_xyhw                                             # [b,t,N,4]

        center, hw = bbox_refine.split([2,2], dim=-1)                           # represented by [c_x, c_y, h, w]
        hw = 0.5 * hw                                                           # anchor's hw is defined as real hw
        bbox = torch.cat([center - hw, center + hw], dim=-1)                    # [b,t,N,4]

        result = {
            'center': center,           # [b,t,N,2]
            'hw': hw,                   # [b,t,N,2]
            'bbox': bbox,               # [b,t,N,4]
            'prob': prob.squeeze(-1),   # [b,t,N]
            'anchor': anchors_xyxy      # [1,1,N,4]
        }
        return result
    

    def get_mask(self, src, t):
        if not torch.is_tensor(self.temporal_mask):
            hw = src.shape[1] // t
            thw = src.shape[1]
            mask = torch.ones(thw, thw).float() * float('-inf')

            window_size = self.window_transformer // 2

            for i in range(t):
                min_idx = max(0, (i-window_size)*hw)
                max_idx = min(thw, (i+window_size+1)*hw)
                mask[i*hw: (i+1)*hw, min_idx: max_idx] = 0.0
            mask = mask.to(src.device)
            self.temporal_mask = mask
        return self.temporal_mask
    


class Head(nn.Module):
    def __init__(self, in_dim=256, in_res=8, out_res=16, n=n_base_sizes, m=n_aspect_ratios):
        super(Head, self).__init__()

        self.in_dim = in_dim
        self.n = n
        self.m = m
        self.num_up_layers = int(math.log2(out_res // in_res))
        self.num_layers = 3
        
        if self.num_up_layers > 0:
            self.up_convs = []
            for _ in range(self.num_up_layers):
                self.up_convs.append(torch.nn.ConvTranspose2d(in_dim, in_dim, kernel_size=4, stride=2, padding=1))
            self.up_convs = nn.Sequential(*self.up_convs)

        self.in_conv = BasicBlock_Conv2D(in_dim=in_dim, out_dim=2*in_dim)

        self.regression_conv = []
        for i in range(self.num_layers):
            self.regression_conv.append(BasicBlock_Conv2D(in_dim, in_dim))
        self.regression_conv = nn.Sequential(*self.regression_conv)

        self.classification_conv = []
        for i in range(self.num_layers):
            self.classification_conv.append(BasicBlock_Conv2D(in_dim, in_dim))
        self.classification_conv = nn.Sequential(*self.classification_conv)

        self.droupout_feat = torch.nn.Dropout(p=0.2)
        self.droupout_cls = torch.nn.Dropout(p=0.2)

        self.regression_head = nn.Conv2d(in_dim, n * m * 4, kernel_size=3, padding=1)
        self.classification_head = nn.Conv2d(in_dim, n * m * 1, kernel_size=3, padding=1)

        self.regression_head.apply(self.init_weights_conv)
        self.classification_head.apply(self.init_weights_conv)

    def init_weights_conv(self, m):
        if type(m) == nn.Conv2d:
            nn.init.normal_(m.weight, mean=0.0, std=1e-6)
            nn.init.normal_(m.bias, mean=0.0, std=1e-6)

    def forward(self, x):
        '''
        x in shape [B,c,h=8,w=8]
        '''
        if self.num_up_layers > 0:
            x = self.up_convs(x)     # [B,c,h=16,w=16]

        B, c, h, w = x.shape

        feat_reg, feat_cls = self.in_conv(x).split([c, c], dim=1)   # both [B,c,h,w]
        # dpout pos 1, seems better
        feat_reg = self.droupout_feat(feat_reg)
        feat_cls = self.droupout_cls(feat_cls)

        feat_reg = self.regression_conv(feat_reg)        # [B,n*m*4,h,w]
        feat_cls = self.classification_conv(feat_cls)    # [B,n*m*1,h,w]

        # dpout pos 2

        out_reg = self.regression_head(feat_reg)
        out_cls = self.classification_head(feat_cls)

        out_reg = rearrange(out_reg, 'B (n m c) h w -> B (h w n m) c', h=h, w=w, n=self.n, m=self.m, c=4)
        out_cls = rearrange(out_cls, 'B (n m c) h w -> B (h w n m) c', h=h, w=w, n=self.n, m=self.m, c=1)

        return out_reg, out_cls

In [10]:
import yaml
import os
import numpy as np
from easydict import EasyDict as edict

config = edict()

# experiment config
config.exp_name = 'vq2d'
config.exp_group = 'baseline'
config.output_dir = './output/'
config.log_dir = './log'
config.workers = 8
config.print_freq = 100
config.vis_freq = 300
config.eval_vis_freq = 20
config.seed = 42
config.inference_cache_path = ''
config.debug = False

# dataset config
config.dataset = edict()
config.dataset.name = 'ego4d_vq2d'
config.dataset.name_val = 'ego4d_vq2d'
config.dataset.query_size = 448
config.dataset.clip_size_fine = 448
config.dataset.clip_size_coarse = 448
config.dataset.clip_num_frames = 30
config.dataset.clip_num_frames_val = 30
config.dataset.clip_sampling = 'rand'
config.dataset.clip_reader = 'decord_balance'
config.dataset.clip_reader_val = 'decord_balance'
config.dataset.frame_interval = 5
config.dataset.query_padding = False
config.dataset.query_square = False
config.dataset.padding_value = 'zero'

# model config
config.model = edict()
config.model.backbone_name = 'dinov2'
config.model.backbone_type = 'vits14'
config.model.bakcbone_use_mae_weight = False
config.model.fix_backbone = True
config.model.num_transformer = 3
config.model.type_transformer = 'global'
config.model.resolution_transformer = 8
config.model.resolution_anchor_feat = 16
config.model.pe_transformer = 'sinusoidal'
config.model.window_transformer = 10
config.model.positive_threshold = 0.2
config.model.positive_topk = 5
config.model.cpt_path = '/kaggle/input/vqloc/pytorch/default/1/cpt_best_prob.pth.tar'

config.train = edict()
config.train.resume = False
config.train.batch_size = 4
config.train.total_iteration = 50000
config.train.lr = 0.001
config.train.weight_decay = 0.0001
config.train.schedular_warmup_iter = 1000
config.train.schedualr_milestones = [15000, 30000, 45000]
config.train.schedular_gamma = 0.3
config.train.grad_max = 20.0
config.train.accumulation_step = 1
config.train.aug_clip = True
config.train.aug_query = True
config.train.aug_clip_iter = 10000
config.train.aug_brightness = 0.2
config.train.aug_contrast = 0.2
config.train.aug_saturation = 0.2
config.train.aug_crop_scale = 0.8
config.train.aug_crop_ratio_min = 0.8
config.train.aug_crop_ratio_max = 1.2
config.train.aug_affine_degree = 90
config.train.aug_affine_translate = 0.1
config.train.aug_affine_scale_min = 0.9
config.train.aug_affine_scale_max = 1.1
config.train.aug_affine_shear_min = -15.0
config.train.aug_affine_shear_max = 15.0
config.train.aug_prob_color = 0.2
config.train.aug_prob_flip = 0.2
config.train.aug_prob_crop = 0.2
config.train.aug_prob_affine = 0.2
config.train.use_hnm = False
config.train.use_query_roi = False
config.train.use_hnm = False

In [11]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [12]:
model = ClipMatcher(config).to(device)

Downloading: "https://github.com/facebookresearch/dinov2/zipball/main" to /root/.cache/torch/hub/main.zip
Downloading: "https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_pretrain.pth" to /root/.cache/torch/hub/checkpoints/dinov2_vits14_pretrain.pth
100%|██████████| 84.2M/84.2M [00:00<00:00, 380MB/s]


In [13]:
print('Model with {} parameters'.format(sum(p.numel() for p in model.parameters())))

Model with 43938364 parameters


In [14]:
from torchinfo import summary

In [15]:
summary(model)

Layer (type:depth-idx)                                       Param #
ClipMatcher                                                  491,520
├─DinoVisionTransformer: 1-1                                 526,848
│    └─PatchEmbed: 2-1                                       --
│    │    └─Conv2d: 3-1                                      226,176
│    │    └─Identity: 3-2                                    --
│    └─ModuleList: 2-2                                       --
│    │    └─NestedTensorBlock: 3-3                           1,775,232
│    │    └─NestedTensorBlock: 3-4                           1,775,232
│    │    └─NestedTensorBlock: 3-5                           1,775,232
│    │    └─NestedTensorBlock: 3-6                           1,775,232
│    │    └─NestedTensorBlock: 3-7                           1,775,232
│    │    └─NestedTensorBlock: 3-8                           1,775,232
│    │    └─NestedTensorBlock: 3-9                           1,775,232
│    │    └─NestedTensorBlock: 3-10

In [16]:
checkpoint = torch.load(config.model.cpt_path, map_location='cpu')
model.load_state_dict(checkpoint["state_dict"], strict=True)
model.eval()

RuntimeError: Error(s) in loading state_dict for ClipMatcher:
	size mismatch for backbone.cls_token: copying a param with shape torch.Size([1, 1, 768]) from checkpoint, the shape in current model is torch.Size([1, 1, 384]).
	size mismatch for backbone.pos_embed: copying a param with shape torch.Size([1, 1370, 768]) from checkpoint, the shape in current model is torch.Size([1, 1370, 384]).
	size mismatch for backbone.mask_token: copying a param with shape torch.Size([1, 768]) from checkpoint, the shape in current model is torch.Size([1, 384]).
	size mismatch for backbone.patch_embed.proj.weight: copying a param with shape torch.Size([768, 3, 14, 14]) from checkpoint, the shape in current model is torch.Size([384, 3, 14, 14]).
	size mismatch for backbone.patch_embed.proj.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for backbone.blocks.0.norm1.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for backbone.blocks.0.norm1.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for backbone.blocks.0.attn.qkv.weight: copying a param with shape torch.Size([2304, 768]) from checkpoint, the shape in current model is torch.Size([1152, 384]).
	size mismatch for backbone.blocks.0.attn.qkv.bias: copying a param with shape torch.Size([2304]) from checkpoint, the shape in current model is torch.Size([1152]).
	size mismatch for backbone.blocks.0.attn.proj.weight: copying a param with shape torch.Size([768, 768]) from checkpoint, the shape in current model is torch.Size([384, 384]).
	size mismatch for backbone.blocks.0.attn.proj.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for backbone.blocks.0.ls1.gamma: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for backbone.blocks.0.norm2.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for backbone.blocks.0.norm2.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for backbone.blocks.0.mlp.fc1.weight: copying a param with shape torch.Size([3072, 768]) from checkpoint, the shape in current model is torch.Size([1536, 384]).
	size mismatch for backbone.blocks.0.mlp.fc1.bias: copying a param with shape torch.Size([3072]) from checkpoint, the shape in current model is torch.Size([1536]).
	size mismatch for backbone.blocks.0.mlp.fc2.weight: copying a param with shape torch.Size([768, 3072]) from checkpoint, the shape in current model is torch.Size([384, 1536]).
	size mismatch for backbone.blocks.0.mlp.fc2.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for backbone.blocks.0.ls2.gamma: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for backbone.blocks.1.norm1.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for backbone.blocks.1.norm1.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for backbone.blocks.1.attn.qkv.weight: copying a param with shape torch.Size([2304, 768]) from checkpoint, the shape in current model is torch.Size([1152, 384]).
	size mismatch for backbone.blocks.1.attn.qkv.bias: copying a param with shape torch.Size([2304]) from checkpoint, the shape in current model is torch.Size([1152]).
	size mismatch for backbone.blocks.1.attn.proj.weight: copying a param with shape torch.Size([768, 768]) from checkpoint, the shape in current model is torch.Size([384, 384]).
	size mismatch for backbone.blocks.1.attn.proj.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for backbone.blocks.1.ls1.gamma: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for backbone.blocks.1.norm2.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for backbone.blocks.1.norm2.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for backbone.blocks.1.mlp.fc1.weight: copying a param with shape torch.Size([3072, 768]) from checkpoint, the shape in current model is torch.Size([1536, 384]).
	size mismatch for backbone.blocks.1.mlp.fc1.bias: copying a param with shape torch.Size([3072]) from checkpoint, the shape in current model is torch.Size([1536]).
	size mismatch for backbone.blocks.1.mlp.fc2.weight: copying a param with shape torch.Size([768, 3072]) from checkpoint, the shape in current model is torch.Size([384, 1536]).
	size mismatch for backbone.blocks.1.mlp.fc2.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for backbone.blocks.1.ls2.gamma: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for backbone.blocks.2.norm1.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for backbone.blocks.2.norm1.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for backbone.blocks.2.attn.qkv.weight: copying a param with shape torch.Size([2304, 768]) from checkpoint, the shape in current model is torch.Size([1152, 384]).
	size mismatch for backbone.blocks.2.attn.qkv.bias: copying a param with shape torch.Size([2304]) from checkpoint, the shape in current model is torch.Size([1152]).
	size mismatch for backbone.blocks.2.attn.proj.weight: copying a param with shape torch.Size([768, 768]) from checkpoint, the shape in current model is torch.Size([384, 384]).
	size mismatch for backbone.blocks.2.attn.proj.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for backbone.blocks.2.ls1.gamma: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for backbone.blocks.2.norm2.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for backbone.blocks.2.norm2.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for backbone.blocks.2.mlp.fc1.weight: copying a param with shape torch.Size([3072, 768]) from checkpoint, the shape in current model is torch.Size([1536, 384]).
	size mismatch for backbone.blocks.2.mlp.fc1.bias: copying a param with shape torch.Size([3072]) from checkpoint, the shape in current model is torch.Size([1536]).
	size mismatch for backbone.blocks.2.mlp.fc2.weight: copying a param with shape torch.Size([768, 3072]) from checkpoint, the shape in current model is torch.Size([384, 1536]).
	size mismatch for backbone.blocks.2.mlp.fc2.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for backbone.blocks.2.ls2.gamma: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for backbone.blocks.3.norm1.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for backbone.blocks.3.norm1.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for backbone.blocks.3.attn.qkv.weight: copying a param with shape torch.Size([2304, 768]) from checkpoint, the shape in current model is torch.Size([1152, 384]).
	size mismatch for backbone.blocks.3.attn.qkv.bias: copying a param with shape torch.Size([2304]) from checkpoint, the shape in current model is torch.Size([1152]).
	size mismatch for backbone.blocks.3.attn.proj.weight: copying a param with shape torch.Size([768, 768]) from checkpoint, the shape in current model is torch.Size([384, 384]).
	size mismatch for backbone.blocks.3.attn.proj.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for backbone.blocks.3.ls1.gamma: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for backbone.blocks.3.norm2.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for backbone.blocks.3.norm2.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for backbone.blocks.3.mlp.fc1.weight: copying a param with shape torch.Size([3072, 768]) from checkpoint, the shape in current model is torch.Size([1536, 384]).
	size mismatch for backbone.blocks.3.mlp.fc1.bias: copying a param with shape torch.Size([3072]) from checkpoint, the shape in current model is torch.Size([1536]).
	size mismatch for backbone.blocks.3.mlp.fc2.weight: copying a param with shape torch.Size([768, 3072]) from checkpoint, the shape in current model is torch.Size([384, 1536]).
	size mismatch for backbone.blocks.3.mlp.fc2.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for backbone.blocks.3.ls2.gamma: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for backbone.blocks.4.norm1.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for backbone.blocks.4.norm1.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for backbone.blocks.4.attn.qkv.weight: copying a param with shape torch.Size([2304, 768]) from checkpoint, the shape in current model is torch.Size([1152, 384]).
	size mismatch for backbone.blocks.4.attn.qkv.bias: copying a param with shape torch.Size([2304]) from checkpoint, the shape in current model is torch.Size([1152]).
	size mismatch for backbone.blocks.4.attn.proj.weight: copying a param with shape torch.Size([768, 768]) from checkpoint, the shape in current model is torch.Size([384, 384]).
	size mismatch for backbone.blocks.4.attn.proj.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for backbone.blocks.4.ls1.gamma: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for backbone.blocks.4.norm2.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for backbone.blocks.4.norm2.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for backbone.blocks.4.mlp.fc1.weight: copying a param with shape torch.Size([3072, 768]) from checkpoint, the shape in current model is torch.Size([1536, 384]).
	size mismatch for backbone.blocks.4.mlp.fc1.bias: copying a param with shape torch.Size([3072]) from checkpoint, the shape in current model is torch.Size([1536]).
	size mismatch for backbone.blocks.4.mlp.fc2.weight: copying a param with shape torch.Size([768, 3072]) from checkpoint, the shape in current model is torch.Size([384, 1536]).
	size mismatch for backbone.blocks.4.mlp.fc2.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for backbone.blocks.4.ls2.gamma: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for backbone.blocks.5.norm1.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for backbone.blocks.5.norm1.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for backbone.blocks.5.attn.qkv.weight: copying a param with shape torch.Size([2304, 768]) from checkpoint, the shape in current model is torch.Size([1152, 384]).
	size mismatch for backbone.blocks.5.attn.qkv.bias: copying a param with shape torch.Size([2304]) from checkpoint, the shape in current model is torch.Size([1152]).
	size mismatch for backbone.blocks.5.attn.proj.weight: copying a param with shape torch.Size([768, 768]) from checkpoint, the shape in current model is torch.Size([384, 384]).
	size mismatch for backbone.blocks.5.attn.proj.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for backbone.blocks.5.ls1.gamma: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for backbone.blocks.5.norm2.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for backbone.blocks.5.norm2.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for backbone.blocks.5.mlp.fc1.weight: copying a param with shape torch.Size([3072, 768]) from checkpoint, the shape in current model is torch.Size([1536, 384]).
	size mismatch for backbone.blocks.5.mlp.fc1.bias: copying a param with shape torch.Size([3072]) from checkpoint, the shape in current model is torch.Size([1536]).
	size mismatch for backbone.blocks.5.mlp.fc2.weight: copying a param with shape torch.Size([768, 3072]) from checkpoint, the shape in current model is torch.Size([384, 1536]).
	size mismatch for backbone.blocks.5.mlp.fc2.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for backbone.blocks.5.ls2.gamma: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for backbone.blocks.6.norm1.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for backbone.blocks.6.norm1.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for backbone.blocks.6.attn.qkv.weight: copying a param with shape torch.Size([2304, 768]) from checkpoint, the shape in current model is torch.Size([1152, 384]).
	size mismatch for backbone.blocks.6.attn.qkv.bias: copying a param with shape torch.Size([2304]) from checkpoint, the shape in current model is torch.Size([1152]).
	size mismatch for backbone.blocks.6.attn.proj.weight: copying a param with shape torch.Size([768, 768]) from checkpoint, the shape in current model is torch.Size([384, 384]).
	size mismatch for backbone.blocks.6.attn.proj.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for backbone.blocks.6.ls1.gamma: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for backbone.blocks.6.norm2.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for backbone.blocks.6.norm2.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for backbone.blocks.6.mlp.fc1.weight: copying a param with shape torch.Size([3072, 768]) from checkpoint, the shape in current model is torch.Size([1536, 384]).
	size mismatch for backbone.blocks.6.mlp.fc1.bias: copying a param with shape torch.Size([3072]) from checkpoint, the shape in current model is torch.Size([1536]).
	size mismatch for backbone.blocks.6.mlp.fc2.weight: copying a param with shape torch.Size([768, 3072]) from checkpoint, the shape in current model is torch.Size([384, 1536]).
	size mismatch for backbone.blocks.6.mlp.fc2.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for backbone.blocks.6.ls2.gamma: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for backbone.blocks.7.norm1.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for backbone.blocks.7.norm1.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for backbone.blocks.7.attn.qkv.weight: copying a param with shape torch.Size([2304, 768]) from checkpoint, the shape in current model is torch.Size([1152, 384]).
	size mismatch for backbone.blocks.7.attn.qkv.bias: copying a param with shape torch.Size([2304]) from checkpoint, the shape in current model is torch.Size([1152]).
	size mismatch for backbone.blocks.7.attn.proj.weight: copying a param with shape torch.Size([768, 768]) from checkpoint, the shape in current model is torch.Size([384, 384]).
	size mismatch for backbone.blocks.7.attn.proj.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for backbone.blocks.7.ls1.gamma: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for backbone.blocks.7.norm2.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for backbone.blocks.7.norm2.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for backbone.blocks.7.mlp.fc1.weight: copying a param with shape torch.Size([3072, 768]) from checkpoint, the shape in current model is torch.Size([1536, 384]).
	size mismatch for backbone.blocks.7.mlp.fc1.bias: copying a param with shape torch.Size([3072]) from checkpoint, the shape in current model is torch.Size([1536]).
	size mismatch for backbone.blocks.7.mlp.fc2.weight: copying a param with shape torch.Size([768, 3072]) from checkpoint, the shape in current model is torch.Size([384, 1536]).
	size mismatch for backbone.blocks.7.mlp.fc2.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for backbone.blocks.7.ls2.gamma: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for backbone.blocks.8.norm1.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for backbone.blocks.8.norm1.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for backbone.blocks.8.attn.qkv.weight: copying a param with shape torch.Size([2304, 768]) from checkpoint, the shape in current model is torch.Size([1152, 384]).
	size mismatch for backbone.blocks.8.attn.qkv.bias: copying a param with shape torch.Size([2304]) from checkpoint, the shape in current model is torch.Size([1152]).
	size mismatch for backbone.blocks.8.attn.proj.weight: copying a param with shape torch.Size([768, 768]) from checkpoint, the shape in current model is torch.Size([384, 384]).
	size mismatch for backbone.blocks.8.attn.proj.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for backbone.blocks.8.ls1.gamma: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for backbone.blocks.8.norm2.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for backbone.blocks.8.norm2.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for backbone.blocks.8.mlp.fc1.weight: copying a param with shape torch.Size([3072, 768]) from checkpoint, the shape in current model is torch.Size([1536, 384]).
	size mismatch for backbone.blocks.8.mlp.fc1.bias: copying a param with shape torch.Size([3072]) from checkpoint, the shape in current model is torch.Size([1536]).
	size mismatch for backbone.blocks.8.mlp.fc2.weight: copying a param with shape torch.Size([768, 3072]) from checkpoint, the shape in current model is torch.Size([384, 1536]).
	size mismatch for backbone.blocks.8.mlp.fc2.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for backbone.blocks.8.ls2.gamma: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for backbone.blocks.9.norm1.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for backbone.blocks.9.norm1.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for backbone.blocks.9.attn.qkv.weight: copying a param with shape torch.Size([2304, 768]) from checkpoint, the shape in current model is torch.Size([1152, 384]).
	size mismatch for backbone.blocks.9.attn.qkv.bias: copying a param with shape torch.Size([2304]) from checkpoint, the shape in current model is torch.Size([1152]).
	size mismatch for backbone.blocks.9.attn.proj.weight: copying a param with shape torch.Size([768, 768]) from checkpoint, the shape in current model is torch.Size([384, 384]).
	size mismatch for backbone.blocks.9.attn.proj.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for backbone.blocks.9.ls1.gamma: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for backbone.blocks.9.norm2.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for backbone.blocks.9.norm2.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for backbone.blocks.9.mlp.fc1.weight: copying a param with shape torch.Size([3072, 768]) from checkpoint, the shape in current model is torch.Size([1536, 384]).
	size mismatch for backbone.blocks.9.mlp.fc1.bias: copying a param with shape torch.Size([3072]) from checkpoint, the shape in current model is torch.Size([1536]).
	size mismatch for backbone.blocks.9.mlp.fc2.weight: copying a param with shape torch.Size([768, 3072]) from checkpoint, the shape in current model is torch.Size([384, 1536]).
	size mismatch for backbone.blocks.9.mlp.fc2.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for backbone.blocks.9.ls2.gamma: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for backbone.blocks.10.norm1.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for backbone.blocks.10.norm1.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for backbone.blocks.10.attn.qkv.weight: copying a param with shape torch.Size([2304, 768]) from checkpoint, the shape in current model is torch.Size([1152, 384]).
	size mismatch for backbone.blocks.10.attn.qkv.bias: copying a param with shape torch.Size([2304]) from checkpoint, the shape in current model is torch.Size([1152]).
	size mismatch for backbone.blocks.10.attn.proj.weight: copying a param with shape torch.Size([768, 768]) from checkpoint, the shape in current model is torch.Size([384, 384]).
	size mismatch for backbone.blocks.10.attn.proj.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for backbone.blocks.10.ls1.gamma: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for backbone.blocks.10.norm2.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for backbone.blocks.10.norm2.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for backbone.blocks.10.mlp.fc1.weight: copying a param with shape torch.Size([3072, 768]) from checkpoint, the shape in current model is torch.Size([1536, 384]).
	size mismatch for backbone.blocks.10.mlp.fc1.bias: copying a param with shape torch.Size([3072]) from checkpoint, the shape in current model is torch.Size([1536]).
	size mismatch for backbone.blocks.10.mlp.fc2.weight: copying a param with shape torch.Size([768, 3072]) from checkpoint, the shape in current model is torch.Size([384, 1536]).
	size mismatch for backbone.blocks.10.mlp.fc2.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for backbone.blocks.10.ls2.gamma: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for backbone.blocks.11.norm1.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for backbone.blocks.11.norm1.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for backbone.blocks.11.attn.qkv.weight: copying a param with shape torch.Size([2304, 768]) from checkpoint, the shape in current model is torch.Size([1152, 384]).
	size mismatch for backbone.blocks.11.attn.qkv.bias: copying a param with shape torch.Size([2304]) from checkpoint, the shape in current model is torch.Size([1152]).
	size mismatch for backbone.blocks.11.attn.proj.weight: copying a param with shape torch.Size([768, 768]) from checkpoint, the shape in current model is torch.Size([384, 384]).
	size mismatch for backbone.blocks.11.attn.proj.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for backbone.blocks.11.ls1.gamma: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for backbone.blocks.11.norm2.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for backbone.blocks.11.norm2.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for backbone.blocks.11.mlp.fc1.weight: copying a param with shape torch.Size([3072, 768]) from checkpoint, the shape in current model is torch.Size([1536, 384]).
	size mismatch for backbone.blocks.11.mlp.fc1.bias: copying a param with shape torch.Size([3072]) from checkpoint, the shape in current model is torch.Size([1536]).
	size mismatch for backbone.blocks.11.mlp.fc2.weight: copying a param with shape torch.Size([768, 3072]) from checkpoint, the shape in current model is torch.Size([384, 1536]).
	size mismatch for backbone.blocks.11.mlp.fc2.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for backbone.blocks.11.ls2.gamma: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for backbone.norm.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for backbone.norm.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for query_down_heads.0.0.weight: copying a param with shape torch.Size([768, 768, 3, 3]) from checkpoint, the shape in current model is torch.Size([384, 384, 3, 3]).
	size mismatch for query_down_heads.0.0.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for query_down_heads.0.1.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for query_down_heads.0.1.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for query_down_heads.0.1.running_mean: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for query_down_heads.0.1.running_var: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for query_down_heads.1.0.weight: copying a param with shape torch.Size([768, 768, 3, 3]) from checkpoint, the shape in current model is torch.Size([384, 384, 3, 3]).
	size mismatch for query_down_heads.1.0.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for query_down_heads.1.1.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for query_down_heads.1.1.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for query_down_heads.1.1.running_mean: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for query_down_heads.1.1.running_var: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for query_down_heads.2.0.weight: copying a param with shape torch.Size([768, 768, 3, 3]) from checkpoint, the shape in current model is torch.Size([384, 384, 3, 3]).
	size mismatch for query_down_heads.2.0.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for query_down_heads.2.1.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for query_down_heads.2.1.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for query_down_heads.2.1.running_mean: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for query_down_heads.2.1.running_var: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for query_down_heads.3.0.weight: copying a param with shape torch.Size([768, 768, 3, 3]) from checkpoint, the shape in current model is torch.Size([384, 384, 3, 3]).
	size mismatch for query_down_heads.3.0.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for query_down_heads.3.1.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for query_down_heads.3.1.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for query_down_heads.3.1.running_mean: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for query_down_heads.3.1.running_var: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for query_down_heads.4.0.weight: copying a param with shape torch.Size([768, 768, 3, 3]) from checkpoint, the shape in current model is torch.Size([384, 384, 3, 3]).
	size mismatch for query_down_heads.4.0.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for query_down_heads.4.1.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for query_down_heads.4.1.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for query_down_heads.4.1.running_mean: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for query_down_heads.4.1.running_var: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
	size mismatch for reduce.0.weight: copying a param with shape torch.Size([256, 768, 3, 3]) from checkpoint, the shape in current model is torch.Size([256, 384, 3, 3]).

In [None]:
import os
import glob
import random
import cv2
import json
from PIL import Image
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import decord
from torchvision import transforms as T
import torch.nn.functional as F
import matplotlib.pyplot as plt
import torchvision.io as io

In [None]:
def normalize_bbox(bbox, h, w):
    '''
    bbox torch tensor in shape [4] or [...,4], under torch axis
    '''
    bbox_cp = bbox.clone()
    if len(bbox.shape) > 1: # [N,4]
        bbox_cp[..., 0] /= h
        bbox_cp[..., 1] /= w
        bbox_cp[..., 2] /= h
        bbox_cp[..., 3] /= w
        return bbox_cp
    else:
        return torch.tensor([bbox_cp[0]/h, bbox_cp[1]/w, bbox_cp[2]/h, bbox_cp[3]/w])

def recover_bbox(bbox, h, w):
    '''
    bbox torch tensor in shape [4] or [...,4], under torch axis
    '''
    bbox_cp = bbox.clone()
    if len(bbox.shape) > 1: # [N,4]
        bbox_cp[..., 0] *= h
        bbox_cp[..., 1] *= w
        bbox_cp[..., 2] *= h
        bbox_cp[..., 3] *= w
        return bbox_cp
    else:
        return torch.tensor([bbox_cp[0]*h, bbox_cp[1]*w, bbox_cp[2]*h, bbox_cp[3]*w])

def bbox_torchTocv2(bbox):
    '''
    torch, idx 0/2 for height, 1/3 for width (x,y,x,y)
    cv2: idx 0/2 for width, 1/3 for height (y,x,y,x)
    bbox torch tensor in shape [4] or [...,4], under torch axis
    '''
    bbox_cp = bbox.clone()
    if len(bbox.shape) > 1:
        bbox_x1 = bbox_cp[...,0].unsqueeze(-1)
        bbox_y1 = bbox_cp[...,1].unsqueeze(-1)
        bbox_x2 = bbox_cp[...,2].unsqueeze(-1)
        bbox_y2 = bbox_cp[...,3].unsqueeze(-1)
        return torch.cat([bbox_y1, bbox_x1, bbox_y2, bbox_x2], dim=-1)
    else:
        return torch.tensor([bbox_cp[1], bbox_cp[0], bbox_cp[3], bbox_cp[2]])

In [None]:
def sample_frames_balance(num_frames, frame_interval, sample, sampling='rand'):
    '''
    sample clips with balanced negative and positive samples
    params:
        num_frames: total number of frames to sample
        query_frame: query time index
        frame_interval: frame interval, where value 1 is for no interval (consecutive frames)
        sample: data annotations
        sampling: only effective for frame_interval larger than 1
    return: 
        frame_idxs: length [num_frames]
    '''
    required_len = (num_frames - 1) * frame_interval + 1
    anno_valid_idx_range = sample["response_track_valid_range"]
    anno_len = anno_valid_idx_range[1] - anno_valid_idx_range[0] + 1
    
    if anno_len <= required_len:
        if anno_len < required_len:
            num_valid = anno_len // frame_interval
        else:
            num_valid = num_frames
        num_invalid = num_frames - num_valid
        if anno_valid_idx_range[1] < required_len:
            idx_start = random.choice(range(anno_valid_idx_range[0])) if anno_valid_idx_range[0] > 0 else 0
            idx_end = idx_start + required_len
        else:
            num_prior = random.choice(range(num_invalid)) if num_invalid != 0 else 0
            num_post = num_invalid - num_prior
            idx_start = anno_valid_idx_range[0] - frame_interval * num_prior
            idx_end = anno_valid_idx_range[1] + frame_interval * num_post + 1
        intervals = np.linspace(start=idx_start, stop=idx_end, num=num_frames+1).astype(int)
        ranges = []
        for idx, interv in enumerate(intervals[:-1]):
            ranges.append((interv, intervals[idx + 1]))
        if sampling == 'rand':
            frame_idxs_pos = [random.choice(range(x[0], x[1])) for x in ranges]
        elif sampling == 'uniform':
            frame_idxs_pos = [(x[0] + x[1]) // 2 for x in ranges]
    else:
        num_addition = anno_len - required_len
        start = random.choice(range(num_addition))
        frame_idxs_pos = [anno_valid_idx_range[0] + start + it for it in range(num_frames)]
    return frame_idxs_pos

decord.bridge.set_bridge("torch")

def read_frames_decord_balance(video_path, num_frames, frame_interval, sample, sampling='rand'):
    video_reader = decord.VideoReader(video_path, num_threads=1)
    vlen = len(video_reader)
    # origin_fps = int(video_reader.get_avg_fps())
    # gt_fps = int(sample['clip_fps'])
    # down_rate = origin_fps // gt_fps
    # query_frame = int(sample['query_frame'])
    frame_idxs = sample_frames_balance(num_frames, frame_interval, sample, sampling)      # downsampled fps idxs, used to get bbox annotation
    # before_query = torch.tensor(frame_idxs) < query_frame
    frame_idxs_origin = [min(it, vlen - 1) for it in frame_idxs]        # origin clip fps frame idxs
    #video_reader.skip_frames(1)
    frames = video_reader.get_batch(frame_idxs_origin)
    frames = frames.float() / 255
    frames = frames.permute(0, 3, 1, 2)
    return frames, frame_idxs#, before_query

In [None]:
class VisualQuery2DDataset(Dataset):
    def __init__(self, clip_params, query_params, data_paths, mode='train', transform=None):
        self.clip_params = clip_params
        self.query_params = query_params
        self.data_paths = data_paths
        self.reduced_data_paths = [path.split('/')[-1] for path in self.data_paths]
        self.mode = mode
        
        if transform is None:
            self.transform = T.Compose([
                T.Resize((self.query_params['query_size'], self.query_params['query_size'])),
                T.ToTensor()
            ])
        else:
            self.transform = transform

        if self.clip_params['padding_value'] == 'zero':
            self.padding_value = 0
        elif self.clip_params['padding_value'] == 'mean':
            self.padding_value = 0.5

        if self.mode == 'train' or self.mode == 'val':
            self.annotations_path = os.path.join(self.data_paths[0].split('/samples')[0], 'annotations/annotations.json')
            self.annotations = self._read_annotations(self.annotations_path)
        else:
            self.annotations = None

    def _read_annotations(self, annotation_path):
        with open(annotation_path, 'r') as f:
            anno_json = json.load(f)
        self.annotations = []
        for video in anno_json:
            if video['video_id'] in self.reduced_data_paths:
                for clip_id, clip in enumerate(video['annotations']):
                    response_track_frame_ids = []
                    bboxes = clip['bboxes']
                    for bbox in bboxes:
                        response_track_frame_ids.append(int(bbox['frame']))
                    frame_id_min = min(response_track_frame_ids)
                    frame_id_max = max(response_track_frame_ids)
                    curr_anno = {
                        'video_id': video['video_id'],
                        'clip_id': clip_id, 
                        'response_track': clip['bboxes'],
                        'response_track_valid_range': [frame_id_min, frame_id_max],
                        'object_title': video['video_id'].split('_')[0],
                    }
                    self.annotations.append(curr_anno)
        return self.annotations

    def _get_clip_bbox(self, sample, clip_idxs, clip_h, clip_w):
        
        clip_with_bbox, clip_bbox = [], []
        response_track = sample['response_track']
        clip_bbox_all = {}
        
        for it in response_track:
            clip_bbox_all[int(it['frame'])] = [it['y1'], it['x1'], it['y2'], it['x2']]
        
        for idx in clip_idxs:
            if int(idx) in clip_bbox_all:
                clip_with_bbox.append(True)
                curr_bbox = torch.tensor(clip_bbox_all[int(idx)])
                curr_bbox_normalize = normalize_bbox(curr_bbox, clip_h, clip_w)
                clip_bbox.append(curr_bbox_normalize)
            else:
                clip_with_bbox.append(False)
                clip_bbox.append(torch.tensor([0.0, 0.0, 0.00001, 0.00001]))
        clip_with_bbox = torch.tensor(clip_with_bbox).float()
        clip_bbox = torch.stack(clip_bbox, dim=0)
        return clip_with_bbox, clip_bbox

    def _get_clip_path(self, data_path):
        clip_path = glob.glob(os.path.join(data_path, '*.mp4'))[0]
        return clip_path

    def _get_query_path(self, data_path):
        query_path = glob.glob(os.path.join(data_path, 'object_images', '*.jpg'))
        return query_path
    
    def _process_clip(self, clip, clip_bbox, clip_with_bbox):
        '''
        clip: in [T,C,H,W]
        bbox: in [T,4] with torch coordinate with value range [0,1] normalized
        clip_with_bbox: in [T]
        '''
        target_size = self.clip_params['fine_size']

        t, _, h, w = clip.shape
        clip_bbox = recover_bbox(clip_bbox, h, w)

        try:
            fg_idxs = torch.where(clip_with_bbox)[0].numpy().tolist()
            idx = random.choice(fg_idxs)
            frame = (clip[idx] * 255).permute(1,2,0).numpy().astype(np.uint8)
            frame = Image.fromarray(frame)
            bbox = bbox_torchTocv2(clip_bbox[idx]).tolist()
            query = frame.crop((bbox[0], bbox[1], bbox[2], bbox[3]))
            query_size = self.query_params['query_size']
            query = query.resize((query_size, query_size))
            query = torch.from_numpy(np.asarray(query) / 255.0).permute(2,0,1)
        except:
            query = None

        max_size, min_size = max(h, w), min(h, w)
        pad_height = True if h < w else False
        pad_size = (max_size - min_size) // 2
        if pad_height:
            pad_input = [0, pad_size] * 2                   # for the left, top, right and bottom borders respectively
            clip_bbox[:,0] += (max_size - min_size) / 2.0   # in padded image size
            clip_bbox[:,2] += (max_size - min_size) / 2.0
        else:
            pad_input = [pad_size, 0] * 2
            clip_bbox[:,1] += (max_size - min_size) / 2.0
            clip_bbox[:,3] += (max_size - min_size) / 2.0
        
        transform_pad = T.Pad(pad_input, fill=self.padding_value)
        clip = transform_pad(clip)        # square image
        h_pad, w_pad = clip.shape[-2:]
        clip = F.interpolate(clip, size=(target_size, target_size), mode='bilinear')#.squeeze(0)
        clip_bbox = clip_bbox / float(h_pad)                # in range [0,1]

        # if self.split == 'train':
        #     clip_bbox, clip_with_bbox = self._process_bbox(clip_bbox, clip_with_bbox)
        return clip, clip_bbox, clip_with_bbox, query, h, w

    def __len__(self):
        # Return the total number of samples
        return len(self.annotations)

    def __getitem__(self, idx):
        # Load and return a sample
        sample = self.annotations[idx]
        data_path = os.path.join('/'.join(self.data_paths[0].split('/')[:-1]), sample['video_id'])
        clip_path = self._get_clip_path(data_path)
        query_path = self._get_query_path(data_path)
        # query_images = [Image.open(img_path).convert("RGB") for img_path in query_path]
        # query_images = [self.transform(img) for img in query_images]
        # query_images = torch.stack(query_images, dim=0)
        query_images = Image.open(query_path[0]).convert("RGB")
        query_images = self.transform(query_images)
        
        sample_method = self.clip_params['sampling']

        clip, clip_idxs = read_frames_decord_balance(clip_path,
                                                    self.clip_params['num_frames'],
                                                    self.clip_params['frame_interval'],
                                                    sample,
                                                sampling=sample_method)
    
        clip_h, clip_w = clip.shape[-2], clip.shape[-1]
        clip_with_bbox, clip_bbox = self._get_clip_bbox(sample, clip_idxs, clip_h, clip_w)
        clip, clip_bbox, clip_with_bbox, query, clip_h, clip_w = self._process_clip(clip, clip_bbox, clip_with_bbox)
        
        results = {
            'clip': clip,
            'clip_with_bbox': clip_with_bbox,
            'clip_bbox': clip_bbox.float(),
            'clip_idxs': clip_idxs,
            'query_images': query_images,
            'clip_h': clip_h,
            'clip_w': clip_w,
            # 'query': query
        }

        return results

In [None]:
abs_dir = '/kaggle/input/aeroeyes/observing/train/samples'
data_paths = glob.glob(abs_dir+'/*')
np.random.shuffle(data_paths)
train_paths = data_paths[:int(0.8*len(data_paths))]
val_paths = data_paths[int(0.8*len(data_paths)):]

In [None]:
clip_params = {
    'num_frames': 30,
    'frame_interval': 30,
    'sampling': 'rand',
    'fine_size': 448,
    'padding_value': 'zero',
}

query_params = {
    'query_size': 448
}

train_transform = T.Compose([
    T.Resize((query_params['query_size'], query_params['query_size'])),
    T.RandomRotation(30),
    T.RandomHorizontalFlip(),
    T.RandomVerticalFlip(),
    T.ToTensor()
])

train_dataset = VisualQuery2DDataset(clip_params, query_params, train_paths, mode='train', transform=train_transform)

In [None]:
train_dataloader = DataLoader(train_dataset, batch_size=4, shuffle=True)

In [None]:
sample_data = next(iter(train_dataloader))

In [None]:
# Take indices of bounding boxes
batch_size = sample_data['clip'].size(0)
bbox_indices = [torch.nonzero(sample_data['clip_with_bbox'][b]).view(-1) for b in range(batch_size)]
bbox_indices

In [None]:
num_sample = 4
fig, axs = plt.subplots(4, 5, figsize=(15,10))
frame = [random.choice(bbox_indices[i]).item() if len(bbox_indices[i])>0 else 0 for i in range(num_sample)]

for i in range(num_sample):
    clip = sample_data['clip'][i, frame[i]].permute(1,2,0).numpy()
    h, w, _ = clip.shape
    axs[i, 0].imshow(clip)
    axs[i, 0].axis('off')
    axs[i, 0].set_title('Clip Frame with BBox')
    bbox = sample_data['clip_bbox'][i, frame[i]]
    bbox = recover_bbox(bbox, h, w)
    rect = plt.Rectangle((bbox[1], bbox[0]), (bbox[3]-bbox[1]), (bbox[2]-bbox[0]), linewidth=1, edgecolor='r', facecolor='none')
    axs[i, 0].add_patch(rect)
    # Crop the image using the bounding box
    crop_image = sample_data['clip'][i, frame[i]].permute(1,2,0).numpy()
    y1, x1, y2, x2 = int(bbox[0]), int(bbox[1]), int(bbox[2]), int(bbox[3])
    crop_image = crop_image[y1:y2, x1:x2]
    print(w, h, crop_image.shape)
    resized_crop = cv2.resize(crop_image, (w, h))
    axs[i, 1].imshow(resized_crop)
    axs[i, 1].axis('off')
    axs[i, 1].set_title('Cropped BBox from Clip')
    axs[i, 2].imshow(sample_data['query_images'][i].permute(1,2,0).numpy())
    axs[i, 2].axis('off')
    # axs[i, 2].set_title('Query Image 1')
    # axs[i, 3].imshow(sample_data['query_images'][i, 1].permute(1,2,0).numpy())
    # axs[i, 3].axis('off')
    # axs[i, 3].set_title('Query Image 2')
    # axs[i, 4].imshow(sample_data['query_images'][i, 2].permute(1,2,0).numpy())
    # axs[i, 4].axis('off')
    # axs[i, 4].set_title('Query Image 3')

plt.tight_layout()
plt.show()

In [None]:
train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=0)
batch = next(iter(train_dataloader))
# batch

In [None]:
batch['clip_with_bbox']

In [None]:
output = model(batch['clip'].to(device), batch['query_images'].to(device), training=False, fix_backbone=True)

In [None]:
max_k_values, max_k_idxs = torch.topk(output['prob'], k=5, dim=-1)

In [None]:
max_values, max_idxs = torch.max(output['prob'], dim=-1)
# bboxes = output['bbox'][max_idxs]

In [None]:
print(max_values)
print(max_idxs)

In [None]:
data_tensor = output['bbox']
# Index tensor: [1, 30]. Values must be between 0 and 767
batch_indices = torch.zeros_like(max_idxs)
# Item indices (0 to 29)
item_indices = torch.arange(30).repeat(1, 1)

# 3. Apply advanced indexing
retrieved_data_adv = data_tensor[batch_indices, item_indices, index_tensor, :]

# Output Shape: [1, 30, 4]
print(f"Retrieved Data Shape: {retrieved_data_adv.shape}")

In [None]:
data_tensor = output['bbox']
index_tensor = max_k_idxs
prepared_indices = max_k_idxs.unsqueeze(-1)

retrieved_data = torch.gather(data_tensor, dim=2, index=prepared_indices)

# Resulting shape: [1, 30, 5, 4]
# This is correct: (Batch, BBoxes, 5 Indices Retrieved, 4 Coordinates)
print(f"Data Tensor Shape: {data_tensor.shape}")
print(f"Index Tensor Shape: {index_tensor.shape}")
print(f"Retrieved Data Shape: {retrieved_data.shape}")

In [None]:
retrieved_data.shape

In [None]:
data_tensor = output['bbox']
index_tensor = max_k_idxs

# 1. Create helper indices for the first two dimensions

# a. Batch Index: All 0s. Shape [1, 1, 1] for broadcasting
batch_indices = torch.zeros(1, 1, 1, dtype=torch.long)

# b. Item Index (0 to 29): Shape [1, 30, 1] for broadcasting
item_indices = torch.arange(30).view(1, 30, 1)

# c. Value Index (Your [1, 30, 5] tensor)
value_indices = index_tensor.long()

# 2. Apply advanced indexing
retrieved_data_adv = data_tensor[batch_indices, item_indices, value_indices, :]

# Resulting shape: [1, 30, 5, 4]
print(f"Corrected Retrieved Data Shape (Advanced Indexing): {retrieved_data_adv.shape}")

In [None]:
num_frames = 30

fig, ax = plt.subplots(30, 3, figsize=(20, 100))

for i in range (num_frames):
    clip = batch['clip'][0, i].permute(1, 2, 0).cpu().numpy()
    h, w, _ = clip.shape
    ax[i, 0].imshow(clip)
    ax[i, 0].axis('off')
    if batch['clip_with_bbox'][0, i] == 1:
        gt_bbox = batch['clip_bbox'][0, i]
        gt_bbox = recover_bbox(gt_bbox, h, w)
        rect = plt.Rectangle((gt_bbox[1], gt_bbox[0]), (gt_bbox[3]-gt_bbox[1]), (gt_bbox[2]-gt_bbox[0]), linewidth=1, edgecolor='r', facecolor='none')
        ax[i, 0].add_patch(rect)
    ax[i, 0].set_title('GT')
    ax[i, 1].imshow(clip)
    for j in range (5):
        bbox = retrieved_data_adv[0, i, j]
        bbox = recover_bbox(bbox, h, w)
        rect = plt.Rectangle((bbox[1], bbox[0]), (bbox[3]-bbox[1]), (bbox[2]-bbox[0]), linewidth=1, edgecolor='r', facecolor='none')
        ax[i, 1].add_patch(rect)
    ax[i, 1].axis('off')
    ax[i, 1].set_title('Predicted')
    ax[i, 2].imshow(batch['query_images'][0].permute(1,2,0).numpy())
    ax[i, 2].axis('off')
    ax[i, 2].set_title('Query Image')
plt.tight_layout()
plt.show()

In [None]:
batch['clip'].shape