In [1]:
!pip install --upgrade timm

Collecting timm
  Downloading timm-0.5.4-py3-none-any.whl (431 kB)
     |████████████████████████████████| 431 kB 885 kB/s            
Installing collected packages: timm
Successfully installed timm-0.5.4


### **Library Imports**

In [2]:
import os
import re
import cv2
import timm
import torch
import numpy as np
import pandas as pd

from torchvision import transforms
from torch.utils.data import Dataset
from torch.utils.data import DataLoader as DL
from sklearn.preprocessing import LabelEncoder

### **Utilities and Constants**

In [3]:
SEED = 42
SIZE = 384
FV_LENGTH = 1536
MODEL_PATH = "../input/swin-pretrained-feature-extractor-weights/swin_base_patch4_window7_224_in22k.pt"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
TRANSFORM = transforms.Compose([transforms.ToTensor(), 
                                transforms.Normalize([0.485, 0.456, 0.406], 
                                                     [0.229, 0.224, 0.225]),
                               ])


def get_image(path: str, size: int) -> np.ndarray:
    image = cv2.imread(path, cv2.IMREAD_COLOR)
    image = cv2.cvtColor(src=image, code=cv2.COLOR_BGR2RGB)
    return cv2.resize(src=image, dsize=(size, size), interpolation=cv2.INTER_AREA)


le = LabelEncoder()

### **Dataset Template**

In [4]:
class FEDS(Dataset):
    def __init__(self, base_path: str, filenames: np.ndarray, transform):
        self.base_path = base_path
        self.filenames = filenames
        self.transform = transform
    
    def __len__(self):
        return self.filenames.shape[0]
    
    def __getitem__(self, idx):
        return self.transform(get_image(os.path.join(self.base_path, self.filenames[idx]), SIZE))

### **Load and Preprocess Data**

In [5]:
train_df = pd.read_csv("../input/sorghum-id-fgvc-9/train_cultivar_mapping.csv")

broken_images = [filename for filename in train_df.image if filename not in os.listdir("../input/sorghum-id-fgvc-9/train_images")]
for broken_image in broken_images:
    index = train_df.index[train_df.image == broken_image]
    train_df = train_df.drop(index=index)
    
filenames = train_df.image.copy().values
labels    = train_df.cultivar.copy().values 

labels = le.fit_transform(labels)
np.save("labels.npy", labels)

features = np.zeros((len(filenames), FV_LENGTH))

### **Get Features**

In [6]:
dataloader_setup = FEDS("../input/sorghum-id-fgvc-9/train_images", filenames, TRANSFORM)
dataloader = DL(dataloader_setup, batch_size=64, shuffle=False)

features = torch.zeros(1, FV_LENGTH).to(DEVICE)
model = timm.create_model("swin_large_patch4_window12_384_in22k", pretrained=True).to(DEVICE)
model = torch.nn.Sequential(*[*model.children()][:-3])

for X in dataloader:
    X = X.to(DEVICE)
    with torch.no_grad():
        output = model(X)
    features = torch.cat((features, output[:, 0, :]), dim=0)

features = features[1:].detach().cpu().numpy()
np.save("features.npy", features)

Downloading: "https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22k.pth" to /root/.cache/torch/hub/checkpoints/swin_large_patch4_window12_384_22k.pth
