In [None]:
%matplotlib inline
import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns; sns.set()

dataset_path = './data'
anns_file_path = dataset_path + '/' + 'annotations.json'

# Read annotations
with open(anns_file_path, 'r') as f:
    dataset = json.loads(f.read())

categories = dataset['categories']
anns = dataset['annotations']
imgs = dataset['images']
nr_cats = len(categories)
nr_annotations = len(anns)
nr_images = len(imgs)

# Load categories and super categories
cat_names = []
super_cat_names = []
super_cat_ids = {}
super_cat_last_name = ''
nr_super_cats = 0
for cat_it in categories:
    cat_names.append(cat_it['name'])
    super_cat_name = cat_it['supercategory']
    # Adding new supercat
    if super_cat_name != super_cat_last_name:
        super_cat_names.append(super_cat_name)
        super_cat_ids[super_cat_name] = nr_super_cats
        super_cat_last_name = super_cat_name
        nr_super_cats += 1

print('Number of super categories:', nr_super_cats)
print('Number of categories:', nr_cats)
print('Number of annotations:', nr_annotations)
print('Number of images:', nr_images)

!pip install pyyaml==5.1

import torch
TORCH_VERSION = ".".join(torch.__version__.split(".")[:2])
CUDA_VERSION = torch.__version__.split("+")[-1]
print("torch: ", TORCH_VERSION, "; cuda: ", CUDA_VERSION)
# Install detectron2 that matches the above pytorch version
# See https://detectron2.readthedocs.io/tutorials/install.html for instructions
!pip install --ignore-installed detectron2 -f https://dl.fbaipublicfiles.com/detectron2/wheels/cpu/torch$TORCH_VERSION/index.html --use-deprecated=html5lib
#!pip install detectron2 -f https://dl.fbaipublicfiles.com/detectron2/wheels/cu10.1/torch$TORCH_VERSION/index.html
# If there is not yet a detectron2 release that matches the given torch + CUDA version, you need to install a different pytorch.

#exit(0)  # After installation, you may need to "restart runtime" in Colab. This line can also restart runtime


# Some basic setup:
# Setup detectron2 logger
import detectron2
from detectron2.utils.logger import setup_logger
setup_logger()

# import some common libraries
import numpy as np
import os, json, cv2, random
#from google.colab.patches import cv2_imshow

# import some common detectron2 utilities
from detectron2 import model_zoo
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2.utils.visualizer import Visualizer
from detectron2.data import MetadataCatalog, DatasetCatalog

import sklearn
from sklearn.model_selection import train_test_split

imgs = dataset['images']
len_images = len(imgs)

cfg = get_cfg()
cfg.MODEL.DEVICE = 'cpu'
cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")
predictor = DefaultPredictor(cfg)
#print("model config done")

#for i in range(0,len_images):
for i in range(0, 100):
    #print(imgs[i]['file_name'])
    im = cv2.imread('data/' + imgs[i]['file_name'])
    #print(im)
    outputs = predictor(im)
    imgs[i]['objects'] = []
    instances = outputs["instances"]
    detected_class_indexes = instances.pred_classes
    prediction_boxes = instances.pred_boxes

    metadata = MetadataCatalog.get(cfg.DATASETS.TRAIN[0])
    class_catalog = metadata.thing_classes

    for idx, coordinates in enumerate(prediction_boxes):
        class_index = detected_class_indexes[idx]
        class_name = class_catalog[class_index]
        imgs[i]['objects'].append(class_name)

print("model done")

df = pd.DataFrame(imgs)
df_train, df_valid = train_test_split(df, test_size = 0.2, shuffle = True)

GARBAGE = 0
RECYCLING = 1
COMPOST = 2
ABSTAIN = -1

from snorkel.labeling import labeling_function

garbage_phrases = ["plastic", "disposable", "styrofoam"]
recycling_phrases = ["paper", "carton", "can", "glass", "cardboard"]
compost_phrases = ["food"]

# Category-based LFs
@labeling_function()
def lf_garbage_object(x):
    for obj in x.objects:
        for phrase in garbage_phrases:
            if phrase in obj:
                return GARBAGE
    return ABSTAIN


@labeling_function()
def lf_recycling_object(x):
    for obj in x.objects:
        for phrase in recycling_phrases:
            if phrase in obj:
                return RECYCLING
    return ABSTAIN


@labeling_function()
def lf_compost_object(x):
    for obj in x.objects:
        for phrase in compost_phrases:
            if phrase in obj:
                return COMPOST
    return ABSTAIN

from snorkel.labeling import PandasLFApplier

lfs = [
    lf_garbage_object,
    lf_recycling_object,
    lf_compost_object
]

applier = PandasLFApplier(lfs)
L_train = applier.apply(df_train)
L_valid = applier.apply(df_valid)

from snorkel.labeling import LFAnalysis

Y_valid = df_valid.label.values
LFAnalysis(L_valid, lfs).lf_summary(Y_valid)

from snorkel.labeling.model import LabelModel

label_model = LabelModel(cardinality=3, verbose=True)
label_model.fit(L_train, seed=123, lr=0.01, log_freq=10, n_epochs=100)

print(L_valid)

label_model.score(L_valid, Y_valid, metrics=["f1_micro"])

## Training Classifier

#### Create DataLoaders for Classifier

from snorkel.classification import DictDataLoader
from model import SceneGraphDataset, create_model

df_train["labels"] = label_model.predict(L_train)

if sample:
    TRAIN_DIR = "data/VRD/sg_dataset/samples"
else:
    TRAIN_DIR = "data/VRD/sg_dataset/sg_train_images"

dl_train = DictDataLoader(
    SceneGraphDataset("train_dataset", "train", TRAIN_DIR, df_train),
    batch_size=16,
    shuffle=True,
)

dl_valid = DictDataLoader(
    SceneGraphDataset("valid_dataset", "valid", TRAIN_DIR, df_valid),
    batch_size=16,
    shuffle=False,
)

#### Define Model Architecture

import torchvision.models as models

# initialize pretrained feature extractor
cnn = models.resnet18(pretrained=True)
model = create_model(cnn)

### Train and Evaluate Model

from snorkel.classification import Trainer

trainer = Trainer(
    n_epochs=1,  # increase for improved performance
    lr=1e-3,
    checkpointing=True,
    checkpointer_config={"checkpoint_dir": "checkpoint"},
)
trainer.fit(model, [dl_train])

model.score([dl_valid])