In [1]:
%run ../startup.py

INFO:root:/cloud/syncthing/git/combinators appended to python path
INFO:root:%load_ext autoreload
INFO:root:%autoreload 2
INFO:root:from IPython.core.debugger import set_trace
INFO:root:from IPython.core.display import display, HTML
INFO:root:import torch
INFO:root:import numpy as np
INFO:root:import scipy as sp
INFO:root:import matplotlib
INFO:root:import matplotlib.pyplot as plt
INFO:root:%matplotlib inline
INFO:root:import seaborn as sns
INFO:root:import pandas as pd


## From test_annealing_experiment

In [2]:
import torch
import os
import math
from torch import nn, Tensor, optim
from torch.utils.tensorboard import SummaryWriter
from tqdm import trange
from typing import Tuple
from matplotlib import pyplot as plt
from pytest import mark, fixture

import combinators.trace.utils as trace_utils
from combinators.trace.utils import RequiresGrad
from combinators.tensor.utils import autodevice, kw_autodevice, copy, show
from combinators.densities import MultivariateNormal, Tempered, RingGMM, Normal
from combinators.densities.kernels import MultivariateNormalKernel, MultivariateNormalLinearKernel, NormalLinearKernel
from combinators.nnets import ResMLPJ
from combinators.objectives import nvo_rkl, nvo_avo, mb0, mb1, _estimate_mc, eval_nrep, _nvo_rkl
from combinators import Forward, Reverse, Propose, Condition, RequiresGrad, Resample
from combinators.stochastic import RandomVariable, ImproperRandomVariable
from combinators.metrics import effective_sample_size, log_Z_hat
from tests.utils import is_smoketest, seed
import combinators.debug as debug

import experiments.annealing.visualize as V
from experiments.annealing.models import mk_model, sample_along, paper_model


## Post-APG additions

In [3]:
from combinators.utils import load_models, save_models

## test_annealing_experiment code

In [4]:
def report(writer, ess, lzh, loss_scalar, i, eval_break, targets, forwards, saveable_models, comment):
    with torch.no_grad():
        # loss
        writer.add_scalar('loss', loss_scalar, i)

        # ESS
        for step, x in zip(range(1,len(ess)+1), ess):
            writer.add_scalar(f'ess/step-{step}', x, i)

        # logZhat
        for step, x in zip(range(1,len(lzh)+1), lzh):
            writer.add_scalar(f'log_Z_hat/step-{step}', x, i)

        # show samples
        if i % eval_break == 0:
            samples = sample_along(targets[0], forwards)
            fig = V.scatter_along(samples)
            writer.add_figure('overview', fig, global_step=i, close=True)

            # =================================================================================================================================
            # POST-APG additions
            # =================================================================================================================================
            save_models(saveable_models, filename=comment)
            # =================================================================================================================================


In [5]:
def experiment_runner(is_smoketest, trainer, resample, objective_tpl, device, budget, num_iterations, num_targets=6, lr=1e-3, name=""):
    debug.seed()
    eval_break=50
    num_iterations=3 if is_smoketest else num_iterations
    # Setup
    oname = objective_tpl[0]
    objective = objective_tpl[1]

    # Models
    out = paper_model(num_targets=num_targets, **kw_autodevice(device))
    num_samples = budget // len(out['targets'])
    sample_shape=(num_samples,)
    comment=f"icml-annealing-{name}_{oname}-{len(out['targets'])}-{'r' if resample else '_'}-d{device}-s{num_samples}-i{num_iterations}-lr{lr}"


    targets, forwards, reverses = [[m.to(autodevice()) for m in out[n]] for n in ['targets', 'forwards', 'reverses']]
    saveable_models = {f'{k}-{i}': m for k, ms in dict(targets=targets, forwards=forwards, reverses=reverses).items() for i, m in enumerate(ms)}
    try:
        saveable_models = load_models(saveable_models, comment)
    except:
        pass

    assert all([len(list(k.parameters())) >  0 for k in [*forwards, *reverses]])

    # logging
    writer = SummaryWriter(comment=comment)
    loss_ct, loss_sum, loss_avgs, loss_all = 0, 0.0, [], []

    optimizer = optim.Adam([dict(params=x.parameters()) for x in [*forwards, *reverses]], lr=lr)

    with trange(num_iterations) as bar:
        for i in bar:

            lvss, loss = trainer(i, targets, forwards, reverses, sample_shape, resample, objective, device)

            loss.backward()

            optimizer.step()
            optimizer.zero_grad()

            # REPORTING
            # ---------------------------------------
            with torch.no_grad():
                lvs = torch.stack(lvss, dim=0)
                lws = torch.cumsum(lvs, dim=0)
                ess = effective_sample_size(lws, sample_dims=-1)
                lzh = log_Z_hat(lws, sample_dims=-1)

                loss_scalar = loss.detach().cpu().mean().item()

                report(writer, ess, lzh, loss_scalar, i, eval_break, targets, forwards, saveable_models, comment=comment)

                loss_ct += 1
                loss_sum += loss_scalar
                # Update progress bar
                if i % 10 == 0:
                    loss_avg = loss_sum / loss_ct
                    loss_template = 'loss={: .4f}'.format(loss_avg)
                    logZh_template = 'logZhat[-1]={: .4f}'.format(lzh[-1].cpu().item())
                    ess_template = 'ess[-1]={: .4f}'.format(ess[-1].cpu().item())
                    loss_ct, loss_sum  = 0, 0.0
                    bar.set_postfix_str("; ".join([loss_template, ess_template, logZh_template]))

In [6]:
def nvi_eager_resample(i, targets, forwards, reverses, sample_shape, resample, objective, device):
    q0 = targets[0]
    p_prv_tr, _, _ = q0(sample_shape=sample_shape)

    loss = torch.zeros(1, **kw_autodevice(device))
    lw, lvss = torch.zeros(sample_shape, **kw_autodevice(device)), []

    for k, (fwd, rev, q, p) in enumerate(zip(forwards, reverses, targets[:-1], targets[1:])):
        q_ext = Forward(fwd, Condition(q, p_prv_tr, requires_grad=RequiresGrad.NO), _step=k)
        p_ext = Reverse(p, rev, _step=k)
        extend = Propose(target=p_ext, proposal=q_ext, _step=k)
        if resample:
            extend = Resample(extend)

        state = extend(sample_shape=sample_shape, sample_dims=0, _debug=True)
        lv = state.weights

        p_prv_tr = state.trace

        lw += lv
        if resample:
            ext_prp = state.program
            ext_tar = state.program.target
        else:
            ext_prp = state.proposal
            ext_tar = state.target

        loss += objective(lw, lv, ext_prp.trace[f'g{k}'], ext_tar.trace[f'g{k+1}'])

        lvss.append(lv)

    return lvss, loss

rkl_obj = ('rkl', nvo_rkl)
_rkl_obj = ('NEWrkl', _nvo_rkl)
avo_obj = ('avo', lambda lw, lv, g_k, g_kp1: nvo_avo(lv))
budget=288
num_iterations=20000



In [None]:
experiment_runner(
    trainer=nvi_eager_resample,
    device="cpu", 
    is_smoketest=False, 
    resample=False, 
    budget=budget, 
    num_iterations=num_iterations, 
    num_targets=4,
    objective_tpl=_rkl_obj,
    name='nvi', 
    lr=1e-3)

 29%|██▊       | 5709/20000 [06:40<17:53, 13.31it/s, loss=-28891.5802; ess[-1]= 14.2814; logZhat[-1]= 1.7870] 

In [None]:
experiment_runner(
    trainer=nvi_eager_resample,
    device="cpu", 
    is_smoketest=False, 
    resample=False, 
    budget=budget, 
    num_iterations=num_iterations, 
    num_targets=4,
    objective_tpl=avo_obj,
    name='avo', 
    lr=1e-3)

In [None]:
def forward_model(resample, objective_tpl, device, budget, lr=1e-3,num_iterations=10000, name=""):
    num_iterations= num_iterations
    # Setup
    oname = objective_tpl[0]
    objective = objective_tpl[1]

    # Models
    out = paper_model(num_targets=8, **kw_autodevice(device))
    num_samples =  budget // len(out['targets'])
    sample_shape=(num_samples,)
    comment=f"icml-annealing-{name}_{oname}-{'r' if resample else '_'}-d{device}-s{num_samples}-i{num_iterations}-lr{lr}"

    # Models
    out = paper_model(**kw_autodevice(device))

    targets, forwards, reverses = [[m.to(autodevice()) for m in out[n]] for n in ['targets', 'forwards', 'reverses']]
    saveable_models = {f'{k}-{i}': m for k, ms in dict(targets=targets, forwards=forwards, reverses=reverses).items() for i, m in enumerate(ms)}
    
    saveable_models = load_models(saveable_models, comment)

    return targets, forwards, reverses

In [None]:

budget=288
num_iterations=20000

In [None]:
!ls ./weights

In [None]:

budget=288
num_iterations=20000

targets, forwards, reverses = forward_model(device="cuda", resample=False, budget=budget, num_iterations=num_iterations, objective_tpl=rkl_obj, name='nvi', lr=1e-3)

def weights_along(proposal, kernels, sample_shape=(1000,)):
    samples = []
    tr, _, out = proposal(sample_shape=sample_shape)
    samples.append(out)
    for k in kernels:
        proposal = Forward(k, proposal)
        tr, _, out = proposal(sample_shape=sample_shape)
        samples.append((out, tr.log_joint(sample_dims=0, batch_dim=None if len(sample_shape) == 1 else 1)))
    return samples


def weights_last(proposal, kernels, sample_shape=(1000,)):
    for k in kernels:
        proposal = Forward(k, proposal)
    tr, _, out = proposal(sample_shape=sample_shape)
    return out, tr.log_joint(sample_dims=0, batch_dim=None if len(sample_shape) == 1 else 1)


In [None]:
len(forwards)

In [None]:
tot = 0
count = 100

for _ in range(count):
    samples = weights_along(targets[0], forwards, sample_shape=(1000,100))[1:]
    lvss = [l[1] for l in samples]
    lvs = torch.stack(lvss, dim=0)
    lws = torch.cumsum(lvs, dim=0)
    ess  = effective_sample_size(lws, sample_dims=-1)
    tot += ess[-1]

print(tot/count)

In [None]:
samples = weights_along(targets[0], forwards, sample_shape=(1000,100))


In [None]:
lvs = torch.stack([l[1] for l in samples[1:]], dim=0)
lws = torch.cumsum(lvs, dim=0)

effective_sample_size(lws, sample_dims=1).mean(dim=1)

In [None]:
lws = torch.cumsum(lvs, dim=0)
ess  = effective_sample_size(lws, sample_dims=-1)
tot += ess[-1]

In [None]:
x = samples[-1][:,0]
y = samples[-1][:,1]
samples[-1].shape
x, y = samples[-1].T
x.shape

In [None]:
def plot_sample_hist(ax, samples, sort=True, bins=20, range=None, weight_cm=False, **kwargs):
    ax.tick_params(bottom=False, top=False, left=False, right=False,
                   labelbottom=False, labeltop=False, labelleft=False, labelright=False)
    ax.grid(False)
    #x, y = [sample[:,i].detach().cpu().numpy() for i in [0,1]]
    x, y = samples.detach().cpu().numpy().T
    mz, x_e, y_e = np.histogram2d(x, y, bins=bins, density=True, range=range)
    X, Y = np.meshgrid(x_e, y_e)
    if weight_cm:
        raise NotImplemented()
    else:
        ax.imshow(mz, **kwargs)

import matplotlib.pyplot as plt

fig = plt.figure(figsize=(3,3))

plot_sample_hist(fig.gca(), samples[-1], bins=150, cmap='viridis')
# bins = 20
# sample= samples[-1]
# range=None
# x, y = [sample[:,i].detach().cpu().numpy() for i in [0,1]]
# mz, x_e, y_e = np.histogram2d(x, y, bins=bins, density=True, range=range)
# ax1.imshow(mz, **kwargs)

plt.show()

In [None]:

from matplotlib.image import NonUniformImage

import matplotlib.pyplot as plt

x, y = [sample[:,i].detach().cpu().numpy() for i in [0,1]]

H, xedges, yedges = np.histogram2d(x, y, bins=50, density=True)

H = H.T  # Let each row list bins with common y range.

fig, (ax1, ax2) = plt.subplots(ncols=1, nrows=1)

X, Y = np.meshgrid(xedges, yedges)

ax1.pcolormesh(X, Y, H)

# ax1 = fig.add_subplot(133, title='NonUniformImage: interpolated',
#         aspect='equal', xlim=xedges[[0, -1]], ylim=yedges[[0, -1]])

# im = NonUniformImage(ax, interpolation='bilinear')

# xcenters = (xedges[:-1] + xedges[1:]) / 2

# ycenters = (yedges[:-1] + yedges[1:]) / 2

# im.set_data(xcenters, ycenters, H)

# ax.images.append(im)

plt.show()