# Image embedding by timm swin transformer

In [None]:
! pip install timm

In [None]:
import timm
import torch
import torchvision.transforms as T
import torchvision.transforms.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision.io import read_image
import matplotlib.pyplot as plt
import glob
import tqdm
import pandas as pd

In [None]:
class Config:
    train_loader = {
         'batch_size': 64,
         'shuffle': False,
         'num_workers': 4,
         'pin_memory': False,
         'drop_last': False,
     }
    
    path_list = glob.glob("../input/h-and-m-personalized-fashion-recommendations/images/*/*.jpg")
    model_name = "swin_tiny_patch4_window7_224"
    

In [None]:
class HAndMImageDataset(Dataset):
    def __init__(self, path_list, image_size=224):
        self._path_list = path_list
        self._transform = T.Compose([
            # https://discuss.pytorch.org/t/convert-grayscale-images-to-rgb/113422
            T.Lambda(lambda x: x.repeat(3, 1, 1) if x.size(0)==1 else x), 
            T.Resize([image_size, image_size])
        ])
            
    def __len__(self):
        return len(self._path_list)

    def __getitem__(self, idx):
        image_path = self._path_list[idx]
        image = read_image(image_path)            
        image = self._transform(image)
        article_id = image_path.split('/')[-1].replace('.jpg', '')
        
        
        return image, article_id


In [None]:
dataset = HAndMImageDataset(Config.path_list)
loader = DataLoader(dataset, **Config.train_loader)

In [None]:
images, article_ids = iter(loader).next()
plt.figure(figsize=(12, 12))
for it, (image, article_id) in enumerate(zip(images[:16], article_ids[:16])):
    plt.subplot(4, 4, it+1)
    plt.imshow(image.permute(1, 2, 0))
    plt.axis('off')
    plt.title(f'article id: {article_id}')

In [None]:
model = timm.create_model(
    Config.model_name, pretrained=True, num_classes=0, in_chans=3
)

In [None]:
IMAGENET_MEAN = [0.485, 0.456, 0.406]  # RGB
IMAGENET_STD = [0.229, 0.224, 0.225]  # RGB
transform = T.Compose([
                T.ConvertImageDtype(torch.float),
                T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
            ])

In [None]:
device = torch.device('cuda')
model = model.to(device)

In [None]:
device

In [None]:
dfs = []
model.eval()
with torch.no_grad():
    for i, (images, article_ids) in enumerate(tqdm.tqdm(loader)):
        images = images.to(device)
        images = transform(images)
        emb = model(images)
        emb = emb.detach().cpu().numpy()

        df = pd.DataFrame(emb)
        df.loc[:, "article_id"] = article_ids
        dfs.append(df)
    
df = pd.concat(dfs)

In [None]:
df.to_csv(f"{Config.model_name}_emb.csv.gz", index=None)

In [None]:
# check the output size is same as input size
df.shape[0] ==  len(Config.path_list)

In [None]:
df.head()

# Check nearest neighbor 

In [None]:
from annoy import AnnoyIndex
f = 768
t = AnnoyIndex(f, 'angular')  # Length of item vector that will be indexed
for i, v in tqdm.tqdm(enumerate(df[list(range(f))].values), total=len(df)):
    t.add_item(i, v)

In [None]:
t.build(10) # 10 trees

In [None]:
i = 1
print(df["article_id"].iloc[i])
nns = t.get_nns_by_item(i, 10)

nn_article_ids = df["article_id"].iloc[nns]
nn_paths = [f"../input/h-and-m-personalized-fashion-recommendations/images/{article_id[:3]}/{article_id}.jpg" 
            for article_id in nn_article_ids]


plt.figure(figsize=(12, 12))
for it, (path, article_id) in enumerate(zip(nn_paths, nn_article_ids)):
    image = read_image(path)
    plt.subplot(4, 4, it+1)
    plt.imshow(image.permute(1, 2, 0))
    plt.axis('off')
    plt.title(f'article id: {article_id}')

In [None]:
i = 2
print(df["article_id"].iloc[i])
nns = t.get_nns_by_item(i, 10)

nn_article_ids = df["article_id"].iloc[nns]
nn_paths = [f"../input/h-and-m-personalized-fashion-recommendations/images/{article_id[:3]}/{article_id}.jpg" 
            for article_id in nn_article_ids]


plt.figure(figsize=(12, 12))
for it, (path, article_id) in enumerate(zip(nn_paths, nn_article_ids)):
    image = read_image(path)
    plt.subplot(4, 4, it+1)
    plt.imshow(image.permute(1, 2, 0))
    plt.axis('off')
    plt.title(f'article id: {article_id}')

In [None]:
i = 3
print(df["article_id"].iloc[i])
nns = t.get_nns_by_item(i, 10)

nn_article_ids = df["article_id"].iloc[nns]
nn_paths = [f"../input/h-and-m-personalized-fashion-recommendations/images/{article_id[:3]}/{article_id}.jpg" 
            for article_id in nn_article_ids]


plt.figure(figsize=(12, 12))
for it, (path, article_id) in enumerate(zip(nn_paths, nn_article_ids)):
    image = read_image(path)
    plt.subplot(4, 4, it+1)
    plt.imshow(image.permute(1, 2, 0))
    plt.axis('off')
    plt.title(f'article id: {article_id}')