In [1]:
import sys

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.autograd as ag
from torch.distributions.categorical import Categorical


In [2]:
from parameters import BasicParameters
from nuts import NUTS3

In [3]:
my_params = BasicParameters(
    nn.Parameter(torch.randn([3,2,1])),
    left_flank=torch.randn([3,2,2]),
    right_flank=torch.randn([3,2,2]) 
)

In [4]:
my_params.forward()

tensor([[[-0.1142, -0.4008, -0.1074,  0.4499, -0.4895],
         [-0.0029,  0.8730, -2.8936,  0.5281, -1.0204]],

        [[ 1.5358,  1.3369,  0.1343, -0.5221,  0.2896],
         [ 1.0553, -0.2518,  0.1042,  0.8339, -0.8317]],

        [[-0.6116, -0.5087, -0.5341, -0.1156, -0.7292],
         [ 1.3639,  1.3699, -2.0414, -0.3959,  1.5372]]],
       grad_fn=<CatBackward>)

In [5]:
def my_energy(in_tensor):
    return in_tensor.pow(2).mean(dim=(1,2))

In [6]:
my_energy( my_params.forward() )

tensor([1.1083, 0.7096, 1.1887], grad_fn=<MeanBackward1>)

In [7]:
my_sampler = NUTS3( my_params, my_energy )

In [8]:
test_theta = torch.randn([3,2,1])
test_r     = torch.randn([3,2,1])
print(test_theta)
print(test_r)
print(my_sampler.params.theta)

tensor([[[-0.2118],
         [-0.5361]],

        [[-0.4070],
         [-0.7517]],

        [[-2.0998],
         [-0.1409]]])
tensor([[[-1.8698],
         [-0.5579]],

        [[-0.0570],
         [ 0.5207]],

        [[-0.0685],
         [ 1.4281]]])
Parameter containing:
tensor([[[-0.1142],
         [-0.0029]],

        [[ 1.5358],
         [ 1.0553]],

        [[-0.6116],
         [ 1.3639]]], requires_grad=True)


In [9]:
result = my_sampler.leapfrog(test_theta, test_r, 1e-3)

In [10]:
print(result)
print(test_theta)
print(test_r)
print(my_sampler.params.theta)

(tensor([[[-0.2137],
         [-0.5367]],

        [[-0.4071],
         [-0.7512]],

        [[-2.0998],
         [-0.1395]]]), tensor([[[-1.8698],
         [-0.5578]],

        [[-0.0569],
         [ 0.5208]],

        [[-0.0681],
         [ 1.4282]]]), tensor([1.1403, 0.4353, 1.4082], grad_fn=<MeanBackward1>))
tensor([[[-0.2118],
         [-0.5361]],

        [[-0.4070],
         [-0.7517]],

        [[-2.0998],
         [-0.1409]]])
tensor([[[-1.8698],
         [-0.5579]],

        [[-0.0570],
         [ 0.5207]],

        [[-0.0685],
         [ 1.4281]]])
Parameter containing:
tensor([[[-0.2137],
         [-0.5367]],

        [[-0.4071],
         [-0.7512]],

        [[-2.0998],
         [-0.1395]]], requires_grad=True)


In [11]:
init_vals = my_sampler.init_trajectory(test_theta)
check_u, check_theta_r, check_r_r, check_theta_f, check_r_f, check_j, check_theta_m, check_n, check_s = init_vals
init_vals

(tensor([0.0023, 0.0806, 0.1081]),
 tensor([[[-0.2118],
          [-0.5361]],
 
         [[-0.4070],
          [-0.7517]],
 
         [[-2.0998],
          [-0.1409]]]),
 tensor([[[ 3.0937],
          [-0.3521]],
 
         [[ 0.3903],
          [-0.0836]],
 
         [[-0.3162],
          [-0.8425]]]),
 tensor([[[-0.2118],
          [-0.5361]],
 
         [[-0.4070],
          [-0.7517]],
 
         [[-2.0998],
          [-0.1409]]]),
 tensor([[[ 3.0937],
          [-0.3521]],
 
         [[ 0.3903],
          [-0.0836]],
 
         [[-0.3162],
          [-0.8425]]]),
 0,
 tensor([[[-0.2118],
          [-0.5361]],
 
         [[-0.4070],
          [-0.7517]],
 
         [[-2.0998],
          [-0.1409]]]),
 tensor([1, 1, 1]),
 tensor([1, 1, 1]))

In [12]:
v = torch.randn([1], dtype=torch.float, layout=test_theta.layout, device=test_theta.device) \
                  .ge(0.).mul(2.).add(-1.)
v

tensor([-1.])

In [13]:
(check_theta_f - check_theta_r).shape

torch.Size([3, 2, 1])

In [14]:
check_r_r.shape

torch.Size([3, 2, 1])

In [15]:
torch.einsum('bs,bs->b', (check_theta_f - check_theta_r).flatten(1), check_r_r.flatten(1))

tensor([0., 0., 0.])

In [16]:
print(test_theta)

tensor([[[-0.2118],
         [-0.5361]],

        [[-0.4070],
         [-0.7517]],

        [[-2.0998],
         [-0.1409]]])


In [17]:
my_sampler.buildtree(test_theta, test_r, init_vals[0], v, 3, 1e-2)

(tensor([[[-0.3239],
          [-0.5694]],
 
         [[-0.4103],
          [-0.7202]],
 
         [[-2.1031],
          [-0.0552]]]),
 tensor([[[-1.8666],
          [-0.5513]],
 
         [[-0.0521],
          [ 0.5295]],
 
         [[-0.0433],
          [ 1.4293]]]),
 tensor([[[-0.2866],
          [-0.5583]],
 
         [[-0.4087],
          [-0.7360]],
 
         [[-2.1004],
          [-0.1266]]]),
 tensor([[[-1.8694],
          [-0.5568]],
 
         [[-0.0561],
          [ 0.5222]],
 
         [[-0.0643],
          [ 1.4284]]]),
 tensor([[[-0.2866],
          [-0.5583]],
 
         [[-0.4087],
          [-0.7360]],
 
         [[-2.1004],
          [-0.1266]]]),
 tensor([6, 6, 0]),
 tensor([0, 0, 0]))

In [18]:
print(test_theta)

tensor([[[-0.2118],
         [-0.5361]],

        [[-0.4070],
         [-0.7517]],

        [[-2.0998],
         [-0.1409]]])


In [19]:
my_samples = my_sampler.collect_samples(1e-3, 10)

In [20]:
my_samples

[tensor([[[-0.3239],
          [-0.5694]],
 
         [[-0.4103],
          [-0.7202]],
 
         [[-2.1031],
          [-0.0552]]]),
 tensor([[[-0.3239],
          [-0.5694]],
 
         [[-0.4103],
          [-0.7202]],
 
         [[-2.1031],
          [-0.0552]]]),
 tensor([[[-0.3239],
          [-0.5694]],
 
         [[-0.4103],
          [-0.7202]],
 
         [[-2.1031],
          [-0.0552]]]),
 tensor([[[-0.3239],
          [-0.5694]],
 
         [[-0.4103],
          [-0.7202]],
 
         [[-2.1031],
          [-0.0552]]]),
 tensor([[[-0.3239],
          [-0.5694]],
 
         [[-0.4103],
          [-0.7202]],
 
         [[-2.1031],
          [-0.0552]]]),
 tensor([[[-0.3239],
          [-0.5694]],
 
         [[-0.4103],
          [-0.7202]],
 
         [[-2.1031],
          [-0.0552]]]),
 tensor([[[-0.3239],
          [-0.5694]],
 
         [[-0.4103],
          [-0.7202]],
 
         [[-2.1031],
          [-0.0552]]]),
 tensor([[[-0.3239],
          [-0.5694]],
 
         [

In [21]:
my_sampler.params.cuda()

BasicParameters()

In [22]:
my_sampler.buildtree(test_theta.cuda(), test_r.cuda(), init_vals[0].cuda(), v.cuda(), 3, 1e-2)

(tensor([[[-0.2492],
          [-0.5472]],
 
         [[-0.4082],
          [-0.7412]],
 
         [[-2.1011],
          [-0.1123]]], device='cuda:0'),
 tensor([[[-1.8689],
          [-0.5558]],
 
         [[-0.0553],
          [ 0.5237]],
 
         [[-0.0601],
          [ 1.4286]]], device='cuda:0'),
 tensor([[[-0.2305],
          [-0.5417]],
 
         [[-0.4076],
          [-0.7465]],
 
         [[-2.1004],
          [-0.1266]]], device='cuda:0'),
 tensor([[[-1.8694],
          [-0.5568]],
 
         [[-0.0561],
          [ 0.5222]],
 
         [[-0.0643],
          [ 1.4284]]], device='cuda:0'),
 tensor([[[-0.2305],
          [-0.5417]],
 
         [[-0.4076],
          [-0.7465]],
 
         [[-2.1004],
          [-0.1266]]], device='cuda:0'),
 tensor([2, 2, 0], device='cuda:0'),
 tensor([0, 0, 0], device='cuda:0'))