# Few-Shot Data

Few-shot datasets return tasks consisting of support (image and annotatations), and query (image and ground truth target).

In [None]:
import os
import subprocess

root_dir = subprocess.check_output(['git', 'rev-parse', '--show-toplevel']).strip()
os.chdir(root_dir)

In [None]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
from PIL import Image, ImageDraw

from torchvision.transforms import Compose


from revolver.data.pascal import VOCSemSeg, VOCInstSeg, SBDDSemSeg, SBDDInstSeg
from revolver.data.seg import MaskSemSeg, MaskInstSeg
from revolver.data.filter import TargetFilter
from revolver.data.sparse import SparseSeg
from revolver.data.interactive import InteractiveSeg
from revolver.data.conditional import ConditionalSemSeg

Here are some helpers we'll need to visualize output of datasets.

In [None]:
def draw_circle(d, r, loc, color='white'):
    '''
    Draw circle of radius r at location loc
    on ImageDraw object d
    d = ImageDraw.Draw(im)
    '''
    y, x = loc[0], loc[1]
    d.ellipse((x-r, y-r, x+r, y+r), fill=tuple(color))
    
def load_and_show(ds, shot):
    plt.rcParams.update({'font.size': 16})
    
    # get data
    in_ = ds[np.random.choice(range(len(ds)))]
    qry, supp, tgt, _ = in_[0], in_[1:-2], in_[-2], in_[-1]
    
    # plot support
    fig, axes = plt.subplots(1, shot+1, figsize=(30, 10))
    for i, s in enumerate(supp):
        # conditional, qry != supp
        if isinstance(s, tuple):
            im, anno = s[0], s[1]
        # interactive: qry == supp
        else:
            anno = s
            im = qry
        im = np.copy(qry)
        im = Image.fromarray(im.astype(np.uint8))
        d = ImageDraw.Draw(im)
        for loc in zip(*np.where(anno != 0)):
            draw_circle(d, 10, loc[1:], color=ds.palette[loc[0]])
        axes[i].imshow(im)
        axes[i].set_title('Support')
    
    for _, ax in np.ndenumerate(axes):
        ax.set_axis_off()
       
    # plot query image and target
    fig, axes = plt.subplots(1, 2, figsize=(30, 20))
    axes[0].imshow(qry)
    axes[0].set_title('Query')
    tgt = Image.fromarray(tgt.astype(np.uint8))
    tgt.putpalette(ds.palette)
    axes[1].imshow(tgt)
    axes[1].set_title('Target')
    
    for _, ax in np.ndenumerate(axes):
        ax.set_axis_off()

When the support image and query image are the same, we recover interactive segmentation.

In [None]:
sem_ds = VOCSemSeg(split='train')
inst_ds = VOCInstSeg(split='train')
mask_ds = MaskInstSeg(sem_ds, inst_ds)
sparse_ds = SparseSeg(mask_ds, count=3)
inter_ds = InteractiveSeg(mask_ds, sparse_ds)

In [None]:
load_and_show(inter_ds, 1)

When the query is a new image, we have a few-shot learning task. 
Here the task is to segment the semantic cateogory indicated by the support annotations.

In [None]:
shot = 2
sem_ds = VOCSemSeg(split='train')
mask_ds = MaskSemSeg(sem_ds)
support_datasets = [TargetFilter(mask_ds, [c]) for c in range(1, len(sem_ds.classes))]
sparse_datasets = [SparseSeg(ds, count=3) for ds in support_datasets]
cond_ds = ConditionalSemSeg(mask_ds, sparse_datasets, shot=shot)

In [None]:
load_and_show(cond_ds, shot)