In [1]:
%load_ext autoreload
%autoreload 2
import numpy as np
import h5py
import json

from tqdm import tqdm
import torch
import torch.nn as nn
from torchvision import transforms
from torch.utils.data import DataLoader
from dl_utils.utils.dataset import viz_dataloader, split_train_valid, hdf5_dataset
from dl_utils.training.build_model import resnet50_, xcit_small
from dl_utils.training.trainer import Trainer, accuracy
from dl_utils.packed_functions import benchmark_task
from dl_utils.utils.utils import list_to_dict, sort_tasks_by_size, find_last_epoch_file, viz_h5_structure

  from .autonotebook import tqdm as notebook_tqdm


### generate embeddings

In [3]:
model = resnet50_(3, 17)
model.load_state_dict(torch.load('../../models/ResNet50/03132025-ResNet50-benchmark-10m/epoch_23.pth', map_location='cpu', weights_only=True))
model

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [None]:
model.fc = nn.Identity()  # Remove the final classification layer to get embeddings
model.eval()

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [None]:
with h5py.File('../../datasets/imagenet_v5_rot_10k_fix_vector_a100_0.h5', 'r') as h5:
    h5_group = h5['imagenet']
    data = h5_group['data']
    print(data.shape)
    
    with h5py.File('../../datasets/imagenet_v5_rot_10k_fix_vector_a100_0_ResNet50_embeddings.h5', 'w') as h5_out:
        h5_out.create_dataset('embeddings', shape=(data.shape[0], 2048), dtype='float32')
        for i in tqdm(range(data.shape[0])):
            img = data[i]
            img = torch.tensor(img).unsqueeze(0).permute(0, 3, 1, 2).float() / 255.0
            with torch.no_grad():
                embedding = model(img)
            h5_out['embeddings'][i] = embedding.cpu().numpy()

(10013, 256, 256, 3)


100%|██████████| 10013/10013 [12:42<00:00, 13.13it/s]


In [None]:
with h5py.File('../../datasets/imagenet_v5_rot_10k_fix_vector_a100_0_ResNet50_embeddings.h5', 'r') as h5:
    viz_h5_structure(h5)
    
    print(h5['embeddings'][:5])

'Dataset': embeddings; Shape: (10013, 2048); dtype: float32
[[5.33669353e-01 9.00103629e-01 1.46383750e+00 ... 5.04102260e-02
  9.16822016e-01 6.74353912e-02]
 [5.37377633e-02 1.38878539e-01 3.91567022e-01 ... 7.00604200e-01
  4.99897897e-02 6.73411938e-04]
 [2.79406132e-03 6.05612025e-02 0.00000000e+00 ... 1.19032145e-01
  7.69485720e-04 2.72758901e-01]
 [2.89863944e-02 8.68181467e-01 2.25061458e-02 ... 6.84287727e-01
  1.10030994e-01 2.18566567e-01]
 [7.40307849e-03 1.21742189e+00 5.37331343e-01 ... 2.33994089e-02
  4.18225676e-02 0.00000000e+00]]


In [None]:
model = xcit_small(3, 17)
model.load_state_dict(torch.load('../../models/XCiT/03132025-XCiT-benchmark-10m/epoch_24.pth', map_location='cpu', weights_only=True))

<All keys matched successfully>

In [None]:
model.eval()

Xcit(
  (patch_embed): ConvPatchEmbed(
    (proj): Sequential(
      (0): Sequential(
        (0): Conv2d(3, 96, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): GELU(approximate='none')
      (2): Sequential(
        (0): Conv2d(96, 192, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (3): GELU(approximate='none')
      (4): Sequential(
        (0): Conv2d(192, 384, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
  )
  (pos_embed): PositionalEncodingFourier(
    (token_projection): Conv2d(64, 384, kernel_size=(1, 1), stride=(1, 1))
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (blocks): ModuleList(
    (0-11): 12 x 

In [None]:
model.head = nn.Identity()  # Remove the final classification layer to get embeddings
model.eval()

Xcit(
  (patch_embed): ConvPatchEmbed(
    (proj): Sequential(
      (0): Sequential(
        (0): Conv2d(3, 96, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): GELU(approximate='none')
      (2): Sequential(
        (0): Conv2d(96, 192, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (3): GELU(approximate='none')
      (4): Sequential(
        (0): Conv2d(192, 384, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
  )
  (pos_embed): PositionalEncodingFourier(
    (token_projection): Conv2d(64, 384, kernel_size=(1, 1), stride=(1, 1))
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (blocks): ModuleList(
    (0-11): 12 x 

In [None]:
with h5py.File('../../datasets/imagenet_v5_rot_10k_fix_vector_a100_0.h5', 'r') as h5:
    h5_group = h5['imagenet']
    data = h5_group['data']
    print(data.shape)
    
    with h5py.File('../../datasets/imagenet_v5_rot_10k_fix_vector_a100_0_xcit_embeddings.h5', 'w') as h5_out:
        h5_out.create_dataset('embeddings', shape=(data.shape[0], 384), dtype='float32')
        for i in tqdm(range(data.shape[0])):
            img = data[i]
            img = torch.tensor(img).unsqueeze(0).permute(0, 3, 1, 2).float() / 255.0
            with torch.no_grad():
                embedding = model(img)
            h5_out['embeddings'][i] = embedding.cpu().numpy()

(10013, 256, 256, 3)


100%|██████████| 10013/10013 [20:50<00:00,  8.01it/s]


In [None]:
with h5py.File('../../datasets/imagenet_v5_rot_10k_fix_vector_a100_0_xcit_embeddings.h5', 'r') as h5:
    viz_h5_structure(h5)
    
    print(h5['embeddings'][:5])

'Dataset': embeddings; Shape: (10013, 384); dtype: float32
[[ 8.0026867e-04  9.2046097e-04  3.5570911e-03 ... -4.0445660e-04
   1.8727551e-04  1.5299559e-03]
 [ 1.4723408e-03  1.0559600e-03  1.5489649e-04 ...  3.2264463e-04
   2.2407285e-05 -1.5636458e-03]
 [ 1.5780854e-03  1.0829333e-03  1.7200188e-03 ...  4.2441179e-04
   2.0487205e-05 -2.1541815e-03]
 [ 1.1525353e-03  9.7975903e-04  4.7592833e-04 ... -1.0869782e-04
   1.2807883e-04  6.3858570e-05]
 [ 5.0207018e-04  8.5420749e-04  1.9978317e-03 ... -7.6333422e-04
   2.6043880e-04  2.5814015e-03]]
