Test an implementation of the mixture density network.

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
#%config InlineBackend.figure_format = 'svg'
#%config InlineBackend.figure_format = 'pdf'



In [None]:
import numpy as np
import torch
import torch.distributions as dists
import torch.nn as nn

import kcgof
import kcgof.log as klog
import kcgof.util as util
import kcgof.cdensity as cden
import kcgof.cdata as cdat
import kcgof.cgoftest as cgof
import kcgof.kernel as ker
import kcgof.plot as plot

In [None]:
import matplotlib
import matplotlib.pyplot as plt

# font options
font = {
    #'family' : 'normal',
    #'weight' : 'bold',
    'size'   : 20
}

plt.rc('font', **font)
plt.rc('lines', linewidth=2)
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42

## Mixture Density Network

In [None]:
# https://discuss.pytorch.org/t/what-is-reshape-layer-in-pytorch/1110/4
class Reshape(nn.Module):
    def __init__(self, *shape):
        super(Reshape, self).__init__()
        self.shape = shape

    def forward(self, x):
        batch = x.shape[0]
        return x.view(batch, *self.shape)
    
class MDNSigma2(nn.Module):
    def __init__(self):
        super(MDNSigma2, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(dx, 16),
        #             nn.BatchNorm1d(num_features=128),
            nn.Tanh(),
            nn.Linear(16, dy*n_comps),
        #             nn.BatchNorm1d(num_features=64),
            Reshape(n_comps, dy)
        )
    def forward(self, x):
        return self.net(x)**2 + 0.1

In [None]:
# number of components
n_comps = 12
dx = 1
dy = 1
# mixing proportion function
pi = nn.Sequential(
    nn.Linear(dx, 32),
#             nn.BatchNorm1d(num_features=128),
    nn.Tanh(),
    nn.Linear(32, 8),
    nn.Tanh(),
    nn.Linear(8, n_comps),
#             nn.BatchNorm1d(num_features=64),
    nn.Softmax(dim=1)
)

In [None]:
mu = nn.Sequential(
    nn.Linear(dx, 32),
#     nn.BatchNorm1d(num_features=128),
    nn.Tanh(),
    nn.Linear(32, 16),
    nn.Tanh(),
    nn.Linear(16, 8),
    nn.Tanh(),
    nn.Linear(8, n_comps*dy),
#     nn.BatchNorm1d(num_features=64),   
    Reshape(n_comps, dy)
)
sigma2 = MDNSigma2()

Test parameter functions

In [None]:
T = torch.randn(200,1) +1
T.requires_grad = False
print(pi(T))

In [None]:
print(sigma2(T))

In [None]:
print(mu(T))

In [None]:
mdn_den = cden.CDMixtureDensityNetwork(n_comps=n_comps, pi=pi, mu=mu, variance=sigma2, dx=dx, dy=dy)
p = mdn_den

In [None]:
# condsource for sampling from the model
cs = cdat.CSMixtureDensityNetwork(n_comps=n_comps, pi=pi, mu=mu, variance=sigma2,
                            dx=dx, dy=dy)

In [None]:
n = 500
rx = dists.Normal(0, 1**2)
X = rx.sample((n, dx))
Y = cs(X, seed=23, verbose=True)
    

In [None]:
ep = 0.4
# make a grid that covers X

domX = torch.linspace(torch.min(X)-ep, torch.max(X)+ep, 100)
domY = torch.linspace(torch.min(Y).item()-ep, torch.max(Y).item()+ep, 100)

fig, axes = plot.plot_2d_cond_model(
    p, 
    lambda X: torch.exp(rx.log_prob(X)), 
    X, Y, domX=domX, domY=domY, 
    height_ratios=[2,1],
    cmap='pink_r', levels=30)
# plt.xlabel('$x$')
# plot.plot_2d_cond_data(X, Y)