In [None]:
import os
import glob
from tqdm import tqdm
from PIL import Image
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.utils.data as data
import torchvision.models as models

In [None]:
class config:
    img_size = 125
    batch_size = 64
    embedding_size = 125

<h3>Get article's images</h3>

In [None]:
img_list = glob.glob('../input/h-and-m-personalized-fashion-recommendations/images/*/*')

<h3>Define custom dataset class and dataloader</h3>

In [None]:
class CustomImageClass(data.Dataset):
    def __init__(self, data_path, transform=None):
        self.root = data_path
        self.transform = transform
        
    def __getitem__(self, indx):
        image = Image.open(self.root[indx]).convert('RGB')
        if self.transform is not None:
            image = self.transform(image)
        return image
    
    def __len__(self):
        return len(self.root)
    
from torchvision import transforms

transform = transforms.Compose([
    transforms.Resize((config.img_size, config.img_size)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                        [0.229, 0.224, 0.225])
])


image_dataset = CustomImageClass(data_path=img_list, 
                                 transform=transform)

In [None]:
custom_data_loader = torch.utils.data.DataLoader(dataset=image_dataset,
                                                batch_size=config.batch_size,
                                                shuffle=True,
                                                num_workers=2)

<h3>Visualize some images in our dataset</h3>

In [None]:
def imageshow(image):
    npimage = image.numpy()
    plt.imshow(np.transpose(npimage, (1, 2, 0)))
    plt.show()

a = iter(custom_data_loader)
images = a.next()

In [None]:
import torchvision

imageshow(torchvision.utils.make_grid(images[:4]))

<h3>Define the model</h3>

In [None]:
class CNNModel(nn.Module):
    def __init__(self, embedding_size):
        super(CNNModel, self).__init__()
        resnet = models.resnet152(pretrained=True)
        module_list = list(resnet.children())[:-1] # exlude the last layer to get the embeddings
        self.resnet_module = nn.Sequential(*module_list)
        self.embedding_layer = nn.Linear(resnet.fc.in_features, embedding_size)
    
    def forward(self, input_images):
        with torch.no_grad():
            resnet_features = self.resnet_module(input_images)
        resnet_features = resnet_features.reshape(resnet_features.size(0), -1)
        embedding = self.embedding_layer(resnet_features)
        return embedding

In [None]:
model = CNNModel(config.embedding_size)
model.to('cuda')

<h3>Generate emeddings</h3>

In [None]:
embeddings = []

with torch.no_grad():
    for data in tqdm(custom_data_loader):
        preds = model(data.to("cuda"))
        preds = preds.detach().cpu().numpy()
        embeddings.append(preds)

In [None]:
embeddings = np.concatenate(embeddings)
img_embeddings = pd.DataFrame(embeddings)
img_embeddings['image_id'] = img_list

# save the embeddings
img_embeddings.to_csv(f"prodemb_img_{config.embedding_size}.csv", index = False)
