In [1]:
!pip install timm -q

[0m

In [2]:
import os
import re
import cv2
import timm
import torch
import numpy as np
import pandas as pd

from time import time, sleep
from torch import nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader as DL
from torchvision import transforms

In [3]:
class Model(nn.Module):
    def __init__(self, model_name: str):
        super(Model, self).__init__()
        
        if model_name == "swin_base_patch4_window12_384_in22k":
            self.model = timm.create_model(model_name, pretrained=True)
            self.freeze()
            self.model = nn.Sequential(*[*self.model.children()][:-1])
        
        elif model_name == "swin_large_patch4_window12_384_in22k":
            self.model = timm.create_model(model_name, pretrained=True)
            self.freeze()
            self.model = nn.Sequential(*[*self.model.children()][:-1])
        
        elif model_name == "vit_base_patch16_384":
            self.model = timm.create_model(model_name, pretrained=True)
            self.freeze()
            self.model = nn.Sequential(*[*self.model.children()][:-1])
        
        elif model_name == "vit_base_patch32_384":
            self.model = timm.create_model(model_name, pretrained=True)
            self.freeze()
            self.model = nn.Sequential(*[*self.model.children()][:-1])
        
        elif model_name == "vit_large_patch16_384":
            self.model = timm.create_model(model_name, pretrained=True)
            self.freeze()
            self.model = nn.Sequential(*[*self.model.children()][:-1])
        
        elif model_name == "vit_large_patch32_384":
            self.model = timm.create_model(model_name, pretrained=True)
            self.freeze()
            self.model = nn.Sequential(*[*self.model.children()][:-1])
        
        elif model_name == "deit3_base_patch16_384_in21ft1k":
            self.model = timm.create_model(model_name, pretrained=True)
            self.freeze()
            self.model = nn.Sequential(*[*self.model.children()][:-1])
        
        elif model_name == "deit3_large_patch16_384_in21ft1k":
            self.model = timm.create_model(model_name, pretrained=True)
            self.freeze()
            self.model = nn.Sequential(*[*self.model.children()][:-1])
        
        elif model_name == "deit3_small_patch16_384_in21ft1k":
            self.model = timm.create_model(model_name, pretrained=True)
            self.freeze()
            self.model = nn.Sequential(*[*self.model.children()][:-1])
        
        elif model_name == "deit_base_distilled_patch16_384":
            self.model = timm.create_model(model_name, pretrained=True)
            self.freeze()
            self.model = nn.Sequential(*[*self.model.children()][:-1])
    
    def freeze(self):
        for p in self.parameters(): p.requires_grad = False
    
    def forward(self, x):
        return self.model(x)

In [4]:
def breaker(num: int=50, char: str="*") -> None:
    print("\n" + num*char + "\n")

    
def get_image(path: str, size: int=224) -> np.ndarray:
    image = cv2.imread(path, cv2.IMREAD_COLOR)
    image = cv2.cvtColor(src=image, code=cv2.COLOR_BGR2RGB)
    return cv2.resize(src=image, dsize=(size, size), interpolation=cv2.INTER_AREA)

In [5]:
class DS(Dataset):
    def __init__(
        self, 
        filepaths: np.ndarray, 
        size: int,
        transform=None
    ):
        
        self.filepaths = filepaths
        self.size = size
        self.transform = transform
    
    def __len__(self):
        return self.filepaths.shape[0]
    
    def __getitem__(self, idx):
        image = get_image(self.filepaths[idx], self.size)
        return self.transform(image)

In [6]:
def get_features(model=None, dataloader=None, feature_length: int=None) -> np.ndarray: 
    model.eval()
    
    y_pred = torch.zeros(1, feature_length).to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
    
    for X in dataloader:
        X = X.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
        with torch.no_grad():
            output = model(X)[:, -1, :]
        y_pred = torch.cat((y_pred, output), dim=0)
    
    return y_pred[1:].detach().cpu().numpy()

In [7]:
df = pd.read_csv("../input/fv-dataframes/known.csv")

data_setup = DS(
    filepaths=df.filepaths.copy().values, 
    size=384,
    transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.48583, 0.37705, 0.32632], [0.27852, 0.24197, 0.22445]),
    ])
)

if not os.path.exists("known-embeddings"): os.makedirs("known-embeddings")

model_names: list = [
    "swin_base_patch4_window12_384_in22k",
    "swin_large_patch4_window12_384_in22k",
    "vit_base_patch16_384",
    "vit_base_patch32_384",
    "vit_large_patch16_384",
    "vit_large_patch32_384",
    "deit3_base_patch16_384_in21ft1k",
    "deit3_large_patch16_384_in21ft1k",
    "deit3_small_patch16_384_in21ft1k",
    "deit_base_distilled_patch16_384",
]
    
feature_lengths: list = [
    1024,
    1536,
    768,
    768,
    1024,
    1024,
    768,
    1024,
    384,
    1000,
]
    
for i in range(len(model_names)):
    data = DL(data_setup, batch_size=16, shuffle=False)
    model = Model(model_names[i]).to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
    features = get_features(model=model, dataloader=data, feature_length=feature_lengths[i])
    np.save(f"known-embeddings/{model_names[i]}.npy", features)
    
    del model, data
    torch.cuda.empty_cache()
    
    sleep(30)

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
Downloading: "https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22k.pth" to /root/.cache/torch/hub/checkpoints/swin_base_patch4_window12_384_22k.pth
Downloading: "https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22k.pth" to /root/.cache/torch/hub/checkpoints/swin_large_patch4_window12_384_22k.pth
Downloading: "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p32_384-9b920ba8.pth" to /root/.cache/torch/hub/checkpoints/jx_vit_large_p32_384-9b920ba8.pth
Downloading: "https://dl.fbaipublicfiles.com/deit/deit_3_base_384_21k.pth" to /root/.cache/torch/hub/checkpoints/deit_3_base_384_21k.pth
Downloading: "https://dl.fbaipublicfiles.com/deit/deit_3_large_384_21k.pth" to /root/.cache/torch/hub/checkpoints/deit_3_large_384_21k.pth
Downloading: "https://dl.fbaipublicfiles.com/deit/deit_3_

In [8]:
df = pd.read_csv("../input/fv-dataframes/unknown.csv")

data_setup = DS(
    filepaths=df.filepaths.copy().values, 
    size=384,
    transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.48583, 0.37705, 0.32632], [0.27852, 0.24197, 0.22445]),
    ])
)

if not os.path.exists("unknown-embeddings"): os.makedirs("unknown-embeddings")

model_names: list = [
    "swin_base_patch4_window12_384_in22k",
    "swin_large_patch4_window12_384_in22k",
    "vit_base_patch16_384",
    "vit_base_patch32_384",
    "vit_large_patch16_384",
    "vit_large_patch32_384",
    "deit3_base_patch16_384_in21ft1k",
    "deit3_large_patch16_384_in21ft1k",
    "deit3_small_patch16_384_in21ft1k",
    "deit_base_distilled_patch16_384",
]
    
feature_lengths: list = [
    1024,
    1536,
    768,
    768,
    1024,
    1024,
    768,
    1024,
    384,
    1000,
]
    
for i in range(len(model_names)):
    data = DL(data_setup, batch_size=16, shuffle=False)
    model = Model(model_names[i]).to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
    features = get_features(model=model, dataloader=data, feature_length=feature_lengths[i])
    np.save(f"unknown-embeddings/{model_names[i]}.npy", features)
    
    del model, data
    torch.cuda.empty_cache()
    
    sleep(30)