## MLP token selection Alg

In [310]:
import torch
import torch.nn.functional as F
torch.set_printoptions(precision=7, linewidth=200)
rand = torch.randint(0, 100, (1,))
torch.manual_seed(0)

def fmt(t):
    return "[" + ", ".join(f"{x:.7f}" for x in t.tolist()) + "]"

def gumbel_softmax(logits, tau=1.0, hard=False, dim=-1):
    return F.gumbel_softmax(logits, tau=tau, hard=hard, dim=dim)

L = 10  # 토큰 개수
C = 50  # 토큰 embedding 차원

# feature token x
tokens = torch.randn(L, C)

# router = MLP(x) : 여기서는 간단히 랜덤 logits
mu = 1.0
sigma = 0.2
logits = mu + sigma * torch.randn(L)
# logits = torch.abs(logits)
print(f"logits: {fmt(logits)}")

torch.manual_seed(rand)

# gumbel softmax
p_soft_ = F.softmax(logits, dim=0)
print(f"softmax 뒤: {fmt(p_soft_)}")
large_pos = torch.argmax(p_soft_)
print(f"결정 지점 (0-base): {large_pos.item()}")

p_soft = F.gumbel_softmax(logits, tau=1, hard=False, dim=0)
print(f"Gumbel 뒤: {fmt(p_soft)}")
large_pos = torch.argmax(p_soft)
print(f"결정 지점 (0-base): {large_pos.item()}")

# cumulative mask
cumsum_p = torch.cumsum(p_soft, dim=0)
pos = cumsum_p[large_pos].item()
print(f"순차합: {fmt(cumsum_p)}")
print(f"결정 지점 값 1: {pos:.3f}")

keep_soft = 1.0 - torch.roll(cumsum_p, shifts=1, dims=0)
keep_soft[0] = 1.0
print(f"keep_soft: {fmt(keep_soft)}")
pos = keep_soft[large_pos].item()
print(f"결정 지점 값 2: {pos:.3f}")

# ST-trick (optional)
keep_hard = (keep_soft >= pos).float()
keep_mask = (keep_hard - keep_soft).detach() + keep_soft
print(f"keep_mask: {fmt(keep_mask)}")

keep_mask[large_pos] = p_soft[large_pos]
if large_pos + 1 != 10:
    keep_mask[large_pos+1] = 1 - cumsum_p[large_pos]

print(f"keep_mask: {fmt(keep_mask)}")

# masked tokens
tokens_masked = tokens * keep_mask.unsqueeze(-1)



logits: [1.1678578, 1.0495843, 1.0413324, 1.1985563, 0.8202838, 0.7594303, 1.6300246, 0.9567003, 1.5219138, 0.9448810]
softmax 뒤: [0.1021033, 0.0907140, 0.0899685, 0.1052864, 0.0721258, 0.0678676, 0.1620901, 0.0826676, 0.1454804, 0.0816963]
결정 지점 (0-base): 6
Gumbel 뒤: [0.0718179, 0.1162387, 0.0288235, 0.0141581, 0.2481723, 0.0791423, 0.1583174, 0.1305682, 0.0721859, 0.0805758]
결정 지점 (0-base): 4
순차합: [0.0718179, 0.1880566, 0.2168801, 0.2310381, 0.4792104, 0.5583526, 0.7166700, 0.8472382, 0.9194242, 1.0000000]
결정 지점 값 1: 0.479
keep_soft: [1.0000000, 0.9281821, 0.8119434, 0.7831199, 0.7689619, 0.5207896, 0.4416474, 0.2833300, 0.1527618, 0.0805758]
결정 지점 값 2: 0.769
keep_mask: [1.0000000, 1.0000000, 1.0000000, 1.0000000, 1.0000000, 0.0000000, 0.0000000, 0.0000000, 0.0000000, 0.0000000]
keep_mask: [1.0000000, 1.0000000, 1.0000000, 1.0000000, 0.2481723, 0.5207896, 0.0000000, 0.0000000, 0.0000000, 0.0000000]


## Model comparison

In [10]:
import os, sys
sys.path.append(os.path.dirname("test.ipynb"))
from modelling.modules.timm_vit.dyvit import VisionTransformerDiffPruning
from modelling.modules.timm_vit.timm_vit_models import TimmViTEncoder as TimmViTEncoder_origin
import warnings
warnings.filterwarnings("ignore", message="Overwriting vit_.* in registry")

encoder = TimmViTEncoder_origin(
    model_name='vit_tiny_patch16_224',
    num_latent_tokens=256, 
    model_kwargs={'img_size': 256, 'patch_size': 16, 'drop_path_rate': 0.0})

encoder2 = TimmViTEncoder_origin(
    in_channels=3, num_latent_tokens=256,
    model_name='vit_tiny_patch16_224',  # 'vit_small_patch14_dinov2.lvd142m', 'vit_base_patch14_dinov2.lvd142m',
    model_kwargs={'img_size': 256, 'patch_size': 16, 'drop_path_rate': 0.0},  # enc_drop_path_rate},
    pretrained=False,
    tuning_method='full',
    tuning_kwargs={'r': 8},
    use_ape=True, use_rope=False, rope_mixed=False, rope_theta=10.0,
    token_drop=0.0,
    token_drop_max=0.6,
    base_img_size=256
    )

from timm.models import create_model
model = create_model('vit_tiny_patch16_224', pretrained=False, **{'img_size': 128, 'patch_size': 16, 'drop_path_rate': 0.0,})
print(encoder2)
print(model)

base_rate = 0.9
SPARSE_RATIO = [base_rate, base_rate - 0.2, base_rate - 0.4]
PRUNING_LOC = [3, 6, 9]
KEEP_RATE = [SPARSE_RATIO[0], SPARSE_RATIO[0] ** 2, SPARSE_RATIO[0] ** 3]
encoder3 = model = VisionTransformerDiffPruning(
        patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True, 
    pruning_loc=PRUNING_LOC, token_ratio=KEEP_RATE, distill=True
    )

print(encoder==encoder2)
print('='*100)
print(encoder3)



TimmViTEncoder(
  (model): VisionTransformer(
    (patch_embed): PatchEmbed(
      (proj): Conv2d(3, 192, kernel_size=(16, 16), stride=(16, 16))
      (norm): Identity()
    )
    (pos_drop): Dropout(p=0.0, inplace=False)
    (patch_drop): Identity()
    (norm_pre): Identity()
    (blocks): Sequential(
      (0): Block(
        (norm1): LayerNorm((192,), eps=1e-06, elementwise_affine=True)
        (attn): Attention(
          (qkv): Linear(in_features=192, out_features=576, bias=True)
          (q_norm): Identity()
          (k_norm): Identity()
          (attn_drop): Dropout(p=0.0, inplace=False)
          (proj): Linear(in_features=192, out_features=192, bias=True)
          (proj_drop): Dropout(p=0.0, inplace=False)
        )
        (ls1): Identity()
        (drop_path1): Identity()
        (norm2): LayerNorm((192,), eps=1e-06, elementwise_affine=True)
        (mlp): Mlp(
          (fc1): Linear(in_features=192, out_features=768, bias=True)
          (act): GELU(approximate='none')

In [14]:
print(encoder.model._pos_embed)

<bound method VisionTransformer._pos_embed of VisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 192, kernel_size=(16, 16), stride=(16, 16))
    (norm): Identity()
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (patch_drop): Identity()
  (norm_pre): Identity()
  (blocks): Sequential(
    (0): Block(
      (norm1): LayerNorm((192,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=192, out_features=576, bias=True)
        (q_norm): Identity()
        (k_norm): Identity()
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=192, out_features=192, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (ls1): Identity()
      (drop_path1): Identity()
      (norm2): LayerNorm((192,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=192, out_features=768, bias=True)
        (act): GELU(approximate='none')
        (drop1): Dropout(p=0

In [11]:
print(encoder3)

VisionTransformerDiffPruning(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 384, kernel_size=(16, 16), stride=(16, 16))
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (blocks): ModuleList(
    (0-11): 12 x Block(
      (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=384, out_features=1152, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=384, out_features=384, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (drop_path): Identity()
      (norm2): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=384, out_features=1536, bias=True)
        (act): GELU(approximate='none')
        (fc2): Linear(in_features=1536, out_features=384, bias=True)
        (drop): Dropout(p=0.0, inplace=False)
      )
    )
  )
  (norm): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
  (pre_logits)

In [None]:
import torch
from modelling.modules.timm_vit.timm_vit_models_toy import TimmViTEncoder
import warnings
warnings.filterwarnings("ignore", message="Overwriting vit_.*")

encoder = TimmViTEncoder()
# decoder = TimmViTDecoder(num_latent_tokens=256)

x = torch.randn(1, 3, 128, 128)

o = encoder(x)
print(o.shape)  # torch.Size([1, 16, 192])
# r = decoder(o)

model_kwargs:
{'img_size': 128, 'patch_size': 16, 'drop_path_rate': 0.0}
torch.Size([1, 16, 192])


  with torch.cuda.amp.autocast(enabled=False):


In [None]:
import torch
from modelling.modules.timm_vit.timm_vit_models import TimmViTEncoder as TimmViTEncoder_origin
from modelling.modules.timm_vit.timm_vit_models import TimmViTDecoder as TimmViTDecoder_origin
import warnings
warnings.filterwarnings("ignore", message="Overwriting vit_.*")

encoder = TimmViTEncoder_origin(num_latent_tokens=256)
decoder = TimmViTDecoder_origin(num_latent_tokens=256)

x = torch.randn(1, 3, 224, 224)
o = encoder(x)
print(o.shape)  # torch.Size([1, 256, 384])
r = decoder(o)
print(r.shape)  # torch.Size([1, 3, 224, 224])

torch.Size([1, 256, 384])
torch.Size([1, 3, 224, 224])


In [None]:
print(encoder)

In [14]:
print(encoder2)

TimmViTEncoder(
  (model): VisionTransformer(
    (patch_embed): PatchEmbed(
      (proj): Conv2d(3, 192, kernel_size=(16, 16), stride=(16, 16))
      (norm): Identity()
    )
    (pos_drop): Dropout(p=0.0, inplace=False)
    (patch_drop): Identity()
    (norm_pre): Identity()
    (blocks): Sequential(
      (0): Block(
        (norm1): LayerNorm((192,), eps=1e-06, elementwise_affine=True)
        (attn): Attention(
          (qkv): Linear(in_features=192, out_features=576, bias=True)
          (q_norm): Identity()
          (k_norm): Identity()
          (attn_drop): Dropout(p=0.0, inplace=False)
          (proj): Linear(in_features=192, out_features=192, bias=True)
          (proj_drop): Dropout(p=0.0, inplace=False)
        )
        (ls1): Identity()
        (drop_path1): Identity()
        (norm2): LayerNorm((192,), eps=1e-06, elementwise_affine=True)
        (mlp): Mlp(
          (fc1): Linear(in_features=192, out_features=768, bias=True)
          (act): GELU(approximate='none')

In [11]:
import torch

def same_model(m1, m2):
    # 구조 비교
    if type(m1) != type(m2):
        return False
    # 파라미터 텐서 값 비교
    for p1, p2 in zip(m1.state_dict().values(), m2.state_dict().values()):
        if not torch.equal(p1, p2):
            return False
    return True

print(same_model(encoder, encoder2))


False


In [15]:
import timm
from inspect import getsource, getfile

model = timm.create_model('vit_tiny_patch16_224')
print(getfile(type(model)))

/workspace/AdapTok/modelling/modules/timm_vit/vision_transformer.py


In [None]:
model = _create_vision_transformer('vit_tiny_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
