-
Notifications
You must be signed in to change notification settings - Fork 24
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #4 from shelhamer/fix
fix release of pytorch 1.0, semantic seg., and instance seg.
- Loading branch information
Showing
21 changed files
with
434 additions
and
223 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,194 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# Few-Shot Data\n", | ||
"\n", | ||
"Few-shot datasets return tasks consisting of support (image and annotatations), and query (image and ground truth target)." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import os\n", | ||
"import subprocess\n", | ||
"\n", | ||
"root_dir = subprocess.check_output(['git', 'rev-parse', '--show-toplevel']).strip()\n", | ||
"os.chdir(root_dir)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import numpy as np\n", | ||
"import matplotlib.pyplot as plt\n", | ||
"%matplotlib inline\n", | ||
"from PIL import Image, ImageDraw\n", | ||
"\n", | ||
"from torchvision.transforms import Compose\n", | ||
"\n", | ||
"\n", | ||
"from revolver.data.pascal import VOCSemSeg, VOCInstSeg, SBDDSemSeg, SBDDInstSeg\n", | ||
"from revolver.data.seg import MaskSemSeg, MaskInstSeg\n", | ||
"from revolver.data.filter import TargetFilter\n", | ||
"from revolver.data.sparse import SparseSeg\n", | ||
"from revolver.data.interactive import InteractiveSeg\n", | ||
"from revolver.data.conditional import ConditionalSemSeg" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Here are some helpers we'll need to visualize output of datasets." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"def draw_circle(d, r, loc, color='white'):\n", | ||
" '''\n", | ||
" Draw circle of radius r at location loc\n", | ||
" on ImageDraw object d\n", | ||
" d = ImageDraw.Draw(im)\n", | ||
" '''\n", | ||
" y, x = loc[0], loc[1]\n", | ||
" d.ellipse((x-r, y-r, x+r, y+r), fill=tuple(color))\n", | ||
" \n", | ||
"def load_and_show(ds, shot):\n", | ||
" plt.rcParams.update({'font.size': 16})\n", | ||
" \n", | ||
" # get data\n", | ||
" in_ = ds[np.random.choice(range(len(ds)))]\n", | ||
" qry, supp, tgt, _ = in_[0], in_[1:-2], in_[-2], in_[-1]\n", | ||
" \n", | ||
" # plot support\n", | ||
" fig, axes = plt.subplots(1, shot+1, figsize=(30, 10))\n", | ||
" for i, s in enumerate(supp):\n", | ||
" # conditional, qry != supp\n", | ||
" if isinstance(s, tuple):\n", | ||
" im, anno = s[0], s[1]\n", | ||
" # interactive: qry == supp\n", | ||
" else:\n", | ||
" anno = s\n", | ||
" im = qry\n", | ||
" im = np.copy(qry)\n", | ||
" im = Image.fromarray(im.astype(np.uint8))\n", | ||
" d = ImageDraw.Draw(im)\n", | ||
" for loc in zip(*np.where(anno != 0)):\n", | ||
" draw_circle(d, 10, loc[1:], color=ds.palette[loc[0]])\n", | ||
" axes[i].imshow(im)\n", | ||
" axes[i].set_title('Support')\n", | ||
" \n", | ||
" for _, ax in np.ndenumerate(axes):\n", | ||
" ax.set_axis_off()\n", | ||
" \n", | ||
" # plot query image and target\n", | ||
" fig, axes = plt.subplots(1, 2, figsize=(30, 20))\n", | ||
" axes[0].imshow(qry)\n", | ||
" axes[0].set_title('Query')\n", | ||
" tgt = Image.fromarray(tgt.astype(np.uint8))\n", | ||
" tgt.putpalette(ds.palette)\n", | ||
" axes[1].imshow(tgt)\n", | ||
" axes[1].set_title('Target')\n", | ||
" \n", | ||
" for _, ax in np.ndenumerate(axes):\n", | ||
" ax.set_axis_off()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"When the support image and query image are the same, we recover interactive segmentation." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"sem_ds = VOCSemSeg(split='train')\n", | ||
"inst_ds = VOCInstSeg(split='train')\n", | ||
"mask_ds = MaskInstSeg(sem_ds, inst_ds)\n", | ||
"sparse_ds = SparseSeg(mask_ds, count=3)\n", | ||
"inter_ds = InteractiveSeg(mask_ds, sparse_ds)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"load_and_show(inter_ds, 1)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"When the query is a new image, we have a few-shot learning task. \n", | ||
"Here the task is to segment the semantic cateogory indicated by the support annotations." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"shot = 2\n", | ||
"sem_ds = VOCSemSeg(split='train')\n", | ||
"mask_ds = MaskSemSeg(sem_ds)\n", | ||
"support_datasets = [TargetFilter(mask_ds, [c]) for c in range(1, len(sem_ds.classes))]\n", | ||
"sparse_datasets = [SparseSeg(ds, count=3) for ds in support_datasets]\n", | ||
"cond_ds = ConditionalSemSeg(mask_ds, sparse_datasets, shot=shot)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"scrolled": false | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"load_and_show(cond_ds, shot)" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.6.5" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
Oops, something went wrong.