In [None]:
import torch
import matplotlib.pyplot as plt

from torch.distributions import MixtureSameFamily, Categorical

from distributions import CircularProjectedNormal

#torch.manual_seed(42)

# Optimization

In [None]:
import numpy as np

import torch.nn as nn
import torch.optim as O
import torch.distributions as D

## Generate Training Data

### Define Ground Truth Distribution

In [None]:
# ground truth loc parameters
mu = torch.tensor(
    [
        [-2.19, -2.09],
        [-0.19, 2.09]
    ]
)

# ground truth covariance matrix
sig11 = .9
sig21 = .4
rho1 = .5
sig12 = .58
sig22 = .4
rho2 = -.84
sigma = torch.tensor(
    [
        [[sig11**2, rho1*sig11*sig21], [rho1*sig11*sig21, sig21**2]],
        [[sig12**2, rho2*sig12*sig22], [rho2*sig12*sig22, sig22**2]]
    ]
)

# ground truth projected normal distribution
mix = Categorical(torch.ones(2))
comp = CircularProjectedNormal(mu, sigma)
true_dist = MixtureSameFamily(mix, comp)

### Sample Ground Truth Distribution

In [None]:
# multivariate normal sample
X = D.MultivariateNormal(mu[0], sigma[0]).sample((500,))
Y = D.MultivariateNormal(mu[1], sigma[1]).sample((500,))

# project to circle
U = X/X.norm(dim = 1)[:,None]
V = Y/Y.norm(dim = 1)[:,None]

### Plot

In [None]:
plt.figure(figsize = (8,4))
plt.subplot(121)
plt.plot(X[:,0], X[:,1], '.')
plt.plot(Y[:,0], Y[:,1], '.')
plt.axvline(0, ls = ':')
plt.axhline(0, ls = ':')

plt.subplot(122)
plt.plot(U[:,0], U[:,1], '.')
plt.plot(V[:,0], V[:,1], '.')
plt.xlim([-1.1, 1.1])
plt.ylim([-1.1, 1.1])

## Define Variables

### Hyperparameters

In [None]:
num_epochs = 4048
n_components = 2

### Decision Variables

In [None]:
loc = torch.randn(n_components, 2)
param_loc = nn.Parameter(loc/loc.norm(dim = 1))
param_sig = nn.Parameter(torch.ones(n_components))
param_gam = nn.Parameter(torch.zeros(n_components))

## Define Optimizer

In [None]:
optimizer = O.Adam(params = (param_loc,param_sig, param_gam), lr = 0.001)

In [None]:
U = torch.vstack((U,V))

## Optimize Parameters

In [None]:
hist = np.zeros((num_epochs,2))

theta = torch.linspace(0, 2*torch.pi, steps = 1000)
xy = torch.stack((torch.cos(theta), torch.sin(theta))).T

for epoch in range(num_epochs):
    # zero gradients
    optimizer.zero_grad()

    # construct covariance matrix
    S = torch.stack((
        torch.stack((param_sig**2 + param_gam**2, param_gam)),
        torch.stack((param_gam, torch.tensor(n_components*[1.])))
    )).permute(2,0,1)

    # construct covariance matrix
    #S = torch.stack((
    #    torch.stack((torch.tensor(n_components*[1.]), torch.tensor(n_components*[0.]))),
    #    torch.stack((torch.tensor(n_components*[0.]), torch.tensor(n_components*[1.])))
    #)).permute(2,0,1)
    
    # define distribution
    mix = Categorical(torch.ones(n_components))
    comp = CircularProjectedNormal(param_loc, S)
    dist = MixtureSameFamily(mix, comp)

    # compute loss
    loss = -dist.log_prob(U).mean()

    with torch.no_grad():
        lnp = true_dist.log_prob(xy)
        lnq = dist.log_prob(xy)
        p = lnp.exp()
        kl = (p*lnp - p*lnq).mean()
        
    loss.backward()
    optimizer.step()

    hist[epoch, 0] = -loss.item()
    hist[epoch, 1] = kl

plt.subplot(211)
plt.plot(hist[:,0])

plt.subplot(212)
plt.plot(hist[:,1])

In [None]:
param_loc

In [None]:
theta = torch.linspace(0, 2*torch.pi, steps = 1000)
xy = torch.stack((torch.cos(theta), torch.sin(theta))).T

# true likelihood
p = true_dist.log_prob(xy).exp()

# estimated likelihood
q = dist.log_prob(xy).exp().detach()

# plot
plt.plot(theta, p, 'k', label = 'Ground Truth Distribution')
plt.plot(theta, q, 'r:', label = 'Estimated Distribution')
plt.xlabel('Angle [rad]')
plt.ylabel('Likelihood')
plt.legend()

In [None]:
plt.polar(theta, 1 + p)
plt.polar(theta, 1 + q)