# Imports

In [None]:
import os, sys
from os.path import dirname, abspath, join
from PIL import Image
from IPython.display import display
from collections import Counter
from rich.progress import track

In [None]:
import numpy as np
import torch
from torchvision.utils import make_grid
from torchvision.transforms import ToPILImage, ToTensor, Pad

In [None]:
PARENT = dirname(abspath(''))
sys.path.append(PARENT)
from learnable_typewriter.utils.loading import load_pretrained_model
from learnable_typewriter.evaluate.qualitative.decompositor import Decompositor
from learnable_typewriter.typewriter.inference import inference 
from learnable_typewriter.evaluate.quantitative.sprite_matching.metrics import error_rate 
from learnable_typewriter.data.dataloader import collate_fn_pad_to_max

In [None]:
%env CUDA_VISIBLE_DEVICES="" # cpu-inference
RUNS = abspath(join(PARENT, "runs/"))
PLOT_DIR='plots/'

# Helper Functions

In [None]:
def plot(x, tag, ext='.png'):
    if not isinstance(x, list):
        x = [x]
        tag = [tag]
    else:
        assert isinstance(tag, list)

    for im, t in zip(x, tag):
        display(im)
        if PLOT_DIR is not None:
            path = join(PLOT_DIR, t + ext)
            os.makedirs(dirname(path), exist_ok=True)
            im.save(path)

In [None]:
def usage_order(trainer):
    if not hasattr(trainer, 'stats'):
        stats = Counter()
        loader = trainer.get_dataloader(split='train', batch_size=1, num_workers=trainer.n_workers, shuffle=True, remove_crop=True)[0]
        for sample in track(loader, description='collecting stats'):
            stats += Counter(e for p in inference(trainer.model, sample, aggregate=not trainer.unsupervised) for e in p)
        trainer.stats = stats

    stats = trainer.stats
    def key(k):
        return stats.get(k, 0)

    return sorted(range(len(trainer.model.sprites)), key=key, reverse=True)

def get_order(trainer, masks, which, order_type=None, ignore_sprites=None):
    if order_type is None:       
        order = list(range(len(trainer.model.sprites)))
    else:
        order = usage_order(trainer)
        
    if ignore_sprites is not None and ignore_sprites.startswith('top-'):
        k = int(ignore_sprites.replace('top-', ''))
        order = order[:k]
        
    if order_type == 'reorder':
        order = sorted(order)

    return order

def plot_sprites(trainer, which, n_row=4, ignore_sprites=None, n_z=0, order_type=None):
    trainer.model
    masks = 1-trainer.model.sprites.masks
    order = get_order(trainer, masks, which, ignore_sprites=ignore_sprites, order_type=order_type)
    pad = n_z * [torch.ones_like(masks[0].unsqueeze(0))]
    masks = torch.cat([masks[i].unsqueeze(0) for i in order] + pad, dim=0)
    return ToPILImage()(make_grid(masks, nrow=n_row, padding_value=1, padding=2))

def plot_colored_sprites(trainer, n_row=4, ignore_sprites=[], order_type=None, which='copiale', n_z=0):
    masks = trainer.model.masks
    colors = torch.Tensor(trainer.decompositor.colors)
    colors = colors.unsqueeze(-1).unsqueeze(-1)  #size (K,3,1,1)
    colors = 1 - colors.expand(masks.size()[0], 3, *masks.size()[2:]) #size (K,3,H_sprite,W_sprite)
    masks = (1 - colors*masks.expand(-1, 3, -1, -1))
    order = get_order(trainer, masks, which, ignore_sprites=ignore_sprites, order_type=order_type)

    pad = n_z * [torch.ones_like(masks[0].unsqueeze(0))]
    masks = torch.cat([masks[i].unsqueeze(0) for i in order] + pad, dim=0)
    return ToPILImage()(make_grid(masks, nrow=n_row, padding_value=1, padding=2))

def seg(trainer, test_idx):
    topil = ToPILImage()
    dataset = trainer.test_loader[0].dataset
    x = collate_fn_pad_to_max([dataset[test_idx]], supervised=True)
    obj = trainer.decompositor(x)
    gt = topil(x['x'].cpu()[0])
    rec = topil(obj['reconstruction'].cpu()[0])
    sgm = topil(obj['segmentation'].cpu()[0])
    return gt, rec, sgm

def resize_w(img, w):
    img = img.convert('RGB')
    wpercent = (w/float(img.size[0]))
    h = int((float(img.size[1])*float(wpercent)))
    img = img.resize((w, h), Image.Resampling.LANCZOS)
    return img

def add_pad_h(img, pad):
    pad = Pad([0, 0, 0, pad], fill=255, padding_mode='constant')
    return pad(img)

def crop(x, w=None):
    if w is None:
        return x
    else:
        width, height = x.size 
        return x.crop((w[0], 0, w[1], height))

def assort(imgs):
    totensor = ToTensor()
    topil = ToPILImage()
    grid = torch.cat([totensor(img) for img in imgs], dim=1)
    return topil(grid)

In [None]:
def paper(trainer, which, test_idx=0, place=['gt', 'sgm', 'sprites'], pad_h=[3, 3, 0], n_row=4, ignore_sprites=None, crop_w=None, n_z=0, order_type=None):
    sprites = plot_sprites(trainer, which=which, n_row=n_row, ignore_sprites=ignore_sprites, order_type=order_type, n_z=n_z)
    gt, rec, sgm = seg(trainer, test_idx)
    figs = {'sprites': sprites, 'gt': crop(gt, crop_w), 'rec': crop(rec, crop_w), 'sgm': crop(sgm, crop_w)}
    native_w = figs['gt'].size[0]
    figs['sprites'] = resize_w(figs['sprites'], native_w)
    
    return [figs[p] for p in place]

def teaser(trainer, test_idx=0, place=['gt', 'rec', 'sgm'], pad_h=[3, 3, 0], n_row=4):
    gt, rec, sgm = seg(trainer, test_idx)
    figs = {'gt': gt, 'rec': rec, 'sgm': sgm}
    return assort([add_pad_h(figs[p], h) for p, h in zip(place, pad_h)])

# Google

In [None]:
import random
trainer = {'supervised': load_pretrained_model(path=join(RUNS, 'google/supervised/'), device=None), 'unsupervised': load_pretrained_model(path=join(RUNS, 'google/unsupervised/'), device=None)}
random.seed(24)
random.shuffle(trainer['unsupervised'].decompositor.colors)

### Paper

In [None]:
plot(paper(trainer['supervised'], which='google', test_idx=5, n_row=10, ignore_sprites='top-60', order_type='reorder', crop_w=[0, 690]), tag=['paper/fig3la', 'paper/fig3lb', 'paper/fig3ra'])

In [None]:
plot(paper(trainer['unsupervised'], which='google', test_idx=5, n_row=10, ignore_sprites=None, order_type='usage', crop_w=[0, 690]), tag=['paper/fig3la', 'paper/fig3lc', 'paper/fig3rb'])

### Supmat

In [None]:
plot(plot_colored_sprites(trainer['supervised'], ignore_sprites=None, n_row=21, which='google', n_z=1), tag='supmat/fig1/supervised/sprites')

In [None]:
plot(plot_colored_sprites(trainer['unsupervised'], ignore_sprites=None, n_row=20, which='google', order_type='eow'), tag='supmat/fig1/unsupervised/sprites')

In [None]:
np.random.seed(42)
N = len(trainer['supervised'].test_loader[0].dataset)
for idx, i in enumerate(np.random.choice(N, size=20, replace=False)):
    for tag in ['supervised', 'unsupervised']:
        plot(teaser(trainer[tag], i), tag=f'supmat/fig1/{tag}/{idx}')

# Copiale

In [None]:
trainer = {'supervised': load_pretrained_model(path=join(RUNS, 'copiale/supervised/'), device=None), 'unsupervised': load_pretrained_model(path=join(RUNS, 'copiale/unsupervised/'), device=None)}

### Paper

In [None]:
# [433, 1275]
# 54
plot(paper(trainer['supervised'], crop_w=[15, 585], which='copiale', ignore_sprites='top-' + str(9*12), test_idx=6, n_row=18, order_type='usage'), tag=['paper/fig4la', 'paper/fig4lb', 'paper/fig4ra'])

In [None]:
plot(paper(trainer['unsupervised'], crop_w=[15, 585], which='copiale', test_idx=6, n_row=18, ignore_sprites='top-' + str(9*12), order_type='usage'), tag=['paper/fig4la', 'paper/fig4lc', 'paper/fig4rb'])

### Supmat

In [None]:
plot(plot_colored_sprites(trainer['supervised'], n_row=28, ignore_sprites=None, which='copiale', order_type='usage'), tag='supmat/fig2/supervised/sprites')

In [None]:
plot(plot_colored_sprites(trainer['unsupervised'], n_row=30, ignore_sprites=None, which='copiale', order_type='usage'), tag='supmat/fig2/unsupervised/sprites')

In [None]:
np.random.seed(42)
N = len(trainer['supervised'].test_loader[0].dataset)
for idx, i in enumerate(np.random.choice(N, size=20, replace=False)):
    for tag in ['supervised', 'unsupervised']:
        plot(teaser(trainer[tag], i), tag=f'supmat/fig2/{tag}/{idx}')