In [8]:
import os
import re
import glob
import analysis
import numpy as np
import glob
import h5py

import matplotlib.pyplot as plt
%matplotlib inline

import seaborn as sns
import sys

import matplotlib


import analysis
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [9]:
import numpy as np
import torch
import glob
import pandas as pd

# Analysis Results

The analysis we report in our paper can be reproduced on several levels:

   - Training a new model. This produces a new log in `log/` with model checkpoints (`.pth`) and loss histories (`losshistory.csv`)
   - Generating model predictions on a test set, using a model checkpoint. This generates a `hdf` dataset in `results/`, containing label informations and model predictions.
   - Generating result tables and plots. This is done interactively in this file.


## Digit Datasets

Prediction and label loading:

In [10]:
def load_results():
    with h5py.File("results/digits/digit_predictions.hdf5") as ds:

        for key in ds.keys():

            source, adapt, test = key.split("-")

            preds = ds[key]["prediction"][...].squeeze()
            lbls = ds[key]["labels"][...].squeeze()

            acc = (preds.argmax(axis=-1) == lbls).mean()

            yield [source, adapt, test, acc]
            
df = pd.DataFrame(load_results(), columns = ["source", "adapt", "test", "acc"])

Generate plots:

In [11]:
order = ['mnist', 'usps', 'synth', 'svhn']

from matplotlib.backends.backend_pdf import PdfPages

source_lbl = "Train"
adapt_test_lbl = "Adapt + Test"
test_lbl = "Test"
adapt_lbl = "Adapt"

action = plt.close

with PdfPages("results/digit-plots.pdf") as pdf:
    
    order = ['mnist', 'usps', 'synth', 'svhn']
    adapt = df[df.adapt == df.test].pivot("source", "adapt", "acc")[order].reindex(order)

    sns.set_context('poster')
    plt.figure(figsize=(5,5))
    sns.heatmap(data=100*adapt, cmap = 'Blues', annot = True, fmt='.1f', square=True, linewidths=1, cbar=None, vmin=65, vmax=100)
    
    plt.xlabel(adapt_test_lbl)
    plt.ylabel(source_lbl)
    
    pdf.savefig(bbox_inches="tight")
    
    action()

    for source in order:

        adapt = df[df.source == source].pivot("adapt", "test", "acc")[order].reindex(order)

        sns.set_context('poster')
        plt.figure(figsize=(7,7))
        sns.heatmap(data=100*adapt, cmap = 'Blues', annot = True, fmt='.1f', square=True, linewidths=1, vmin=30, vmax=100)
        plt.title("Train: " +source)
        
        plt.xlabel(test_lbl)
        plt.ylabel(adapt_lbl)
        
        pdf.savefig(bbox_inches="tight")
        action()

## Multi-Task Learning for Noise Adaptation

We evaluate models after 90 epochs of training and adaptation on the SVHN dataset, with varying degrees and types of noise, namely

- Clean data vs. White Noise ("white")
- Clean data vs. Salt and Pepper Noise ("snp")
- White Noise vs. Salt and Pepper Noise ("mixed")

In [12]:
def plot_full_analysis(ACC, imgs, noise_vars, P, title, sym):
    
    titles = ['${}={}$'.format(sym, i) for i in noise_vars]
    
    fig, ax = analysis.plot_overview(ACC, imgs, noise_vars, titles) 
    ax.set_title(title)
    yield fig

    names = ['Mean', 'Variance', u'$\gamma$', u'$\\beta$', 'Scale', 'Shift']

    sns.set_context('poster', font_scale=.9)
    for m, name in zip(P, names):
        m = m.squeeze()

        fig, (ax_angle, ax_acc, ax_tril) = plt.subplots(1,3,figsize=(15,5))
        a = analysis.compute_angle(m)
        analysis.transfer_plot(a,ACC,ax_angle, ax_acc, noise_vars)

        ax_angle.text(s=name, x=-3,y = -.75, color='gray')
        ax_angle.text(s=u'$\sigma_{train} > \sigma_{test}$',x=-.2,y=6.5,color='white')
        ax_angle.text(s=u'$\sigma_{test} > \sigma_{train}$',x=3,y=1.25,color='white')

        #ax_tril, ax_triu = plt.subplots(1,2,figsize=(10,3))[1] #,sharex=True,sharey=True)[1]
        x,y = analysis.get_corr(ACC.T, a, np.triu)
        analysis.plot_reg(x, y, ax=ax_tril)
        ax_tril.set_title('Correlation ($\sigma_{test} > \sigma_{train}$)')
        #x,y = get_corr(ACC, a, np.triu)
        #plot_reg(x, y, ax=ax_triu)
        #plt.suptitle(name)
        #plt.show()

        xticks = ax_tril.get_xticks()
        yticks = ax_tril.get_yticks()
        ax_tril.spines['bottom'].set_bounds(xticks[1],xticks[-2])
        ax_tril.spines['left'].set_bounds(yticks[1],yticks[-2])

        plt.tight_layout()
        yield fig
        
        
def plot_generalization_curves(ACC, noise_vars):
    
    #ACC = logit(ACC)
    
    norm = matplotlib.colors.Normalize(vmin=-noise_vars[0], vmax=noise_vars[-1])
    #cmap = matplotlib.cm.get_cmap('RdGy_r')
    cmap = sns.cubehelix_palette(9, as_cmap=True)
    rgba = lambda x: cmap(norm(x))
    
    sns.set_context("paper", font_scale=2.5)

    fig, axes = plt.subplots(1,3,figsize=(23,3.5), sharey=True)

    axes[0].plot(noise_vars, np.diag(ACC))
    axes[0].set_xlabel(r"Test Noise")
    axes[0].set_ylabel("Accuracy [%]")
    axes[0].set_title("Baseline [Adapt $=$ Test]")

    for i in reversed(range(len(ACC))):
        axes[1].plot(noise_vars, ACC[:,i], color = rgba(.3-noise_vars[i]), linewidth=2)
        axes[1].plot([noise_vars[i]]*2,[0,1], color = rgba(noise_vars[i]), linestyle="--")
        
    #axes[1].plot(noise_vars, np.diag(ACC))
    
    axes[1].set_xlabel(r"Adapt Noise Level")
    axes[1].set_title(r"Adapt $\neq$ Test")

    for i in range(len(ACC)):
        axes[2].plot(noise_vars, ACC[i,:], color = rgba(noise_vars[i]), linewidth=2)
        axes[2].plot([noise_vars[i]]*2,[0,1], color = rgba(noise_vars[i]), linestyle="--")
    axes[2].set_xlabel(r"Test Noise Level")
    axes[2].set_title(r"Adapt $\neq$ Test")

    for ax in axes:
        sns.despine(ax=ax)

In [13]:
# (only for reference)

#! cp ./log/multitask/white/clean/20180820-200745_MultidomainBCESolver/20180820-200745-checkpoint-ep90.pth results/multitask/20180820-200745-checkpoint-ep90-white-clean.pth
#! cp ./log/multitask/white/noise/20180820-200902_MultidomainBCESolver/20180820-200902-checkpoint-ep90.pth results/multitask/20180820-200902-checkpoint-ep90-white-noise.pth

#! cp ./log/multitask/snp/noise/20180905-170020_MultidomainBCESolver/20180905-170020-checkpoint-ep90.pth results/multitask/20180905-170020-checkpoint-ep90-snp-noise.pth
#! cp ./log/multitask/snp/clean/20180905-165820_MultidomainBCESolver/20180905-165820-checkpoint-ep90.pth results/multitask/20180905-165820-checkpoint-ep90-snp-clean.pth

In [14]:
# 2176 x 4 adaptable parameters

## Plot Noise Adaptation Results here

In [17]:
titles = {
    'noise' : 'Adaptation from High to Low',
    'clean' : 'Adaptation from Low to High'
}

from matplotlib.backends.backend_pdf import PdfPages

import datasets_white, datasets_snp

noisemodels = {
    "white" : {"clean" : datasets_white.clean2noise(),
               "noise" : datasets_white.clean2noise()
              },
    "snp" : {"clean" : datasets_snp.clean2noise(),
             "noise" : datasets_snp.clean2noise()
            }
}

tmpl_fname = 'results/noise/*-checkpoint-ep90-{noise}-{source}.{fmt}'

with PdfPages("results/noise-plots.pdf") as pdf:
    for noise, sources in noisemodels.items():
        for source, noisemodel in sources.items():

            assert noise in ["white", "snp"]
            assert source in ["clean", "noise"]

            model_fname, = list(glob.glob(tmpl_fname.format(source=source, noise=noise, fmt="pth")))
            eval_fname,  = list(glob.glob(tmpl_fname.format(source=source, noise=noise, fmt="hdf5")))

            #print(model_fname, os.path.exists(model_fname))
            #print(eval_fname, os.path.exists(eval_fname))

            n_domains  = len(noisemodel)

            revert = (source == 'noise')
            ACC        = analysis.load_file(eval_fname, n_domains=n_domains, revert=revert)
            imgs       = np.stack(N(.5 + torch.zeros(28,28)).numpy() for N in noisemodel)
            noise_vars = [g.__dict__.get('sigma', 0) + g.__dict__.get('prob', 0) for g in noisemodel]

            P      = analysis.load_params([model_fname])
            if revert:
                P = [p[::-1].copy() for p in P]
            M, B   = analysis.compute_linear(P)
            P      = list(P) + [M, B]

            for fig in plot_full_analysis(ACC, imgs, noise_vars, P, titles[source], sym = "\sigma^2" if noise == "white" else "p"):
                pdf.savefig(fig, bbox_inches='tight')
                action()

            plot_generalization_curves(ACC, noise_vars)
            pdf.savefig(bbox_inches='tight')
            action()
            #plt.close()

            #break
        #break

0 4


  return np.add.reduce(sorted[indexer] * weights, axis=axis) / sumval


0 4
0 4
0 4
