In [1]:
import logging
import argparse
import pickle

import numpy as np

from pathlib import Path
from tqdm.notebooks import tqdm
from redis import Redis

import torch
from torch.utils.data import DataLoader, sampler

from nnmodels.hash import HashEncoder
from nnmodels.datasets import PharmaPackDataset
from nnmodels.hash.helpers import encode_image_to_uuid

import utils

In [20]:
base_model = 'resnet101'
dir_alg = Path('/ndata/chaban/pharmapack/NN/Complete/MI1/')
redis_db = Redis(host='localhost', port=6379, db=7)
dataset = PharmaPackDataset(dir_alg)
loader = DataLoader(
    dataset=dataset,
    batch_size=800,
    sampler=sampler.SubsetRandomSampler(np.arange(len(dataset))),
    num_workers=0,
    drop_last=False,
    shuffle=False
)

In [4]:
descriptor_lengths = [256, 512, 1024]

In [17]:
hash_model = HashEncoder(base_model, descriptor_lengths)
parallel_model = torch.nn.DataParallel(hash_model)
cuda = torch.cuda.is_available()
parallel_model.to('cuda') if cuda else parallel_model.to('cpu')

DataParallel(
  (module): HashEncoder(
    (features): Sequential(
      (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (4): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
 

In [6]:
def unzip_d(data):
    if not type(data) in (tuple, list):
        data = (data,)
    if cuda:
        data = tuple(d.cuda() for d in data)
    return data

In [None]:
with torch.no_grad():
    parallel_model.eval()

    bar = tqdm(np.arange(len(loader)), desc='Inserting', total=len(loader))

    for batch_idx, (data, filepaths) in enumerate(loader, start=1):
        data = unzip_d(data)

        tensor = parallel_model(*data)

        for descriptor_length, tensor in zip(hash_model.descriptor_lengths, tensor):
            for descriptor, filepath in zip(tensor.cpu().numpy(), filepaths):
                uuid = encode_image_to_uuid(base_model, descriptor_length, Path(filepath))
                redis_db.append(uuid, pickle.dumps(descriptor))

        bar.update()
    bar.close()

In [None]:
redis_db.flush