Skip to content

Commit

Permalink
Merge pull request #4 from shelhamer/fix
Browse files Browse the repository at this point in the history
fix release of pytorch 1.0, semantic seg., and instance seg.
  • Loading branch information
shelhamer committed Apr 2, 2019
2 parents 386cb43 + 2e68d39 commit 74fc12a
Show file tree
Hide file tree
Showing 21 changed files with 434 additions and 223 deletions.
70 changes: 40 additions & 30 deletions evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,56 +5,63 @@

import click
import numpy as np
from PIL import Image

import torch
from torch.autograd import Variable
import torch.nn as nn

from revolver.data import datasets, datatypes, prepare_loader
from revolver.model import models, prepare_model
from revolver.model.loss import CrossEntropyLoss2D
from revolver.metrics import SegScorer


def evaluate(model, weights, dataset, datatype, split, count, shot, multi, seed, gpu, output):
def evaluate(model, weights, dataset, datatype, split, count, shot, seed, gpu, hist_path, seg_path):
print("evaluating {} with weights {} on {} {}-{}".format(model, weights, datatype, dataset, split))
os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu)
device = torch.device('cuda:0')

torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)

prepare_data = datatypes[datatype]
dataset = prepare_data(dataset, split, count=count, shot=shot, multi=multi)
dataset = prepare_data(dataset, split, count=count, shot=shot)
loader = prepare_loader(dataset, evaluation=True)

model = prepare_model(model, dataset.num_classes, weights=weights).cuda()
model = prepare_model(model, dataset.num_classes, weights=weights).to(device)
model.eval()

loss_fn = CrossEntropyLoss2D(size_average=True, ignore_index=dataset.ignore_index)
loss_fn = nn.CrossEntropyLoss(reduction='mean', ignore_index=dataset.ignore_index)

total_loss = 0.
metrics = SegScorer(len(dataset.classes)) # n.b. this is the full no. of classes, not the no. of model outputs
for i, data in enumerate(loader):
inputs, target, aux = data[:-2], data[-2], data[-1]
inputs = [Variable(inp, volatile=True).cuda() if not isinstance(inp, list) else
[[Variable(i_, volatile=True).cuda() for i_ in in_] for in_ in inp] for inp in inputs]
target = Variable(target, volatile=True).cuda(async=True)

scores = model(*inputs)
loss = loss_fn(scores, target)
total_loss += loss.data[0]

# segmentation evaluation
_, seg = scores.data[0].max(0)
metrics.update(seg.cpu().numpy(), target.data.cpu().numpy(), aux)
with torch.no_grad():
for i, data in enumerate(loader):
inputs, target, aux = data[:-2], data[-2], data[-1]
inputs = [inp.to(device) if not isinstance(inp, list) else
[[i_.to(device) for i_ in in_] for in_ in inp] for inp in inputs]
target = target.to(device)

scores = model(*inputs)
loss = loss_fn(scores, target)
total_loss += loss.item()

# segmentation evaluation
_, seg = scores.data[0].max(0)
metrics.update(seg.to('cpu').numpy(), target.to('cpu').numpy(), aux)
# optionally save segs
if seg_path is not None:
seg = Image.fromarray(seg.to('cpu').numpy().astype(np.uint8), mode='P')
save_id = f"{aux['slug']}_{aux.get('cls', 'all')}_{aux.get('inst', 'all')}"
seg.save(f"{seg_path}/{save_id}.png")

print("loss {}".format(total_loss / len(dataset)))
for metric, score in metrics.score().items():
score = np.nanmean(score)
print("{:10s} {:.3f}".format(metric, score))

if output:
metrics.save(output)
if hist_path is not None:
metrics.save(hist_path)


@click.command()
Expand All @@ -66,10 +73,10 @@ def evaluate(model, weights, dataset, datatype, split, count, shot, multi, seed,
@click.option('--split', type=str, default='valid')
@click.option('--count', type=int, default=None)
@click.option('--shot', type=int, default=1)
@click.option('--multi', is_flag=True, default=False)
@click.option('--save_seg', is_flag=True, default=False)
@click.option('--seed', default=1337)
@click.option('--gpu', default=0)
def main(experiment, model, weights, dataset, datatype, split, count, shot, multi, seed, gpu):
def main(experiment, model, weights, dataset, datatype, split, count, shot, save_seg, seed, gpu):
setproctitle.setproctitle("eval-{}".format(experiment))
args = locals()
print("args: {}".format(args))
Expand All @@ -88,19 +95,22 @@ def main(experiment, model, weights, dataset, datatype, split, count, shot, mult

# template the output path
count_ = 'dense' if count == -1 else "{}sparse".format(count) if count else 'randsparse'
multi_ = '-multi' if multi else ''
output_fmt = '-{}-{}-{}-{}-{}shot-{}{}'.format(dataset, datatype, split, count_, shot, seed, multi_)
output_fmt = exp_dir + 'hist-' + model + '-iter{}' + output_fmt
output_fmt = '-{}-{}-{}-{}-{}shot-{}'.format(dataset, datatype, split, count_, shot, seed)
output_fmt = model + '-iter{}' + output_fmt

for weights in evaluations:
# make output path
iter_ = re.search('iter(\d+).pth', weights).group(1)
output = output_fmt.format(iter_)
hist_path = exp_dir + 'hist-' + output_fmt.format(iter_)
seg_path = None
if save_seg:
seg_path = exp_dir + output_fmt.format(iter_)
os.makedirs(seg_path, exist_ok=True)
# skip existing
if os.path.isfile(output + '.npz'):
print("skipping existing {}".format(output))
if os.path.isfile(hist_path + '.npz'):
print("skipping existing {}".format(hist_path))
continue
evaluate(model, weights, dataset, datatype, split, count, shot, multi, seed, gpu, output)
evaluate(model, weights, dataset, datatype, split, count, shot, seed, gpu, hist_path, seg_path)

if __name__ == '__main__':
main()
194 changes: 194 additions & 0 deletions notebooks/data-fewshot.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Few-Shot Data\n",
"\n",
"Few-shot datasets return tasks consisting of support (image and annotatations), and query (image and ground truth target)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import subprocess\n",
"\n",
"root_dir = subprocess.check_output(['git', 'rev-parse', '--show-toplevel']).strip()\n",
"os.chdir(root_dir)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"%matplotlib inline\n",
"from PIL import Image, ImageDraw\n",
"\n",
"from torchvision.transforms import Compose\n",
"\n",
"\n",
"from revolver.data.pascal import VOCSemSeg, VOCInstSeg, SBDDSemSeg, SBDDInstSeg\n",
"from revolver.data.seg import MaskSemSeg, MaskInstSeg\n",
"from revolver.data.filter import TargetFilter\n",
"from revolver.data.sparse import SparseSeg\n",
"from revolver.data.interactive import InteractiveSeg\n",
"from revolver.data.conditional import ConditionalSemSeg"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Here are some helpers we'll need to visualize output of datasets."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def draw_circle(d, r, loc, color='white'):\n",
" '''\n",
" Draw circle of radius r at location loc\n",
" on ImageDraw object d\n",
" d = ImageDraw.Draw(im)\n",
" '''\n",
" y, x = loc[0], loc[1]\n",
" d.ellipse((x-r, y-r, x+r, y+r), fill=tuple(color))\n",
" \n",
"def load_and_show(ds, shot):\n",
" plt.rcParams.update({'font.size': 16})\n",
" \n",
" # get data\n",
" in_ = ds[np.random.choice(range(len(ds)))]\n",
" qry, supp, tgt, _ = in_[0], in_[1:-2], in_[-2], in_[-1]\n",
" \n",
" # plot support\n",
" fig, axes = plt.subplots(1, shot+1, figsize=(30, 10))\n",
" for i, s in enumerate(supp):\n",
" # conditional, qry != supp\n",
" if isinstance(s, tuple):\n",
" im, anno = s[0], s[1]\n",
" # interactive: qry == supp\n",
" else:\n",
" anno = s\n",
" im = qry\n",
" im = np.copy(qry)\n",
" im = Image.fromarray(im.astype(np.uint8))\n",
" d = ImageDraw.Draw(im)\n",
" for loc in zip(*np.where(anno != 0)):\n",
" draw_circle(d, 10, loc[1:], color=ds.palette[loc[0]])\n",
" axes[i].imshow(im)\n",
" axes[i].set_title('Support')\n",
" \n",
" for _, ax in np.ndenumerate(axes):\n",
" ax.set_axis_off()\n",
" \n",
" # plot query image and target\n",
" fig, axes = plt.subplots(1, 2, figsize=(30, 20))\n",
" axes[0].imshow(qry)\n",
" axes[0].set_title('Query')\n",
" tgt = Image.fromarray(tgt.astype(np.uint8))\n",
" tgt.putpalette(ds.palette)\n",
" axes[1].imshow(tgt)\n",
" axes[1].set_title('Target')\n",
" \n",
" for _, ax in np.ndenumerate(axes):\n",
" ax.set_axis_off()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"When the support image and query image are the same, we recover interactive segmentation."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"sem_ds = VOCSemSeg(split='train')\n",
"inst_ds = VOCInstSeg(split='train')\n",
"mask_ds = MaskInstSeg(sem_ds, inst_ds)\n",
"sparse_ds = SparseSeg(mask_ds, count=3)\n",
"inter_ds = InteractiveSeg(mask_ds, sparse_ds)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"load_and_show(inter_ds, 1)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"When the query is a new image, we have a few-shot learning task. \n",
"Here the task is to segment the semantic cateogory indicated by the support annotations."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"shot = 2\n",
"sem_ds = VOCSemSeg(split='train')\n",
"mask_ds = MaskSemSeg(sem_ds)\n",
"support_datasets = [TargetFilter(mask_ds, [c]) for c in range(1, len(sem_ds.classes))]\n",
"sparse_datasets = [SparseSeg(ds, count=3) for ds in support_datasets]\n",
"cond_ds = ConditionalSemSeg(mask_ds, sparse_datasets, shot=shot)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": false
},
"outputs": [],
"source": [
"load_and_show(cond_ds, shot)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Loading

0 comments on commit 74fc12a

Please sign in to comment.