In [1]:
import sys

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.autograd as ag
from torch.distributions.categorical import Categorical


In [2]:
class BasicParameters(nn.Module):
    def __init__(self,
                 data,
                 left_flank=None,
                 right_flank=None,
                 batch_dim=0,
                 cat_axis=-1
                ):
        
        super().__init__()
        
        self.register_parameter('theta', data)
        self.register_buffer('left_flank', left_flank)
        self.register_buffer('right_flank', right_flank)
        
        self.cat_axis = cat_axis
        self.batch_dim = batch_dim
        
    @property
    def shape(self):
        return self().shape

    def forward(self):
        my_attr = [ getattr(self, x) for x in ['theta', 'left_flank', 'right_flank'] ]
        return torch.cat( [ x for x in my_attr if x is not None ], axis=self.cat_axis )
    
    def rebatch(self, input):
        return input


In [3]:
class NUTS3(nn.Module):
    def __init__(self,
                 params,
                 energy_fn,
                ):
        
        super().__init__()
        self.params = params
        self.energy_fn  = energy_fn
        
        self.d_max = 1000.
        
    def calc_energy(self):
        energy = self.energy_fn(self.params())
        energy = self.params.rebatch( energy )
        return energy

    def leapfrog(self, theta, r, epsilon):
        
        self.params.theta.data = theta
        energy = self.calc_energy()
        grad_U = ag.grad( energy.sum(), self.params.theta )[0]
        
        with torch.no_grad():
            r = r - grad_U.mul(epsilon).div(2.)
            
            theta = theta + r.mul(epsilon)
            
        self.params.theta.data = theta
        energy = self.calc_energy()
        grad_U = ag.grad( energy.sum(), self.params.theta )[0]
        
        with torch.no_grad():
            r = r - grad_U.mul(epsilon).div(2.)
            
        return theta, r, energy
        
    def buildtree(self, theta, r, u, v, j, epsilon):
        if j == 0:
            theta_p, r_p, energy_p = self.leapfrog(theta, r, epsilon)
            batch_dot = torch.einsum('bs,bs->b', r_p.flatten(1), r_p.flatten(1))
            hamilton  = energy_p + batch_dot.div(2.)
            n_p = (u <= torch.exp(-hamilton)).type(torch.long)
            s_p = (torch.log(u).add(-self.d_max) < -hamilton).type(torch.long)
            return theta_p, r_p, theta_p, r_p, theta_p, n_p, s_p
        
        else:
            bt_pack = self.buildtree(theta, r, u, v, j-1, epsilon)
            theta_r, r_r, theta_f, r_f, theta_p, n_p, s_p = bt_pack
            if s_p.sum() > 0:
                if v == -1:
                    bt_pack = self.buildtree(theta_r, r_r, u, v, j-1, epsilon)
                    theta_r, r_r, _, _, theta_pp, n_pp, s_pp = bt_pack
                    
                else:
                    bt_pack = self.buildtree(theta_f, r_f, u, v, j-1, epsilon)
                    _, _, theta_f, r_f, theta_pp, n_pp, s_pp = bt_pack
                
                update_flag = torch.rand(n_pp.size(), dtype=torch.float, 
                                         layout=n_pp.layout, device=n_pp.device)
                update_flag = update_flag < n_pp.div( n_p + n_pp )
                update_flag = torch.logical_and(update_flag, s_p.ge(1) )
                theta_p[ update_flag ] = theta_pp[ update_flag ]
                s_p = s_p * s_pp * \
                      torch.einsum('bs,bs->b', (theta_f - theta_r).flatten(1), r_r.flatten(1)) \
                        .ge(0.).type(torch.long) * \
                      torch.einsum('bs,bs->b', (theta_f - theta_r).flatten(1), r_f.flatten(1)) \
                        .ge(0.).type(torch.long)
                n_p = n_p + n_pp
                
            return theta_r, r_r, theta_f, r_f, theta_p, n_p, s_p
        
    def init_trajectory(self, theta):
        with torch.no_grad():
            r_0 = torch.randn_like( theta )
            energy_0 = self.calc_energy()
            batch_dot= torch.einsum('bs,bs->b', r_0.flatten(1), r_0.flatten(1))
            hamilton = energy_0 + batch_dot.div(2.)
            u   = torch.rand_like( hamilton ).mul( torch.exp(-hamilton) )
            
            theta_r, theta_f = theta, theta
            r_r, r_f = r_0, r_0
            j = 0
            theta_m = theta
            n = torch.ones(batch_dot.size(), dtype=torch.long, layout=batch_dot.layout, device=batch_dot.device)
            s = torch.ones(batch_dot.size(), dtype=torch.long, layout=batch_dot.layout, device=batch_dot.device)
        return u, theta_r, r_r, theta_f, r_f, j, theta_m, n, s
    
    def sample_trajectory(self, theta, epsilon):
        u, theta_r, r_r, theta_f, r_f, j, theta_m, n, s = self.init_trajectory(theta)
        while s.sum() >= 1:
            v = torch.randn([1], dtype=torch.float, layout=theta.layout, device=theta.device) \
                  .ge(0.).mul(2.).add(-1.)
            if v < 0:
                theta_r, r_r, _, _, theta_p, n_p, s_p = self.buildtree(theta_r, r_r, u, v, j, epsilon)
            else:
                _, _, theta_f, r_f, theta_p, n_p, s_p = self.buildtree(theta_f, r_f, u, v, j, epsilon)
            
            update_flag = torch.minimum( n / n_p, torch.ones_like(n.type(torch.float)) ) <= torch.rand_like(n.type(torch.float))
            update_flag = torch.logical_and( update_flag, s.ge(1) )
            update_flag = torch.logical_and( update_flag, s_p.ge(1) )
            theta_m[ update_flag ] = theta_p[ update_flag ]
            
            n = n + n_p
            s = s * s_p * \
                torch.einsum('bs,bs->b', (theta_f - theta_r).flatten(1), r_r.flatten(1)) \
                  .ge(0.).type(torch.long) * \
                torch.einsum('bs,bs->b', (theta_f - theta_r).flatten(1), r_f.flatten(1)) \
                  .ge(0.).type(torch.long)
            j = j + 1
        
        return theta_m.detach().clone()
    
    def collect_samples(self, epsilon, n_samples=1):
        samples = []
        theta_m = self.params.theta.clone().detach()
        for m in range(n_samples):
            theta_m = self.sample_trajectory( theta_m, epsilon )
            samples.append( theta_m )
        return samples


In [4]:
from parameters import BasicParameters
from nuts import NUTS3

In [5]:
my_params = BasicParameters(
    nn.Parameter(torch.randn([3,2,1])),
    left_flank=torch.randn([3,2,2]),
    right_flank=torch.randn([3,2,2]) 
)

In [6]:
my_params.forward()

tensor([[[ 8.6594e-01,  8.0312e-01, -3.0701e-01, -3.4206e-01,  6.3109e-01],
         [ 8.7925e-01, -9.4702e-01, -1.4804e+00, -1.6634e-03,  6.1142e-01]],

        [[ 2.1903e-01,  6.4296e-01,  3.7056e-02,  1.3612e+00,  1.1617e-02],
         [ 5.1287e-01,  9.8834e-01, -6.7040e-01, -6.6214e-01,  1.4057e-01]],

        [[-1.2719e+00,  5.7301e-01, -1.0352e+00, -9.8352e-01,  9.9522e-01],
         [ 5.4203e-01, -1.7494e+00, -6.2569e-01,  1.1204e+00,  6.5439e-01]]],
       grad_fn=<CatBackward>)

In [7]:
def my_energy(in_tensor):
    return in_tensor.pow(2).mean(dim=(1,2))

In [8]:
my_energy( my_params.forward() )

tensor([0.6240, 0.4463, 1.0405], grad_fn=<MeanBackward1>)

In [9]:
my_sampler = NUTS3( my_params, my_energy )

In [10]:
test_theta = torch.randn([3,2,1])
test_r     = torch.randn([3,2,1])
print(test_theta)
print(test_r)
print(my_sampler.params.theta)

tensor([[[-1.2053],
         [ 1.4418]],

        [[ 1.6426],
         [-0.1297]],

        [[-0.5519],
         [-0.9181]]])
tensor([[[-0.5853],
         [-0.2071]],

        [[ 0.7339],
         [-1.5180]],

        [[-0.5162],
         [ 1.1002]]])
Parameter containing:
tensor([[[ 0.8659],
         [ 0.8793]],

        [[ 0.2190],
         [ 0.5129]],

        [[-1.2719],
         [ 0.5420]]], requires_grad=True)


In [11]:
result = my_sampler.leapfrog(test_theta, test_r, 1e-3)

In [12]:
print(result)
print(test_theta)
print(test_r)
print(my_sampler.params.theta)

(tensor([[[-1.2059],
         [ 1.4416]],

        [[ 1.6433],
         [-0.1312]],

        [[-0.5524],
         [-0.9170]]]), tensor([[[-0.5851],
         [-0.2074]],

        [[ 0.7336],
         [-1.5180]],

        [[-0.5161],
         [ 1.1003]]]), tensor([0.8249, 0.6870, 0.9639], grad_fn=<MeanBackward1>))
tensor([[[-1.2053],
         [ 1.4418]],

        [[ 1.6426],
         [-0.1297]],

        [[-0.5519],
         [-0.9181]]])
tensor([[[-0.5853],
         [-0.2071]],

        [[ 0.7339],
         [-1.5180]],

        [[-0.5162],
         [ 1.1002]]])
Parameter containing:
tensor([[[-1.2059],
         [ 1.4416]],

        [[ 1.6433],
         [-0.1312]],

        [[-0.5524],
         [-0.9170]]], requires_grad=True)


In [13]:
init_vals = my_sampler.init_trajectory(test_theta)
check_u, check_theta_r, check_r_r, check_theta_f, check_r_f, check_j, check_theta_m, check_n, check_s = init_vals
init_vals

(tensor([0.0421, 0.1706, 0.1485]),
 tensor([[[-1.2053],
          [ 1.4418]],
 
         [[ 1.6426],
          [-0.1297]],
 
         [[-0.5519],
          [-0.9181]]]),
 tensor([[[-0.2866],
          [ 0.1812]],
 
         [[-0.0254],
          [ 0.6422]],
 
         [[ 1.1266],
          [-0.1823]]]),
 tensor([[[-1.2053],
          [ 1.4418]],
 
         [[ 1.6426],
          [-0.1297]],
 
         [[-0.5519],
          [-0.9181]]]),
 tensor([[[-0.2866],
          [ 0.1812]],
 
         [[-0.0254],
          [ 0.6422]],
 
         [[ 1.1266],
          [-0.1823]]]),
 0,
 tensor([[[-1.2053],
          [ 1.4418]],
 
         [[ 1.6426],
          [-0.1297]],
 
         [[-0.5519],
          [-0.9181]]]),
 tensor([1, 1, 1]),
 tensor([1, 1, 1]))

In [14]:
v = torch.randn([1], dtype=torch.float, layout=test_theta.layout, device=test_theta.device) \
                  .ge(0.).mul(2.).add(-1.)
v

tensor([-1.])

In [15]:
(check_theta_f - check_theta_r).shape

torch.Size([3, 2, 1])

In [16]:
check_r_r.shape

torch.Size([3, 2, 1])

In [17]:
torch.einsum('bs,bs->b', (check_theta_f - check_theta_r).flatten(1), check_r_r.flatten(1))

tensor([0., 0., 0.])

In [18]:
print(test_theta)

tensor([[[-1.2053],
         [ 1.4418]],

        [[ 1.6426],
         [-0.1297]],

        [[-0.5519],
         [-0.9181]]])


In [19]:
my_sampler.buildtree(test_theta, test_r, init_vals[0], v, 3, 1e-2)

(tensor([[[-1.2169],
          [ 1.4376]],
 
         [[ 1.6572],
          [-0.1600]],
 
         [[-0.5622],
          [-0.8960]]]),
 tensor([[[-0.5805],
          [-0.2129]],
 
         [[ 0.7273],
          [-1.5174]],
 
         [[-0.5139],
          [ 1.1038]]]),
 tensor([[[-1.2111],
          [ 1.4397]],
 
         [[ 1.6499],
          [-0.1449]],
 
         [[-0.5571],
          [-0.9070]]]),
 tensor([[[-0.5829],
          [-0.2100]],
 
         [[ 0.7306],
          [-1.5177]],
 
         [[-0.5151],
          [ 1.1020]]]),
 tensor([[[-1.2111],
          [ 1.4397]],
 
         [[ 1.6499],
          [-0.1449]],
 
         [[-0.5571],
          [-0.9070]]]),
 tensor([2, 0, 2]),
 tensor([0, 0, 0]))

In [20]:
print(test_theta)

tensor([[[-1.2053],
         [ 1.4418]],

        [[ 1.6426],
         [-0.1297]],

        [[-0.5519],
         [-0.9181]]])


In [21]:
my_samples = my_sampler.collect_samples(1e-3, 10)

In [22]:
my_samples

[tensor([[[-1.2169],
          [ 1.4376]],
 
         [[ 1.6572],
          [-0.1600]],
 
         [[-0.5622],
          [-0.8960]]]),
 tensor([[[-1.2169],
          [ 1.4376]],
 
         [[ 1.6572],
          [-0.1600]],
 
         [[-0.5622],
          [-0.8960]]]),
 tensor([[[-1.2169],
          [ 1.4376]],
 
         [[ 1.6572],
          [-0.1600]],
 
         [[-0.5622],
          [-0.8960]]]),
 tensor([[[-1.2169],
          [ 1.4376]],
 
         [[ 1.6572],
          [-0.1600]],
 
         [[-0.5622],
          [-0.8960]]]),
 tensor([[[-1.2169],
          [ 1.4376]],
 
         [[ 1.6572],
          [-0.1600]],
 
         [[-0.5622],
          [-0.8960]]]),
 tensor([[[-1.2169],
          [ 1.4376]],
 
         [[ 1.6572],
          [-0.1600]],
 
         [[-0.5622],
          [-0.8960]]]),
 tensor([[[-1.2169],
          [ 1.4376]],
 
         [[ 1.6572],
          [-0.1600]],
 
         [[-0.5622],
          [-0.8960]]]),
 tensor([[[-1.2169],
          [ 1.4376]],
 
         [

In [23]:
my_sampler.params.cuda()

BasicParameters()

In [24]:
my_sampler.buildtree(test_theta.cuda(), test_r.cuda(), init_vals[0].cuda(), v.cuda(), 3, 1e-2)

(tensor([[[-1.2285],
          [ 1.4333]],
 
         [[ 1.6717],
          [-0.1904]],
 
         [[-0.5725],
          [-0.8739]]], device='cuda:0'),
 tensor([[[-0.5756],
          [-0.2186]],
 
         [[ 0.7206],
          [-1.5167]],
 
         [[-0.5117],
          [ 1.1073]]], device='cuda:0'),
 tensor([[[-1.2169],
          [ 1.4376]],
 
         [[ 1.6499],
          [-0.1449]],
 
         [[-0.5673],
          [-0.8850]]], device='cuda:0'),
 tensor([[[-0.5829],
          [-0.2100]],
 
         [[ 0.7306],
          [-1.5177]],
 
         [[-0.5151],
          [ 1.1020]]], device='cuda:0'),
 tensor([[[-1.2169],
          [ 1.4376]],
 
         [[ 1.6499],
          [-0.1449]],
 
         [[-0.5673],
          [-0.8850]]], device='cuda:0'),
 tensor([4, 0, 4], device='cuda:0'),
 tensor([0, 0, 0], device='cuda:0'))