In [None]:
!pip -q install ..

In [None]:
basepath = None

In [None]:
from bisturi.dataset.broden import BrodenDataset
from bisturi.dataset.broden import BrodenOntology
from random import choices
import os
from torchvision.transforms import ToPILImage
from torchvision.transforms import Resize
import numpy as np
from IPython.display import display
from ipywidgets import widgets, HBox

In [None]:
dset_path = os.path.join(basepath, 'datasets/broden1_224/')
ontology = BrodenOntology(dset_path)
dset = BrodenDataset(dset_path, mean=[1,1,1], std=[1,1,1], ontology=ontology)
concepts = ontology.to_list()

## Aligned concepts

We are interested in selecting the Wordnet Concepts which are directly aligned to Broden labels.

In [None]:
concepts = ontology.to_list()

In [None]:
aligned_concepts = [c for c in concepts if c.original_b_ids]
propagated = [c for c in concepts if c.propagated]

In [None]:
len(aligned_concepts), len(propagated), len(concepts)

## Analyze WordNet Concept

In [None]:
concept_idx = 4

In [None]:
c = aligned_concepts[concept_idx]
b_ids = iter(c.original_b_ids)


print('IDX:', concept_idx, end='\n\n')
print('Name:', c.synset.name(), 'n' + str(c.id)[1:], end='\n\n')
print('Definition:', c.synset.definition(), sep='\n', end='\n\n')
print('Examples:', *c.synset.examples(), sep='\n',end='\n\n')

print('Hypernyms:', *c.synset.hypernyms(), sep='\n',end='\n\n')

try:
    print('Hyponyms:', *choices(c.synset.hyponyms(), k=10), sep='\n',end='\n\n')
except IndexError:
    print('No hyponyms', end='\n\n')

print('Corresponding Broden:')
for broden_id in c.original_b_ids:
    print(broden_id, dset.labels[broden_id]['name'], dset.labels[broden_id]['syns'], sep='\t')
    
    
print('\nSamples:')
samples = dset.reverse_index[c.id]
print(len(samples),'images found.\n')
samples = choices(samples, k=10)

for sample in samples:
    _, img, masks = dset[sample]

    # Image
    to_pil = ToPILImage()
    original = to_pil(img)

    # Mask
    c_mask = masks.get_concept_mask(c)
    if img.shape[1:] == c_mask.shape:
        masked = img * c_mask
    else:
        to_cmask_size = Resize(c_mask.shape)
        to_img_size = Resize(img.shape[1:])
        masked = to_cmask_size(img)
        masked = masked * c_mask
        masked = to_img_size(masked)
    masked = to_pil(masked)

    display(original)
    display(masked)    