diff --git a/tests/test_models.py b/tests/test_models.py index 3ba3615db4..d6d6e3b6f7 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -53,13 +53,13 @@ 'vision_transformer', 'vision_transformer_sam', 'vision_transformer_hybrid', 'vision_transformer_relpos', 'beit', 'mvitv2', 'eva', 'cait', 'xcit', 'volo', 'twins', 'deit', 'swin_transformer', 'swin_transformer_v2', 'swin_transformer_v2_cr', 'maxxvit', 'efficientnet', 'mobilenetv3', 'levit', 'efficientformer', 'resnet', - 'regnet', 'byobnet', 'byoanet', 'mlp_mixer', 'hiera', 'fastvit', 'hieradet_sam2', 'aimv2*' + 'regnet', 'byobnet', 'byoanet', 'mlp_mixer', 'hiera', 'fastvit', 'hieradet_sam2', 'aimv2*', 't2t_vit', ] # transformer / hybrid models don't support full set of spatial / feature APIs and/or have spatial output. NON_STD_FILTERS = [ 'vit_*', 'tnt_*', 'pit_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*', - 'convit_*', 'levit*', 'visformer*', 'deit*', 'xcit_*', 'crossvit_*', 'beit*', 'aimv2*', + 'convit_*', 'levit*', 'visformer*', 'deit*', 'xcit_*', 'crossvit_*', 'beit*', 'aimv2*', 't2t_vit*', 'poolformer_*', 'volo_*', 'sequencer2d_*', 'mvitv2*', 'gcvit*', 'efficientformer*', 'sam_hiera*', 'eva_*', 'flexivit*', 'eva02*', 'samvit_*', 'efficientvit_m*', 'tiny_vit_*', 'hiera_*', 'vitamin*', 'test_vit*', ] diff --git a/timm/models/__init__.py b/timm/models/__init__.py index 3db5af6049..3613209146 100644 --- a/timm/models/__init__.py +++ b/timm/models/__init__.py @@ -64,6 +64,7 @@ from .swin_transformer import * from .swin_transformer_v2 import * from .swin_transformer_v2_cr import * +from .t2t_vit import * from .tiny_vit import * from .tnt import * from .tresnet import * diff --git a/timm/models/t2t_vit.py b/timm/models/t2t_vit.py new file mode 100644 index 0000000000..3de255cabe --- /dev/null +++ b/timm/models/t2t_vit.py @@ -0,0 +1,702 @@ +"""T2T-ViT +Paper: `Tokens-to-Token ViT: Training Vision Transformers From Scratch on ImageNet` + - https://arxiv.org/pdf/2101.11986 + - https://openaccess.thecvf.com/content/ICCV2021/papers/Yuan_Tokens-to-Token_ViT_Training_Vision_Transformers_From_Scratch_on_ImageNet_ICCV_2021_paper.pdf + +Model from official source: + - https://github.com/yitu-opensource/T2T-ViT + +@InProceedings{Yuan_2021_ICCV, + author = {Yuan, Li and Chen, Yunpeng and Wang, Tao and Yu, Weihao and Shi, Yujun and Jiang, Zi-Hang and Tay, Francis E.H. and Feng, Jiashi and Yan, Shuicheng}, + title = {Tokens-to-Token ViT: Training Vision Transformers From Scratch on ImageNet}, + booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)}, + month = {October}, + year = {2021}, + pages = {558-567} +} + +Original implementation by Wenhui Yuan et al., +adapted for timm by Ryan Hou and Ross Wightman, original copyright below +""" +# Copyright (c) [2012]-[2021] Shanghai Yitu Technology Co., Ltd. +# +# This source code is licensed under the Clear BSD License +# LICENSE file in the root directory of this file +# All rights reserved. + +import math +from functools import partial +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union +try: + from typing import Literal +except ImportError: + from typing_extensions import Literal + +import torch +import torch.nn as nn + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from timm.layers import Mlp, LayerNorm, DropPath, trunc_normal_, to_2tuple + +from ._builder import build_model_with_cfg +from ._features import feature_take_indices +from ._manipulate import checkpoint +from ._registry import generate_default_cfgs, register_model + +def get_sinusoid_encoding(n_position: int, d_hid: int) -> torch.Tensor: + ''' Sinusoid position encoding table using PyTorch ''' + + # Create a position tensor of shape (n_position, 1) + position = torch.arange(n_position, dtype=torch.float32).unsqueeze(1) + + # Compute the divisor term: 1 / (10000^(2i/d_hid)) + div_term = torch.exp(torch.arange(0, d_hid, 2).float() * (-torch.log(torch.tensor(10000.0)) / d_hid)) + + # Compute the sinusoid table + sinusoid_table = torch.zeros(n_position, d_hid) + sinusoid_table[:, 0::2] = torch.sin(position * div_term) # Apply sin to even indices + sinusoid_table[:, 1::2] = torch.cos(position * div_term) # Apply cos to odd indices + + return sinusoid_table.unsqueeze(0) # Add batch dimension + +class Token_attention(nn.Module): + def __init__( + self, + dim: int, + in_dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + qk_scale: Optional[float] = None, + attn_drop: float = 0., + proj_drop: float = 0., + ): + super().__init__() + self.num_heads = num_heads + self.in_dim = in_dim + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + self.qkv = nn.Linear(dim, in_dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(in_dim, in_dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, N, _ = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.in_dim).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + + attn = (q * self.scale) @ k.transpose(-2, -1) + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, self.in_dim) + x = self.proj(x) + x = self.proj_drop(x) + + # skip connection + x = v.squeeze(1) + x # because the original x has different size with current x, use v to do skip connection + return x + +class Token_transformer(nn.Module): + def __init__( + self, + dim: int, + in_dim: int, + num_heads: int = 1, + mlp_ratio: float = 1., + qkv_bias: bool = False, + qk_scale: Optional[float] = None, + drop_rate: float = 0., + drop_path: float = 0., + attn_drop: float = 0., + act_layer: Callable = nn.GELU, + norm_layer: Callable = partial(LayerNorm, eps=1e-5), + ): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Token_attention( + dim, + in_dim=in_dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop_rate + ) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(in_dim) + self.mlp = Mlp( + in_features=in_dim, + hidden_features=int(in_dim*mlp_ratio), + out_features=in_dim, + drop=drop_rate, + act_layer=act_layer, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.attn(self.norm1(x)) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + +class Token_performer(nn.Module): + def __init__( + self, + dim: int, + in_dim: int, + head_cnt: int = 1, + kernel_ratio: float = 0.5, + dp1: float = 0.1, + dp2: float = 0.1, + act_layer: Callable = nn.GELU, + norm_layer: Callable = partial(LayerNorm, eps=1e-5), + ): + super().__init__() + self.emb = in_dim * head_cnt # we use 1, so it is no need here + self.kqv = nn.Linear(dim, 3 * self.emb) + self.dp = nn.Dropout(dp1) + self.proj = nn.Linear(self.emb, self.emb) + self.norm1 = norm_layer(dim) + self.norm2 = norm_layer(self.emb) + self.epsilon = 1e-8 # for stable in division + + self.mlp = nn.Sequential( + nn.Linear(self.emb, 1 * self.emb), + act_layer(), + nn.Linear(1 * self.emb, self.emb), + nn.Dropout(dp2), + ) + + self.m = int(self.emb * kernel_ratio) + # self.w = torch.randn(self.m, self.emb) + # self.w = nn.Parameter(nn.init.orthogonal_(self.w) * math.sqrt(self.m), requires_grad=False) + self.register_buffer('w', nn.init.orthogonal_(torch.randn(self.m, self.emb)) * math.sqrt(self.m)) + + def prm_exp(self, x: torch.Tensor) -> torch.Tensor: + # part of the function is borrow from https://github.com/lucidrains/performer-pytorch + # and Simo Ryu (https://github.com/cloneofsimo) + # ==== positive random features for gaussian kernels ==== + # x = (B, T, hs) + # w = (m, hs) + # return : x : B, T, m + # SM(x, y) = E_w[exp(w^T x - |x|/2) exp(w^T y - |y|/2)] + # therefore return exp(w^Tx - |x|/2)/sqrt(m) + xd = ((x * x).sum(dim=-1, keepdim=True)).repeat(1, 1, self.m) / 2 + wtx = torch.einsum('bti,mi->btm', x.float(), self.w) + + return torch.exp(wtx - xd) / math.sqrt(self.m) + + def single_attn(self, x: torch.Tensor) -> torch.Tensor: + k, q, v = torch.split(self.kqv(x), self.emb, dim=-1) + kp, qp = self.prm_exp(k), self.prm_exp(q) # (B, T, m), (B, T, m) + D = torch.einsum('bti,bi->bt', qp, kp.sum(dim=1)).unsqueeze(dim=2) # (B, T, m) * (B, m) -> (B, T, 1) + kptv = torch.einsum('bin,bim->bnm', v.float(), kp) # (B, emb, m) + y = torch.einsum('bti,bni->btn', qp, kptv) / (D.repeat(1, 1, self.emb) + self.epsilon) # (B, T, emb)/Diag + # skip connection + y = v + self.dp(self.proj(y)) # same as token_transformer in T2T layer, use v as skip connection + return y + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.single_attn(self.norm1(x)) + x = x + self.mlp(self.norm2(x)) + return x + +class T2T_module(nn.Module): + """ + Tokens-to-Token encoding module + """ + def __init__( + self, + img_size: Optional[int] = 224, + patch_size: int = 16, + in_chans: int = 3, + embed_dim: int = 768, + token_dim: int = 64, + tokens_type: Literal['performer', 'transformer'] = 'performer', + act_layer: Callable = nn.GELU, + norm_layer: Callable = partial(LayerNorm, eps=1e-5), + ): + super().__init__() + self.patch_size = to_2tuple(patch_size) + self.img_size, self.grid_size, self.num_patches = self._init_img_size(img_size) + + self.soft_split0 = nn.Unfold(kernel_size=(7, 7), stride=(4, 4), padding=(2, 2)) + self.soft_split1 = nn.Unfold(kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) + self.soft_split2 = nn.Unfold(kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) + + token_module = Token_performer if tokens_type == 'performer' else Token_transformer + + self.attention1 = token_module(dim=in_chans * 7 * 7, in_dim=token_dim, act_layer=act_layer, norm_layer=norm_layer) + self.attention2 = token_module(dim=token_dim * 3 * 3, in_dim=token_dim, act_layer=act_layer, norm_layer=norm_layer) + self.project = nn.Linear(token_dim * 3 * 3, embed_dim) + + def _init_img_size(self, img_size: Union[int, Tuple[int, int]]): + assert self.patch_size + if img_size is None: + return None, None, None + img_size = to_2tuple(img_size) + grid_size = tuple([s // p for s, p in zip(img_size, self.patch_size)]) + num_patches = grid_size[0] * grid_size[1] + return img_size, grid_size, num_patches + + def feat_ratio(self, as_scalar: bool = True) -> Union[Tuple[int, int], int]: + if as_scalar: + return max(self.patch_size) + else: + return self.patch_size + + def dynamic_feat_size(self, img_size: Tuple[int, int]) -> Tuple[int, int]: + return img_size[0] // self.patch_size[0], img_size[1] // self.patch_size[1] + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, _, H, W = x.shape + # step0: soft split + x = self.soft_split0(x).transpose(1, 2) + + # iteration1: re-structurization/reconstruction + x = self.attention1(x) + x = x.transpose(1,2).reshape(B, -1, H // 4, W // 4) + # iteration1: soft split + x = self.soft_split1(x).transpose(1, 2) + + # iteration2: re-structurization/reconstruction + x = self.attention2(x) + x = x.transpose(1, 2).reshape(B, -1, H // 8, W // 8) + # iteration2: soft split + x = self.soft_split2(x).transpose(1, 2) + + # final tokens + x = self.project(x) + return x + +class Attention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + qk_scale: Optional[float] = None, + attn_drop: float = 0., + proj_drop: float = 0., + ): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + + self.scale = qk_scale or head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4., + qkv_bias: bool = False, + qk_scale: Optional[float] = None, + drop_rate: float = 0., + drop_path: float = 0., + attn_drop: float = 0., + act_layer: Callable = nn.GELU, + norm_layer: Callable = partial(LayerNorm, eps=1e-5), + ): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop_rate, + ) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + drop=drop_rate, + act_layer=act_layer, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x + self.drop_path(self.attn(self.norm1(x))) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + +class T2T_ViT(nn.Module): + def __init__( + self, + img_size: Union[int, Tuple[int, int]] = 224, + patch_size: Union[int, Tuple[int, int]] = 16, + tokens_type: Literal['performer', 'transformer'] = 'performer', + token_dim: int = 64, + in_chans: int = 3, + num_classes: int = 1000, + global_pool: str = 'token', + embed_dim: int = 768, + depth: int = 12, + num_heads: int = 12, + mlp_ratio: float = 4., + qkv_bias: bool = False, + qk_scale: Optional[float] = None, + drop_rate: float = 0., + attn_drop_rate: float = 0., + drop_path_rate: float = 0., + act_layer: Callable = nn.GELU, + norm_layer: Callable = partial(LayerNorm, eps=1e-5), + ): + super().__init__() + self.num_classes = num_classes + self.global_pool = global_pool + self.num_features = self.head_hidden_size = self.embed_dim = embed_dim # for consistency with other models + self.num_prefix_tokens = 1 + self.grad_checkpointing = False + + self.patch_embed = T2T_module( + img_size=img_size, + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + tokens_type=tokens_type, + token_dim=token_dim, + act_layer=act_layer, + norm_layer=norm_layer, + ) + num_patches = self.patch_embed.num_patches + r = self.patch_embed.feat_ratio() if hasattr(self.patch_embed, 'feat_ratio') else patch_size + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + # self.pos_embed = nn.Parameter(data=get_sinusoid_encoding(n_position=num_patches + 1, d_hid=embed_dim), requires_grad=False) + self.register_buffer('pos_embed', get_sinusoid_encoding(n_position=num_patches + 1, d_hid=embed_dim)) + self.pos_drop = nn.Dropout(p=drop_rate) + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + self.blocks = nn.ModuleList([ + Block( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop_rate=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[i], + act_layer=act_layer, + norm_layer=norm_layer, + ) + for i in range(depth)]) + self.feature_info = [ + dict(module=f'blocks.{i}', num_chs=embed_dim, reduction=r) for i in range(depth)] + + # class_head + use_fc_norm = False + self.global_pool = 'token' + self.norm = nn.Identity() if use_fc_norm else norm_layer(embed_dim) + self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity() + self.head_drop = nn.Dropout(drop_rate) + self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self) -> Set: + return {'cls_token'} + + @torch.jit.ignore + def set_grad_checkpointing(self, enable: bool = True): + self.grad_checkpointing = enable + + @torch.jit.ignore + def group_matcher(self, coarse: bool = False) -> Dict: + matcher = dict( + stem=r'^cls_token|pos_embed|patch_embed', # stem and embed + blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))], + ) + return matcher + + @torch.jit.ignore + def get_classifier(self) -> nn.Module: + return self.head + + def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): + self.num_classes = num_classes + if global_pool is not None: + self.global_pool = global_pool + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + def forward_intermediates( + self, + x: torch.Tensor, + indices: Optional[Union[int, List[int]]] = None, + return_prefix_tokens: bool = False, + norm: bool = False, + stop_early: bool = False, + output_fmt: str = 'NCHW', + intermediates_only: bool = False, + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: + """ Forward features that returns intermediates. + + Args: + x: Input image tensor + indices: Take last n blocks if an int, if is a sequence, select by matching indices + return_prefix_tokens: Return both prefix and spatial intermediate tokens + norm: Apply norm layer to all intermediates + stop_early: Stop iterating over blocks when last desired intermediate hit + output_fmt: Shape of intermediate feature outputs + intermediates_only: Only return intermediate features + Returns: + + """ + assert output_fmt in ('NCHW', 'NLC'), 'Output format must be one of NCHW or NLC.' + reshape = output_fmt == 'NCHW' + intermediates = [] + take_indices, max_index = feature_take_indices(len(self.blocks), indices) + + # forward pass + B, _, height, width = x.shape + + x = self.patch_embed(x) + x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) + x = x + self.pos_embed + x = self.pos_drop(x) + + if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript + blocks = self.blocks + else: + blocks = self.blocks[:max_index + 1] + + for i, blk in enumerate(blocks): + x = blk(x) + if i in take_indices: + # normalize intermediates with final norm layer if enabled + intermediates.append(self.norm(x) if norm else x) + + # process intermediates + if self.num_prefix_tokens: + # split prefix (e.g. class, distill) and spatial feature tokens + prefix_tokens = [y[:, 0:self.num_prefix_tokens] for y in intermediates] + intermediates = [y[:, self.num_prefix_tokens:] for y in intermediates] + + if reshape: + # reshape to BCHW output format + H, W = self.patch_embed.dynamic_feat_size((height, width)) + intermediates = [y.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for y in intermediates] + if not torch.jit.is_scripting() and return_prefix_tokens: + # return_prefix not support in torchscript due to poor type handling + intermediates = list(zip(intermediates, prefix_tokens)) + + if intermediates_only: + return intermediates + + x = self.norm(x) + + return x, intermediates + + def prune_intermediate_layers( + self, + indices: Union[int, List[int]] = 1, + prune_norm: bool = False, + prune_head: bool = True, + ): + """ Prune layers not required for specified intermediates. + """ + take_indices, max_index = feature_take_indices(len(self.blocks), indices) + self.blocks = self.blocks[:max_index + 1] # truncate blocks + if prune_norm: + self.norm = nn.Identity() + if prune_head: + self.fc_norm = nn.Identity() + self.reset_classifier(0, '') + return take_indices + + def forward_features(self, x: torch.Tensor) -> torch.Tensor: + x = self.patch_embed(x) + x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) + x = x + self.pos_embed + x = self.pos_drop(x) + + for blk in self.blocks: + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint(blk, x) + else: + x = blk(x) + x = self.norm(x) + return x + + def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor: + if self.global_pool: + x = x[:, self.num_prefix_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0] + x = self.fc_norm(x) + x = self.head_drop(x) + return x if pre_logits else self.head(x) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.forward_features(x) + x = self.forward_head(x) + return x + +def _cfg(url: str = '', **kwargs: Any) -> Dict[str, Any]: + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, + 'crop_pct': 0.9, 'interpolation': 'bicubic', 'fixed_input_size': True, + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'patch_embed.project', 'classifier': 'head', + 'paper_ids': 'https://arxiv.org/pdf/2101.11986', + 'paper_name': 'Tokens-to-Token ViT: Training Vision Transformers From Scratch on ImageNet', + 'origin_url': 'https://github.com/yitu-opensource/T2T-ViT', + **kwargs + } + +def checkpoint_filter_fn( + state_dict: Dict[str, torch.Tensor], + model: T2T_ViT, +) -> Dict[str, torch.Tensor]: + if 'state_dict_ema' in state_dict: + state_dict = state_dict['state_dict_ema'] + + if 'patch_embed.project.weight' in state_dict: + return state_dict + + out_dict = {} + for k, v in state_dict.items(): + k = k.replace('module.', '') + k = k.replace('tokens_to_token.', 'patch_embed.') + out_dict[k] = v + + return out_dict + +default_cfgs = generate_default_cfgs({ + 't2t_vit_7.in1k': _cfg( + # hf_hub_id='timm/', + url='https://github.com/yitu-opensource/T2T-ViT/releases/download/main/71.7_T2T_ViT_7.pth.tar', + ), + 't2t_vit_10.in1k': _cfg( + # hf_hub_id='timm/', + url='https://github.com/yitu-opensource/T2T-ViT/releases/download/main/75.2_T2T_ViT_10.pth.tar' + ), + 't2t_vit_12.in1k': _cfg( + # hf_hub_id='timm/', + url='https://github.com/yitu-opensource/T2T-ViT/releases/download/main/76.5_T2T_ViT_12.pth.tar' + ), + 't2t_vit_14.in1k': _cfg( + # hf_hub_id='timm/', + url='https://github.com/yitu-opensource/T2T-ViT/releases/download/main/81.5_T2T_ViT_14.pth.tar' + ), + 't2t_vit_19.in1k': _cfg( + # f_hub_id='timm/', + url='https://github.com/yitu-opensource/T2T-ViT/releases/download/main/81.9_T2T_ViT_19.pth.tar' + ), + 't2t_vit_24.in1k': _cfg( + # hf_hub_id='timm/', + url='https://github.com/yitu-opensource/T2T-ViT/releases/download/main/82.3_T2T_ViT_24.pth.tar' + ), + 't2t_vit_t_14.in1k': _cfg( + # hf_hub_id='timm/', + url='https://github.com/yitu-opensource/T2T-ViT/releases/download/main/81.7_T2T_ViTt_14.pth.tar' + ), + 't2t_vit_t_19.in1k': _cfg( + # hf_hub_id='timm/', + url='https://github.com/yitu-opensource/T2T-ViT/releases/download/main/82.4_T2T_ViTt_19.pth.tar' + ), + 't2t_vit_t_24.in1k': _cfg( + # hf_hub_id='timm/', + url='https://github.com/yitu-opensource/T2T-ViT/releases/download/main/82.6_T2T_ViTt_24.pth.tar' + ), +}) + +def _create_t2t_vit(variant: str, pretrained: bool, **kwargs: Any) -> T2T_ViT: + out_indices = kwargs.pop('out_indices', 3) + model = build_model_with_cfg( + T2T_ViT, variant, pretrained, + pretrained_filter_fn=checkpoint_filter_fn, + feature_cfg=dict(out_indices=out_indices, feature_cls='getter'), + **kwargs, + ) + return model + +@register_model +def t2t_vit_7(pretrained: bool = False, **kwargs: Any) -> T2T_ViT: + model_kwargs = dict( + tokens_type='performer', embed_dim=256, depth=7, num_heads=4, mlp_ratio=2., **kwargs) + model = _create_t2t_vit('t2t_vit_7', pretrained=pretrained, **model_kwargs) + return model + +@register_model +def t2t_vit_10(pretrained: bool = False, **kwargs: Any) -> T2T_ViT: + model_kwargs = dict(embed_dim=256, depth=10, num_heads=4, mlp_ratio=2., **kwargs) + model = _create_t2t_vit('t2t_vit_10', pretrained=pretrained, **model_kwargs) + return model + +@register_model +def t2t_vit_12(pretrained: bool = False, **kwargs: Any) -> T2T_ViT: + model_kwargs = dict(embed_dim=256, depth=12, num_heads=4, mlp_ratio=2., **kwargs) + model = _create_t2t_vit('t2t_vit_12', pretrained=pretrained, **model_kwargs) + return model + +@register_model +def t2t_vit_14(pretrained: bool = False, **kwargs: Any) -> T2T_ViT: + model_kwargs = dict(embed_dim=384, depth=14, num_heads=6, mlp_ratio=3., **kwargs) + model = _create_t2t_vit('t2t_vit_14', pretrained=pretrained, **model_kwargs) + return model + +@register_model +def t2t_vit_19(pretrained: bool = False, **kwargs: Any) -> T2T_ViT: + model_kwargs = dict(embed_dim=448, depth=19, num_heads=7, mlp_ratio=3., **kwargs) + model = _create_t2t_vit('t2t_vit_19', pretrained=pretrained, **model_kwargs) + return model + +@register_model +def t2t_vit_24(pretrained: bool = False, **kwargs: Any) -> T2T_ViT: + model_kwargs = dict(embed_dim=512, depth=24, num_heads=8, mlp_ratio=3., **kwargs) + model = _create_t2t_vit('t2t_vit_24', pretrained=pretrained, **model_kwargs) + return model + +@register_model +def t2t_vit_t_14(pretrained: bool = False, **kwargs: Any) -> T2T_ViT: + model_kwargs = dict( + tokens_type='transformer', embed_dim=384, depth=14, num_heads=6, mlp_ratio=3., **kwargs) + model = _create_t2t_vit('t2t_vit_t_14', pretrained=pretrained, **model_kwargs) + return model + +@register_model +def t2t_vit_t_19(pretrained: bool = False, **kwargs: Any) -> T2T_ViT: + model_kwargs = dict( + tokens_type='transformer', embed_dim=448, depth=19, num_heads=7, mlp_ratio=3., **kwargs) + model = _create_t2t_vit('t2t_vit_t_19', pretrained=pretrained, **model_kwargs) + return model + +@register_model +def t2t_vit_t_24(pretrained: bool = False, **kwargs: Any) -> T2T_ViT: + model_kwargs = dict( + tokens_type='transformer', embed_dim=512, depth=24, num_heads=8, mlp_ratio=3., **kwargs) + model = _create_t2t_vit('t2t_vit_t_24', pretrained=pretrained, **model_kwargs) + return model \ No newline at end of file