In [2]:
import time
import yaml
import torch
from torch import nn
from torch.utils import checkpoint
import torch.nn.functional as F
import math
import cv2
import sys
import numpy as np
import fastitpn as fastitpn_module



In [3]:
class EncoderBase(nn.Module):

    def __init__(self, encoder: nn.Module, train_encoder: bool, open_layers: list, num_channels: int):
        super().__init__()
        open_blocks = open_layers[2:]
        open_items = open_layers[0:2]
        for name, parameter in encoder.named_parameters():

            if not train_encoder:
                freeze = True
                for open_block in open_blocks:
                    if open_block in name:
                        freeze = False
                if name in open_items:
                    freeze = False
                if freeze == True:
                    parameter.requires_grad_(False)  # here should allow users to specify which layers to freeze !

        self.body = encoder
        self.num_channels = num_channels

    def forward(self, template_list, search_list, template_anno_list):
        xs = self.body(template_list, search_list, template_anno_list)
        return xs


#fast_itpn_tiny_1600e_1k

class Encoder(EncoderBase):
    """FastITPN encoder."""
    def __init__(self, name: str,
                 train_encoder: bool,
                 pretrain_type: str,
                 search_size: int,
                 search_number: int,
                 template_size: int,
                 template_number: int,
                 open_layers: list,
                 cfg=None):
        if "fastitpn" in name.lower():
            encoder = getattr(fastitpn_module, name)(
                pretrained=True,
                search_size=search_size,
                template_size=template_size,
                drop_rate=0.0,
                drop_path_rate=0.1,
                attn_drop_rate=0.0,
                init_values=0.1,
                drop_block_rate=None,
                use_mean_pooling=True,
                grad_ckpt=cfg["MODEL"]["ENCODER"]["GRAD_CKPT"],
                pos_type=cfg["MODEL"]["ENCODER"]["POS_TYPE"],
                token_type_indicate=cfg["MODEL"]["ENCODER"]["TOKEN_TYPE_INDICATE"],
                pretrain_type = cfg["MODEL"]["ENCODER"]["PRETRAIN_TYPE"],
            )
            if "itpnb" in name:
                num_channels = 512
            elif "itpnl" in name:
                num_channels = 768
            elif "itpnt" in name:
                num_channels = 384
            elif "itpns" in name:
                num_channels = 384
            else:
                num_channels = 512
        else:
            raise ValueError()
        super().__init__(encoder, train_encoder, open_layers, num_channels)

def build_encoder(cfg):
    train_encoder = (cfg["TRAIN"]["ENCODER_MULTIPLIER"] > 0) and (cfg["TRAIN"]["FREEZE_ENCODER"] == False)
    encoder = Encoder(cfg["MODEL"]["ENCODER"]["TYPE"], train_encoder,
                      cfg["MODEL"]["ENCODER"]["PRETRAIN_TYPE"],
                      cfg["DATA"]["SEARCH"]["SIZE"], cfg["DATA"]["SEARCH"]["NUMBER"],
                      cfg["DATA"]["TEMPLATE"]["SIZE"], cfg["DATA"]["TEMPLATE"]["NUMBER"],
                      cfg["TRAIN"]["ENCODER_OPEN"], cfg)
    return encoder

class MambaBlock(nn.Module):
    def __init__(self,dt_scale, d_model,d_inner,dt_rank,d_state,bias,d_conv,conv_bias,dt_init,dt_max,dt_min,dt_init_floor):
        super().__init__()
        #  projects block input from D to 2*ED (two branches)
        self.dt_scale = dt_scale
        self.d_model = d_model
        self.d_inner = d_inner
        self.dt_rank = dt_rank
        self.d_state = d_state
        self.in_proj = nn.Linear(self.d_model, 2 * self.d_inner, bias=bias)

        self.conv1d = nn.Conv1d(in_channels=self.d_inner, out_channels=self.d_inner,
                                kernel_size=d_conv, bias=conv_bias,
                                groups=self.d_inner,
                                padding=(d_conv - 1)//2)

        #  projects x to input-dependent Δ, B, C
        self.x_proj = nn.Linear(self.d_inner, self.dt_rank + 2 * self.d_state, bias=False)

        #  projects Δ from dt_rank to d_inner
        self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True)

        #  dt initialization
        #  dt weights
        dt_init_std = self.dt_rank ** -0.5 * self.dt_scale
        if dt_init == "constant":
            nn.init.constant_(self.dt_proj.weight, dt_init_std)
        elif dt_init == "random":
            nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std)
        else:
            raise NotImplementedError

        # dt bias
        dt = torch.exp(
            torch.rand(self.d_inner) * (math.log(dt_max) - math.log(dt_min)) + math.log(dt_min)
        ).clamp(min=dt_init_floor)
        inv_dt = dt + torch.log(
            -torch.expm1(-dt))  #  inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
        with torch.no_grad():
            self.dt_proj.bias.copy_(inv_dt)
        # self.dt_proj.bias._no_reinit = True # initialization would set all Linear.bias to zero, need to mark this one as _no_reinit
        #  todo : explain why removed

        # S4D real initialization
        A = torch.arange(1, self.d_state + 1, dtype=torch.float32).repeat(self.d_inner, 1)
        self.A_log = nn.Parameter(
            torch.log(A))  # why store A in log ? to keep A < 0 (cf -torch.exp(...)) ? for gradient stability ?
        self.D = nn.Parameter(torch.ones(self.d_inner))

        #  projects block output from ED back to D
        self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias)
    def forward(self, x, h):
        #  x : (B,L, D)
        # h : (B,L, ED, N)

        #  y : (B, L, D)


        xz = self.in_proj(x)  # (B, L,2*ED)
        x, z = xz.chunk(2, dim=-1)  #  (B,L, ED), (B,L, ED)
        x_cache = x.permute(0,2,1)#(B, ED,L)

        #  x branch
        x = self.conv1d( x_cache).permute(0,2,1) #  (B,L , ED)

        x = F.silu(x)
        y, h = self.ssm_step(x, h)
        #y->B,L,ED;h->B,L,ED,N

        #  z branch
        z = F.silu(z)

        output = y * z
        output = self.out_proj(output)  #  (B, L, D)

        return output, h

    def ssm_step(self, x, h):
        #  x : (B, L, ED)
        #  h : (B, L, ED, N)

        A = -torch.exp(
            self.A_log.float())  # (ED, N) # todo : ne pas le faire tout le temps, puisque c'est indépendant de la timestep
        D = self.D.float()
        #  TODO remove .float()

        deltaBC = self.x_proj(x)  #  (B, L, dt_rank+2*N)

        delta, B, C = torch.split(deltaBC, [self.dt_rank, self.d_state, self.d_state],
                                  dim=-1)  #  (B, L,dt_rank), (B, L, N), (B, L, N)
        delta = F.softplus(self.dt_proj(delta))  #  (B, L, ED)

        deltaA = torch.exp(delta.unsqueeze(-1) * A)  #  (B,L, ED, N)
        deltaB = delta.unsqueeze(-1) * B.unsqueeze(2)  #  (B,L, ED, N)

        BX = deltaB * (x.unsqueeze(-1))  #  (B, L,ED, N)

        if h is None:
            h = torch.zeros(x.size(0), x.size(1), self.d_inner, self.d_state, device=deltaA.device)  #  (B, L, ED, N)

        h = deltaA * h + BX  #  (B, L, ED, N)

        y = (h @ C.unsqueeze(-1)).squeeze(3)  #  (B, L, ED, N) @ (B, L, N, 1) -> (B, L, ED, 1)

        y = y + D * x#B,L,ED

        #  todo : pq h.squeeze(1) ??
        return y, h
    
class DWConv(nn.Module):
    def __init__(self, dim=768):
        super().__init__()
        self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)

    def forward(self, x):
        x = x.permute(1,0,2)
        B, N, C = x.shape
        x = x.transpose(1,2).view(B,C,int(N**0.5),int(N**0.5)).contiguous()
        x = self.dwconv(x).flatten(2).transpose(1, 2)#B,N,C
        x = x.permute(1,0,2)
        return x
class ConvFFN(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None,
                 act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.dwconv = DWConv(hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.dwconv(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x
    
class ConvFFN(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None,
                 act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.dwconv = DWConv(hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.dwconv(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True):

    if drop_prob == 0. or not training:
        return x
    keep_prob = 1 - drop_prob
    shape = (x.shape[0],) + (1,) * (x.ndim - 1)  # work with diff dim tensors, not just 2D ConvNets
    random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
    if keep_prob > 0.0 and scale_by_keep:
        random_tensor.div_(keep_prob)
    return x * random_tensor

class DropPath(nn.Module):
   
    def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob
        self.scale_by_keep = scale_by_keep

    def forward(self, x):
        return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)

    def extra_repr(self):
        return f'drop_prob={round(self.drop_prob,3):0.3f}'
    

class Extractor(nn.Module):
    def __init__(self, d_model, num_heads=8, dropout=0.1, drop_path=0.1,
                 norm_layer=lambda x: nn.LayerNorm(x, eps=1e-6)):  # Замена partial на лямбду
        super().__init__()
        self.query_norm = norm_layer(d_model)
        self.feat_norm = norm_layer(d_model)
        self.attn = nn.MultiheadAttention(d_model, num_heads, dropout=dropout)
        # convffn
        self.ffn = ConvFFN(in_features=d_model, hidden_features=int(d_model * 0.25), drop=0.)
        self.ffn_norm = norm_layer(d_model)
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()

    def forward(self, query, feat):

        def _inner_forward(query, feat):
            # query:l,b,d;feat:l,b,d
            attn = self.attn(self.query_norm(query),
                             self.feat_norm(feat), self.feat_norm(feat))[0]
            query = query + attn

            query = query + self.drop_path(self.ffn(self.ffn_norm(query)))
            return query

        query = _inner_forward(query, feat)

        return query
 
class Injector(nn.Module):
    def __init__(self, d_model, n_heads=8, norm_layer=lambda x: nn.LayerNorm(x, eps=1e-6), dropout=0.1, init_values=0.):
        super().__init__()
        self.query_norm = norm_layer(d_model)
        self.feat_norm = norm_layer(d_model)
        self.attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout)
        self.gamma = nn.Parameter(init_values * torch.ones((d_model)), requires_grad=True)
        
    def forward(self, query,feat):
            #query:l,b,d;feat:l,b,d
        def _inner_forward(query, feat):

            attn = self.attn(self.query_norm(query),
                             self.feat_norm(feat),self.feat_norm(feat))[0]
            return query + self.gamma * attn
        query = _inner_forward(query, feat)
        return query    
    
    

class InteractionBlock(nn.Module):
    def __init__(self, d_model, extra_extractor, grad_ckpt):
        super().__init__()
        self.grad_ckpt = grad_ckpt
        self.injector = Injector(d_model=d_model)
        self.extractor = Extractor(d_model=d_model)
        if extra_extractor:
            self.extra_extractors = nn.Sequential(*[
                Extractor(d_model=d_model)
                for _ in range(2)])
        else:
            self.extra_extractors = None

    def forward(self,x,xs,blocks):
        x = self.injector(x.permute(1,0,2),xs.permute(1,0,2)).permute(1,0,2)
        for idx,blk in enumerate(blocks):
            x = checkpoint.checkpoint(blk, x, None,use_reentrant=False) if self.grad_ckpt else blk(x,None)
        xs = checkpoint.checkpoint(self.extractor, xs.permute(1,0,2),x.permute(1,0,2),use_reentrant=False).permute(1,0,2) \
            if self.grad_ckpt else self.extractor(xs.permute(1, 0, 2), x.permute(1, 0, 2)).permute(1, 0, 2)  # b,n,c
        # xs = self.extractor(xs.permute(1,0,2),x.permute(1,0,2)).permute(1,0,2)#b,n,c
        if self.extra_extractors is not None:
            for extractor in self.extra_extractors:
                xs = checkpoint.checkpoint(extractor, xs.permute(1, 0, 2), x.permute(1, 0, 2), use_reentrant=False).permute(1, 0, 2) \
                    if self.grad_ckpt else extractor(xs.permute(1, 0, 2), x.permute(1, 0, 2)).permute(1, 0,2)  # b,n,c
                # xs = extractor(xs.permute(1,0,2),x.permute(1,0,2)).permute(1,0,2)
        return x,xs
class RMSNorm(nn.Module):
    def __init__(self, d_model: int, eps: float = 1e-5):
        super().__init__()

        self.eps = eps
        self.weight = nn.Parameter(torch.ones(d_model))

    def forward(self, x):
        output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight

        return output
    

class ResidualBlock(nn.Module):
    def __init__(self,dt_scale, d_model,d_inner,dt_rank,d_state,bias,d_conv,conv_bias,dt_init,dt_max,dt_min,dt_init_floor,grad_ckpt):
        super().__init__()

        self.grad_ckpt = grad_ckpt
        self.mixer = MambaBlock(dt_scale,d_model,d_inner,dt_rank,d_state,bias,d_conv,conv_bias,dt_init,dt_max,dt_min,dt_init_floor)
        self.norm = RMSNorm(d_model)

    def forward(self, x, h):
        #  x : (B, L, D)
        # h : (B, L, ED, N)
        #  output : (B,L, D)

        x = self.norm(x)
        output, h = checkpoint.checkpoint(self.mixer,x,h,use_reentrant=False) if self.grad_ckpt else self.mixer(x, h)
        output = output + x
        return output, h
    
class Mamba_Neck(nn.Module):
    def __init__(self, in_channel=512,d_model=512,d_inner=1024,bias=False,n_layers=4,dt_rank=32,d_state=16,d_conv=3,dt_min=0.001,
                 dt_max=0.1,dt_init='random',dt_scale=1.0,conv_bias=True,dt_init_floor=0.0001,grad_ckpt=False):
        super().__init__()
        self.d_model = d_model
        self.d_inner = d_inner
        self.bias = bias
        self.dt_rank = dt_rank
        self.d_state = d_state
        self.dt_scale = dt_scale
        self.num_channels = self.d_model
        self.layers = nn.ModuleList(
            [ResidualBlock(dt_scale,d_model,d_inner,dt_rank,d_state,bias,d_conv,conv_bias,dt_init,dt_max,dt_min,dt_init_floor,grad_ckpt)
             for _ in range(n_layers)])
        self.interactions = nn.ModuleList([
            InteractionBlock(d_model=d_model,extra_extractor=(True if i == n_layers - 1 else False),grad_ckpt=grad_ckpt)
            for i in range(n_layers)
        ])
        # self.norm_f = RMSNorm(config.d_model)

    def forward(self, x,xs,h,blocks,interaction_indexes):
        #  x : (B, L, D)
        #  caches : [cache(layer) for all layers], cache : (h, inputs)

        #  y : (B, L, D)
        #  caches : [cache(layer) for all layers], cache : (h, inputs)
        for i,index in enumerate(interaction_indexes):
            xs, h[i] = self.layers[i](xs, h[i])
            x,xs = self.interactions[i](x,xs,blocks[index[0]:index[1]])

        return x, xs, h
def build_neck(cfg,encoder):
    in_channel = encoder.num_channels
    d_model = cfg["MODEL"]["NECK"]["D_MODEL"]
    n_layers = cfg["MODEL"]["NECK"]["N_LAYERS"]
    d_state = cfg["MODEL"]["NECK"]["D_STATE"]
    grad_ckpt = cfg["MODEL"]["ENCODER"]["GRAD_CKPT"]
    neck = Mamba_Neck(in_channel=in_channel,d_model=d_model,d_inner=2*d_model,n_layers=n_layers,dt_rank=d_model//16,d_state=d_state,grad_ckpt=grad_ckpt)
    return neck

def box_xyxy_to_cxcywh(x):
    x0, y0, x1, y1 = x.unbind(-1)
    b = [(x0 + x1) / 2, (y0 + y1) / 2,
         (x1 - x0), (y1 - y0)]
    return torch.stack(b, dim=-1)

class MLPPredictor(nn.Module):
    def __init__(self, inplanes=64, channel=256, feat_sz=20, stride=16):
        super(MLPPredictor, self).__init__()
        self.feat_sz = feat_sz
        self.stride = stride
        self.img_sz = self.feat_sz * self.stride

        self.num_layers = 3
        h = [channel] * (self.num_layers - 1)
        self.layers_cls = nn.ModuleList(nn.Linear(n, k)
                                        for n, k in zip([inplanes] + h, h + [1]))
        self.layers_reg = nn.ModuleList(nn.Linear(n, k)
                                        for n, k in zip([inplanes] + h, h + [4]))

        # for p in self.parameters():
        #     if p.dim() > 1:
        #         nn.init.xavier_uniform_(p)

    def forward(self, x, gt_score_map=None):
        """ Forward pass with input x. """
        score_map, offset_map = self.get_score_map(x)

        # assert gt_score_map is None
        if gt_score_map is None:
            bbox = self.cal_bbox(score_map, offset_map)
        else:
            bbox = self.cal_bbox(gt_score_map.unsqueeze(1), offset_map)

        return score_map, bbox, offset_map

    def cal_bbox(self, score_map, offset_map, return_score=False):
        max_score, idx = torch.max(score_map.flatten(1), dim=1, keepdim=True)
        idx_y = torch.div(idx, self.feat_sz, rounding_mode='floor')
        idx_x = idx % self.feat_sz

        idx = idx.unsqueeze(1).expand(idx.shape[0], 4, 1) # torch.Size([32, 4, 1])
        offset = offset_map.flatten(2).gather(dim=2, index=idx).squeeze(-1)
        # offset: (l,t,r,b)

        # x1, y1, x2, y2
        bbox = torch.cat([idx_x.to(torch.float) / self.feat_sz - offset[:, :1], # the offset should not divide the self.feat_sz, since I use the sigmoid to limit it in (0,1)
                          idx_y.to(torch.float) / self.feat_sz - offset[:, 1:2],
                          idx_x.to(torch.float) / self.feat_sz + offset[:, 2:3],
                          idx_y.to(torch.float) / self.feat_sz + offset[:, 3:4],
                          ], dim=1)
        bbox = box_xyxy_to_cxcywh(bbox)
        if return_score:
            return bbox, max_score
        return bbox

    def get_score_map(self, x):

        def _sigmoid(x):
            y = torch.clamp(x.sigmoid_(), min=1e-4, max=1 - 1e-4)
            return y

        x_cls = x
        for i, layer in enumerate(self.layers_cls):
            x_cls = F.relu(layer(x_cls)) if i < self.num_layers - 1 else layer(x_cls)
        x_cls = x_cls.permute(0,2,1).reshape(-1,1,self.feat_sz,self.feat_sz)

        x_reg = x
        for i, layer in enumerate(self.layers_reg):
            x_reg = F.relu(layer(x_reg)) if i < self.num_layers - 1 else layer(x_reg)
        x_reg = x_reg.permute(0, 2, 1).reshape(-1, 4, self.feat_sz, self.feat_sz)

        return _sigmoid(x_cls), _sigmoid(x_reg)
    
class FrozenBatchNorm2d(torch.nn.Module):


    def __init__(self, n):
        super(FrozenBatchNorm2d, self).__init__()
        self.register_buffer("weight", torch.ones(n))
        self.register_buffer("bias", torch.zeros(n))
        self.register_buffer("running_mean", torch.zeros(n))
        self.register_buffer("running_var", torch.ones(n))

    def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
                              missing_keys, unexpected_keys, error_msgs):
        num_batches_tracked_key = prefix + 'num_batches_tracked'
        if num_batches_tracked_key in state_dict:
            del state_dict[num_batches_tracked_key]

        super(FrozenBatchNorm2d, self)._load_from_state_dict(
            state_dict, prefix, local_metadata, strict,
            missing_keys, unexpected_keys, error_msgs)

    def forward(self, x):
        # move reshapes to the beginning
        # to make it fuser-friendly
        w = self.weight.reshape(1, -1, 1, 1)
        b = self.bias.reshape(1, -1, 1, 1)
        rv = self.running_var.reshape(1, -1, 1, 1)
        rm = self.running_mean.reshape(1, -1, 1, 1)
        eps = 1e-5
        scale = w * (rv + eps).rsqrt()  # rsqrt(x): 1/sqrt(x), r: reciprocal
        bias = b - rm * scale
        return x * scale + bias

def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1,
         freeze_bn=False):
    if freeze_bn:
        return nn.Sequential(
            nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
                      padding=padding, dilation=dilation, bias=True),
            FrozenBatchNorm2d(out_planes),
            nn.ReLU(inplace=True))
    else:
        return nn.Sequential(
            nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
                      padding=padding, dilation=dilation, bias=True),
            nn.BatchNorm2d(out_planes),
            nn.ReLU(inplace=True))
    
def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1,
         freeze_bn=False):
    if freeze_bn:
        return nn.Sequential(
            nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
                      padding=padding, dilation=dilation, bias=True),
            FrozenBatchNorm2d(out_planes),
            nn.ReLU(inplace=True))
    else:
        return nn.Sequential(
            nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
                      padding=padding, dilation=dilation, bias=True),
            nn.BatchNorm2d(out_planes),
            nn.ReLU(inplace=True))
    
class CenterPredictor(nn.Module, ):
    def __init__(self, inplanes=64, channel=256, feat_sz=20, stride=16, freeze_bn=False):
        super(CenterPredictor, self).__init__()
        self.feat_sz = feat_sz
        self.stride = stride
        self.img_sz = self.feat_sz * self.stride

        # corner predict
        self.conv1_ctr = conv(inplanes, channel, freeze_bn=freeze_bn)
        self.conv2_ctr = conv(channel, channel // 2, freeze_bn=freeze_bn)
        self.conv3_ctr = conv(channel // 2, channel // 4, freeze_bn=freeze_bn)
        self.conv4_ctr = conv(channel // 4, channel // 8, freeze_bn=freeze_bn)
        self.conv5_ctr = nn.Conv2d(channel // 8, 1, kernel_size=1)

        # size regress
        self.conv1_offset = conv(inplanes, channel, freeze_bn=freeze_bn)
        self.conv2_offset = conv(channel, channel // 2, freeze_bn=freeze_bn)
        self.conv3_offset = conv(channel // 2, channel // 4, freeze_bn=freeze_bn)
        self.conv4_offset = conv(channel // 4, channel // 8, freeze_bn=freeze_bn)
        self.conv5_offset = nn.Conv2d(channel // 8, 2, kernel_size=1)

        # size regress
        self.conv1_size = conv(inplanes, channel, freeze_bn=freeze_bn)
        self.conv2_size = conv(channel, channel // 2, freeze_bn=freeze_bn)
        self.conv3_size = conv(channel // 2, channel // 4, freeze_bn=freeze_bn)
        self.conv4_size = conv(channel // 4, channel // 8, freeze_bn=freeze_bn)
        self.conv5_size = nn.Conv2d(channel // 8, 2, kernel_size=1)

        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    def forward(self, x, gt_score_map=None):
        """ Forward pass with input x. """
        score_map_ctr, size_map, offset_map = self.get_score_map(x) # x: torch.Size([b, c, h, w])
        # score_map_ctr: torch.Size([32, 1, 16, 16]) size_map: torch.Size([32, 2, 16, 16]) offset_map: torch.Size([32, 2, 16, 16])

        # assert gt_score_map is None
        if gt_score_map is None:
            bbox = self.cal_bbox(score_map_ctr, size_map, offset_map)
        else:
            bbox = self.cal_bbox(gt_score_map.unsqueeze(1), size_map, offset_map)

        return score_map_ctr, bbox, size_map, offset_map

    def cal_bbox(self, score_map_ctr, size_map, offset_map, return_score=False):
        max_score, idx = torch.max(score_map_ctr.flatten(1), dim=1, keepdim=True) # score_map_ctr.flatten(1): torch.Size([32, 256]) idx: torch.Size([32, 1]) max_score: torch.Size([32, 1])
        idx_y = torch.div(idx, self.feat_sz, rounding_mode='floor')
        idx_x = idx % self.feat_sz

        idx = idx.unsqueeze(1).expand(idx.shape[0], 2, 1)
        size = size_map.flatten(2).gather(dim=2, index=idx) # size_map: torch.Size([32, 2, 16, 16])  size_map.flatten(2): torch.Size([32, 2, 256])
        offset = offset_map.flatten(2).gather(dim=2, index=idx).squeeze(-1)

        # bbox = torch.cat([idx_x - size[:, 0] / 2, idx_y - size[:, 1] / 2,
        #                   idx_x + size[:, 0] / 2, idx_y + size[:, 1] / 2], dim=1) / self.feat_sz
        # cx, cy, w, h
        bbox = torch.cat([(idx_x.to(torch.float) + offset[:, :1]) / self.feat_sz,
                          (idx_y.to(torch.float) + offset[:, 1:]) / self.feat_sz,
                          size.squeeze(-1)], dim=1)

        if return_score:
            return bbox, max_score
        return bbox

    def get_pred(self, score_map_ctr, size_map, offset_map):
        max_score, idx = torch.max(score_map_ctr.flatten(1), dim=1, keepdim=True)
        idx_y = idx // self.feat_sz
        idx_x = idx % self.feat_sz

        idx = idx.unsqueeze(1).expand(idx.shape[0], 2, 1)
        size = size_map.flatten(2).gather(dim=2, index=idx)
        offset = offset_map.flatten(2).gather(dim=2, index=idx).squeeze(-1)

        # bbox = torch.cat([idx_x - size[:, 0] / 2, idx_y - size[:, 1] / 2,
        #                   idx_x + size[:, 0] / 2, idx_y + size[:, 1] / 2], dim=1) / self.feat_sz
        return size * self.feat_sz, offset

    def get_score_map(self, x):

        def _sigmoid(x):
            y = torch.clamp(x.sigmoid_(), min=1e-4, max=1 - 1e-4)
            return y

        # ctr branch
        x_ctr1 = self.conv1_ctr(x)
        x_ctr2 = self.conv2_ctr(x_ctr1)
        x_ctr3 = self.conv3_ctr(x_ctr2)
        x_ctr4 = self.conv4_ctr(x_ctr3)
        score_map_ctr = self.conv5_ctr(x_ctr4)

        # offset branch
        x_offset1 = self.conv1_offset(x)
        x_offset2 = self.conv2_offset(x_offset1)
        x_offset3 = self.conv3_offset(x_offset2)
        x_offset4 = self.conv4_offset(x_offset3)
        score_map_offset = self.conv5_offset(x_offset4)

        # size branch
        x_size1 = self.conv1_size(x)
        x_size2 = self.conv2_size(x_size1)
        x_size3 = self.conv3_size(x_size2)
        x_size4 = self.conv4_size(x_size3)
        score_map_size = self.conv5_size(x_size4)
        return _sigmoid(score_map_ctr), _sigmoid(score_map_size), score_map_offset
    
def build_decoder(cfg, encoder):
    num_channels_enc = encoder.num_channels
    stride = cfg["MODEL"]["ENCODER"]["STRIDE"]
    if cfg["MODEL"]["DECODER"]["TYPE"] == "MLP":
        in_channel = num_channels_enc
        hidden_dim = cfg["MODEL"]["DECODER"]["NUM_CHANNELS"]
        feat_sz = int(cfg["DATA"]["SEARCH"]["SIZE"] / stride)
        mlp_head = MLPPredictor(inplanes=in_channel, channel=hidden_dim,
                                feat_sz=feat_sz, stride=stride)
        return mlp_head
    elif "CORNER" in cfg["MODEL"]["DECODER"]["TYPE"]:
        feat_sz = int(cfg["DATA"]["SEARCH"]["SIZE"] / stride)
        channel = getattr(cfg["MODEL"], "NUM_CHANNELS", 256)
        print("head channel: %d" % channel)
        if cfg["MODEL"]["HEAD"]["TYPE"] == "CORNER":
            corner_head = Corner_Predictor(inplanes=cfg["MODEL"]["HIDDEN_DIM"], channel=channel,
                                           feat_sz=feat_sz, stride=stride)
        else:
            raise ValueError()
        return corner_head
    elif cfg["MODEL"]["DECODER"]["TYPE"] == "CENTER":
        in_channel = num_channels_enc
        out_channel = cfg["MODEL"]["DECODER"]["NUM_CHANNELS"]
        feat_sz = int(cfg["DATA"]["SEARCH"]["SIZE"] / stride)
        center_head = CenterPredictor(inplanes=in_channel, channel=out_channel,
                                      feat_sz=feat_sz, stride=stride)
        return center_head
    else:
        raise ValueError("HEAD TYPE %s is not supported." % cfg["MODEL"]["HEAD_TYPE"])
    
class Corner_Predictor(nn.Module):
    """ Corner Predictor module"""

    def __init__(self, inplanes=64, channel=256, feat_sz=20, stride=16, freeze_bn=False):
        super(Corner_Predictor, self).__init__()
        self.feat_sz = feat_sz
        self.stride = stride
        self.img_sz = self.feat_sz * self.stride
        '''top-left corner'''
        self.conv1_tl = conv(inplanes, channel, freeze_bn=freeze_bn)
        self.conv2_tl = conv(channel, channel // 2, freeze_bn=freeze_bn)
        self.conv3_tl = conv(channel // 2, channel // 4, freeze_bn=freeze_bn)
        self.conv4_tl = conv(channel // 4, channel // 8, freeze_bn=freeze_bn)
        self.conv5_tl = nn.Conv2d(channel // 8, 1, kernel_size=1)

        '''bottom-right corner'''
        self.conv1_br = conv(inplanes, channel, freeze_bn=freeze_bn)
        self.conv2_br = conv(channel, channel // 2, freeze_bn=freeze_bn)
        self.conv3_br = conv(channel // 2, channel // 4, freeze_bn=freeze_bn)
        self.conv4_br = conv(channel // 4, channel // 8, freeze_bn=freeze_bn)
        self.conv5_br = nn.Conv2d(channel // 8, 1, kernel_size=1)

        '''about coordinates and indexs'''
        with torch.no_grad():
            self.indice = torch.arange(0, self.feat_sz).view(-1, 1) * self.stride
            # generate mesh-grid
            self.coord_x = self.indice.repeat((self.feat_sz, 1)) \
                .view((self.feat_sz * self.feat_sz,)).float().cuda()
            self.coord_y = self.indice.repeat((1, self.feat_sz)) \
                .view((self.feat_sz * self.feat_sz,)).float().cuda()

    def forward(self, x, return_dist=False, softmax=True):
        """ Forward pass with input x. """
        score_map_tl, score_map_br = self.get_score_map(x)
        if return_dist:
            coorx_tl, coory_tl, prob_vec_tl = self.soft_argmax(score_map_tl, return_dist=True, softmax=softmax)
            coorx_br, coory_br, prob_vec_br = self.soft_argmax(score_map_br, return_dist=True, softmax=softmax)
            return torch.stack((coorx_tl, coory_tl, coorx_br, coory_br), dim=1) / self.img_sz, prob_vec_tl, prob_vec_br
        else:
            coorx_tl, coory_tl = self.soft_argmax(score_map_tl)
            coorx_br, coory_br = self.soft_argmax(score_map_br)
            return torch.stack((coorx_tl, coory_tl, coorx_br, coory_br), dim=1) / self.img_sz

    def get_score_map(self, x):
        # top-left branch
        x_tl1 = self.conv1_tl(x)
        x_tl2 = self.conv2_tl(x_tl1)
        x_tl3 = self.conv3_tl(x_tl2)
        x_tl4 = self.conv4_tl(x_tl3)
        score_map_tl = self.conv5_tl(x_tl4)

        # bottom-right branch
        x_br1 = self.conv1_br(x)
        x_br2 = self.conv2_br(x_br1)
        x_br3 = self.conv3_br(x_br2)
        x_br4 = self.conv4_br(x_br3)
        score_map_br = self.conv5_br(x_br4)
        return score_map_tl, score_map_br

    def soft_argmax(self, score_map, return_dist=False, softmax=True):
        """ get soft-argmax coordinate for a given heatmap """
        score_vec = score_map.view((-1, self.feat_sz * self.feat_sz))  # (batch, feat_sz * feat_sz)
        prob_vec = nn.functional.softmax(score_vec, dim=1)
        exp_x = torch.sum((self.coord_x * prob_vec), dim=1)
        exp_y = torch.sum((self.coord_y * prob_vec), dim=1)
        if return_dist:
            if softmax:
                return exp_x, exp_y, prob_vec
            else:
                return exp_x, exp_y, score_vec
        else:
            return exp_x, exp_y
        
class Preprocessor(object):
    def __init__(self):
        self.mean = torch.tensor([0.485, 0.456, 0.406]).view((1, 3, 1, 1)).cuda()
        self.std = torch.tensor([0.229, 0.224, 0.225]).view((1, 3, 1, 1)).cuda()
        self.mm_mean = torch.tensor([0.485, 0.456, 0.406, 0.485, 0.456, 0.406]).view((1, 6, 1, 1)).cuda()
        self.mm_std = torch.tensor([0.229, 0.224, 0.225, 0.229, 0.224, 0.225]).view((1, 6, 1, 1)).cuda()

    def process(self, img_arr: np.ndarray):
        if img_arr.shape[-1] == 6:
            mean = self.mm_mean
            std = self.mm_std
        else:
            mean = self.mean
            std = self.std
        # Deal with the image patch
        img_tensor = torch.tensor(img_arr).cuda().float().permute((2,0,1)).unsqueeze(dim=0)
        # img_tensor = torch.tensor(img_arr).float().permute((2,0,1)).unsqueeze(dim=0)
        img_tensor_norm = ((img_tensor / 255.0) - mean) / std  # (1,3,H,W)
        return img_tensor_norm
    
def hann1d(sz: int, centered = True) -> torch.Tensor:
    """1D cosine window."""
    if centered:
        return 0.5 * (1 - torch.cos((2 * math.pi / (sz + 1)) * torch.arange(1, sz + 1).float()))
    w = 0.5 * (1 + torch.cos((2 * math.pi / (sz + 2)) * torch.arange(0, sz//2 + 1).float()))
    return torch.cat([w, w[1:sz-sz//2].flip((0,))])
    
def hann2d(sz: torch.Tensor, centered = True) -> torch.Tensor:
    """2D cosine window."""
    return hann1d(sz[0].item(), centered).reshape(1, 1, -1, 1) * hann1d(sz[1].item(), centered).reshape(1, 1, 1, -1)    

def sample_target(im, target_bb, search_area_factor, output_sz=None):
   
    if not isinstance(target_bb, list):
        x, y, w, h = target_bb.tolist()
    else:
        x, y, w, h = target_bb
    # Crop image
    crop_sz = math.ceil(math.sqrt(w * h) * search_area_factor)

    if crop_sz < 1:
        raise Exception('Too small bounding box.')

    x1 = round(x + 0.5 * w - crop_sz * 0.5)
    x2 = x1 + crop_sz

    y1 = round(y + 0.5 * h - crop_sz * 0.5)
    y2 = y1 + crop_sz

    x1_pad = max(0, -x1)
    x2_pad = max(x2 - im.shape[1] + 1, 0)

    y1_pad = max(0, -y1)
    y2_pad = max(y2 - im.shape[0] + 1, 0)

    # Crop target
    im_crop = im[y1 + y1_pad:y2 - y2_pad, x1 + x1_pad:x2 - x2_pad, :]

    # Pad
    im_crop_padded = cv2.copyMakeBorder(im_crop, y1_pad, y2_pad, x1_pad, x2_pad, cv2.BORDER_CONSTANT)
    # deal with attention mask
    H, W, _ = im_crop_padded.shape

    if output_sz is not None:
        resize_factor = output_sz / crop_sz
        im_crop_padded = cv2.resize(im_crop_padded, (output_sz, output_sz))

        return im_crop_padded, resize_factor

    else:
        return im_crop_padded, 1.0
def transform_image_to_crop(box_in: torch.Tensor, box_extract: torch.Tensor, resize_factor: float,
                            crop_sz: torch.Tensor, normalize=False) -> torch.Tensor:
   
    box_extract_center = box_extract[0:2] + 0.5 * box_extract[2:4]

    box_in_center = box_in[0:2] + 0.5 * box_in[2:4]

    box_out_center = (crop_sz - 1) / 2 + (box_in_center - box_extract_center) * resize_factor
    box_out_wh = box_in[2:4] * resize_factor

    box_out = torch.cat((box_out_center - 0.5 * box_out_wh, box_out_wh))
    if normalize:
        return box_out / (crop_sz[0]-1)
    else:
        return box_out
def clip_box(box: list, H, W, margin=0):
    x1, y1, w, h = box
    x2, y2 = x1 + w, y1 + h
    x1 = min(max(0, x1), W-margin)
    x2 = min(max(margin, x2), W)
    y1 = min(max(0, y1), H-margin)
    y2 = min(max(margin, y2), H)
    w = max(margin, x2-x1)
    h = max(margin, y2-y1)
    return [x1, y1, w, h]

class BaseTracker():
    """Base class for all trackers."""

    def __init__(self, params):
        self.params = params
        self.visdom = None

    def predicts_segmentation_mask(self):
        return False

    def initialize(self, image, info: dict) -> dict:
        """Overload this function in your tracker. This should initialize the model."""
        raise NotImplementedError

    def track(self, image, info: dict = None) -> dict:
        """Overload this function in your tracker. This should track in the frame and update the model."""
        raise NotImplementedError

    def visdom_draw_tracking(self, image, box, segmentation=None):
        # Упрощенная обработка box без OrderedDict
        if isinstance(box, dict):  # Проверяем на обычный dict вместо OrderedDict
            box = list(box.values())  # Берем только значения
        elif not isinstance(box, (list, tuple)):  # Если не коллекция
            box = (box,)  # Превращаем в кортеж
        
        # Визуализация
        if segmentation is None:
            self.visdom.register((image, *box), 'Tracking', 1, 'Tracking')
        else:
            self.visdom.register((image, *box, segmentation), 'Tracking', 1, 'Tracking')


class MCITRACK(BaseTracker):
    def __init__(self, params):
        
        super(MCITRACK, self).__init__(params)
        network = build_mcitrack(params.cfg)
        network.load_state_dict(torch.load("MCITRACK_ep0300.pth.tar", map_location='cpu')['net'], strict=True)
        
        self.cfg = params.cfg
        self.network = network.cuda()
        self.network.eval()
        self.preprocessor = Preprocessor()
        self.state = None

        self.fx_sz = self.cfg["TEST"]["SEARCH_SIZE"] // self.cfg["MODEL"]["ENCODER"]["STRIDE"]
        if self.cfg["TEST"]["WINDOW"] == True:  # for window penalty
            self.output_window = hann2d(torch.tensor([self.fx_sz, self.fx_sz]).long(), centered=True).cuda()

        self.num_template = self.cfg["TEST"]["NUM_TEMPLATES"]

   
        self.frame_id = 0
        # for update
        self.h_state = [None] * self.cfg["MODEL"]["NECK"]["N_LAYERS"]



        self.memory_bank = self.cfg["TEST"]["MB"]["DEFAULT"]
        self.update_h_t = self.cfg["TEST"]["UPH"]["DEFAULT"]
        self.update_threshold = self.cfg["TEST"]["UPT"]["DEFAULT"]
        self.update_intervals = self.cfg["TEST"]["INTER"]["DEFAULT"]
        print("Update threshold is: ", self.memory_bank)

    def initialize(self, image, info: dict):


        # get the initial templates
        z_patch_arr, resize_factor = sample_target(image, info['init_bbox'], self.params.template_factor,
                                                   output_sz=self.params.template_size)
        z_patch_arr = z_patch_arr
        template = self.preprocessor.process(z_patch_arr)
        self.template_list = [template] * self.num_template

        self.state = info['init_bbox']
        prev_box_crop = transform_image_to_crop(torch.tensor(info['init_bbox']),
                                                torch.tensor(info['init_bbox']),
                                                resize_factor,
                                                torch.Tensor([self.params.template_size, self.params.template_size]),
                                                normalize=True)
        self.template_anno_list = [prev_box_crop.to(template.device).unsqueeze(0)] * self.num_template
        self.frame_id = 0
        self.memory_template_list = self.template_list.copy()
        self.memory_template_anno_list = self.template_anno_list.copy()


    def track(self, image, info: dict = None):
        H, W, _ = image.shape
        self.frame_id += 1
        x_patch_arr, resize_factor = sample_target(image, self.state, self.params.search_factor,
                                                   output_sz=self.params.search_size)  # (x1, y1, w, h)
        search = self.preprocessor.process(x_patch_arr)
        search_list = [search]

        # run the encoder

        with torch.no_grad():
            out_dict = self.network.forward(
                template_list=self.template_list,
                search_list=search_list,
                template_anno_list=self.template_anno_list,
                
                gt_score_map=None
            )

        

        # add hann windows
        pred_score_map = out_dict['score_map']
        if self.cfg["TEST"]["WINDOW"] == True:  # for window penalty
            response = self.output_window * pred_score_map
        else:
            response = pred_score_map
        if 'size_map' in out_dict.keys():
            pred_boxes, conf_score = self.network.decoder.cal_bbox(response, out_dict['size_map'],
                                                                   out_dict['offset_map'], return_score=True)
        else:
            pred_boxes, conf_score = self.network.decoder.cal_bbox(response,
                                                                   out_dict['offset_map'],
                                                                   return_score=True)
        pred_boxes = pred_boxes.view(-1, 4)
        # Baseline: Take the mean of all pred boxes as the final result
        pred_box = (pred_boxes.mean(dim=0) * self.params.search_size / resize_factor).tolist()  # (cx, cy, w, h) [0,1]
        # get the final box result
        self.state = clip_box(self.map_box_back(pred_box, resize_factor), H, W, margin=10)
        # update hiden state
        # self.h_state = h
        # if conf_score.item() < self.update_h_t:
        #     self.h_state = [None] * self.cfg["MODEL"]["NECK"]["N_LAYERS"]

        # update the template
        if self.num_template > 1:
            if (conf_score > self.update_threshold):
                z_patch_arr, resize_factor = sample_target(image, self.state, self.params.template_factor,
                                                           output_sz=self.params.template_size)
                template = self.preprocessor.process(z_patch_arr)
                self.memory_template_list.append(template)
                prev_box_crop = transform_image_to_crop(torch.tensor(self.state),
                                                        torch.tensor(self.state),
                                                        resize_factor,
                                                        torch.Tensor(
                                                            [self.params.template_size, self.params.template_size]),
                                                        normalize=True)
                self.memory_template_anno_list.append(prev_box_crop.to(template.device).unsqueeze(0))
                if len(self.memory_template_list) > self.memory_bank:
                    self.memory_template_list.pop(0)
                    self.memory_template_anno_list.pop(0)
        if (self.frame_id % self.update_intervals == 0):
            assert len(self.memory_template_anno_list) == len(self.memory_template_list)
            len_list = len(self.memory_template_anno_list)
            interval = len_list // self.num_template
            for i in range(1, self.num_template):
                idx = interval * i
                if idx > len_list:
                    idx = len_list
                self.template_list.append(self.memory_template_list[idx])
                self.template_list.pop(1)
                self.template_anno_list.append(self.memory_template_anno_list[idx])
                self.template_anno_list.pop(1)
        assert len(self.template_list) == self.num_template



        return {"target_bbox": self.state,
                "best_score": conf_score}

    def map_box_back(self, pred_box: list, resize_factor: float):
        cx_prev, cy_prev = self.state[0] + 0.5 * self.state[2], self.state[1] + 0.5 * self.state[3]
        cx, cy, w, h = pred_box
        half_side = 0.5 * self.params.search_size / resize_factor
        cx_real = cx + (cx_prev - half_side)
        cy_real = cy + (cy_prev - half_side)
        return [cx_real - 0.5 * w, cy_real - 0.5 * h, w, h]

    def map_box_back_batch(self, pred_box: torch.Tensor, resize_factor: float):
        cx_prev, cy_prev = self.state[0] + 0.5 * self.state[2], self.state[1] + 0.5 * self.state[3]
        cx, cy, w, h = pred_box.unbind(-1)  # (N,4) --> (N,)
        half_side = 0.5 * self.params.search_size / resize_factor
        cx_real = cx + (cx_prev - half_side)
        cy_real = cy + (cy_prev - half_side)
        return torch.stack([cx_real - 0.5 * w, cy_real - 0.5 * h, w, h], dim=-1)

class MCITrack(nn.Module):
    
    def __init__(self, encoder, decoder, neck, cfg,
                 num_frames=1, num_template=1, decoder_type="CENTER"):
      
        super().__init__()
        self.encoder = encoder
        self.decoder_type = decoder_type
        self.neck = neck

        self.num_patch_x = self.encoder.body.num_patches_search
        self.num_patch_z = self.encoder.body.num_patches_template
        self.fx_sz = int(math.sqrt(self.num_patch_x))
        self.fz_sz = int(math.sqrt(self.num_patch_z))

        self.decoder = decoder

        self.num_frames = num_frames
        self.num_template = num_template
        self.freeze_en = cfg["TRAIN"]["FREEZE_ENCODER"]
        self.interaction_indexes = cfg["MODEL"]["ENCODER"]["INTERACTION_INDEXES"]

    def forward(self, template_list, search_list, template_anno_list, gt_score_map=None):
    
        # Step 1: Forward pass through the encoder
        
        neck_h_state=[None,None,None,None]
        
        xz = self.encoder(template_list, search_list, template_anno_list)

        # Step 2: Forward pass through the neck
        xs = xz[:, 0:self.num_patch_x]  # Extract patch embeddings
        x, xs, h = self.neck(
            xz, xs, neck_h_state, 
            self.encoder.body.blocks, 
            self.interaction_indexes
        )
        x = self.encoder.body.fc_norm(x)
        xs = xs + x[:, 0:self.num_patch_x]  # Updated patch embeddings

        # Step 3: Forward pass through the decoder
        bs, HW, C = xs.size()
        if self.decoder_type in ['CORNER', 'CENTER']:
            xs = xs.permute((0, 2, 1)).contiguous()
            xs = xs.view(bs, C, self.fx_sz, self.fx_sz)

        if self.decoder_type == "CORNER":
            # Run the corner head
            pred_box, score_map = self.decoder(xs, True)
            outputs_coord = box_xyxy_to_cxcywh(pred_box)
            outputs_coord_new = outputs_coord.view(bs, 1, 4)
            return {
                'pred_boxes': outputs_coord_new,
                'score_map': score_map
            }

        elif self.decoder_type == "CENTER":
            # Run the center head
            score_map_ctr, bbox, size_map, offset_map = self.decoder(xs, gt_score_map)
            outputs_coord = bbox
            outputs_coord_new = outputs_coord.view(bs, 1, 4)
            return {
                'pred_boxes': outputs_coord_new,
                'score_map': score_map_ctr,
                'size_map': size_map,
                'offset_map': offset_map
            }

        elif self.decoder_type == "MLP":
            # Run the MLP head
            score_map, bbox, offset_map = self.decoder(xs, gt_score_map)
            outputs_coord = bbox
            outputs_coord_new = outputs_coord.view(bs, 1, 4)
            return {
                'pred_boxes': outputs_coord_new,
                'score_map': score_map,
                'offset_map': offset_map
            }

        else:
            raise NotImplementedError(f"Decoder type not supported: {self.decoder_type}")

def build_mcitrack(cfg):
    encoder = build_encoder(cfg)
    neck = build_neck(cfg,encoder)
    decoder = build_decoder(cfg, neck)
    model = MCITrack(
        encoder,
        decoder,
        neck,
        cfg,
        num_frames = cfg["DATA"]["SEARCH"]["NUMBER"],
        num_template = cfg["DATA"]["TEMPLATE"]["NUMBER"],
        decoder_type=cfg["MODEL"]["DECODER"]["TYPE"],
    )
    return model

def get_tracker_class():
    return MCITRACK



In [4]:
cfg = {}

# MODEL
cfg["MODEL"] = {}

# MODEL.ENCODER
cfg["MODEL"]["ENCODER"] = {
    "TYPE": "dinov2_vitb14",  # encoder model
    "DROP_PATH": 0,
    "PRETRAIN_TYPE": "mae",  # mae, default, or scratch. This parameter is not activated for dinov2.
    "USE_CHECKPOINT": False,  # to save the memory.
    "STRIDE": 14,
    "POS_TYPE": 'interpolate',  # type of loading the positional encoding. "interpolate" or "index".
    "TOKEN_TYPE_INDICATE": False,  # add a token_type_embedding to indicate the search, template_foreground, template_background
    "INTERACTION_INDEXES": [[0, 6], [6, 12], [12, 18], [18, 24]],
    "GRAD_CKPT": False
}

# MODEL.NECK
cfg["MODEL"]["NECK"] = {
    "N_LAYERS": 4,
    "D_MODEL": 512,
    "D_STATE": 16  # MAMABA_HIDDEN_STATE
}

# MODEL.DECODER
cfg["MODEL"]["DECODER"] = {
    "TYPE": "CENTER",  # MLP, CORNER, CENTER
    "NUM_CHANNELS": 256
}

# TRAIN
cfg["TRAIN"] = {
    "LR": 0.0001,
    "WEIGHT_DECAY": 0.0001,
    "EPOCH": 500,
    "LR_DROP_EPOCH": 400,
    "BATCH_SIZE": 8,
    "NUM_WORKER": 8,
    "OPTIMIZER": "ADAMW",
    "ENCODER_MULTIPLIER": 0.1,  # encoder's LR = this factor * LR
    "FREEZE_ENCODER": False,  # for freezing the parameters of encoder
    "ENCODER_OPEN": [],  # only for debug, open some layers of encoder when FREEZE_ENCODER is True
    "CE_WEIGHT": 1.0,  # weight for cross-entropy loss
    "GIOU_WEIGHT": 2.0,
    "L1_WEIGHT": 5.0,
    "PRINT_INTERVAL": 50,  # interval to print the training log
    "GRAD_CLIP_NORM": 0.1,
    "FIX_BN": False,
    "ENCODER_W": "",
    "TYPE": "normal",  # normal, peft or fft
    "PRETRAINED_PATH": None
}

# TRAIN.SCHEDULER
cfg["TRAIN"]["SCHEDULER"] = {
    "TYPE": "step",
    "DECAY_RATE": 0.1
}

# DATA
cfg["DATA"] = {
    "MEAN": [0.485, 0.456, 0.406],
    "STD": [0.229, 0.224, 0.225],
    "MAX_SAMPLE_INTERVAL": 200,
    "SAMPLER_MODE": "order",
    "LOADER": "tracking"
}

# DATA.TRAIN
cfg["DATA"]["TRAIN"] = {
    "DATASETS_NAME": ["LASOT", "GOT10K_vottrain"],
    "DATASETS_RATIO": [1, 1],
    "SAMPLE_PER_EPOCH": 60000
}

# DATA.SEARCH
cfg["DATA"]["SEARCH"] = {
    "NUMBER": 1,  # number of search region, only support 1 for now.
    "SIZE": 256,
    "FACTOR": 4.0,
    "CENTER_JITTER": 3.5,
    "SCALE_JITTER": 0.5
}

# DATA.TEMPLATE
cfg["DATA"]["TEMPLATE"] = {
    "NUMBER": 1,
    "SIZE": 128,
    "FACTOR": 2.0,
    "CENTER_JITTER": 0,
    "SCALE_JITTER": 0
}

# TEST
cfg["TEST"] = {
    "TEMPLATE_FACTOR": 4.0,
    "TEMPLATE_SIZE": 256,
    "SEARCH_FACTOR": 2.0,
    "SEARCH_SIZE": 128,
    "EPOCH": 500,
    "WINDOW": False,  # window penalty
    "NUM_TEMPLATES": 1
}

# TEST.UPT
cfg["TEST"]["UPT"] = {
    "DEFAULT": 1,
    "LASOT": 0,
    "LASOT_EXTENSION_SUBSET": 0,
    "TRACKINGNET": 0,
    "TNL2K": 0,
    "NFS": 0,
    "UAV": 0,
    "VOT20": 0,
    "GOT10K_TEST": 0
}

# TEST.UPH
cfg["TEST"]["UPH"] = {
    "DEFAULT": 1,
    "LASOT": 0,
    "LASOT_EXTENSION_SUBSET": 0,
    "TRACKINGNET": 0,
    "TNL2K": 0,
    "NFS": 0,
    "UAV": 0,
    "VOT20": 0,
    "GOT10K_TEST": 0
}

# TEST.INTER
cfg["TEST"]["INTER"] = {
    "DEFAULT": 999999,
    "LASOT": 0,
    "LASOT_EXTENSION_SUBSET": 0,
    "TRACKINGNET": 0,
    "TNL2K": 0,
    "NFS": 0,
    "UAV": 0,
    "VOT20": 0,
    "GOT10K_TEST": 0
}

# TEST.MB
cfg["TEST"]["MB"] = {
    "DEFAULT": 500,
    "LASOT": 0,
    "LASOT_EXTENSION_SUBSET": 0,
    "TRACKINGNET": 0,
    "TNL2K": 0,
    "NFS": 0,
    "UAV": 0,
    "VOT20": 0,
    "GOT10K_TEST": 0
}

In [5]:
#Params
class TrackerParams:
    """Class for tracker parameters."""
    def set_default_values(self, default_vals: dict):
        for name, val in default_vals.items():
            if not hasattr(self, name):
                setattr(self, name, val)

    def get(self, name: str, *default):
        """Get a parameter value with the given name. If it does not exists, it return the default value given as a
        second argument or returns an error if no default value is given."""
        if len(default) > 1:
            raise ValueError('Can only give one default value.')

        if not default:
            return getattr(self, name)

        return getattr(self, name, default[0])

    def has(self, name: str):
        """Check if there exist a parameter with the given name."""
        return hasattr(self, name)

def _update_config(base_cfg, exp_cfg):
    if isinstance(base_cfg, dict) and isinstance(exp_cfg, dict):
        for k, v in exp_cfg.items():
            if k in base_cfg:
                if not isinstance(v, dict):
                    base_cfg[k] = v
                else:
                    _update_config(base_cfg[k], v)
            else:
                raise ValueError("{} not exist in config.py".format(k))
    else:
        return

def update_config_from_file(filename):
    exp_config = None
    with open(filename) as f:
        exp_config = yaml.safe_load(f)
        _update_config(cfg, exp_config)
    
def parameters(yaml_name: str):
    params = TrackerParams()

    yaml_file = "mcitrack_t224.yaml"
    update_config_from_file(yaml_file)
    params.cfg = cfg
    print("test config: ", cfg)

    params.yaml_name = yaml_name
    # template and search region
    params.template_factor = cfg["TEST"]["TEMPLATE_FACTOR"]
    params.template_size = cfg["TEST"]["TEMPLATE_SIZE"]
    params.search_factor = cfg["TEST"]["SEARCH_FACTOR"]
    params.search_size = cfg["TEST"]["SEARCH_SIZE"]

    # Network checkpoint path
    params.checkpoint = "fast_itpn_tiny_1600e_1k.pt"
    # whether to save boxes from all queries
    params.save_all_boxes = False

    return params

params = parameters("./mcitrack_t224.yaml")

test config:  {'MODEL': {'ENCODER': {'TYPE': 'fastitpnt', 'DROP_PATH': 0.1, 'PRETRAIN_TYPE': './fast_itpn_tiny_1600e_1k.pt', 'USE_CHECKPOINT': False, 'STRIDE': 16, 'POS_TYPE': 'index', 'TOKEN_TYPE_INDICATE': True, 'INTERACTION_INDEXES': [[4, 7], [7, 10], [10, 13], [13, 16]], 'GRAD_CKPT': False}, 'NECK': {'N_LAYERS': 4, 'D_MODEL': 384, 'D_STATE': 16}, 'DECODER': {'TYPE': 'CENTER', 'NUM_CHANNELS': 256}}, 'TRAIN': {'LR': 0.0004, 'WEIGHT_DECAY': 0.0001, 'EPOCH': 300, 'LR_DROP_EPOCH': 240, 'BATCH_SIZE': 64, 'NUM_WORKER': 10, 'OPTIMIZER': 'ADAMW', 'ENCODER_MULTIPLIER': 0.1, 'FREEZE_ENCODER': False, 'ENCODER_OPEN': [], 'CE_WEIGHT': 1.0, 'GIOU_WEIGHT': 2.0, 'L1_WEIGHT': 5.0, 'PRINT_INTERVAL': 50, 'GRAD_CLIP_NORM': 0.1, 'FIX_BN': False, 'ENCODER_W': '', 'TYPE': 'normal', 'PRETRAINED_PATH': None, 'SCHEDULER': {'TYPE': 'step', 'DECAY_RATE': 0.1}}, 'DATA': {'MEAN': [0.485, 0.456, 0.406], 'STD': [0.229, 0.224, 0.225], 'MAX_SAMPLE_INTERVAL': 400, 'SAMPLER_MODE': 'order', 'LOADER': 'tracking', 'TRAIN

In [6]:
treacker = MCITRACK(params)

Update threshold is:  500


In [None]:
network = build_mcitrack(params.cfg)
network.load_state_dict(torch.load("MCITRACK_ep0300.pth.tar", map_location='cpu')['net'], strict=True)
cfg = params.cfg
network = network.cuda()
network.eval()

In [7]:
list1 = [torch.zeros(1, 3, 112, 112).to('cuda') for _ in range(5)]  # 5 тензоров размером [1, 3, 112, 112]
list2 = [torch.zeros(1, 3, 224, 224).to('cuda')]                    # 1 тензор размером [1, 3, 224, 224]
list3 = [torch.zeros(1, 4).to('cuda') for _ in range(5)]

In [8]:
res = network.forward(list1,list2,list3)

In [None]:
for i in  res:
    print(i)
    print(res[i].shape)

In [None]:
#ModelWrapper jit
class ModelWrapper(torch.nn.Module):
    def __init__(self, original_model):
        super(ModelWrapper, self).__init__()
        self.original_model = original_model

    def forward(self, template_list, search_list,template_anno_list):
        
        output_dict = self.original_model(template_list, search_list,template_anno_list)
        
        
        return (output_dict['pred_boxes'],
                output_dict['score_map'],
                output_dict['size_map'],
                output_dict['offset_map'])


model = network
model.eval()


wrapped_model = ModelWrapper(model)


template_list = [torch.zeros(1, 3, 112, 112).to('cuda') for _ in range(5)]  # 5 тензоров размером [1, 3, 112, 112]
search_list = [torch.zeros(1, 3, 224, 224).to('cuda')]                    # 1 тензор размером [1, 3, 224, 224]
template_anno_list = [torch.zeros(1, 4).to('cuda') for _ in range(5)]

traced_model = torch.jit.trace(wrapped_model, (template_list, search_list,template_anno_list))


optimized_model = torch.jit.optimize_for_inference(traced_model)


optimized_model.save("MCITrack.pt")


loaded_model = torch.jit.load("MCITrack.pt")

with torch.no_grad():
    outputs = loaded_model(template_list, search_list,template_anno_list)

for output in outputs:
    print(output)


  x = x.transpose(1,2).view(B,C,int(N**0.5),int(N**0.5)).contiguous()


In [11]:
#ModelWrapper onnx
import torch


class ModelWrapper(torch.nn.Module):
    def __init__(self, original_model):
        super(ModelWrapper, self).__init__()
        self.original_model = original_model

    def forward(self, template_list, search_list,template_anno_list):
        
        output_dict = self.original_model(template_list, search_list,template_anno_list)
        
        
        return (output_dict['pred_boxes'],
                output_dict['score_map'],
                output_dict['size_map'],
                output_dict['offset_map'])



model = network
model.eval()


wrapped_model = ModelWrapper(model)

template_list = [torch.zeros(1, 3, 112, 112).to('cuda') for _ in range(5)]  # 5 тензоров размером [1, 3, 112, 112]
search_list = [torch.zeros(1, 3, 224, 224).to('cuda')]                    # 1 тензор размером [1, 3, 224, 224]
template_anno_list = [torch.zeros(1, 4).to('cuda') for _ in range(5)]

# Важно: для onnx-модели модель должна быть на cpu или cuda, и входы должны быть на том же устройстве.
wrapped_model = wrapped_model.to('cuda')
wrapped_model.eval()

# Указываем пути для сохранения
onnx_path = "MCITrack.onnx"

# Экспортируем модель в ONNX
torch.onnx.export(
    wrapped_model,                                   # Модель
    (template_list, search_list,template_anno_list),                                # Входные данные (tuple)
    onnx_path,                                       # Имя файла
    export_params=True,                              # Экспортировать параметры (веса)
    opset_version=17,                                # Версия ONNX opset
    do_constant_folding=True,                        # Оптимизация констант
    input_names = ['template_list', 'search_list','template_anno_list'],                        # Имена входов
    output_names = ['pred_boxes','score_map','size_map','offset_map'],                   # Имена выходов
    #dynamic_axes={'z': {0: 'batch_size'},            # Динамическая ось для батча
    #              'x': {0: 'batch_size'},
    #              'pred_boxes': {0: 'batch_size'}},
    verbose=True                                     # Показывать подробности
)

print(f'Model has been exported to {onnx_path}')

  x = x.transpose(1,2).view(B,C,int(N**0.5),int(N**0.5)).contiguous()


Model has been exported to MCITrack.onnx


In [None]:
#TRT
import tensorrt as trt
#trtexec --onnx=MCITrack.onnx  --saveEngine=MCITrack.trt  --fp16
""" 

def build_engine(onnx_path, trt_path):
    logger = trt.Logger(trt.Logger.WARNING)
    builder = trt.Builder(logger)
    network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
    parser = trt.OnnxParser(network, logger)
    
    with open(onnx_path, 'rb') as model:
        if not parser.parse(model.read()):
            for error in range(parser.num_errors):
                print(parser.get_error(error))
            return None
    
    config = builder.create_builder_config()
    config.max_workspace_size = 1 << 30  # 1GB
    engine = builder.build_engine(network, config)
    
    with open(trt_path, 'wb') as f:
        f.write(engine.serialize())
    
    return engine  """

In [7]:
#Трекинг по видео
file = "0516.mp4"
video = cv2.VideoCapture(file)
#fourcc = cv2.VideoWriter_fourcc(*'XVID')
#fps=video.get(cv2.CAP_PROP_FPS)
#video_vriter = cv2.VideoWriter(file.split('.')[0]+"_"+".avi", fourcc, fps, (1920, 1080))


ok, image = video.read()
if not video.isOpened():
    print("Could not open video")
    sys.exit()
    
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

x, y, w, h = cv2.selectROI( image, fromCenter=False)
init_state = [x, y, w, h]
def _build_init_info(box):
            return {'init_bbox': box}
treacker.initialize(image, _build_init_info(init_state))
counter = 0
while True:
            ok, image = video.read()
            if not ok:
                print("Can't read frame")
                break

            
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            start = time.time() 
            out  = treacker.track(image)
            state = [int(s) for s in out['target_bbox']]
            best_score=out["best_score"].cpu().numpy()[0][0]
            end_time = (time.time() - start)
            
            
            org = (50, 50)

            # fontScale
            fontScale = 1
            font = cv2.FONT_HERSHEY_SIMPLEX
            # Blue color in BGR
            color = (255, 0, 0)
            # Line thickness of 2 px
            thickness = 2              
            # Using cv2.putText() method
            image = cv2.putText(image, str(best_score), org, font, 
                            fontScale, color, thickness, cv2.LINE_AA)
            image = cv2.putText(image, str(end_time), (50,100), font, 
                            fontScale, color, thickness, cv2.LINE_AA)

            x, y, w, h = [int(x) for x in state]

            color = (0, 0, 255)  # Цвет в формате BGR
            cv2.rectangle(image, (x, y), (x + w, y + h), color, 2)


            cv2.imshow("tracking", image)
            #video_vriter.write(image)


            k = cv2.waitKey(1)            
            if k == 32:  # SPACE
                ok, image = video.read()                             
                x, y, w, h = cv2.selectROI( image, fromCenter=False)
                init_state = [x, y, w, h]
                treacker.initialize(image, _build_init_info(init_state))
            if k == 27:  # ESC
                break
        
                
                

cv2.destroyAllWindows()
#video.release()
#video_vriter.release()



In [8]:
#Метрики
import numpy as np

def iou(boxA, boxB):
    # boxA, boxB: [x, y, w, h]
    xA = max(boxA[0], boxB[0])
    yA = max(boxA[1], boxB[1])
    xB = min(boxA[0] + boxA[2], boxB[0] + boxB[2])
    yB = min(boxA[1] + boxA[3], boxB[1] + boxB[3])

    interW = max(0, xB - xA)
    interH = max(0, yB - yA)
    interArea = interW * interH

    boxAArea = boxA[2] * boxA[3]
    boxBArea = boxB[2] * boxB[3]
    unionArea = boxAArea + boxBArea - interArea

    if unionArea == 0:
        return 0.0
    return interArea / unionArea

def precision(boxA, boxB):
    # центры bbox
    centerA = (boxA[0] + boxA[2]/2, boxA[1] + boxA[3]/2)
    centerB = (boxB[0] + boxB[2]/2, boxB[1] + boxB[3]/2)
    dist = np.sqrt((centerA[0] - centerB[0])**2 + (centerA[1] - centerB[1])**2)
    return dist
sr_thresh = 0.5
prec_thresh = 20

In [10]:
#Трекинг got10k с метриками
import glob
import time
import  os
gt_bboxes = []
pred_bboxes = []
seq_path = "val/GOT-10k_Val_000006"
txt_files = glob.glob(os.path.join(seq_path, '*.txt'))
if not txt_files:
    raise FileNotFoundError(f"No .txt files found in {seq_path}")

img_files = sorted(glob.glob(os.path.join(seq_path, '*.jpg')))
with open(txt_files[0], 'r') as f:
    gt_bboxes = [list(map(float, line.strip().split(','))) for line in f]

# Получаем размер первого изображения
sample_img = cv2.imread(img_files[0])
if sample_img is None:
    raise ValueError(f"Failed to read sample image: {img_files[0]}")

#height, width = sample_img.shape[:2]
#fourcc = cv2.VideoWriter_fourcc(*'XVID')
#output_filename = f"{seq_path.split('/')[-1]}_output.avi"
#video_vriter = cv2.VideoWriter(output_filename, fourcc, 10, (width, height))  

assert len(img_files) == len(gt_bboxes), "Количество кадров и bbox'ов не совпадает"

x, y, w, h = map(int, gt_bboxes[0])
init_state = [x, y, w, h]

def _build_init_info(box):
            return {'init_bbox': box}

counter = 0


treacker.initialize(sample_img, _build_init_info(init_state))

start_time = time.time()  # Начало замера

for img_file, bbox in zip(img_files, gt_bboxes):
        
        # Читаем изображение
        img = cv2.imread(img_file)
        if img is None:
            print(f"Не удалось загрузить изображение: {img_file}")
            continue
        
        
        out  = treacker.track(img)
        state = [int(s) for s in out['target_bbox']]   
                           
        # Рисуем bounding box        
        x, y, w, h = [int(x) for x in state]

        cv2.rectangle(img, (x, y), (x + w, y + h), (0, 0, 200), 2)
        
        x1, y1, w1, h1 = map(int, bbox)
        cv2.rectangle(img, (x1, y1), (x1+w1, y1+h1), (0, 200, 0), 2)
        bbox_pred = x, y, w, h
        
        gt_bboxes.append(bbox)
        pred_bboxes.append(bbox_pred)

        #cv2.imshow(seq_path, img)
        #video_vriter.write(img)
        counter+=1


        # Выход по нажатию 'q' или ESC
        key = cv2.waitKey(1) & 0xFF
        if key == ord('q') or key == 27:
            break
       
        
                
end_time = time.time()    # Конец замера    
total_frames = counter       # Общее количество обработанных кадров
total_time = end_time - start_time
fps = total_frames / total_time
ious = [iou(gt, pred) for gt, pred in zip(gt_bboxes, pred_bboxes)]
ao = np.mean(ious)
sr = np.mean([1 if val >= sr_thresh else 0 for val in ious])
precisions = [precision(gt, pred) for gt, pred in zip(gt_bboxes, pred_bboxes)]
prec = np.mean([1 if d <= prec_thresh else 0 for d in precisions])

print(f"FPS: {fps:.2f}")
print(f'Average Overlap (AO): {ao:.2f}')
print(f'Success Rate (SR@0.5): {sr:.2f}')
print(f'Precision @20px: {prec:.2f}')

cv2.destroyAllWindows()
#video_vriter.release()
#print(f"Video saved as: {output_filename}")

FPS: 28.16
Average Overlap (AO): 0.92
Success Rate (SR@0.5): 1.00
Precision @20px: 0.99


In [1]:
import torch
print(f"CUDA доступна: {torch.cuda.is_available()}")
print(f"Версия CUDA (PyTorch): {torch.version.cuda}")
print(f"PyTorch version: {torch.__version__}")    
try:
    import tensorrt as trt
    print(f"TensorRT версия: {trt.__version__}")
except:
    print("TensorRT не установлен")
    
try:
    import onnx
    print(f"ONNX версия: {onnx.__version__}")
except:
    print("ONNX не установлен") 
print(f"cuDNN включён в PyTorch: {torch.backends.cudnn.enabled}")
print(f"Версия cuDNN (из PyTorch): {torch.backends.cudnn.version()}")
try:
    import onnxruntime as ort
    print(f"Версия onnxruntime: {ort.__version__}")
    print(f"Device onnxruntime: {ort.get_device()}")  # Должно вернуть 'GPU'
except:
    print("onnxruntime не установлен")
try:
    import tensorflow as tf
    print(f"Версия tensorflow: {tf.__version__}")
except:
    print("tensorflow не установлен")


CUDA доступна: True
Версия CUDA (PyTorch): 12.8
PyTorch version: 2.7.1+cu128
TensorRT версия: 10.11.0.33
ONNX версия: 1.18.0
cuDNN включён в PyTorch: True
Версия cuDNN (из PyTorch): 90701
Версия onnxruntime: 1.20.1
Device onnxruntime: GPU
tensorflow не установлен


In [None]:
# https://pytorch.org/get-started/locally/

In [2]:
import onnxruntime as ort

# Проверка провайдеров
print("Доступные провайдеры:", ort.get_available_providers())

# Создание сессии с TensorRT
try:
    sess = ort.InferenceSession(
        "MCITrack.onnx",
        providers=['CUDAExecutionProvider'],
        provider_options = [{}],
        )
    print("Используется:", sess.get_providers())
except Exception as e:
    print("Ошибка:", e)

Доступные провайдеры: ['TensorrtExecutionProvider', 'CUDAExecutionProvider', 'CPUExecutionProvider']
Используется: ['CPUExecutionProvider']


In [1]:
import pycuda.driver as cuda
import pycuda.autoinit  # Автоматически инициализирует GPU

# Проверка количества устройств
print(f"Доступно GPU: {cuda.Device.count()}")

# Информация о GPU
gpu = cuda.Device(0)
print(f"Название GPU: {gpu.name()}")
print(f"Вычислительная способность: {gpu.compute_capability()}")
print(f"Общая память: {gpu.total_memory() / 1024**2:.2f} МБ")

# Проверка контекста (должен быть создан pycuda.autoinit)
ctx = cuda.Context.get_current()
print(f"Контекст GPU активен: {ctx is not None}")

Доступно GPU: 1
Название GPU: NVIDIA GeForce RTX 3060 Ti
Вычислительная способность: (8, 6)
Общая память: 8191.50 МБ
Контекст GPU активен: True
