In [1]:
import json
import os
import sys

In [4]:
# os.chdir('../')

In [107]:
from icecream import ic
import numpy as np
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from torchsummary import summary
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.models import densenet
from tqdm import tqdm

In [6]:
sys.path.append('../dataset/')
import coco_data_prep

### Global Variables

In [7]:
train_np_data_dir = '../data/numpy_imgs/train_subset/'
train_jpg_data_dir = '../data/raw/train/train2014/'
train_annot_filepath = '../data/raw/train/annotations/instances_train2014.json'

with open('../dataset/imgs_by_supercategory.json', 'r') as f:
    desired_categories = json.load(f)

In [21]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

### Load Data

In [166]:
%autoreload

In [167]:
train_ds = coco_data_prep.COCODataset(train_np_data_dir, 
                                      train_annot_filepath,
                                      0.05)

loading annotations into memory...


  5%|██▊                                                      | 4047/82783 [03:19<1:04:46, 20.26it/s]


Done (t=18.01s)
creating index...
index created!


100%|████████████████████████████████████████████████████████| 82783/82783 [00:57<00:00, 1439.57it/s]


In [168]:
train_dl = coco_data_prep.get_dataloader(train_ds, batch_size=250)

### Load Model

#### Densenet121

In [170]:
model = torch.hub.load('pytorch/vision:v0.10.0', 'densenet121', pretrained=True)

Using cache found in /home/ec2-user/.cache/torch/hub/pytorch_vision_v0.10.0


In [172]:
def slice_model(original_model, from_layer=None, to_layer=None):
    return nn.Sequential(*list(original_model.children())[from_layer:to_layer])

In [184]:
model_conv_features = slice_model(model, to_layer=-1).to(device)
# model_2 = torch.nn.DataParallel(model_conv_features, device_ids=[0,1]).cuda()

In [188]:
model_conv_features

Sequential(
  (0): Sequential(
    (conv0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (norm0): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu0): ReLU(inplace=True)
    (pool0): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (denseblock1): _DenseBlock(
      (denselayer1): _DenseLayer(
        (norm1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu1): ReLU(inplace=True)
        (conv1): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu2): ReLU(inplace=True)
        (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (denselayer2): _DenseLayer(
        (norm1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu1): R

In [185]:
features_list = []
labels_list = []

# Use GPUs to speed up the inference, this should take around 10 minutes

model_conv_features.to(device)
for batch in tqdm(train_dl):
    image_batch, label_batch = [x[0] for x in batch], [x[1] for x in batch]
    image_batch = torch.stack(image_batch).to(device) 

    with torch.no_grad():
        features_batch = model_conv_features(image_batch).flatten(start_dim=1)
    features_list.append(features_batch)
    labels_list.extend(label_batch)

100%|████████████████████████████████████████████████████████████████| 17/17 [00:30<00:00,  1.82s/it]


In [186]:
len(features_list)

17

In [189]:
for fs in features_list:
    print(fs.size())

torch.Size([250, 50176])
torch.Size([250, 50176])
torch.Size([250, 50176])
torch.Size([250, 50176])
torch.Size([250, 50176])
torch.Size([250, 50176])
torch.Size([250, 50176])
torch.Size([250, 50176])
torch.Size([250, 50176])
torch.Size([250, 50176])
torch.Size([250, 50176])
torch.Size([250, 50176])
torch.Size([250, 50176])
torch.Size([250, 50176])
torch.Size([250, 50176])
torch.Size([250, 50176])
torch.Size([139, 50176])


In [190]:
torch.save(features_list[0], '../data/torch_embeddings/densenet_pretrained_embs_len_50176.pt')