## Library Imports

In [1]:
from time import time
notebook_start_time = time()

In [2]:
import os
import re
import random as r
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch
from torch import nn, optim
from torch.utils.data import Dataset
from torch.utils.data import DataLoader as DL
from torch.nn.utils import weight_norm as WN
from torchvision import models, transforms

import warnings
warnings.filterwarnings("ignore")

## Constants and Utilities

In [3]:
def breaker(num=50, char="*") -> None:
    print("\n" + num*char + "\n")


def head(x, no_of_ele=5) -> None:
    print(x[:no_of_ele])

    
def show(image: np.ndarray) -> None:
    plt.figure(figsize=(9, 6))
    plt.imshow(image)
    plt.axis("off")
    plt.show()

In [4]:
TRANSFORM_PRE = transforms.Compose([transforms.ToTensor(), 
                                    transforms.Normalize([0.485, 0.456, 0.406],
                                                         [0.229, 0.224, 0.225]),
                                   ])
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Dataset Template

In [5]:
class DS(Dataset):
    def __init__(self, images=None, transform=None):
        self.images    = images
        self.transform = transform
        
    def __len__(self):
        return self.images.shape[0]
    
    def __getitem__(self, idx):
        return self.transform(self.images[idx])

## Build DataLoader

In [6]:
def build_dataloader(images: np.ndarray, transform=None):    
    data_setup = DS(images=images, transform=transform)
    data = DL(data_setup, batch_size=64, shuffle=False)
    
    return data

## Build Model

In [7]:
def build_model(model_name: str, pretrained=True):
    class ImageModel(nn.Module):
        def __init__(self, model_name=None, pretrained=False):
            super(ImageModel, self).__init__()

            if re.match(r"^resnet50$", model_name, re.IGNORECASE):
                self.features = models.resnet50(pretrained=pretrained, progress=True)
                if pretrained:
                    self.freeze()
                self.features = nn.Sequential(*[*self.features.children()][:-1])
                self.features.add_module("Flatten", nn.Flatten())


            elif re.match(r"^vgg16$", model_name, re.IGNORECASE):
                self.features = models.vgg16_bn(pretrained=pretrained, progress=True)
                if pretrained:
                    self.freeze()
                self.features = nn.Sequential(*[*self.features.children()][:-2])
                self.features.add_module("Adaptive Average Pool", nn.AdaptiveAvgPool2d(output_size=(2, 2)))
                self.features.add_module("Flatten", nn.Flatten())


            elif re.match(r"^mobilenet$", model_name, re.IGNORECASE):
                self.features = models.mobilenet_v2(pretrained=pretrained, progress=True)
                if pretrained:
                    self.freeze()
                self.features = nn.Sequential(*[*self.features.children()][:-1])
                self.features.add_module("Adaptive Average Pool", nn.AdaptiveAvgPool2d(output_size=(1, 1)))
                self.features.add_module("Flatten", nn.Flatten())


            elif re.match(r"^densenet169$", model_name, re.IGNORECASE):
                self.features = models.densenet169(pretrained=pretrained, progress=True)
                if pretrained:
                    self.freeze()
                self.features = nn.Sequential(*[*self.features.children()][:-1])
                self.features.add_module("Adaptive Average Pool", nn.AdaptiveAvgPool2d(output_size=(1, 1)))
                self.features.add_module("Flatten", nn.Flatten())
        
        def freeze(self):
            for params in self.parameters():
                params.requires_grad = False

        def forward(self, x):
            return self.features(x)
    
    model = ImageModel(model_name=model_name, pretrained=pretrained)
    
    return model

## Acquire Features Helper

In [8]:
def get_features(model=None, dataloader=None, num_features=None):
    model.to(DEVICE)
    model.eval()

    y_pred = torch.zeros(1, num_features).to(DEVICE)
    for X in dataloader:
        X = X.to(DEVICE)
        with torch.no_grad():
            output = model(X)
        y_pred = torch.cat((y_pred, output.view(-1, num_features)), dim=0)
    
    return y_pred[1:].detach().cpu().numpy()

## Obtain and Save Features

In [9]:
def save_features():
    images = np.load("../input/petfinder-pretrained-images-ccropped/Images_224x224.npy")
    dataloader = build_dataloader(images=images, transform=TRANSFORM_PRE)
    
    model_names  = ["mobilenet", "densenet169", "resnet50", "vgg16"]
    num_features = [1280, 1664, 2048, 2048]
    
    for i in range(len(model_names)):
        start_time = time()
        model = build_model(model_name=model_names[i], pretrained=True)    
        features = get_features(model, dataloader, num_features=num_features[i])
        np.save("./{}_features.npy".format(model_names[i]), features)
        print("{}, {:.2f} minutes".format(model_names[i].capitalize(), (time()-start_time)/60))
    breaker()

In [10]:
save_features()

Downloading: "https://download.pytorch.org/models/mobilenet_v2-b0353104.pth" to /root/.cache/torch/hub/checkpoints/mobilenet_v2-b0353104.pth


  0%|          | 0.00/13.6M [00:00<?, ?B/s]

Mobilenet, 0.32 minutes


Downloading: "https://download.pytorch.org/models/densenet169-b2777c0a.pth" to /root/.cache/torch/hub/checkpoints/densenet169-b2777c0a.pth


  0%|          | 0.00/54.7M [00:00<?, ?B/s]

Densenet169, 0.41 minutes


Downloading: "https://download.pytorch.org/models/resnet50-19c8e357.pth" to /root/.cache/torch/hub/checkpoints/resnet50-19c8e357.pth


  0%|          | 0.00/97.8M [00:00<?, ?B/s]

Resnet50, 0.37 minutes


Downloading: "https://download.pytorch.org/models/vgg16_bn-6c64b313.pth" to /root/.cache/torch/hub/checkpoints/vgg16_bn-6c64b313.pth


  0%|          | 0.00/528M [00:00<?, ?B/s]

Vgg16, 0.90 minutes

**************************************************



In [11]:
breaker()
print("Notebook Run Time : {:.2f} minutes".format((time()-notebook_start_time)/60))
breaker()


**************************************************

Notebook Run Time : 2.25 minutes

**************************************************

