In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
!pip install torchsde
!pip install torch==2.0.1+cu118 --index-url https://download.pytorch.org/whl/cu118

Collecting torchsde
  Downloading torchsde-0.2.6-py3-none-any.whl (61 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.2/61.2 kB[0m [31m8.9 MB/s[0m eta [36m0:00:00[0m
Collecting trampoline>=0.1.2 (from torchsde)
  Downloading trampoline-0.1.2-py3-none-any.whl (5.2 kB)
Installing collected packages: trampoline, torchsde
Successfully installed torchsde-0.2.6 trampoline-0.1.2

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.1.2[0m[39;49m -> [0m[32;49m23.3.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
Looking in indexes: https://download.pytorch.org/whl/cu118

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.1.2[0m[39;49m -> [0m[32;49m23.3.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [12]:
import matplotlib.pyplot as plt
import torch
import torchsde
import numpy as np

In [52]:
        
def estimate_grad_Rt(x, t, beta=1.0, num_mc_samples=100):
    def grad_fxn(_x, _t):
        repeated_x = _x.unsqueeze(0).repeat_interleave(num_mc_samples, dim=0)
        repeated_t = _t.unsqueeze(0).repeat_interleave(num_mc_samples, dim=0)
        h_t = beta * repeated_t
        
        # samples = repeated_x + torch.normal(torch.zeros_like(repeated_x), h_t.unsqueeze(1) ** 0.5)
        samples = repeated_x + torch.randn(repeated_x.shape).to(device) * h_t.unsqueeze(1) ** 0.5
        log_rewards = target_distrib.energy(samples)#unnormalize(samples, global_mins, global_maxs))
        
        #import pdb; pdb.set_trace()
        return torch.logsumexp(log_rewards, dim=-1) - np.log(num_mc_samples)
    
    t = t * torch.ones(x.shape[0]).to(device)
    est_scores = torch.vmap(torch.func.grad(grad_fxn), randomness='different')(x, t)
    print(est_scores.max(-1)[0].mean())
    return est_scores#.clamp(min=-100.0, max=100.0)

def estimate_grad_Rt_vp(x, t, num_mc_samples=100):
    def grad_fxn(_x, _t):
        h_t = _t.repeat_interleave(num_mc_samples).unsqueeze(1)# / global_maxs
        repeated_x = _x.unsqueeze(0).repeat_interleave(num_mc_samples, dim=0)
        repeated_t = _t.unsqueeze(0).repeat_interleave(num_mc_samples, dim=0)
        
        samples = (repeated_x / (1 - repeated_t).sqrt().unsqueeze(1)) + (torch.normal(torch.zeros_like(repeated_x), h_t.sqrt()).to(device))
        log_rewards = target_distrib.log_prob(samples)#unnormalize(samples, global_mins, global_maxs))
        
        return torch.logsumexp(log_rewards, dim=-1) - np.log(num_mc_samples)
    
    t = t * torch.ones(x.shape[0]).to(device)
    est_scores = torch.vmap(torch.func.grad(grad_fxn), randomness='different')(x, t)
    return est_scores.clamp(min=-100.0, max=100.0)

def true_Rt(x,  h_t):
    target_distrib.convolve(h_t[0])# * (50 ** 2))
    energy = target_distrib.log_prob(x)#unnormalize(x, global_mins, global_maxs))
    target_distrib.reset()
    return energy

def true_Rt_vp(x, h_t):
    target_distrib.convolve(h_t)
    energy = target_distrib.log_prob(x / (1 - h_t).sqrt())#unnormalize(x / (1 - t).sqrt(), global_mins, global_maxs))
    target_distrib.reset()
    return energy

def true_grad_Rt(x, h_t):
    with torch.enable_grad():
        x.requires_grad_()
        samples_energy = true_Rt(x, h_t)
        true_scores = torch.autograd.grad(samples_energy.sum(), x, retain_graph=True)[0].detach()
    return true_scores.detach()

def true_grad_Rt_vp(x, h_t, global_mins=-50, global_maxs=50):
    samples_energy = true_Rt_vp(x, h_t)
    true_scores = torch.autograd.grad(samples_energy.sum(), x, retain_graph=True)[0].detach()
    return true_scores
        
def reward_matching_loss(target, vectorfield, constant_noise_scale, x, t):     
    t = t.unsqueeze(1)
    h_t = constant_noise_scale * t
    
    x.requires_grad = True

    # noisy sample: x(t)
    noisy_x = torch.normal(x, h_t ** 0.5)
    pred_scores = vectorfield(noisy_x, t.squeeze())

    estimated_scores = estimate_grad_Rt(noisy_x, t.squeeze(1), beta, 100)
    return ((estimated_scores - pred_scores)**2).mean()

In [30]:
def w1(x):
    return ((2 * np.pi * x[..., 0]) / 4.0).sin()

def w2(x):
    exp_arg = -0.5 * ((x[..., 0] - 1) / 0.6).pow(2)
    return 3 * exp_arg.exp()

def w3(x):
    def sigmoid(v):
        return 1 / (1 + (-v).exp())
    
    return 3 * sigmoid((x[..., 0] - 1) / 0.3)

def parenthesis_energy(x):
    two_norm = torch.linalg.norm(x, ord=2, dim=-1)
    first_term = 0.5 * ((two_norm - 2) / 0.4).pow(2)
    
    first_exp_arg_term = -0.5 * ((x[..., 0] - 2) / 0.6).pow(2)
    second_exp_arg_term = -0.5 * ((x[..., 0] + 2) / 0.6).pow(2)
    
    ln_term = torch.stack([first_exp_arg_term, second_exp_arg_term]).logsumexp(0)
    
    return first_term - ln_term

def contiguous_squiglies_energy(x):
    return 0.5 * ((x[..., 1] - w1(x)) / 0.4).pow(2)

def middle_divergent_squiglies_energy(x):
    first_exp_arg = -0.5 * ((x[..., 1] - w1(x)) / 0.35).pow(2)
    scnd_exp_arg = -0.5 * ((x[..., 1] - w1(x) + w2(x)) / 0.35).pow(2)
    
    return -torch.stack([first_exp_arg, scnd_exp_arg]).logsumexp(0)

def end_divergent_squiglies_energy(x):
    first_exp_arg = -0.5 * ((x[..., 1] - w1(x)) / 0.4).pow(2)
    scnd_exp_arg = -0.5 * ((x[..., 1] - w1(x) + w3(x)) / 0.35).pow(2)
    
    return -torch.stack([first_exp_arg, scnd_exp_arg]).logsumexp(0)

class EnergyDistribution:
    def __init__(self, energy_fxn):
        self.energy_fxn = energy_fxn
        
    def energy(self, x):
        return -self.energy_fxn(x)
    
class VEReverseSDE(torch.nn.Module):
    noise_type = 'diagonal'
    sde_type = 'ito'
    
    def __init__(self, score_net):
        super().__init__()
        self.score_model = score_net
    
    def f(self, t, x):
        # pdb.set_trace()
        t = t.repeat(len(x)).to(x.device)
        score = self.score_model(x, 1-t)
        return self.g(t, x).pow(2) * score 
    
    def g(self, t, x):
        return torch.full_like(x, beta ** 0.5) #+ 0.1

In [46]:
target_distrib = EnergyDistribution(contiguous_squiglies_energy)

In [53]:
num_mc_samples = 10
device = torch.device('cpu')

def est_ve_score_wrapper(x, t):
    with torch.enable_grad():
        return estimate_grad_Rt(x, t, beta=beta, num_mc_samples=1000)

ve_sde_est_score = VEReverseSDE(est_ve_score_wrapper)

beta = 40.0
num_samples = 1000
x1_samples = torch.randn((num_samples, 2)) * beta

In [None]:
t = torch.linspace(0.0, 1.0, 500)
with torch.no_grad():
    est_samples = torchsde.sdeint(ve_sde_est_score, x1_samples, t, method='euler')

tensor(60.9182)
tensor(59.0613)
tensor(51.8459)
tensor(32.6450)
tensor(27.6173)
tensor(20.8187)
tensor(17.2524)
tensor(11.5707)
tensor(9.7494)
tensor(6.6325)
tensor(6.5244)
tensor(5.0616)
tensor(4.0987)
tensor(3.4386)
tensor(3.4342)
tensor(3.2770)
tensor(2.6044)
tensor(2.5946)
tensor(2.5443)
tensor(2.0609)
tensor(1.9744)
tensor(1.7316)
tensor(1.8941)
tensor(1.7207)
tensor(1.6402)
tensor(1.5335)
tensor(1.6300)
tensor(1.5883)
tensor(1.6327)
tensor(1.2260)
tensor(1.4389)
tensor(1.3530)
tensor(1.3245)
tensor(1.1856)
tensor(1.2036)
tensor(0.9563)
tensor(1.2796)
tensor(1.3048)
tensor(1.2535)
tensor(1.2163)
tensor(1.0982)
tensor(1.0516)
tensor(1.4005)
tensor(1.0585)
tensor(0.9413)
tensor(1.0245)
tensor(1.0503)
tensor(1.0210)
tensor(0.9272)
tensor(1.0064)
tensor(0.9507)
tensor(0.9093)
tensor(0.8861)
tensor(1.0306)
tensor(0.8855)
tensor(0.9042)
tensor(0.7764)
tensor(0.9641)
tensor(0.7415)
tensor(0.7917)
tensor(0.8518)
tensor(0.8155)
tensor(0.8421)
tensor(0.7927)
tensor(0.8528)
tensor(0.9020)
te

In [None]:
plt.scatter(*est_samples[-1].T.detach().cpu())