In [1]:
# Copyright 2021 NVIDIA Corporation. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

<img src="http://developer.download.nvidia.com/compute/machine-learning/frameworks/nvidia_logo.png" style="width: 90px; float: right;">


# Movie Poster Feature Extraction with ResNet

In this notebook, we will use a pretrained ResNet-50 network to extract image features from the movie poster images.

Note: this notebook should be executed from within the nvidia_resnet50 container, built as follows (make sure you execute these commands in the directory where this notebook resides, else you might encounter import issues)
```
git clone https://github.com/NVIDIA/DeepLearningExamples
git checkout 5d6d417ff57e8824ef51573e00e5e21307b39697
cd DeepLearningExamples/PyTorch/Classification/ConvNets
docker build . -t nvidia_resnet50
```

Start the container, mounting the current directory where this notebook resides:
```
docker run --gpus=all -it --rm --net=host --ipc=host -v $PWD:/workspace  -w /workspace --ipc=host nvidia_resnet50
```

Then from within the container:

```
jupyter notebook --allow-root --no-browser --NotebookApp.token='' --ip='0.0.0.0'
```

Let's install an extra package that will help us visualize what is happening as we extract the data.

In [2]:
!pip install ipywidgets tqdm
import IPython

IPython.Application.instance().kernel.do_shutdown(True)



{'status': 'ok', 'restart': True}

# Setting up model

In [1]:
from PIL import Image
import argparse
import numpy as np
import json
import torch
from torch.cuda.amp import autocast
import torch.backends.cudnn as cudnn

import sys
sys.path.append('/workspace/DeepLearningExamples/PyTorch/Classification/ConvNets')
from image_classification import models
import torchvision.transforms as transforms

In [2]:
from image_classification.models import resnet50

In [9]:
def load_jpeg_from_file(path, image_size, cuda=True):
    img_transforms = transforms.Compose(
        [
            transforms.Resize(image_size + 32),
            transforms.CenterCrop(image_size),
            transforms.ToTensor(),
        ]
    )

    img = img_transforms(Image.open(path))
    with torch.no_grad():
        mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
        std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)

        if cuda:
            mean = mean.cuda()
            std = std.cuda()
            img = img.cuda()
        img = img.float()

        if img.shape[0] == 1: #mono image
            #pad channels
            img = img.repeat([3, 1, 1])
        input = img.unsqueeze(0).sub_(mean).div_(std)

    return input

def check_quant_weight_correctness(checkpoint_path, model):
    state_dict = torch.load(checkpoint_path, map_location=torch.device('cpu'))
    state_dict = {k[len("module."):] if k.startswith("module.") else k: v for k, v in state_dict.items()}
    quantizers_sd_keys = {f'{n[0]}._amax' for n in model.named_modules() if 'quantizer' in n[0]}
    sd_all_keys = quantizers_sd_keys | set(model.state_dict().keys())
    assert set(state_dict.keys()) == sd_all_keys, (f'Passed quantized architecture, but following keys are missing in '
                                                   f'checkpoint: {list(sd_all_keys - set(state_dict.keys()))}')
    
model_args = {}
model_args["pretrained_from_file"] = './nvidia_resnet50_200821.pth.tar'
model = resnet50(model_args)

model = model.cuda()
model.eval()

Downloading: "https://api.ngc.nvidia.com/v2/models/nvidia/resnet50_pyt_amp/versions/20.06.0/files/nvidia_resnet50_200821.pth.tar" to /root/.cache/torch/hub/checkpoints/nvidia_resnet50_200821.pth.tar


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=102491118.0), HTML(value='')))




ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layers): Sequential(
    (0): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downsample): Sequential(
          (0): Conv2d

# Feature Extraction

Next, we will extract feature for all movie posters, using the last layer just before the classification head, containing 2048 float values.

In [10]:
import glob

filelist = glob.glob('data/poster_images/*.jpg')
len(filelist)

1592

In [11]:
filelist[:10]

['./poster_images/899.jpg',
 './poster_images/254.jpg',
 './poster_images/1069.jpg',
 './poster_images/1004.jpg',
 './poster_images/1234.jpg',
 './poster_images/535.jpg',
 './poster_images/242.jpg',
 './poster_images/442.jpg',
 './poster_images/122.jpg',
 './poster_images/1377.jpg']

In [12]:
from tqdm import tqdm

batchsize = 64
num_bathces = len(filelist)//batchsize
batches = np.array_split(filelist, num_bathces)

In [16]:
### strip the last layer
feature_extractor = torch.nn.Sequential(*list(model.children())[:-1])

feature_dict = {}
error = 0
for batch in tqdm(batches):
    inputs = []
    imgs = []
    for i, f in enumerate(batch):
        try:
            img = load_jpeg_from_file(f, 224, cuda=True)
            imgs.append(f.split('/')[-1].split('.')[0])
            inputs.append(img.squeeze())
        except Exception as e:
            print(e)
            error +=1
    features = feature_extractor(torch.stack(inputs, dim=0)).cpu().detach().numpy().squeeze()  
    for i, f in enumerate(imgs):
        feature_dict[int(f)] =features[i,:]

print('Unable to extract features for %d images'%error)

  0%|          | 0/24 [00:00<?, ?it/s]

cannot identify image file './poster_images/1069.jpg'
cannot identify image file './poster_images/242.jpg'
cannot identify image file './poster_images/1349.jpg'
cannot identify image file './poster_images/478.jpg'
cannot identify image file './poster_images/74.jpg'
cannot identify image file './poster_images/317.jpg'
cannot identify image file './poster_images/450.jpg'


  4%|▍         | 1/24 [00:00<00:10,  2.11it/s]

cannot identify image file './poster_images/427.jpg'
cannot identify image file './poster_images/141.jpg'
cannot identify image file './poster_images/448.jpg'
cannot identify image file './poster_images/905.jpg'
cannot identify image file './poster_images/312.jpg'
cannot identify image file './poster_images/39.jpg'
cannot identify image file './poster_images/1361.jpg'


  8%|▊         | 2/24 [00:00<00:10,  2.06it/s]

cannot identify image file './poster_images/939.jpg'
cannot identify image file './poster_images/228.jpg'
cannot identify image file './poster_images/616.jpg'
cannot identify image file './poster_images/361.jpg'
cannot identify image file './poster_images/484.jpg'
cannot identify image file './poster_images/470.jpg'
cannot identify image file './poster_images/454.jpg'
cannot identify image file './poster_images/1133.jpg'


 12%|█▎        | 3/24 [00:01<00:10,  1.95it/s]

cannot identify image file './poster_images/367.jpg'


 17%|█▋        | 4/24 [00:01<00:09,  2.05it/s]

cannot identify image file './poster_images/763.jpg'
cannot identify image file './poster_images/802.jpg'
cannot identify image file './poster_images/838.jpg'
cannot identify image file './poster_images/655.jpg'
cannot identify image file './poster_images/837.jpg'
cannot identify image file './poster_images/984.jpg'
cannot identify image file './poster_images/220.jpg'
cannot identify image file './poster_images/263.jpg'
cannot identify image file './poster_images/525.jpg'
cannot identify image file './poster_images/211.jpg'
cannot identify image file './poster_images/570.jpg'
cannot identify image file './poster_images/133.jpg'


 21%|██        | 5/24 [00:02<00:09,  2.10it/s]

cannot identify image file './poster_images/1196.jpg'
cannot identify image file './poster_images/608.jpg'
cannot identify image file './poster_images/293.jpg'
cannot identify image file './poster_images/690.jpg'
cannot identify image file './poster_images/565.jpg'
cannot identify image file './poster_images/713.jpg'
cannot identify image file './poster_images/591.jpg'
cannot identify image file './poster_images/1229.jpg'


 25%|██▌       | 6/24 [00:03<00:09,  1.96it/s]

cannot identify image file './poster_images/1061.jpg'
cannot identify image file './poster_images/1673.jpg'
cannot identify image file './poster_images/940.jpg'
cannot identify image file './poster_images/521.jpg'
cannot identify image file './poster_images/654.jpg'
cannot identify image file './poster_images/1449.jpg'
cannot identify image file './poster_images/124.jpg'
cannot identify image file './poster_images/1252.jpg'
cannot identify image file './poster_images/485.jpg'


 29%|██▉       | 7/24 [00:03<00:10,  1.67it/s]

cannot identify image file './poster_images/663.jpg'
cannot identify image file './poster_images/557.jpg'
cannot identify image file './poster_images/847.jpg'
cannot identify image file './poster_images/436.jpg'
cannot identify image file './poster_images/1512.jpg'
cannot identify image file './poster_images/1489.jpg'
cannot identify image file './poster_images/725.jpg'


 33%|███▎      | 8/24 [00:04<00:10,  1.56it/s]

cannot identify image file './poster_images/1119.jpg'
cannot identify image file './poster_images/324.jpg'
cannot identify image file './poster_images/1275.jpg'
cannot identify image file './poster_images/1048.jpg'
cannot identify image file './poster_images/632.jpg'
cannot identify image file './poster_images/716.jpg'


 38%|███▊      | 9/24 [00:05<00:09,  1.65it/s]

cannot identify image file './poster_images/76.jpg'
cannot identify image file './poster_images/519.jpg'
cannot identify image file './poster_images/572.jpg'
cannot identify image file './poster_images/128.jpg'


 42%|████▏     | 10/24 [00:05<00:08,  1.70it/s]

cannot identify image file './poster_images/467.jpg'
cannot identify image file './poster_images/872.jpg'
cannot identify image file './poster_images/531.jpg'
cannot identify image file './poster_images/113.jpg'
cannot identify image file './poster_images/212.jpg'
cannot identify image file './poster_images/23.jpg'
cannot identify image file './poster_images/207.jpg'


 46%|████▌     | 11/24 [00:06<00:07,  1.76it/s]

cannot identify image file './poster_images/227.jpg'
cannot identify image file './poster_images/1453.jpg'
cannot identify image file './poster_images/180.jpg'
cannot identify image file './poster_images/731.jpg'
cannot identify image file './poster_images/609.jpg'
cannot identify image file './poster_images/1188.jpg'
cannot identify image file './poster_images/88.jpg'
cannot identify image file './poster_images/494.jpg'


 50%|█████     | 12/24 [00:06<00:07,  1.61it/s]

cannot identify image file './poster_images/1143.jpg'
cannot identify image file './poster_images/617.jpg'
cannot identify image file './poster_images/611.jpg'


 54%|█████▍    | 13/24 [00:07<00:06,  1.63it/s]

cannot identify image file './poster_images/709.jpg'
cannot identify image file './poster_images/1628.jpg'
cannot identify image file './poster_images/1097.jpg'
cannot identify image file './poster_images/650.jpg'
cannot identify image file './poster_images/647.jpg'


 58%|█████▊    | 14/24 [00:07<00:05,  1.77it/s]

cannot identify image file './poster_images/432.jpg'
cannot identify image file './poster_images/157.jpg'
cannot identify image file './poster_images/977.jpg'
cannot identify image file './poster_images/1295.jpg'
cannot identify image file './poster_images/821.jpg'
cannot identify image file './poster_images/77.jpg'
cannot identify image file './poster_images/828.jpg'


 62%|██████▎   | 15/24 [00:08<00:05,  1.57it/s]

cannot identify image file './poster_images/134.jpg'
cannot identify image file './poster_images/1334.jpg'
cannot identify image file './poster_images/1646.jpg'
cannot identify image file './poster_images/965.jpg'
cannot identify image file './poster_images/1383.jpg'
cannot identify image file './poster_images/651.jpg'


 67%|██████▋   | 16/24 [00:09<00:04,  1.65it/s]

cannot identify image file './poster_images/560.jpg'
cannot identify image file './poster_images/1394.jpg'
cannot identify image file './poster_images/1109.jpg'


 71%|███████   | 17/24 [00:09<00:04,  1.61it/s]

cannot identify image file './poster_images/1505.jpg'
cannot identify image file './poster_images/1006.jpg'
cannot identify image file './poster_images/1461.jpg'
cannot identify image file './poster_images/1558.jpg'
cannot identify image file './poster_images/64.jpg'


 75%|███████▌  | 18/24 [00:10<00:03,  1.70it/s]

cannot identify image file './poster_images/718.jpg'
cannot identify image file './poster_images/1120.jpg'
cannot identify image file './poster_images/701.jpg'
cannot identify image file './poster_images/580.jpg'


 79%|███████▉  | 19/24 [00:10<00:02,  1.80it/s]

cannot identify image file './poster_images/29.jpg'
cannot identify image file './poster_images/182.jpg'
cannot identify image file './poster_images/1504.jpg'
cannot identify image file './poster_images/1531.jpg'
cannot identify image file './poster_images/682.jpg'
cannot identify image file './poster_images/1240.jpg'


 83%|████████▎ | 20/24 [00:11<00:02,  1.84it/s]

cannot identify image file './poster_images/176.jpg'
cannot identify image file './poster_images/223.jpg'
cannot identify image file './poster_images/653.jpg'
cannot identify image file './poster_images/1084.jpg'
cannot identify image file './poster_images/737.jpg'
cannot identify image file './poster_images/1190.jpg'
cannot identify image file './poster_images/127.jpg'
cannot identify image file './poster_images/371.jpg'
cannot identify image file './poster_images/187.jpg'


 88%|████████▊ | 21/24 [00:11<00:01,  1.95it/s]

cannot identify image file './poster_images/158.jpg'
cannot identify image file './poster_images/924.jpg'
cannot identify image file './poster_images/225.jpg'
cannot identify image file './poster_images/183.jpg'
cannot identify image file './poster_images/1064.jpg'
cannot identify image file './poster_images/1469.jpg'
cannot identify image file './poster_images/820.jpg'
cannot identify image file './poster_images/620.jpg'
cannot identify image file './poster_images/513.jpg'


 92%|█████████▏| 22/24 [00:12<00:01,  1.99it/s]

cannot identify image file './poster_images/185.jpg'
cannot identify image file './poster_images/380.jpg'
cannot identify image file './poster_images/67.jpg'
cannot identify image file './poster_images/882.jpg'
cannot identify image file './poster_images/92.jpg'
cannot identify image file './poster_images/349.jpg'
cannot identify image file './poster_images/426.jpg'
cannot identify image file './poster_images/194.jpg'
cannot identify image file './poster_images/1202.jpg'
cannot identify image file './poster_images/613.jpg'
cannot identify image file './poster_images/1549.jpg'
cannot identify image file './poster_images/665.jpg'
cannot identify image file './poster_images/962.jpg'


 96%|█████████▌| 23/24 [00:12<00:00,  2.08it/s]

cannot identify image file './poster_images/700.jpg'
cannot identify image file './poster_images/1424.jpg'
cannot identify image file './poster_images/1210.jpg'
cannot identify image file './poster_images/1367.jpg'
cannot identify image file './poster_images/115.jpg'


100%|██████████| 24/24 [00:13<00:00,  1.79it/s]

Unable to extract features for 159 images





We obtain a feature dictionary with keys being the movie indices.

The values are the image descriptors with a dimensionality of 2048.

In [23]:
feature_dict[1].shape

(2048,)

Let's store these descriptors so that we can use them as input to our model.

In [41]:
import pickle
with open('data/movies_poster_features.pkl', 'wb') as f:
    pickle.dump(feature_dict, f)