In [1]:
#!/usr/bin/env python3
import torch
import math
from torch import nn, Tensor
from torch.utils.tensorboard import SummaryWriter
from tqdm import trange
from typing import Tuple
from matplotlib import pyplot as plt

import combinators.trace.utils as trace_utils
from combinators.trace.utils import RequiresGrad
from combinators.tensor.utils import autodevice, kw_autodevice, copy, show
from combinators.densities import MultivariateNormal, Tempered, RingGMM
from combinators.densities.kernels import MultivariateNormalKernel, MultivariateNormalLinearKernel
from combinators.nnets import ResMLPJ
from combinators.objectives import nvo_rkl
from combinators import Forward, Reverse, Propose
from combinators.stochastic import RandomVariable, ImproperRandomVariable
from combinators.metrics import effective_sample_size, log_Z_hat
import visualize as V

def mk_kernel(from_:int, to_:int, std:float, num_hidden:int):
    embedding_dim = 2
    return MultivariateNormalKernel(
        ext_from=f'g{from_}',
        ext_to=f'g{to_}',
        loc=torch.zeros(2, **kw_autodevice()),
        cov=torch.eye(2, **kw_autodevice())*std**2,
        net=ResMLPJ(
            dim_in=2,
            dim_hidden=num_hidden,
            dim_out=embedding_dim).to(autodevice()))
    # return MultivariateNormalLinearKernel(
    #     ext_from=f'g{from_}',
    #     ext_to=f'g{to_}',
    #     loc=torch.zeros(2, **kw_autodevice()),
    #     cov=torch.eye(2, **kw_autodevice())*std**2)

def mk_model(num_targets:int):
    proposal_std = 16
    g0 = MultivariateNormal(name='g0', loc=torch.zeros(2, **kw_autodevice()), cov=torch.eye(2, **kw_autodevice())*proposal_std**2)
    gK = RingGMM(scale=8, count=8, name=f"g{num_targets - 1}").to(autodevice())

    # Make an annealing path
    betas = torch.arange(0., 1., 1./(num_targets - 1))[1:] # g_0 is beta=0
    path = [Tempered(f'g{k}', g0, gK, beta) for k, beta in zip(range(1,num_targets-1), betas)]
    path = [g0] + path + [gK]
    assert len(path) == num_targets # sanity check that the betas line up

    return dict(
        targets=path,
        forwards=[mk_kernel(from_=i, to_=i+1, std=1., num_hidden=64) for i in range(num_targets-1)],
        reverses=[mk_kernel(from_=i+1, to_=i, std=1., num_hidden=64) for i in range(num_targets-1)],
    )

In [2]:
import torch
import math
from torch import nn, Tensor
from torch.utils.tensorboard import SummaryWriter
from tqdm import trange
from typing import Tuple
from matplotlib import pyplot as plt

import combinators.trace.utils as trace_utils
from combinators.tensor.utils import autodevice, kw_autodevice
from combinators.densities import MultivariateNormal, Tempered, RingGMM
from combinators.densities.kernels import MultivariateNormalKernel
from combinators.nnets import ResMLPJ
from combinators.objectives import nvo_rkl
from combinators import Forward, Reverse, Propose
from combinators.stochastic import RandomVariable, ImproperRandomVariable
from combinators.metrics import effective_sample_size, log_Z_hat
import visualize as V

In [3]:
from main import mk_model, mk_kernel
from tqdm.notebook import trange, tqdm


In [4]:
from combinators import Forward
    
def sample_along(proposal, kernels, sample_shape=(2000,)):
    samples = []
    tr, out = proposal(sample_shape=sample_shape)
    samples.append(out)
    for k in forwards:
        proposal = Forward(k, proposal)
        tr, out = proposal(sample_shape=sample_shape)
        samples.append(out)
    return samples

In [5]:
# main() arguments
seed=1
num_iterations=5000
eval_break = 500

In [6]:
# Setup
torch.manual_seed(seed)
K = 8
num_samples = 256
sample_shape=(num_samples,)

# Models
out = mk_model(K)
targets, forwards, reverses = [[m.to(autodevice()) for m in out[n]] for n in ['targets', 'forwards', 'reverses']]

assert all([len(list(k.parameters())) >  0 for k in [*forwards, *reverses]])
optimizer = torch.optim.Adam([dict(params=x.parameters()) for x in [*forwards, *reverses]], lr=1e-4)

# logging
writer = SummaryWriter()
loss_ct, loss_sum, loss_avgs, loss_all = 0, 0.0, [], []

In [None]:
with trange(num_iterations) as bar:
    for i in bar:
        q0 = targets[0]
        p_prv_tr, out0 = q0(sample_shape=sample_shape)

        loss = torch.zeros(1, **kw_autodevice())
        lw, lvss = torch.zeros(sample_shape, **kw_autodevice()), []
        for k, (fwd, rev, q, p) in enumerate(zip(forwards, reverses, targets[:-1], targets[1:])):
            q.with_observations(trace_utils.copytrace(p_prv_tr, detach=p_prv_tr.keys()))
            q_ext = Forward(fwd, q, _step=k)
            p_ext = Reverse(p, rev, _step=k)
            extend = Propose(target=p_ext, proposal=q_ext, _step=k)
            state, lv = extend(sample_shape=sample_shape, sample_dims=0)

            p_prv_tr = state.target.trace
            p.clear_observations()
            q.clear_observations()

            lw += lv

            from combinators.objectives import mb0, mb1, _estimate_mc
            breakpoint()
            batch_dim=None
            sample_dims=0
            rv_proposal=state.proposal.trace[f'g{k}']
            rv_target=state.target.trace[f'g{k+1}']
            # TODO: move back from the proposal and target RVs to joint logprobs?
            reducedims = (sample_dims,)

            lw = lw.detach()
            ldZ = lv.detach().logsumexp(dim=sample_dims) - math.log(lv.shape[sample_dims])
            f = -lv

            # rv_proposal = next(iter(proposal_trace.values())) # tr[\gamma_{k-1}]
            # rv_target = next(iter(target_trace.values()))     # tr[\gamma_{k}]

            kwargs = dict(
                sample_dims=sample_dims,
                reducedims=reducedims,
                keepdims=False
            )

            baseline = _estimate_mc(f.detach(), lw, **kwargs).detach()

            kl_term = _estimate_mc(mb1(rv_proposal._log_prob) * (f - baseline), lw, **kwargs)

            grad_log_Z1 = _estimate_mc(rv_proposal._log_prob, lw, **kwargs)
            grad_log_Z2 = _estimate_mc(eval_nrep(rv_target)._log_prob, lw+lv.detach(), **kwargs)

            loss = kl_term + mb0(baseline * grad_log_Z1 - grad_log_Z2) + baseline + ldZ
            
            if k == (K-2):
                breakpoint()
                loss += nvo_rkl(lw, lv, state.proposal.trace[f'g{k}'], state.target.trace[f'g{k+1}'])
            lvss.append(lv)
        loss.backward()

        optimizer.step()
        optimizer.zero_grad()

        with torch.no_grad():
            # REPORTING
            # ---------------------------------------
            # ESS
            lvs = torch.stack(lvss, dim=0)
            lws = torch.cumsum(lvs, dim=1)
            ess = effective_sample_size(lws, sample_dims=-1)
            for step, x in zip(range(1,len(ess)+1), ess):
                writer.add_scalar(f'ess/step-{step}', x, i)

            # logZhat
            lzh = log_Z_hat(lws, sample_dims=-1)
            for step, x in zip(range(1,len(lzh)+1), lzh):
                writer.add_scalar(f'log_Z_hat/step-{step}', x, i)

            # loss
            loss_ct += 1
            loss_scalar = loss.detach().cpu().mean().item()
            writer.add_scalar('training/loss', loss_scalar, i)
            loss_sum += loss_scalar

            # progress bar
            loss_avg = loss_sum / loss_ct
            loss_template = 'loss={}{:.4f}'.format('' if loss_avg < 0 else ' ', loss_avg)
            logZh_template = 'logZhat[-1]={:.4f}'.format(lzh[-1].cpu().item())
            ess_template = 'ess[-1]={:.4f}'.format(ess[-1].cpu().item())
            loss_ct, loss_sum  = 0, 0.0
            bar.set_postfix_str("; ".join([loss_template, ess_template, logZh_template]))

            # show samples
            if i % (eval_break + 1) == 0:
                samples = sample_along(targets[0], forwards)
                fig = V.scatter_along(samples)
                writer.add_figure('overview', fig, global_step=i, close=True)
#                 for ix, xs in enumerate(samples):
#                     writer.add_figure(f'step-{ix+1}', V.scatter(xs), global_step=i, close=True)



  0%|          | 0/5000 [00:00<?, ?it/s]

> [0;32m<ipython-input-7-dcf3b518b661>[0m(23)[0;36m<module>[0;34m()[0m
[0;32m     22 [0;31m            [0mbreakpoint[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 23 [0;31m            [0mbatch_dim[0m[0;34m=[0m[0;32mNone[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     24 [0;31m            [0msample_dims[0m[0;34m=[0m[0;36m0[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  n


> [0;32m<ipython-input-7-dcf3b518b661>[0m(24)[0;36m<module>[0;34m()[0m
[0;32m     23 [0;31m            [0mbatch_dim[0m[0;34m=[0m[0;32mNone[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 24 [0;31m            [0msample_dims[0m[0;34m=[0m[0;36m0[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     25 [0;31m            [0mrv_proposal[0m[0;34m=[0m[0mstate[0m[0;34m.[0m[0mproposal[0m[0;34m.[0m[0mtrace[0m[0;34m[[0m[0;34mf'g{k}'[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  n


> [0;32m<ipython-input-7-dcf3b518b661>[0m(25)[0;36m<module>[0;34m()[0m
[0;32m     24 [0;31m            [0msample_dims[0m[0;34m=[0m[0;36m0[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 25 [0;31m            [0mrv_proposal[0m[0;34m=[0m[0mstate[0m[0;34m.[0m[0mproposal[0m[0;34m.[0m[0mtrace[0m[0;34m[[0m[0;34mf'g{k}'[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     26 [0;31m            [0mrv_target[0m[0;34m=[0m[0mstate[0m[0;34m.[0m[0mtarget[0m[0;34m.[0m[0mtrace[0m[0;34m[[0m[0;34mf'g{k+1}'[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  lvs.shape


*** NameError: name 'lvs' is not defined


ipdb>  lv.shape


torch.Size([256])


ipdb>  n


> [0;32m<ipython-input-7-dcf3b518b661>[0m(26)[0;36m<module>[0;34m()[0m
[0;32m     25 [0;31m            [0mrv_proposal[0m[0;34m=[0m[0mstate[0m[0;34m.[0m[0mproposal[0m[0;34m.[0m[0mtrace[0m[0;34m[[0m[0;34mf'g{k}'[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 26 [0;31m            [0mrv_target[0m[0;34m=[0m[0mstate[0m[0;34m.[0m[0mtarget[0m[0;34m.[0m[0mtrace[0m[0;34m[[0m[0;34mf'g{k+1}'[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     27 [0;31m            [0;31m# TODO: move back from the proposal and target RVs to joint logprobs?[0m[0;34m[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  n


> [0;32m<ipython-input-7-dcf3b518b661>[0m(28)[0;36m<module>[0;34m()[0m
[0;32m     27 [0;31m            [0;31m# TODO: move back from the proposal and target RVs to joint logprobs?[0m[0;34m[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 28 [0;31m            [0mreducedims[0m [0;34m=[0m [0;34m([0m[0msample_dims[0m[0;34m,[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     29 [0;31m[0;34m[0m[0m
[0m


ipdb>  n


> [0;32m<ipython-input-7-dcf3b518b661>[0m(30)[0;36m<module>[0;34m()[0m
[0;32m     29 [0;31m[0;34m[0m[0m
[0m[0;32m---> 30 [0;31m            [0mlw[0m [0;34m=[0m [0mlw[0m[0;34m.[0m[0mdetach[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     31 [0;31m            [0mldZ[0m [0;34m=[0m [0mlv[0m[0;34m.[0m[0mdetach[0m[0;34m([0m[0;34m)[0m[0;34m.[0m[0mlogsumexp[0m[0;34m([0m[0mdim[0m[0;34m=[0m[0msample_dims[0m[0;34m)[0m [0;34m-[0m [0mmath[0m[0;34m.[0m[0mlog[0m[0;34m([0m[0mlv[0m[0;34m.[0m[0mshape[0m[0;34m[[0m[0msample_dims[0m[0;34m][0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  n


> [0;32m<ipython-input-7-dcf3b518b661>[0m(31)[0;36m<module>[0;34m()[0m
[0;32m     30 [0;31m            [0mlw[0m [0;34m=[0m [0mlw[0m[0;34m.[0m[0mdetach[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 31 [0;31m            [0mldZ[0m [0;34m=[0m [0mlv[0m[0;34m.[0m[0mdetach[0m[0;34m([0m[0;34m)[0m[0;34m.[0m[0mlogsumexp[0m[0;34m([0m[0mdim[0m[0;34m=[0m[0msample_dims[0m[0;34m)[0m [0;34m-[0m [0mmath[0m[0;34m.[0m[0mlog[0m[0;34m([0m[0mlv[0m[0;34m.[0m[0mshape[0m[0;34m[[0m[0msample_dims[0m[0;34m][0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     32 [0;31m            [0mf[0m [0;34m=[0m [0;34m-[0m[0mlv[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  lw.shape


torch.Size([256])


ipdb>  n


> [0;32m<ipython-input-7-dcf3b518b661>[0m(32)[0;36m<module>[0;34m()[0m
[0;32m     31 [0;31m            [0mldZ[0m [0;34m=[0m [0mlv[0m[0;34m.[0m[0mdetach[0m[0;34m([0m[0;34m)[0m[0;34m.[0m[0mlogsumexp[0m[0;34m([0m[0mdim[0m[0;34m=[0m[0msample_dims[0m[0;34m)[0m [0;34m-[0m [0mmath[0m[0;34m.[0m[0mlog[0m[0;34m([0m[0mlv[0m[0;34m.[0m[0mshape[0m[0;34m[[0m[0msample_dims[0m[0;34m][0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 32 [0;31m            [0mf[0m [0;34m=[0m [0;34m-[0m[0mlv[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     33 [0;31m[0;34m[0m[0m
[0m


ipdb>  ldZ.shape


torch.Size([])


ipdb>  n


> [0;32m<ipython-input-7-dcf3b518b661>[0m(37)[0;36m<module>[0;34m()[0m
[0;32m     36 [0;31m[0;34m[0m[0m
[0m[0;32m---> 37 [0;31m            kwargs = dict(
[0m[0;32m     38 [0;31m                [0msample_dims[0m[0;34m=[0m[0msample_dims[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  n


> [0;32m<ipython-input-7-dcf3b518b661>[0m(38)[0;36m<module>[0;34m()[0m
[0;32m     37 [0;31m            kwargs = dict(
[0m[0;32m---> 38 [0;31m                [0msample_dims[0m[0;34m=[0m[0msample_dims[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     39 [0;31m                [0mreducedims[0m[0;34m=[0m[0mreducedims[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  n


> [0;32m<ipython-input-7-dcf3b518b661>[0m(39)[0;36m<module>[0;34m()[0m
[0;32m     38 [0;31m                [0msample_dims[0m[0;34m=[0m[0msample_dims[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 39 [0;31m                [0mreducedims[0m[0;34m=[0m[0mreducedims[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     40 [0;31m                [0mkeepdims[0m[0;34m=[0m[0;32mFalse[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  n


> [0;32m<ipython-input-7-dcf3b518b661>[0m(40)[0;36m<module>[0;34m()[0m
[0;32m     39 [0;31m                [0mreducedims[0m[0;34m=[0m[0mreducedims[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 40 [0;31m                [0mkeepdims[0m[0;34m=[0m[0;32mFalse[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     41 [0;31m            )
[0m


ipdb>  n


> [0;32m<ipython-input-7-dcf3b518b661>[0m(37)[0;36m<module>[0;34m()[0m
[0;32m     36 [0;31m[0;34m[0m[0m
[0m[0;32m---> 37 [0;31m            kwargs = dict(
[0m[0;32m     38 [0;31m                [0msample_dims[0m[0;34m=[0m[0msample_dims[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  n


> [0;32m<ipython-input-7-dcf3b518b661>[0m(43)[0;36m<module>[0;34m()[0m
[0;32m     42 [0;31m[0;34m[0m[0m
[0m[0;32m---> 43 [0;31m            [0mbaseline[0m [0;34m=[0m [0m_estimate_mc[0m[0;34m([0m[0mf[0m[0;34m.[0m[0mdetach[0m[0;34m([0m[0;34m)[0m[0;34m,[0m [0mlw[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m.[0m[0mdetach[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     44 [0;31m[0;34m[0m[0m
[0m


ipdb>  n


> [0;32m<ipython-input-7-dcf3b518b661>[0m(45)[0;36m<module>[0;34m()[0m
[0;32m     44 [0;31m[0;34m[0m[0m
[0m[0;32m---> 45 [0;31m            [0mkl_term[0m [0;34m=[0m [0m_estimate_mc[0m[0;34m([0m[0mmb1[0m[0;34m([0m[0mrv_proposal[0m[0;34m.[0m[0m_log_prob[0m[0;34m)[0m [0;34m*[0m [0;34m([0m[0mf[0m [0;34m-[0m [0mbaseline[0m[0;34m)[0m[0;34m,[0m [0mlw[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     46 [0;31m[0;34m[0m[0m
[0m


ipdb>  baseline


tensor(38.2808, device='cuda:0')


ipdb>  n


> [0;32m<ipython-input-7-dcf3b518b661>[0m(47)[0;36m<module>[0;34m()[0m
[0;32m     46 [0;31m[0;34m[0m[0m
[0m[0;32m---> 47 [0;31m            [0mgrad_log_Z1[0m [0;34m=[0m [0m_estimate_mc[0m[0;34m([0m[0mrv_proposal[0m[0;34m.[0m[0m_log_prob[0m[0;34m,[0m [0mlw[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     48 [0;31m            [0mgrad_log_Z2[0m [0;34m=[0m [0m_estimate_mc[0m[0;34m([0m[0meval_nrep[0m[0;34m([0m[0mrv_target[0m[0;34m)[0m[0;34m.[0m[0m_log_prob[0m[0;34m,[0m [0mlw[0m[0;34m+[0m[0mlv[0m[0;34m.[0m[0mdetach[0m[0;34m([0m[0;34m)[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  kl_term


tensor(-8.8289e-07, device='cuda:0', grad_fn=<SumBackward1>)


ipdb>  (mb1(rv_proposal._log_prob) * (f - baseline)).shape


torch.Size([256])


ipdb>  (mb1(rv_proposal._log_prob) ).shape


torch.Size([256])


ipdb>  rv_proposal._log_prob.shape


torch.Size([256])


ipdb>  rv_proposal._log_prob


tensor([-5.4635, -2.2040, -1.9204, -5.7113, -5.7258, -2.1578, -2.8195, -3.4620,
        -1.9583, -3.4753, -2.3907, -2.7854, -2.7172, -6.6422, -1.9373, -3.9268,
        -2.0027, -5.6784, -1.9400, -3.9187, -2.2655, -1.8483, -4.5269, -2.7946,
        -2.4407, -2.2646, -2.4601, -2.4522, -2.1176, -2.2776, -2.3705, -2.8373,
        -4.0239, -1.8732, -5.1909, -2.7361, -3.4399, -3.2864, -1.9102, -3.1194,
        -2.8840, -2.0699, -2.4503, -4.9765, -4.1048, -3.0285, -3.3769, -3.0123,
        -2.0621, -2.4099, -4.8356, -3.0698, -1.8830, -1.8808, -2.9520, -3.0570,
        -3.5481, -3.4461, -3.6942, -1.9640, -2.6313, -2.8533, -2.0251, -2.6844,
        -2.3050, -2.3685, -1.8867, -2.0564, -2.5897, -2.9680, -3.1532, -2.7890,
        -2.2845, -2.5548, -2.1770, -4.0307, -3.3111, -2.0785, -2.2710, -4.5252,
        -2.0763, -2.0993, -1.8598, -2.0463, -2.1192, -3.4935, -2.8658, -2.0514,
        -1.8390, -1.9937, -2.5651, -3.5768, -3.3661, -2.4932, -2.4793, -1.9957,
        -2.7700, -2.2445, -2.4369, -3.63