In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


# Python setup

In [2]:
!python -m pip install pyyaml==5.1
!pip install 'git+https://github.com/facebookresearch/detectron2.git@5aeb252b194b93dc2879b4ac34bc51a31b5aee13'

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting pyyaml==5.1
  Downloading PyYAML-5.1.tar.gz (274 kB)
[K     |████████████████████████████████| 274 kB 4.9 MB/s 
[?25hBuilding wheels for collected packages: pyyaml
  Building wheel for pyyaml (setup.py) ... [?25l[?25hdone
  Created wheel for pyyaml: filename=PyYAML-5.1-cp37-cp37m-linux_x86_64.whl size=44092 sha256=aa058714d177a36be189a2db6916abac00ab5708cbd71a39a96999c95f1a5d0e
  Stored in directory: /root/.cache/pip/wheels/77/f5/10/d00a2bd30928b972790053b5de0c703ca87324f3fead0f2fd9
Successfully built pyyaml
Installing collected packages: pyyaml
  Attempting uninstall: pyyaml
    Found existing installation: PyYAML 6.0
    Uninstalling PyYAML-6.0:
      Successfully uninstalled PyYAML-6.0
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
dask 202

In [3]:
# !rm -rf artwork_captioning
!git clone https://github.com/rbarile17/artwork_captioning.git

Cloning into 'artwork_captioning'...
remote: Enumerating objects: 242, done.[K
remote: Counting objects: 100% (242/242), done.[K
remote: Compressing objects: 100% (174/174), done.[K
remote: Total 242 (delta 120), reused 183 (delta 65), pack-reused 0[K
Receiving objects: 100% (242/242), 1.23 MiB | 11.02 MiB/s, done.
Resolving deltas: 100% (120/120), done.


# Model and dataloaders

In [7]:
from tqdm import tqdm

import numpy as np

from detectron2 import model_zoo
from detectron2.config import get_cfg
from detectron2.data import DatasetCatalog

import torch

In [8]:
from detectron2.modeling import build_model
from detectron2.checkpoint import DetectionCheckpointer

cfg = get_cfg()
cfg.merge_from_file(model_zoo.get_config_file("COCO-Detection/faster_rcnn_R_101_FPN_3x.yaml"))
# cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-Detection/faster_rcnn_R_101_FPN_3x.yaml")
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.6

cfg.MODEL.WEIGHTS = './drive/MyDrive/artwork_captioning/object_detection_output/adaIn/model_final.pth'
model = build_model(cfg)
DetectionCheckpointer(model).load(cfg.MODEL.WEIGHTS)
print('model loaded')

model loaded


In [9]:
def get_object_features(model, batch):
    images = model.preprocess_image(batch)  # don't forget to preprocess
    features = model.backbone(images.tensor)  # set of cnn features
    
    proposals, _ = model.proposal_generator(images, features, None)  # RPN outputs boxes

    features_ = [features[f] for f in model.roi_heads.box_in_features] # arrange features as a list
    box_features = model.roi_heads.box_pooler(features_, [x.proposal_boxes for x in proposals])
    box_features = model.roi_heads.box_head(box_features)  # features of all 1k candidates

    predictions = model.roi_heads.box_predictor(box_features)
    
    _, pred_inds = model.roi_heads.box_predictor.inference(predictions, proposals)
    # features of the proposed boxes
    feats = [
        image_box_features[image_pred_inds] 
        for image_box_features, image_pred_inds in zip(torch.tensor_split(box_features, len(batch)), pred_inds)]

    return feats

In [10]:
import os
os.environ['DETECTRON2_DATASETS'] = './data'

In [11]:
from detectron2.data import build_detection_test_loader
from artwork_captioning.src.modules.dataset_loading import URLMapper
import json

with open('./drive/MyDrive/artwork_captioning/data/datasets/artpedia.json', 'r') as file:
    artpedia = json.load(file)

artpedia = list(artpedia.values())
artpedia_dataloader = build_detection_test_loader(artpedia, mapper=URLMapper(cfg, is_train=True))

# Object features

In [12]:
import h5py
from tqdm import tqdm
from PIL.Image import DecompressionBombError
from urllib.request import HTTPError

iterator = iter(artpedia_dataloader)
model.eval()
with torch.no_grad():
    with tqdm(desc=f'Progress', unit='iteration', total=len(artpedia_dataloader)) as pbar:
        with h5py.File('./drive/MyDrive/artpedia_detections_adaIn.hdf5', 'w') as h5_file:
            i = 0
            while i < 2930:
                try:
                    batch = next(iterator)
                    if (batch[0]['image'].shape[1] * batch[0]['image'].shape[2]) >= 83132049:
                        with open('./drive/MyDrive/invalid_urls_adaIn.txt', 'a') as invalid_urls:
                            invalid_urls.write(batch[0]['img_url'] + '\n')
                        continue
                    feats = get_object_features(model, batch)[0]
                    h5_features = h5_file.create_dataset(
                        f"{batch[0]['title'].replace(' ', '').lower()}_features", 
                        shape=feats.shape, 
                        dtype=np.float32)
                    for j in range(0, feats.shape[0]):
                        h5_features[j] = feats[j].cpu()
                except DecompressionBombError as e:
                    url = str(e)
                    if url is not None:
                        with open('./drive/MyDrive/invalid_urls_adaIn.txt', 'a') as invalid_urls:
                            invalid_urls.write(url + '\n')
                    continue
                except HTTPError as err:
                    if err.code == 404:
                        with open('./drive/MyDrive/invalid_urls_adaIn.txt', 'a') as invalid_urls:
                            invalid_urls.write(err.url + '\n')
                    else:
                        raise
                finally:
                    i += 1
                    pbar.update()

Progress: 100%|██████████| 2930/2930 [34:07<00:00,  1.43iteration/s]
