In [1]:
import math
import torch
from tqdm.notebook import tqdm
tqdm.get_lock().locks = []

In [2]:
def random_categorical(dim):
    return torch.distributions.categorical.Categorical(logits=torch.randn(dim)).probs

dim = 64
nu = random_categorical(dim)
mu = random_categorical(dim)

In [5]:
class Gamma_KL(torch.autograd.Function):
    @staticmethod
    def init_gamma(f, nu):
        return torch.mean(f)
        
    @staticmethod
    def F(f, nu, gamma):
        return -torch.sum(nu * torch.exp(f - gamma), dim=-1) + 1
    
    @staticmethod
    def F_derivative(f, nu, gamma):
        return torch.sum(nu * torch.exp(f - gamma))
    
    @staticmethod
    def F_gradient(f, nu, gamma):
        return -nu * torch.exp(f - gamma)
    
    @staticmethod
    def phi(x):
        return x * torch.log(x) - x + 1
    
    @staticmethod
    def phi_conjugate(y):
        return torch.exp(y) - 1
    
    @staticmethod
    def forward(ctx, f, nu):
        f, nu = f.detach(), nu.detach()
        
        tol = 1e-6
        max_steps = 10000

        gamma = Gamma_KL.init_gamma(f, nu)
        for i in range(max_steps):
            gamma_prev = gamma
            F = Gamma_KL.F(f, nu, gamma)
            F_derivative = Gamma_KL.F_derivative(f, nu, gamma)
            gamma = gamma_prev - F / F_derivative
            if torch.abs(gamma - gamma_prev) < tol:
                break
        if torch.abs(gamma - gamma_prev) >= tol:
            print(f"Newton: tolerance not reached")

        ctx.save_for_backward(f, nu, gamma)

        return gamma

    @staticmethod
    def backward(ctx, grad_output):
        with torch.no_grad():
            f, nu, gamma = ctx.saved_tensors
            grad_input_f = -Gamma_KL.F_gradient(f, nu, gamma) / Gamma_KL.F_derivative(f, nu, gamma)
        return grad_input_f * grad_output, None
    
    @staticmethod
    def conjugate(f, nu):
        gamma = Gamma_KL.apply(f, nu)
        return torch.sum(nu * Gamma_KL.phi_conjugate(f - gamma), dim=-1) + gamma

In [6]:
print(f"KL exact: {torch.sum(nu * Gamma_KL.phi(mu / nu)).item():.5f}")

KL exact: 0.72302


In [7]:
iterations = 2000
f = torch.zeros(dim, requires_grad=True)
optimizer = torch.optim.SGD((f,), lr=1e-1, momentum=0.9)
t = tqdm(range(iterations), ncols='100%')
for _ in t:
    optimizer.zero_grad()
    objective = torch.sum(mu * f, dim=-1) - Gamma_KL.conjugate(f, nu)
    loss = -objective
    loss.backward()
    optimizer.step()
    t.set_description(f"objective: {objective.item():.5f}")

print(f"KL approx: {(torch.sum(mu * f, dim=-1) - Gamma_KL.conjugate(f, nu)).item():.5f}")

HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=2000.0), HTML(value='')), layout=Layout(d…


KL approx: 0.72302


In [8]:
iterations = 2000
f = torch.zeros(dim, requires_grad=True)
optimizer = torch.optim.SGD((f,), lr=1e-1, momentum=0.9)
t = tqdm(range(iterations), ncols='100%')
for _ in t:
    optimizer.zero_grad()
    objective = torch.sum(mu * f, dim=-1) - torch.log(torch.sum(nu * torch.exp(f)))
    loss = -objective
    loss.backward()
    optimizer.step()
    t.set_description(f"objective: {objective.item():.5f}")

print(f"KL approx closed form: {(torch.sum(mu * f, dim=-1) - torch.log(torch.sum(nu * torch.exp(f)))).item():.5f}")

HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=2000.0), HTML(value='')), layout=Layout(d…


KL approx closed form: 0.72302


In [9]:
class Gamma_reverseKL(torch.autograd.Function):
    @staticmethod
    def init_gamma(f, nu):
        return torch.max(f) - 1 + 0.01
        
    @staticmethod
    def F(f, nu, gamma):
        return -torch.sum(nu * (1 / (1 - f + gamma)), dim=-1) + 1
    
    @staticmethod
    def F_derivative(f, nu, gamma):
        return torch.sum(nu * (1 / ((1 - f + gamma) ** 2)))
    
    @staticmethod
    def F_gradient(f, nu, gamma):
        return -nu * (1 / ((1 - f + gamma) ** 2))
    
    @staticmethod
    def phi(x):
        return x - 1 - torch.log(x)
    
    @staticmethod
    def phi_conjugate(y):
        return -torch.log(1 - y)
    
    @staticmethod
    def forward(ctx, f, nu):
        f, nu = f.detach(), nu.detach()
        
        tol = 1e-6
        max_steps = 10000

        gamma = Gamma_reverseKL.init_gamma(f, nu)
        for i in range(max_steps):
            gamma_prev = gamma
            F = Gamma_reverseKL.F(f, nu, gamma)
            F_derivative = Gamma_reverseKL.F_derivative(f, nu, gamma)
            gamma = gamma_prev - F / F_derivative
            if torch.abs(gamma - gamma_prev) < tol:
                break
        if torch.abs(gamma - gamma_prev) >= tol:
            print(f"Newton: tolerance not reached")
        
        ctx.save_for_backward(f, nu, gamma)

        return gamma

    @staticmethod
    def backward(ctx, grad_output):
        with torch.no_grad():
            f, nu, gamma = ctx.saved_tensors
            grad_input_f = -Gamma_reverseKL.F_gradient(f, nu, gamma) / Gamma_reverseKL.F_derivative(f, nu, gamma)
        return grad_input_f * grad_output, None
    
    @staticmethod
    def conjugate(f, nu):
        gamma = Gamma_reverseKL.apply(f, nu)
        return torch.sum(nu * Gamma_reverseKL.phi_conjugate(f - gamma), dim=-1) + gamma

In [10]:
print(f"reverse KL exact: {torch.sum(nu * Gamma_reverseKL.phi(mu / nu)).item():.5f}")

reverse KL exact: 0.64295


In [11]:
iterations = 10000
f = torch.zeros(dim, requires_grad=True)
optimizer = torch.optim.SGD((f,), lr=1e-1, momentum=0.9)
t = tqdm(range(iterations), ncols='100%')
for _ in t:
    optimizer.zero_grad()
    objective = torch.sum(mu * f, dim=-1) - (Gamma_reverseKL.conjugate(f - torch.max(f), nu) + torch.max(f))
    loss = -objective
    loss.backward()
    optimizer.step()
    t.set_description(f"objective: {objective.item():.5f}")

print(f"reverse KL approx: {(torch.sum(mu * f, dim=-1) - Gamma_reverseKL.conjugate(f, nu)).item():.5f}")

HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=10000.0), HTML(value='')), layout=Layout(…


reverse KL approx: 0.64273


In [19]:
class Gamma_Chi2(torch.autograd.Function):
    @staticmethod
    def init_gamma(f, nu):
        return torch.max(f)
        
    @staticmethod
    def F(f, nu, gamma):
        return -torch.sum(nu * ((1 / 2) * (f - gamma) + 1), dim=-1) + 1
    
    @staticmethod
    def F_derivative(f, nu, gamma):
        return 1 / 2
    
    @staticmethod
    def F_gradient(f, nu, gamma):
        return -nu * (1 / 2)
    
    @staticmethod
    def phi(x):
        return (x - 1) ** 2
    
    @staticmethod
    def phi_conjugate(y):
        return (1 / 4) * (y ** 2) + y
    
    @staticmethod
    def forward(ctx, f, nu):
        f, nu = f.detach(), nu.detach()
        
        tol = 1e-5
        max_steps = 10000

        gamma = Gamma_Chi2.init_gamma(f, nu)
        for i in range(max_steps):
            gamma_prev = gamma
            F = Gamma_Chi2.F(f, nu, gamma)
            F_derivative = Gamma_Chi2.F_derivative(f, nu, gamma)
            gamma = gamma_prev - F / F_derivative
            if torch.abs(gamma - gamma_prev) < tol:
                break
        if torch.abs(gamma - gamma_prev) >= tol:
            print(f"Newton: tolerance not reached")
        
        ctx.save_for_backward(f, nu, gamma)

        return gamma

    @staticmethod
    def backward(ctx, grad_output):
        with torch.no_grad():
            f, nu, gamma = ctx.saved_tensors
            grad_input_f = -Gamma_Chi2.F_gradient(f, nu, gamma) / Gamma_Chi2.F_derivative(f, nu, gamma)
        return grad_input_f * grad_output, None
    
    @staticmethod
    def conjugate(f, nu):
        gamma = Gamma_Chi2.apply(f, nu)
        return torch.sum(nu * Gamma_Chi2.phi_conjugate(f - gamma), dim=-1) + gamma

In [20]:
print(f"Chi2 exact: {torch.sum(nu * Gamma_Chi2.phi(mu / nu)).item():.5f}")

Chi2 exact: 2.64669


In [21]:
iterations = 10000
f = torch.zeros(dim, requires_grad=True)
optimizer = torch.optim.SGD((f,), lr=1e-1, momentum=0.9)
t = tqdm(range(iterations), ncols='100%')
for _ in t:
    optimizer.zero_grad()
    objective = torch.sum(mu * f, dim=-1) - (Gamma_Chi2.conjugate(f - torch.max(f), nu) + torch.max(f))
    loss = -objective
    loss.backward()
    optimizer.step()
    t.set_description(f"objective: {objective.item():.5f}")

print(f"Chi2 approx: {(torch.sum(mu * f, dim=-1) - Gamma_Chi2.conjugate(f, nu)).item():.5f}")

HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=10000.0), HTML(value='')), layout=Layout(…


Chi2 approx: 2.64669


In [22]:
iterations = 10000
f = torch.zeros(dim, requires_grad=True)
optimizer = torch.optim.SGD((f,), lr=1e-1, momentum=0.9)
t = tqdm(range(iterations), ncols='100%')
for _ in t:
    optimizer.zero_grad()
    objective = torch.sum(mu * f, dim=-1) - (torch.sum(nu * f, dim=-1) + (1 / 4) * torch.sum(nu * ((f - torch.sum(nu * f, dim=-1)) ** 2), dim=-1))
    loss = -objective
    loss.backward()
    optimizer.step()
    t.set_description(f"objective: {objective.item():.5f}")

print(f"Chi2 approx closed form: {(torch.sum(mu * f, dim=-1) - (torch.sum(nu * f, dim=-1) + (1 / 4) * torch.sum(nu * ((f - torch.sum(nu * f, dim=-1)) ** 2), dim=-1))).item():.5f}")

HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=10000.0), HTML(value='')), layout=Layout(…


Chi2 approx closed form: 2.64669


In [23]:
class Gamma_reverseChi2(torch.autograd.Function):
    @staticmethod
    def init_gamma(f, nu):
        return torch.max(f) - 1 + 1e-7
        
    @staticmethod
    def F(f, nu, gamma):
        return -torch.sum(nu * (1 / torch.sqrt(1 - f + gamma)), dim=-1) + 1
    
    @staticmethod
    def F_derivative(f, nu, gamma):
        return torch.sum(nu * (1 / (2 * (torch.sqrt(1 - f + gamma) ** 3))))
    
    @staticmethod
    def F_gradient(f, nu, gamma):
        return -nu * (1 / (2 * (torch.sqrt(1 - f + gamma) ** 3)))
    
    @staticmethod
    def phi(x):
        return 1 / x + x - 2
    
    @staticmethod
    def phi_conjugate(y):
        return 2 - 2 * torch.sqrt(1 - y)
    
    @staticmethod
    def forward(ctx, f, nu):
        with torch.no_grad():
            f, nu = f.detach(), nu.detach()

            tol = 1e-6
            max_steps = 10000

            gamma = Gamma_reverseChi2.init_gamma(f, nu)
            for i in range(max_steps):
                gamma_prev = gamma
                F = Gamma_reverseChi2.F(f, nu, gamma)
                F_derivative = Gamma_reverseChi2.F_derivative(f, nu, gamma)
                gamma = gamma_prev - F / F_derivative
                if torch.abs(gamma - gamma_prev) < tol or Gamma_reverseChi2.F(f, nu, gamma) == 0.0:
                    break
            if torch.abs(gamma - gamma_prev) >= tol:
                print(f"Newton: tolerance not reached")
        
        ctx.save_for_backward(f, nu, gamma)

        return gamma

    @staticmethod
    def backward(ctx, grad_output):
        with torch.no_grad():
            f, nu, gamma = ctx.saved_tensors
            grad_input_f = -Gamma_reverseChi2.F_gradient(f, nu, gamma) / Gamma_reverseChi2.F_derivative(f, nu, gamma)
        return grad_input_f * grad_output, None
    
    @staticmethod
    def conjugate(f, nu):
        gamma = Gamma_reverseChi2.apply(f, nu)
        return torch.sum(nu * Gamma_reverseChi2.phi_conjugate(f - gamma), dim=-1) + gamma

In [24]:
print(f"reverseChi2 exact: {torch.sum(nu * Gamma_reverseChi2.phi(mu / nu)).item():.5f}")

reverseChi2 exact: 1.78507


In [73]:
iterations = 1000000
f = torch.zeros(dim, requires_grad=True)
optimizer = torch.optim.SGD((f,), lr=1e-1, momentum=0.9)
t = tqdm(range(iterations), ncols='100%')
for _ in t:
    optimizer.zero_grad()
    objective = torch.sum(mu * f, dim=-1) - (Gamma_reverseChi2.conjugate(f - torch.max(f), nu) + torch.max(f))
    loss = -objective
    loss.backward()
    optimizer.step()
    t.set_description(f"objective: {objective.item():.5f}")

print(f"reverseChi2 approx: {(torch.sum(mu * f, dim=-1) - (Gamma_reverseChi2.conjugate(f - torch.max(f), nu) + torch.max(f))).item():.5f}")

HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1000000.0), HTML(value='')), layout=Layou…


reverseChi2 approx: 1.78482


In [28]:
class Gamma_Hellinger2(torch.autograd.Function):
    @staticmethod
    def init_gamma(f, nu):
        return torch.max(f) - 1 + 0.1
        
    @staticmethod
    def F(f, nu, gamma):
        return -torch.sum(nu * (1 / ((1 - f + gamma) ** 2)), dim=-1) + 1
    
    @staticmethod
    def F_derivative(f, nu, gamma):
        return torch.sum(nu * (2 / ((1 - f + gamma) ** 3)), dim=-1)
    
    @staticmethod
    def F_gradient(f, nu, gamma):
        return -nu * (2 / ((1 - f + gamma) ** 3))
    
    @staticmethod
    def phi(x):
        return (x ** (1/2) - 1) ** 2
    
    @staticmethod
    def phi_conjugate(y):
        return y / (1 - y)
    
    @staticmethod
    def forward(ctx, f, nu):
        f, nu = f.detach(), nu.detach()
        
        tol = 1e-6
        max_steps = 10000

        gamma = Gamma_Hellinger2.init_gamma(f, nu)
        for i in range(max_steps):
            gamma_prev = gamma
            F = Gamma_Hellinger2.F(f, nu, gamma)
            F_derivative = Gamma_Hellinger2.F_derivative(f, nu, gamma)
            gamma = gamma_prev - F / F_derivative
            if torch.abs(gamma - gamma_prev) < tol:
                break
        if torch.abs(gamma - gamma_prev) >= tol:
            print(f"Newton: tolerance not reached")
        
        ctx.save_for_backward(f, nu, gamma)

        return gamma

    @staticmethod
    def backward(ctx, grad_output):
        with torch.no_grad():
            f, nu, gamma = ctx.saved_tensors
            grad_input_f = -Gamma_Hellinger2.F_gradient(f, nu, gamma) / Gamma_Hellinger2.F_derivative(f, nu, gamma)
        return grad_input_f * grad_output, None
    
    @staticmethod
    def conjugate(f, nu):
        gamma = Gamma_Hellinger2.apply(f, nu)
        return torch.sum(nu * Gamma_Hellinger2.phi_conjugate(f - gamma), dim=-1) + gamma

In [29]:
print(f"Hellinger2 exact: {torch.sum(nu * Gamma_Hellinger2.phi(mu / nu)).item():.5f}")

Hellinger2 exact: 0.32209


In [32]:
iterations = 3000
f = torch.zeros(dim, requires_grad=True)
optimizer = torch.optim.SGD((f,), lr=1e-1, momentum=0.9)
t = tqdm(range(iterations), ncols='100%')
for _ in t:
    optimizer.zero_grad()
    objective = torch.sum(mu * f, dim=-1) - Gamma_Hellinger2.conjugate(f, nu)
    loss = -objective
    loss.backward()
    optimizer.step()
    t.set_description(f"objective: {objective.item():.5f}")

print(f"Hellinger2 approx: {(torch.sum(mu * f, dim=-1) - Gamma_Hellinger2.conjugate(f, nu)).item():.5f}")

HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=3000.0), HTML(value='')), layout=Layout(d…


Hellinger2 approx: 0.32209


In [33]:
class Gamma_JS(torch.autograd.Function):
    @staticmethod
    def init_gamma(f, nu):
        return torch.max(f) - torch.log(torch.Tensor([2.0])) + 0.00001
        
    @staticmethod
    def F(f, nu, gamma):
        return -torch.sum(nu * (1 / (2 * torch.exp(gamma - f) - 1)), dim=-1) + 1
    
    @staticmethod
    def F_derivative(f, nu, gamma):
        return torch.sum(nu * ((2 * torch.exp(f - gamma)) / ((torch.exp(f - gamma) - 2) ** 2)), dim=-1)
    
    @staticmethod
    def F_gradient(f, nu, gamma):
        return -nu * ((2 * torch.exp(f - gamma)) / ((torch.exp(f - gamma) - 2) ** 2))
    
    @staticmethod
    def phi(x):
        return x * torch.log(x) - (x + 1) * torch.log((x + 1) / 2)
    
    @staticmethod
    def phi_conjugate(y):
        return -torch.log(2 - torch.exp(y))
    
    @staticmethod
    def forward(ctx, f, nu):
        f, nu = f.detach(), nu.detach()
        
        tol = 1e-6
        max_steps = 10000

        gamma = Gamma_JS.init_gamma(f, nu)
        for i in range(max_steps):
            gamma_prev = gamma
            F = Gamma_JS.F(f, nu, gamma)
            F_derivative = Gamma_JS.F_derivative(f, nu, gamma)
            gamma = gamma_prev - F / F_derivative
            if torch.abs(gamma - gamma_prev) < tol:
                break
        if torch.abs(gamma - gamma_prev) >= tol:
            print(f"Newton: tolerance not reached")
        
        ctx.save_for_backward(f, nu, gamma)

        return gamma

    @staticmethod
    def backward(ctx, grad_output):
        with torch.no_grad():
            f, nu, gamma = ctx.saved_tensors
            grad_input_f = -Gamma_JS.F_gradient(f, nu, gamma) / Gamma_JS.F_derivative(f, nu, gamma)
        return grad_input_f * grad_output, None
    
    @staticmethod
    def conjugate(f, nu):
        gamma = Gamma_JS.apply(f, nu)
        return torch.sum(nu * Gamma_JS.phi_conjugate(f - gamma), dim=-1) + gamma

In [34]:
print(f"JS exact: {torch.sum(nu * Gamma_JS.phi(mu / nu)).item():.5f}")

JS exact: 0.30584


In [40]:
iterations = 2000
f = torch.zeros(dim, requires_grad=True)
optimizer = torch.optim.SGD((f,), lr=1e-1, momentum=0.9)
t = tqdm(range(iterations), ncols='100%')
for _ in t:
    optimizer.zero_grad()
    objective = torch.sum(mu * f, dim=-1) - Gamma_JS.conjugate(f, nu)
    loss = -objective
    loss.backward()
    optimizer.step()
    t.set_description(f"objective: {objective.item():.5f}")

print(f"JS approx: {(torch.sum(mu * f, dim=-1) - Gamma_JS.conjugate(f, nu)).item():.5f}")

HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=2000.0), HTML(value='')), layout=Layout(d…


JS approx: 0.30584


In [41]:
class LambertW(torch.autograd.Function):
    @staticmethod
    def forward(ctx, z):
        z = z.detach()
        
        tol = 1e-5
        max_steps = 10000
        
        w = torch.log(1 + z)
        
        for i in range(max_steps):
            w_prev = w
            w = w - (w * torch.exp(w) - z) / (torch.exp(w) + w * torch.exp(w))
            if torch.abs(w - w_prev).max() < tol:
                break
        if torch.abs(w - w_prev).max() >= tol:
            print(f"LambertW: tolerance not reached")
        
        ctx.save_for_backward(z, w)

        return w
    
    @staticmethod
    def backward(ctx, grad_output):
        with torch.no_grad():
            z, w = ctx.saved_tensors
            grad_input_z = w / (z * (1 + w))
        return grad_output * grad_input_z


def lambertw(z):
    return LambertW.apply(z)
    

class Gamma_Jeffreys(torch.autograd.Function):
    @staticmethod
    def init_gamma(f, nu):
        return torch.sum(nu * f)
        
    @staticmethod
    def F(f, nu, gamma):
        return -torch.sum(nu * (1 / lambertw(torch.exp(1 - f + gamma))), dim=-1) + 1
    
    @staticmethod
    def F_derivative(f, nu, gamma):
        return torch.sum(nu * ((1 / (lambertw(torch.exp(1 - f + gamma)))) - (1 / (lambertw(torch.exp(1 - f + gamma)) + 1))), dim=-1)
    
    @staticmethod
    def F_gradient(f, nu, gamma):
        return -nu * ((1 / (lambertw(torch.exp(1 - f + gamma)))) - (1 / (lambertw(torch.exp(1 - f + gamma)) + 1)))
    
    @staticmethod
    def phi(x):
        return (x - 1) * torch.log(x)
    
    @staticmethod
    def phi_conjugate(y):
        return y + lambertw(torch.exp(1 - y)) + 1 / lambertw(torch.exp(1 - y)) - 2
    
    @staticmethod
    def forward(ctx, f, nu):
        f, nu = f.detach(), nu.detach()
        
        tol = 1e-6
        max_steps = 10000

        gamma = Gamma_Jeffreys.init_gamma(f, nu)
        for i in range(max_steps):
            gamma_prev = gamma
            F = Gamma_Jeffreys.F(f, nu, gamma)
            F_derivative = Gamma_Jeffreys.F_derivative(f, nu, gamma)
            gamma = gamma_prev - F / F_derivative
            if torch.abs(gamma - gamma_prev) < tol:
                break
        if torch.abs(gamma - gamma_prev) >= tol:
            print(f"Newton: tolerance not reached")
        
        ctx.save_for_backward(f, nu, gamma)

        return gamma

    @staticmethod
    def backward(ctx, grad_output):
        with torch.no_grad():
            f, nu, gamma = ctx.saved_tensors
            grad_input_f = -Gamma_Jeffreys.F_gradient(f, nu, gamma) / Gamma_Jeffreys.F_derivative(f, nu, gamma)
        return grad_input_f * grad_output, None
    
    @staticmethod
    def conjugate(f, nu):
        gamma = Gamma_Jeffreys.apply(f, nu)
        return torch.sum(nu * Gamma_Jeffreys.phi_conjugate(f - gamma), dim=-1) + gamma

In [42]:
print(f"Jeffreys exact: {torch.sum(nu * Gamma_Jeffreys.phi(mu / nu)).item():.5f}")

Jeffreys exact: 1.36596


In [49]:
iterations = 30000
f = torch.zeros(dim, requires_grad=True)
optimizer = torch.optim.SGD((f,), lr=1e-1, momentum=0.9)
t = tqdm(range(iterations), ncols='100%')
for _ in t:
    optimizer.zero_grad()
    objective = torch.sum(mu * f, dim=-1) - (Gamma_Jeffreys.conjugate(f - torch.max(f), nu) + torch.max(f))
    loss = -objective
    loss.backward()
    optimizer.step()
    t.set_description(f"objective: {objective.item():.5f}")

print(f"Jeffreys approx: {(torch.sum(mu * f, dim=-1) - Gamma_Jeffreys.conjugate(f, nu)).item():.5f}")

HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=30000.0), HTML(value='')), layout=Layout(…


Jeffreys approx: 1.36596


In [50]:
class Gamma_Triangular(torch.autograd.Function):
    @staticmethod
    def init_gamma(f, nu):
        return torch.max(f) - 1 + 0.00001
        
    @staticmethod
    def F(f, nu, gamma):
        return -torch.sum(nu * (
            0 * (f - gamma < -3) +
            (2 / ((1 - f + gamma) ** (1/2)) - 1) * (f - gamma >= -3)
        ), dim=-1) + 1
    
    @staticmethod
    def F_derivative(f, nu, gamma):
        return torch.sum(nu * (
            0 * (f - gamma < -3) +
            (1 / (((1 - f + gamma) ** (1/2)) ** 3)) * (f - gamma >= -3)
        ), dim=-1)
    
    @staticmethod
    def F_gradient(f, nu, gamma):
        return -nu * (
            0 * (f - gamma < -3) +
            (1 / (((1 - f + gamma) ** (1/2)) ** 3)) * (f - gamma >= -3)
        )
    
    @staticmethod
    def phi(x):
        return ((x - 1) ** 2) / (x + 1)
    
    @staticmethod
    def phi_conjugate(y):
        return (
            -1 * (y < -3) +
            (4 - 4 * ((1 - y) ** (1/2)) - y) * (y >= -3)
        )
    
    @staticmethod
    def forward(ctx, f, nu):
        f, nu = f.detach(), nu.detach()
        
        tol = 1e-6
        max_steps = 10000

        gamma = Gamma_Triangular.init_gamma(f, nu)
        for i in range(max_steps):
            gamma_prev = gamma
            F = Gamma_Triangular.F(f, nu, gamma)
            F_derivative = Gamma_Triangular.F_derivative(f, nu, gamma)
            gamma = gamma_prev - F / F_derivative
            if torch.abs(gamma - gamma_prev) < tol:
                break
        if torch.abs(gamma - gamma_prev) >= tol:
            print(f"Newton: tolerance not reached")
        
        ctx.save_for_backward(f, nu, gamma)

        return gamma

    @staticmethod
    def backward(ctx, grad_output):
        with torch.no_grad():
            f, nu, gamma = ctx.saved_tensors
            grad_input_f = -Gamma_Triangular.F_gradient(f, nu, gamma) / Gamma_Triangular.F_derivative(f, nu, gamma)
        return grad_input_f * grad_output, None
    
    @staticmethod
    def conjugate(f, nu):
        gamma = Gamma_Triangular.apply(f, nu)
        return torch.sum(nu * Gamma_Triangular.phi_conjugate(f - gamma), dim=-1) + gamma

In [51]:
print(f"Triangular exact: {torch.sum(nu * Gamma_Triangular.phi(mu / nu)).item():.5f}")

Triangular exact: 0.56001


In [54]:
iterations = 2000
f = torch.zeros(dim, requires_grad=True)
optimizer = torch.optim.SGD((f,), lr=1e-1, momentum=0.9)
t = tqdm(range(iterations), ncols='100%')
for _ in t:
    optimizer.zero_grad()
    objective = torch.sum(mu * f, dim=-1) - Gamma_Triangular.conjugate(f, nu)
    loss = -objective
    loss.backward()
    optimizer.step()
    t.set_description(f"objective: {objective.item():.5f}")

print(f"Triangular approx: {(torch.sum(mu * f, dim=-1) - Gamma_Triangular.conjugate(f, nu)).item():.5f}")

HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=2000.0), HTML(value='')), layout=Layout(d…


Triangular approx: 0.56001


In [55]:
def conjugate_tv(f, nu):
    gamma = torch.max(f) - 1
    return torch.sum(nu * (
        -1 * (f - gamma < -1).float() +
        (f - gamma) * (f - gamma >= -1).float()
    )) + gamma

In [56]:
print(f"TV exact: {torch.sum(nu * torch.abs(mu / nu - 1)).item():.5f}")

TV exact: 0.95936


In [72]:
iterations = 50000
f = torch.zeros(dim, requires_grad=True)
optimizer = torch.optim.SGD((f,), lr=1e-1, momentum=0.9)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda i: 1 - i / iterations)
t = tqdm(range(iterations), ncols='100%')
for _ in t:
    optimizer.zero_grad()
    objective = torch.sum(mu * f, dim=-1) - (conjugate_tv(f - torch.max(f), nu) + torch.max(f))
    loss = -objective
    loss.backward()
    optimizer.step()
    scheduler.step()
    t.set_description(f"objective: {objective.item():.5f}")

print(f"TV approx: {(torch.sum(mu * f, dim=-1) - conjugate_tv(f, nu)).item():.5f}")

HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=50000.0), HTML(value='')), layout=Layout(…


TV approx: 0.95933
