In [1]:
from tools.parser import get_default_parser, get_model_args
import torch
from models import token_adapter_vit_l16_in1k
import torchvision.transforms as transforms

In [2]:
class ModelArgs:
    def __init__(self):
        self.rh = 0.25
        self.rw = 0.25
        self.rp_hr= 0.7
        self.rp_wr= 0.7
        self.l_b= 11
        self.l_m= 13
        self.l_a= 0
        self.threshold= 0.5
        self.model = 'token_adapter_vit_l16_in1k'
        self.pretrained = False
        self.resolution = 64

args = ModelArgs()
model_args = get_model_args(args)

model = token_adapter_vit_l16_in1k(pretrained=args.pretrained, image_size=args.resolution, **model_args)
model = model.to('cpu')

model.eval()

print('Number of Layers:', len(model.encoder.layers))

num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Number of trainable parameters: {num_params / 1e6:.2f}M")


trnf = transforms.Compose(
    [
        transforms.Resize(
            args.resolution,
            interpolation=transforms.InterpolationMode.BICUBIC,
        ),
        transforms.CenterCrop(args.resolution),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
        ),
    ]
)

from PIL import Image
image = Image.open("assets/Hourglass_vit_framework.png").convert("RGB")
x = trnf(image).unsqueeze(0)  # shape: [1, 3, H, W]
x = x.to('cpu')
print(f"Input shape: {x.shape}")

Number of Layers: 24
Number of trainable parameters: 304.14M
Input shape: torch.Size([1, 3, 64, 64])


In [3]:
x_out = model(x)

Original input sequence: torch.Size([1, 17, 1024])
Injected token at layer 11: torch.Size([1, 9, 1024])
Ejected token at layer 13: torch.Size([1, 16, 1024])


In [4]:
# Starting token adapter at layer 11 | torch.Size([1, 4, 4, 1024])
# Reduced token shape: torch.Size([1, 9, 1024])
# Input token shape at layer 12: torch.Size([1, 10, 1024])
# torch.Size([1, 16, 1024]) torch.Size([1, 9, 1024])
# distance matrix: torch.Size([1, 16, 9])
# Ejecting tokens at layer 13 | token shape: torch.Size([1, 16, 1024])
# Post ejector layer torch.Size([1, 17, 1024])
# Post ejector layer torch.Size([1, 17, 1024])
# Post ejector layer torch.Size([1, 17, 1024])
# Post ejector layer torch.Size([1, 17, 1024])
# Post ejector layer torch.Size([1, 17, 1024])
# Post ejector layer torch.Size([1, 17, 1024])
# Post ejector layer torch.Size([1, 17, 1024])
# Post ejector layer torch.Size([1, 17, 1024])
# Post ejector layer torch.Size([1, 17, 1024])
# Post ejector layer torch.Size([1, 17, 1024])