In [15]:
import os
import re
import csv
import random
from PIL import Image
import numpy as np
import pandas as pd
import torch
import torchvision.models as models
import torchvision.transforms as transforms
import torch.nn as nn

from google.protobuf import text_format

In [16]:
from modules import model

In [17]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [18]:
def get_vector(input_image):
    image = input_image.convert("RGB")  # in case input image is not in RGB format
    img_t = transform(image)
    batch_t = torch.unsqueeze(img_t, 0).to(DEVICE)
    my_embedding = base_model(batch_t)
    return my_embedding.squeeze().cpu().data.numpy()

In [20]:
# loading the trained model and generating embedding based on that
base_model = models.resnet18(pretrained=True).to(DEVICE)
for param in base_model.parameters():
    param.requires_grad = False
num_ftrs = base_model.fc.in_features
base_model.fc = nn.Sequential(nn.Linear(num_ftrs, 256), nn.Linear(256, 128))
base_model = base_model.to(DEVICE)

# loading the trained model with trained weights
checkpoint = torch.load("./weights/big_dataset_2/model_best.pth")
base_model.load_state_dict(checkpoint['state_dict'])
base_model.eval()

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [32]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )])
im_path = './dataset/big_dataset/train/'
im_names = [os.path.join(root, name)
            for root, dirs, files in os.walk(im_path)
            for name in files]


im_path = "./dataset/big_dataset/val/"
val_im_names = [os.path.join(root, name)
            for root, dirs, files in os.walk(im_path)
            for name in files]
random.shuffle(val_im_names)
im_names.extend(val_im_names)

print(len(im_names))


random.shuffle(im_names)
im_names = im_names
existing_images_df = pd.DataFrame([[re.findall(r"[\w']+", im_name)[1] + "_" + re.findall(r"[\w']+", im_name)[3], re.findall(r"[\w']+", im_name)[2] + "_" + re.findall(r"[\w']+", im_name)[4]] for im_name in im_names],
                                  columns=['cat_id', 'pid'])
# existing_images_df = pd.DataFrame([re.findall(r"[\w']+", im_name)[2:4] for im_name in im_names],
#                                   columns=['cat_id', 'pid'])
existing_images_df['impath'] = im_names
vecs = [list(get_vector(Image.open(impath))) for _, pid, impath in existing_images_df.values]
if not os.path.exists("./vis/vis_big"):
    os.makedirs("./vis_merged")
with open('./vis/vis_big/feature_vecs.tsv', 'w') as fw:
    csv_writer = csv.writer(fw, delimiter='\t')
    csv_writer.writerows(vecs)

7122


In [33]:
images = [Image.open(filename).resize((50,50)) for filename in existing_images_df['impath']]
image_width, image_height = images[0].size
one_square_size = int(np.ceil(np.sqrt(len(images))))
master_width = (image_width * one_square_size) 
master_height = image_height * one_square_size
spriteimage = Image.new(
    mode='RGBA',
    size=(master_width, master_height),
    color=(0,0,0,0))  # fully transparent
for count, image in enumerate(images):
    div, mod = divmod(count,one_square_size)
    h_loc = image_width*div
    w_loc = image_width*mod    
    spriteimage.paste(image,(w_loc,h_loc))
spriteimage.convert("RGB").save('./vis/vis_big/sprite.jpg', transparency=0)

In [22]:
metadata = existing_images_df[['cat_id', 'pid']].to_csv('./vis/vis_big/metadata.tsv', sep='\t', index=False)

In [23]:
existing_images_df.head()

Unnamed: 0,cat_id,pid,impath


In [94]:
len(im_names)

1000