In [1]:
import os
import sys
import torch
from torchvision import transforms
from torch.utils.data import DataLoader
from torchvision import datasets
import pandas as pd
from PIL import Image
import numpy as np

import timm
from huggingface_hub import login, hf_hub_download


local_dir = "../assets/uni-2-h_model/"
timm_kwargs = {
   'model_name': 'vit_giant_patch14_224',
   'img_size': 224, 
   'patch_size': 14, 
   'depth': 24,
   'num_heads': 24,
   'init_values': 1e-5, 
   'embed_dim': 1536,
   'mlp_ratio': 2.66667*2,
   'num_classes': 0, 
   'no_embed_class': True,
   'mlp_layer': timm.layers.SwiGLUPacked, 
   'act_layer': torch.nn.SiLU, 
   'reg_tokens': 8, 
   'dynamic_img_size': True
  }
model = timm.create_model(**timm_kwargs)
model.load_state_dict(torch.load(os.path.join(local_dir, "pytorch_model.bin"), map_location="cpu"), strict=True)


transform = transforms.Compose(
 [
  transforms.Resize(224),
  transforms.CenterCrop(224),
  transforms.ToTensor(),
  transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
 ]
)
model.eval()

VisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 1536, kernel_size=(14, 14), stride=(14, 14))
    (norm): Identity()
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (patch_drop): Identity()
  (norm_pre): Identity()
  (blocks): Sequential(
    (0): Block(
      (norm1): LayerNorm((1536,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=1536, out_features=4608, bias=True)
        (q_norm): Identity()
        (k_norm): Identity()
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=1536, out_features=1536, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (ls1): LayerScale()
      (drop_path1): Identity()
      (norm2): LayerNorm((1536,), eps=1e-06, elementwise_affine=True)
      (mlp): GluMlp(
        (fc1): Linear(in_features=1536, out_features=8192, bias=True)
        (act): SiLU()
        (drop1): Dropout(p=0.0, inplace=False)
        (norm): Identity()
    

In [16]:
for param in model.parameters():
    print(param.dtype)

torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.float32
torch.

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

VisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 1536, kernel_size=(14, 14), stride=(14, 14))
    (norm): Identity()
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (patch_drop): Identity()
  (norm_pre): Identity()
  (blocks): Sequential(
    (0): Block(
      (norm1): LayerNorm((1536,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=1536, out_features=4608, bias=True)
        (q_norm): Identity()
        (k_norm): Identity()
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=1536, out_features=1536, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (ls1): LayerScale()
      (drop_path1): Identity()
      (norm2): LayerNorm((1536,), eps=1e-06, elementwise_affine=True)
      (mlp): GluMlp(
        (fc1): Linear(in_features=1536, out_features=8192, bias=True)
        (act): SiLU()
        (drop1): Dropout(p=0.0, inplace=False)
        (norm): Identity()
    

In [3]:
# # Not enough memory on 4070 Max-Q to trace, attempt with A100/H100

# torch.jit.enable_onednn_fusion(True)
# traced_model = torch.jit.trace(model, torch.rand(150, 3, 224, 224).to(device))
# traced_model = torch.jit.freeze(traced_model)

In [42]:
#Batched inference

dataset = {'eval' : datasets.ImageFolder("../images/", transform)}
dataloader = {'eval': torch.utils.data.DataLoader(dataset['eval'], batch_size = 150, shuffle=False, num_workers=8, pin_memory=True)}

for inputs, labels in dataloader['eval']:
  inputs = inputs.to(device, non_blocking=True)
  print(inputs.shape)
  with torch.inference_mode():
    outputs = model(inputs)
  
  outputs = outputs.cpu().numpy()
  break


torch.Size([150, 3, 224, 224])


In [6]:
# Inference sequentially, non-batched
np.set_printoptions(threshold=sys.maxsize)
pd.set_option('display.max_columns', None) 
pd.set_option('display.width', None) 
pd.set_option('display.max_colwidth', None)

files = os.listdir("../images/Image_A-22-00025A4-1/")
files.sort()
output_dataframe = pd.DataFrame(columns=["filename", "output"])
length = len(files)

for i in (range(0, length)):
    if i % 500 == 0:
        print(i)
    image = transform(Image.open("../images/Image_A-22-00025A4-1/" + files[i]))
    image = image.unsqueeze(0)
    image = image.to(device)
    with torch.inference_mode():
        output = model(image)
    output_dataframe.loc[i] = {"filename": files[i], "output": [output[0].cpu().numpy()]}

# output_dataframe.to_csv("../output/Image_A-22-00025A4-1_embeddings.csv", index=False)

0


500
1000
1500
2000
2500
3000
3500
4000
4500
5000
5500
6000
6500
7000
7500
8000
8500
9000


In [12]:
len(output_dataframe["output"][0][0])

1536

In [12]:
output.cpu()[0][0]

tensor(0.2940)

In [9]:
output_dataframe["output"][0][0][0]

np.float32(-0.26522276)