In [8]:
import yaml
import cv2
import torch
import requests
import numpy as np
import gc
import torch.nn.functional as F
import pandas as pd


import torchvision.models as models
import torchvision.transforms as transforms

from PIL import Image
from IPython.display import display, HTML, clear_output
from ipywidgets import widgets, Layout
from io import BytesIO


from maskrcnn_benchmark.config import cfg
from maskrcnn_benchmark.layers import nms
from maskrcnn_benchmark.modeling.detector import build_detection_model
from maskrcnn_benchmark.structures.image_list import to_image_list
from maskrcnn_benchmark.utils.model_serialization import load_state_dict


class FeatureExtractor:
  TARGET_IMAGE_SIZE = [448, 448]
  CHANNEL_MEAN = [0.485, 0.456, 0.406]
  CHANNEL_STD = [0.229, 0.224, 0.225]
  def __init__(self):
    # self._init_processors()
    self.detection_model = self._build_detection_model()
  
  def __call__(self, url):
    with torch.no_grad():
      detectron_features = self.get_detectron_features(url)
    
    return detectron_features
  
  def _build_detection_model(self):

      cfg.merge_from_file('/Users/zetong/image_captioning/content/model_data/detectron_config.yaml')
      cfg.freeze()

      model = build_detection_model(cfg)
      checkpoint = torch.load('/Users/zetong/image_captioning/content/model_data/detectron_model.pth', 
                              map_location=torch.device("cpu"))

      load_state_dict(model, checkpoint.pop("model"))

      model.eval()
      return model
  
  def get_actual_image(self, image_path):
      if image_path.startswith('http'):
          path = requests.get(image_path, stream=True).raw
      else:
          path = image_path
      
      return path

  def _image_transform(self, image_path):
      path = self.get_actual_image(image_path)

      img = Image.open(path)
      im = np.array(img).astype(np.float32)
      im = im[:, :, ::-1]
      im -= np.array([102.9801, 115.9465, 122.7717])
      im_shape = im.shape
      im_size_min = np.min(im_shape[0:2])
      im_size_max = np.max(im_shape[0:2])
      im_scale = float(800) / float(im_size_min)
      # Prevent the biggest axis from being more than max_size
      if np.round(im_scale * im_size_max) > 1333:
           im_scale = float(1333) / float(im_size_max)
      im = cv2.resize(
           im,
           None,
           None,
           fx=im_scale,
           fy=im_scale,
           interpolation=cv2.INTER_LINEAR
       )
      img = torch.from_numpy(im).permute(2, 0, 1)
      return img, im_scale


  def _process_feature_extraction(self, output,
                                 im_scales,
                                 feat_name='fc6',
                                 conf_thresh=0.2):
      batch_size = len(output[0]["proposals"])
      n_boxes_per_image = [len(_) for _ in output[0]["proposals"]]
      score_list = output[0]["scores"].split(n_boxes_per_image)
      score_list = [torch.nn.functional.softmax(x, -1) for x in score_list]
      feats = output[0][feat_name].split(n_boxes_per_image)
      cur_device = score_list[0].device

      feat_list = []

      for i in range(batch_size):
          dets = output[0]["proposals"][i].bbox / im_scales[i]
          scores = score_list[i]

          max_conf = torch.zeros((scores.shape[0])).to(cur_device)

          for cls_ind in range(1, scores.shape[1]):
              cls_scores = scores[:, cls_ind]
              keep = nms(dets, cls_scores, 0.5)
              max_conf[keep] = torch.where(cls_scores[keep] > max_conf[keep],
                                           cls_scores[keep],
                                           max_conf[keep])

          keep_boxes = torch.argsort(max_conf, descending=True)[:100]
          feat_list.append(feats[i][keep_boxes])
      return feat_list
    
  def get_detectron_features(self, image_path):
      im, im_scale = self._image_transform(image_path)
      img_tensor, im_scales = [im], [im_scale]
      current_img_list = to_image_list(img_tensor, size_divisible=32)
      #current_img_list = current_img_list.to('cuda')
      with torch.no_grad():
          output = self.detection_model(current_img_list)
      feat_list = self._process_feature_extraction(output, im_scales, 
                                                  'fc6', 0.2)
      return feat_list[0]

feature_extractor = FeatureExtractor()

import captioning
import captioning.utils.misc
import captioning.models
infos = captioning.utils.misc.pickle_load(open('infos_trans12-best.pkl', 'rb'))
infos['opt'].vocab = infos['vocab']
model = captioning.models.setup(infos['opt'])
model.load_state_dict(torch.load('model-best.pth', map_location=torch.device('cpu')))

image_path = feature_extractor.get_actual_image("/Users/zetong/Desktop/56ED4D4057E2D074E89B0B1DF6ABE4A1.jpg")
image = Image.open(image_path)


In [9]:
image_path

'/Users/zetong/Desktop/56ED4D4057E2D074E89B0B1DF6ABE4A1.jpg'

In [15]:
get_captions(feature_extractor("/Users/zetong/Desktop/4.jpg"))

['a desk with a computer and a chair',
 'a desk with a computer on top of it',
 'a desk with a computer monitor and a chair',
 'a home office with a desk chair and computer',
 'a home office with a desk chair and computer desk']

In [33]:
img_dir = "/Users/zetong/frames/"
files = []
for filename in os.listdir(img_dir):
    if ".jpg" in filename:
        files.append(img_dir + filename)

In [45]:
files.sort(key=lambda x: int(x[len(img_dir):-4]))
files

['/Users/zetong/frames/300.jpg',
 '/Users/zetong/frames/600.jpg',
 '/Users/zetong/frames/900.jpg',
 '/Users/zetong/frames/1200.jpg',
 '/Users/zetong/frames/1500.jpg',
 '/Users/zetong/frames/1800.jpg',
 '/Users/zetong/frames/2100.jpg',
 '/Users/zetong/frames/2400.jpg',
 '/Users/zetong/frames/2700.jpg',
 '/Users/zetong/frames/3000.jpg',
 '/Users/zetong/frames/3300.jpg',
 '/Users/zetong/frames/3600.jpg',
 '/Users/zetong/frames/3900.jpg',
 '/Users/zetong/frames/4200.jpg',
 '/Users/zetong/frames/4500.jpg',
 '/Users/zetong/frames/4800.jpg',
 '/Users/zetong/frames/5100.jpg',
 '/Users/zetong/frames/5400.jpg',
 '/Users/zetong/frames/5700.jpg',
 '/Users/zetong/frames/6000.jpg',
 '/Users/zetong/frames/6300.jpg']

In [51]:
descriptions = []
for f in files:
    descriptions.append(get_captions(feature_extractor(f)))

In [57]:
d1 = [d[0] for d in descriptions]

In [58]:
d1

['a man taking a selfie in front of a store window',
 'two men sitting in front of a computer screen',
 'a man wearing glasses and a black shirt',
 'a man wearing glasses and a tie',
 'a couple of people sitting next to a window',
 'a group of people standing in a room',
 'a man and a woman standing in front of a door',
 'a group of people in a large room',
 'a man sitting in the back of a pick up truck',
 'a group of men standing next to each other',
 'a woman sitting next to a man in a car',
 'a person driving a car on a city street',
 'a man is talking on a cell phone',
 'a couple of cars driving across a desert field',
 'a group of cars driving down a street',
 'a man wearing a pair of sunglasses and a helmet',
 'a man in a plaid shirt and sunglasses',
 'a group of cars driving down a road',
 'a couple of cars parked on top of a sandy beach',
 'a person driving a car down a road',
 'a woman standing next to a bed in a room']