In [1]:
%load_ext autoreload
%autoreload 2

On Colab, you will need to clone and install [probabll/dists.pt](https://github.com/probabll/dists.pt.git)

In [2]:
import torch
torch.__version__

'1.8.1+cu102'

In [3]:
import numpy as np
import torch
import torch.distributions as td
import probabll.distributions as pd
import matplotlib.pyplot as plt
import torch.nn as nn

In [4]:
from collections import namedtuple, OrderedDict, defaultdict
from tqdm.auto import tqdm
from itertools import chain
from tabulate import tabulate

In [5]:
import sys
sys.path.append("../")

In [6]:
from components import GenerativeModel, InferenceModel, VAE
from data import load_mnist
from hparams import load_cfg, make_args
from main import make_state, get_batcher, validate

In [7]:
from analysis import probe_prior, compare_marginals, compare_samples

In [8]:
import pathlib

In [9]:
import random
random.seed(0)
np.random.seed(0)
torch.manual_seed(0)
rng = np.random.RandomState(0)

In [10]:
import pickle
knn_model = pickle.load(open('knnclassifier.pickle', 'rb'))

# Helper code

In [11]:
from analysis import collect_samples

# Load model and data

* Load hyperparameters
* Load model state
* Load MNIST data

In [12]:
train_loader, valid_loader, test_loader = load_mnist(
    batch_size=100, 
    save_to='../tmp', 
    height=28, 
    width=28
)

In [13]:
num_samples_test = 1000

In [14]:
ls ../neurips-mixed-rv/iclr/

[0m[01;34mcategorical[0m/  [01;34mgaussian[0m/           [01;34mmixed-maxent[0m/
[01;34mdirichlet[0m/    [01;34mgaussiansp-maxent[0m/  [01;34monehotcat[0m/


In [15]:
valid_results = defaultdict(list)
test_results = defaultdict(list)

In [16]:
dirs = [('mixed-maxent', d, True) for d in pathlib.Path('../neurips-mixed-rv/iclr/mixed-maxent/').iterdir() if d.is_dir()]
# dirs = [('gaussian', d, False) for d in pathlib.Path('../neurips-mixed-rv/iclr/gaussian/').iterdir() if d.is_dir()]
# dirs += [('dirichlet', d, False) for d in pathlib.Path('../neurips-mixed-rv/iclr/dirichlet/').iterdir() if d.is_dir()]
# dirs += [('mixed-maxent', d, False) for d in pathlib.Path('../neurips-mixed-rv/iclr/mixed-maxent/').iterdir() if d.is_dir()]
# dirs += [('onehotcat', d, False) for d in pathlib.Path('../neurips-mixed-rv/iclr/onehotcat').iterdir() if d.is_dir()]
# dirs += [('categorical', d, True) for d in pathlib.Path('../neurips-mixed-rv/iclr/categorical').iterdir() if d.is_dir()]

In [17]:
dirs

[('mixed-maxent',
  PosixPath('../neurips-mixed-rv/iclr/mixed-maxent/robust-wildflower-4'),
  True),
 ('mixed-maxent',
  PosixPath('../neurips-mixed-rv/iclr/mixed-maxent/smooth-armadillo-2'),
  True),
 ('mixed-maxent',
  PosixPath('../neurips-mixed-rv/iclr/mixed-maxent/leafy-sky-3'),
  True),
 ('mixed-maxent',
  PosixPath('../neurips-mixed-rv/iclr/mixed-maxent/electric-firefly-5'),
  True),
 ('mixed-maxent',
  PosixPath('../neurips-mixed-rv/iclr/mixed-maxent/lucky-grass-1'),
  True)]

In [None]:
for cls, directory, redo in tqdm(dirs):
    if not redo:
        continue
    args = make_args(
        load_cfg(
            f"{directory}/cfg.json", 
            # use this to specify a decide for analysis
            device='cuda:0',
            # use this to change paths if you need
            data_dir='../tmp',
            # you don't really need to change the output_dir
        )
    )
    experiment = directory.name
    print(f"Experiment: {cls}/{experiment}")

    state = make_state(
        args, 
        device=args.device, 
        ckpt_path=f"{directory}/ckpt.last"
    )
        
    random.seed(0)
    np.random.seed(0)
    torch.manual_seed(0)
    rng = np.random.RandomState(0) 
        
    print('Validating...')
    val_metrics = validate(
        state.vae, get_batcher(valid_loader, args), 
        num_samples=num_samples_test, 
        compute_DR=True,
        progressbar=True,
    )            
    
    print()
                
    r = [
        val_metrics[0].numpy(),  # NLL
        val_metrics[1].numpy(),  # BPD
        val_metrics[2]['ELBO'].mean(),  # ELBO
        val_metrics[2]['D'].mean(),  # D
        val_metrics[2]['R'].mean(),  # R
        val_metrics[2].get('R_F', np.zeros(1)).mean(),  # R
        val_metrics[2].get('R_Y|f', np.zeros(1)).mean(),  # R
        val_metrics[2].get('R_Y', np.zeros(1)).mean(),  # R
        val_metrics[2].get('R_Z', np.zeros(1)).mean(),  # R
    ]
    valid_results[cls].append(r)
    
    random.seed(0)
    np.random.seed(0)
    torch.manual_seed(0)
    rng = np.random.RandomState(0) 
    
    print('Testing...')
    test_metrics = validate(
        state.vae, get_batcher(test_loader, args), 
        num_samples=num_samples_test, 
        compute_DR=True,
        progressbar=True,
    )            
    
    print()
                
    r = [
        test_metrics[0].numpy(),  # NLL
        test_metrics[1].numpy(),  # BPD
        test_metrics[2]['ELBO'].mean(),  # ELBO
        test_metrics[2]['D'].mean(),  # D
        test_metrics[2]['R'].mean(),  # R
        test_metrics[2].get('R_F', np.zeros(1)).mean(),  # R
        test_metrics[2].get('R_Y|f', np.zeros(1)).mean(),  # R
        test_metrics[2].get('R_Y', np.zeros(1)).mean(),  # R
        test_metrics[2].get('R_Z', np.zeros(1)).mean(),  # R
    ]
    test_results[cls].append(r)
    

  0%|          | 0/5 [00:00<?, ?it/s]

Overriding device to user choice cuda:0
Overriding data_dir to user choice ../tmp
Experiment: mixed-maxent/robust-wildflower-4
Validating...


  0%|          | 0/50 [00:00<?, ?it/s]


Testing...


  0%|          | 0/100 [00:00<?, ?it/s]

In [None]:
from tabulate import tabulate

In [None]:
headers = ['NLL', 'BPD', 'ELBO', 'D', 'R', 'R_F', 'R_Y|f', 'R_Y', 'R_Z']
# print("Validation")
# print(tabulate(valid_results['gaussian'], headers=headers, floatfmt='.2f'))
# print(tabulate(valid_results['dirichlet'], headers=headers, floatfmt='.2f'))
# print(tabulate(valid_results['mixed-maxent'], headers=headers, floatfmt='.2f'))

In [None]:
# print('mixed-dir')
# print(tabulate(
#     [
#         ['mean'] + [x for x in np.mean(valid_results['mixed-maxent'], 0)],
#         ['std'] + [x for x in np.std(valid_results['mixed-maxent'], 0)],
#         ['min'] + [x for x in np.min(valid_results['mixed-maxent'], 0)],
#         ['max'] + [x for x in np.max(valid_results['mixed-maxent'], 0)]
#     ], 
#     headers=headers, floatfmt='.2f'))

In [None]:
idx = np.array([3, 4, 0], dtype=int)

print(tabulate(
    [
        ['gaussian', 'valid'] + [x for x in np.array(valid_results['gaussian'])[:,idx].mean(0)],        
        ['dirichlet', 'valid'] + [x for x in np.array(valid_results['dirichlet'])[:,idx].mean(0)],
        ['mixed-dir', 'valid'] + [x for x in np.array(valid_results['mixed-maxent'])[:,idx].mean(0)],
        ['onehotcat', 'valid'] + [x for x in np.array(valid_results['onehotcat'])[:,idx].mean(0)],
        ['categorical', 'valid'] + [x for x in np.array(valid_results['categorical'])[:,idx].mean(0)],

        ['gaussian', 'test'] + [x for x in np.array(test_results['gaussian'])[:,idx].mean(0)],
        ['dirichlet', 'test'] + [x for x in np.array(test_results['dirichlet'])[:,idx].mean(0)],
        ['mixed-dir', 'test'] + [x for x in np.array(test_results['mixed-maxent'])[:,idx].mean(0)],
        ['onehotcat', 'test'] + [x for x in np.array(test_results['onehotcat'])[:,idx].mean(0)],
        ['categorical', 'test'] + [x for x in np.array(test_results['categorical'])[:,idx].mean(0)],
    ], 
    headers=['Model', 'Dataset', 'D', 'R', 'NLL'], floatfmt='.2f'))

```
Model      Dataset        D      R     NLL
---------  ---------  -----  -----  ------
gaussian   valid      77.05  19.93   91.68
dirichlet  valid      79.15  20.13   94.48
mixed-dir  valid      90.97  19.16  107.12
gaussian   test       76.67  19.94   91.12
dirichlet  test       78.62  19.94   93.81
mixed-dir  test       90.34  19.39  106.59
```

In [None]:
import pickle
knn_model = pickle.load(open('knnclassifier.pickle', 'rb'))

In [None]:
x_gen_prior = np.concatenate(prior['x'])
x_gen_prior.shape

In [None]:
knn_pred = knn_model.predict(x_gen_prior)

In [None]:
clustered = defaultdict(list)
for x_, y_ in zip(x_gen_prior, knn_pred):
    clustered[y_].append(x_)
clustered = {cls: np.stack(digits) for cls, digits in clustered.items()}

In [None]:
p_emp = np.array([len(clustered.get(c, []))/x_gen_prior.shape[0] for c in range(10)])

In [None]:
_ = plt.plot(np.arange(10) + 1, p_emp, 'o')

In [None]:
# KL(p_emp||U) and KL(U||p_emp)
kl1 = (p_emp * (np.log(p_emp) - np.log(0.1))).sum(0)
kl2 = (0.1 * (np.log(0.1) - np.log(p_emp))).sum(0)
print(tabulate([[kl1, kl2, (kl1+kl2)/2]], headers=['KL from uniform', 'KL from empirical', 'JS']))

In [None]:
fig, axs = plt.subplots(
    2, 5, 
    sharex=True, sharey=True,
    gridspec_kw={'hspace': 0, 'wspace': 0})
for c in range(10):
    axs[c // 5, c % 5].imshow(clustered.get(c, np.zeros((1, 28 * 28))).mean(0).reshape(args.height, args.width), cmap='Greys')
    axs[c // 5, c % 5].set_xlabel(f"({c}) {p_emp[c] * 100:.2f}") #set_title(f"X'|X={c}")
for ax in axs.flat:
    ax.set_xticks([])
    ax.set_yticks([])
# Hide x labels and tick labels for top plots and y ticks for right plots.
#for ax in axs.flat:
#    ax.label_outer()
_ = fig.suptitle(r'Samples from prior clustered by KNN classifier with F1=95%')

In [None]:
fig, axs = plt.subplots(
    10, 10, 
    figsize=(15, 10),
    sharex=True, sharey=True,
    gridspec_kw={'hspace': 0, 'wspace': 0},
)
for c in range(10):
    axs[c, 0].imshow(clustered.get(c, np.zeros((1, 28 * 28))).mean(0).reshape(args.height, args.width), cmap='Greys')    
    for i, k in enumerate(np.random.choice(len(clustered.get(c, [])), size=9)):
        axs[c, i + 1].imshow(clustered[c][k].reshape(args.height, args.width), cmap='Greys')

for ax in axs.flat:
    ax.set_xticks([])
    ax.set_yticks([])
_ = fig.suptitle(r'Samples from prior clustered by KNN classifier with F1=95%')

In [None]:
K = density_estimation(
    state.vae, 
    get_batcher(test_loader, args),  # we MC estimate KL(q(y)||p(y)) by sampling x~dev, y|x ~ infnet
    get_batcher(valid_loader, args),  # we estimate log q(y) via mean(log q(y|x) for x in training)
    batch_size_y=5000, 
    progressbar=True
)

In [None]:
_ = plt.hist(K, density=True, alpha=0.5, bins=100, label='hopeful-puddle-2')
_ = plt.legend()

In [None]:
test_metrics = validate(state.vae, get_batcher(test_loader, args), num_samples, compute_DR=True)
print(f'Saved ckpt - Test: nll={test_metrics[0]:.2f} bpd={test_metrics[1]:.2f}')

In [None]:
rows = [('IS-NLL', test_metrics[0], None), ('IS-BPD', test_metrics[1], None)]
for k, v in test_metrics[2].items():
    rows.append((k, v.mean(), v.std()))
print(tabulate(rows, headers=['metric', 'mean', 'std']))    

# Training Curves

In [None]:
np_stats_tr = {k: np.array(v) for k, v in state.stats_tr.items()}
np_stats_val = {k: np.array(v) for k, v in state.stats_val.items()}

In [None]:
def smooth(v, kernel_size=100):
    if kernel_size is None:
        return v
    return np.convolve(v, np.ones(kernel_size)/kernel_size, 'valid')

In [None]:
for k, v in np_stats_tr.items():
    v = smooth(v)
    plt.plot(np.arange(1, v.size + 1), v, '.')    
    plt.ylabel(f'Training {k}')
    plt.xlabel('iteration')
    plt.show()

# Validation Curves

In [None]:
print(tabulate(
    [(k, np.mean(v[-100:]), np.min(v[-100:])) for k, v in np_stats_val.items()],
    headers=['metric', 'mean', 'min']
))

In [None]:
for k, v in np_stats_val.items():
    plt.plot(np.arange(1, v.size + 1), v, 'o')
    plt.ylabel(f'Validation {k}')
    plt.xlabel('epoch')
    plt.show()
for k, v in np_stats_val.items():
    v = v[v.size//2:]
    plt.plot(np.arange(1 + v.size//2, v.size + 1 + v.size//2), v, 'o')
    plt.ylabel(f'Validation {k}')
    plt.xlabel('epoch')
    plt.show()    

In [None]:
val_nll, val_bpd, val_DR = validate(
    state.vae, get_batcher(valid_loader, args), num_samples, compute_DR=True, progressbar=True)

In [None]:
rows = [('IS-NLL', val_nll, None), ('IS-BPD', val_bpd, None)]
for k, v in val_DR.items():
    rows.append((k, v.mean(), v.std()))
print(tabulate(rows, headers=['metric', 'mean', 'std']))    

In [None]:
_ = plt.hist(val_DR['D'], bins='auto')
_ = plt.xlabel('D')
plt.show()

_ = plt.hist(val_DR['R'], bins='auto')
_ = plt.xlabel('R')
plt.show()

if state.vae.p.z_dim:
    _ = plt.hist(val_DR['R_Z'], bins='auto')
    _ = plt.xlabel('R(Z)')
    plt.show()
    
if state.vae.p.y_dim:
    _ = plt.hist(val_DR['R_F'], bins='auto')
    _ = plt.xlabel('R(F)')
    plt.show()

    _ = plt.hist(val_DR['R_Y|f'], bins='auto')
    _ = plt.xlabel('R(Y|F)')
    plt.show()

# Analysis

In [None]:
import torch_two_sample as t2s
from analysis import collect_samples

In [None]:
prior, posterior = collect_samples(state.vae, get_batcher(valid_loader, args), args, num_samples=num_samples)

In [None]:
for rv, dim in [('f', state.p.y_dim), ('y', state.p.y_dim), ('z', state.p.z_dim)]:
    if dim == 0:
        continue
    p_f = np.stack(prior[rv]).reshape(-1, dim)
    q_f = np.stack(posterior[rv]).reshape(-1, dim)
    mmd = t2s.statistics_diff.MMDStatistic(p_f.shape[0], q_f.shape[0])
    v = np.array([mmd(torch.tensor(p_f), torch.tensor(q_f), [alpha]).cpu().numpy() for alpha in np.random.gamma(10., 1./10, size=20)])

    _ = plt.hist(v, color='blue', alpha=0.3, label='')
    _ = plt.axvline(x=v.mean(), c='blue')
    #_ = plt.hist(v2, color='red', alpha=0.3, label='volcanic-firefly-4')
    #_ = plt.axvline(x=v2.mean(), c='red')
    _ = plt.xlabel(f"MMD {rv}")
    plt.show()

In [None]:
#v2 = np.array([mmd(torch.tensor(p_f), torch.tensor(q_f), [alpha]).cpu().numpy() for alpha in np.random.gamma(10., 1./10, size=100)])

## KL

In [None]:
compare_marginals(state.vae, get_batcher(valid_loader, args), args, cols=5, num_samples=num_samples)

## Posterior and Prior Samples

In [None]:
compare_samples(state.vae, get_batcher(valid_loader, args), args, N=5, num_figs=2, num_samples=1000)

## TSNE

In [None]:
from analysis import samples_per_digit

In [None]:
f, y, z, x, marginal_f, scores, concs = samples_per_digit(
    state.vae, get_batcher(valid_loader, args, onehot=False), 
    args, return_marginal=args.y_dim > 0)

In [None]:
f.shape, f.mean(1).shape, z.shape, z.mean(1).shape, x.shape, x.mean(1).shape

In [None]:
marginal_f.shape, marginal_f.mean(1).shape, scores.shape, scores.mean(1).shape, concs.shape, concs.mean(1).shape

In [None]:
if state.vae.p.y_dim:
    _ = plt.imshow(f.mean(1))
    _ = plt.ylabel('Class')
    _ = plt.xlabel(r'$k$')
    _ = plt.title(r'$E[F_k = 1|X]$')
    _ = plt.colorbar()
    plt.show()
    
#     _ = plt.imshow(f.sum(-1).mean(1, keepdims=True))
#     _ = plt.ylabel('Class')
#     #_ = plt.xlabel(r'$\max_k$')
#     _ = plt.xticks([], [])
#     _ = plt.title(r'mean argmax')
#     _ = plt.colorbar()
#     plt.show()
    
    for k in range(10):
        _ = plt.hist(f[k].sum(-1), label=f'{k}')
    _ = plt.legend()
    plt.show()
    
    _ = plt.imshow(scores.mean(1))
    _ = plt.ylabel('Class')
    _ = plt.xlabel(r'$k$')
    _ = plt.title(r'$E[\omega_k|X]$')
    _ = plt.colorbar()
    plt.show()
    
    _ = plt.imshow(concs.mean(1))
    _ = plt.ylabel('Class')
    _ = plt.xlabel(r'$k$')
    _ = plt.title(r'$E[\alpha_k|X]$')
    _ = plt.colorbar()
    plt.show()

In [None]:
from itertools import product

if state.vae.p.y_dim:
    mc_marginal_f = f.mean(1)
    Fs = [td.Independent(td.Bernoulli(probs=(torch.tensor(mc_marginal_f[c])*0.99 + 1e-4)), 1) for c in range(10)]
    JS_F = np.array([[(0.5*td.kl_divergence(Fs[c], Fs[c_])+0.5*td.kl_divergence(Fs[c_], Fs[c])).numpy() for c_ in range(10)] for c in range(10)])
    #KL_F = np.array([[td.kl_divergence(Fs[c], Fs[c_]).numpy() for c_ in range(10)] for c in range(10)])
    _ = plt.imshow(JS_F)
    _ = plt.ylabel('Class')
    _ = plt.yticks(np.arange(10), np.arange(10))
    _ = plt.xlabel('Class')
    _ = plt.xticks(np.arange(10), np.arange(10))
    _ = plt.title(r'KL')
    _ = plt.colorbar()
    plt.show()
    
    mc_marginal_f = marginal_f.mean(1)
    Fs = [td.Independent(td.Bernoulli(probs=(torch.tensor(mc_marginal_f[c])*0.99 + 1e-4)), 1) for c in range(10)]
    JS_F = np.array([[(0.5*td.kl_divergence(Fs[c], Fs[c_])+0.5*td.kl_divergence(Fs[c_], Fs[c])).numpy() for c_ in range(10)] for c in range(10)])
    #KL_F = np.array([[td.kl_divergence(Fs[c], Fs[c_]).numpy() for c_ in range(10)] for c in range(10)])
    _ = plt.imshow(JS_F)
    _ = plt.ylabel('Class')
    _ = plt.yticks(np.arange(10), np.arange(10))
    _ = plt.xlabel('Class')
    _ = plt.xticks(np.arange(10), np.arange(10))
    _ = plt.title(r'KL')
    _ = plt.colorbar()

In [None]:
from sklearn.manifold import TSNE

In [None]:
def tsne_plot(samples, title, legend=True, filename=None):
    """
    :param samples: [10, N,D]
    """
    assert samples.shape[0] == 10, "I need 10 digits"
    D = samples.shape[-1]
    assert D > 0, "0-dimensional features?"
    tsne_results = TSNE(n_components=2, random_state=1).fit_transform(samples.reshape(-1, D)).reshape(10, -1, 2)
    plt.figure(figsize=(6, 5))
    colors = 'r', 'g', 'b', 'c', 'm', 'y', 'k', 'gray', 'orange', 'purple'
    for i, c in zip(np.arange(10), colors):
        plt.scatter(tsne_results[i, :, 0], tsne_results[i, :, 1], c=c, label=i)
    plt.xticks([], [])    
    plt.yticks([], [])
    if legend:
        #plt.legend(bbox_to_anchor=(1, 0.85), loc='upper left', framealpha=0.5)
        plt.legend(loc='upper right', framealpha=1.0)
    #plt.title(title)
    
    if filename:
        plt.savefig(f'{filename}.pdf', bbox_inches='tight') 
    plt.show()

    return tsne_results

In [None]:
if state.vae.p.y_dim:
    _ = tsne_plot(f, r"$f \sim Q_{F|X=x_{obs}}$", legend=False, filename='tsne_f')

In [None]:
if state.vae.p.y_dim:
    _ = tsne_plot(marginal_f, r"$ \Pr(e_k \in F |X_{obs})$", legend=False, filename='tsne_mu')    

In [None]:
if state.vae.p.y_dim:
    _ = tsne_plot(scores, r"$ w_k \phi_k(f) |X=x_{obs}$", legend=False, filename='tsne_scores')    

In [None]:
if state.vae.p.y_dim:
    #_ = tsne_plot(f, r"$F|X_{obs}$")
    #_ = tsne_plot(marginal_f, r"$\Pr(k|X_{obs})$")
    _ = tsne_plot(y, r"$y \sim Q_{Y|X=x_{obs}}$", legend=True, filename='tsne_y')

In [None]:
if state.vae.p.z_dim:
    _ = tsne_plot(z, r"$Z|X_{obs}$")

In [None]:
if state.vae.p.z_dim:
    _ = tsne_plot((z > 0) * 1.0, r"$Z > 0 | X_{obs}$")

In [None]:
#if state.vae.p.z_dim:
#    _ = tsne_plot((z > 0.01) * 1.0, r"$Z > 0.01 | X_{obs}$")

In [None]:
if state.vae.p.z_dim:
    _ = plt.hist((z > 0.).sum(-1).flatten(), label="0")
    #_ = plt.hist((z > 0.01).sum(-1).flatten(), label="0.01")
    #_ = plt.hist((z > 0.1).sum(-1).flatten(), label="0.1")
    #_ = plt.hist((z > 0.5).sum(-1).flatten(), label="0.5")
    _ = plt.legend()

## Marginal samples per class

In [None]:
marginal_x = x.mean(1)
fig, axs = plt.subplots(
    2, 5, 
    sharex=True, sharey=True,
    gridspec_kw={'hspace': 0, 'wspace': 0})
for c in range(10):
    axs[c // 5, c % 5].imshow(marginal_x[c].reshape(args.height, args.width), cmap='Greys')
    #axs[c // 5, c % 5].set_title(f"X'|X={c}")
for ax in axs.flat:
    ax.set_xticks([])
    ax.set_yticks([])
# Hide x labels and tick labels for top plots and y ticks for right plots.
for ax in axs.flat:
    ax.label_outer()
_ = fig.suptitle(r'$E[X|X_{obs}, \lambda, \theta]$')    

In [None]:
from analysis import probe_corners

In [None]:
corner_samples = probe_corners(state.vae, get_batcher(valid_loader, args), args)

In [None]:
corner_samples[0]['x'].shape

In [None]:
fig, axs = plt.subplots(
    2, 5, 
    sharex=True, sharey=True,
    gridspec_kw={'hspace': 0, 'wspace': 0})
for c in range(10):
    axs[c // 5, c % 5].imshow(corner_samples[c]['x'].mean(0).reshape(args.height, args.width), cmap='Greys')
    #axs[c // 5, c % 5].set_xlabel(c) #set_title(f"X'|X={c}")
for ax in axs.flat:
    ax.set_xticks([])
    ax.set_yticks([])
# Hide x labels and tick labels for top plots and y ticks for right plots.
for ax in axs.flat:
    ax.label_outer()
_ = fig.suptitle(r'Samples from vertices')    

In [None]:
#np.save(open('to-vlad2-f.npy', 'wb'), f)

In [None]:
#np.save(open('to-vlad2-marginal.npy', 'wb'), marginal_f)

In [None]:
#np.save(open('to-vlad2-y.npy', 'wb'), y)

In [None]:
#np.save(open('to-vlad2-scores.npy', 'wb'), scores)

In [None]:
#frs = t2s.statistics_nondiff.FRStatistic(p_f.shape[0], q_f.shape[0])

In [None]:
#frs(torch.tensor(p_f), torch.tensor(q_f), norm=1)