### **Setup**

In [1]:
!pip install timm -q

[0m

### **Library Imports**

In [2]:
import os
import re
import cv2
import timm
import torch
import numpy as np
import pandas as pd

from PIL import Image
from time import time, sleep

from torch import nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader as DL
from torchvision import transforms

### **Models**

In [3]:
class Model(nn.Module):
    def __init__(self, model_name: str):
        super(Model, self).__init__()
        
        if model_name == "swin_base_patch4_window12_384_in22k":
            self.model = timm.create_model(model_name, pretrained=True)
            self.freeze()
            self.model = nn.Sequential(*[*self.model.children()][:-1])
        
        elif model_name == "swin_large_patch4_window12_384_in22k":
            self.model = timm.create_model(model_name, pretrained=True)
            self.freeze()
            self.model = nn.Sequential(*[*self.model.children()][:-1])
        
    def freeze(self):
        for p in self.parameters(): p.requires_grad = False
    
    def forward(self, x):
        return self.model(x)

### **Utils**

In [4]:
def breaker(num: int=50, char: str="*") -> None:
    print("\n" + num*char + "\n")

    
def get_image(path: str, size: int=224) -> np.ndarray:
    image = cv2.imread(path, cv2.IMREAD_COLOR)
    image = cv2.cvtColor(src=image, code=cv2.COLOR_BGR2RGB)
    return cv2.resize(src=image, dsize=(size, size), interpolation=cv2.INTER_AREA)


def get_features(model=None, dataloader=None, feature_length: int=None) -> np.ndarray: 
    model.eval()
    
    y_pred = torch.zeros(1, feature_length).to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
    
    for X in dataloader:
        X = X.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
        with torch.no_grad():
            output = model(X)[:, -1, :]
        y_pred = torch.cat((y_pred, output), dim=0)
    
    return y_pred[1:].detach().cpu().numpy()

### **Dataset Template**

In [5]:
class DS(Dataset):
    def __init__(
        self, 
        filepaths: np.ndarray, 
        size: int,
        transform=None
    ):
        
        self.filepaths = filepaths
        self.size = size
        self.transform = transform
    
    def __len__(self):
        if isinstance(self.filepaths, np.ndarray):
            return self.filepaths.shape[0]
        else:
            return len(self.filepaths)
    
    def __getitem__(self, idx):
        image = get_image(self.filepaths[idx], self.size)
        return self.transform(image)

In [6]:
model_names: list = [
    "swin_base_patch4_window12_384_in22k",
    "swin_large_patch4_window12_384_in22k"
]
    
feature_lengths: list = [
    1024,
    1536
]

### **Train Features**

In [7]:
# df = pd.read_csv("../input/fic-dataframe/train.csv")

# data_setup = DS(
#     filepaths=df.filepaths.copy().values, 
#     size=384,
#     transform=transforms.Compose([
#         transforms.ToTensor(),
#         transforms.Normalize([0.45807, 0.41868, 0.29889], [0.24757, 0.21952, 0.22436]),
#     ])
# )
    
# for i in range(len(model_names)):
#     model = Model(model_names[i]).to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
#     data = DL(data_setup, batch_size=16, shuffle=False)
#     features = get_features(model=model, dataloader=data, feature_length=feature_lengths[i])
#     np.save(f"tr_{model_names[i]}.npy", features)
    
#     del model, data
#     torch.cuda.empty_cache()
    
#     sleep(30)

### **Test Featues**

In [8]:
ts_df = pd.read_csv("../input/5-flowers-image-classification/Sample_submission.csv")
ts_filenames = list(ts_df.id.copy().values)
ts_filenames = [str(name) + ".jpg" for name in ts_filenames]

ts_filepaths = [os.path.join("../input/5-flowers-image-classification/test", name) for name in ts_filenames]

data_setup = DS(
    filepaths=ts_filepaths,
    size=384,
    transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.45807, 0.41868, 0.29889], [0.24757, 0.21952, 0.22436]),
    ])
)
    
for i in range(len(model_names)):
    model = Model(model_names[i]).to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
    data = DL(data_setup, batch_size=16, shuffle=False)
    features = get_features(model=model, dataloader=data, feature_length=feature_lengths[i])
    np.save(f"ts_{model_names[i]}.npy", features)
    
    del model, data
    torch.cuda.empty_cache()
    
    sleep(30)

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
Downloading: "https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22k.pth" to /root/.cache/torch/hub/checkpoints/swin_base_patch4_window12_384_22k.pth
Downloading: "https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22k.pth" to /root/.cache/torch/hub/checkpoints/swin_large_patch4_window12_384_22k.pth
