In [2]:
cd ..

/home/pierreo/semisupervised


In [74]:
import numpy as np
from typing import Callable, Optional, Tuple
from module.quadrature import Quadrature, QuadratureExploration, QuadratureExplorationBis
from botorch import fit_gpytorch_mll
from botorch.models import SingleTaskGP
import torch
from gpytorch.mlls import ExactMarginalLogLikelihood
from torch.distributions.multivariate_normal import MultivariateNormal
import matplotlib.pyplot as plt
import scipy.stats as stats
from gpytorch.kernels import RBFKernel
from gpytorch.kernels.scale_kernel import ScaleKernel
from gpytorch.priors.torch_priors import GammaPrior
import geoopt


In [75]:

quad_distrib = MultivariateNormal(torch.tensor([0., 0.]), torch.diag(torch.tensor([1., 1.])))
train_X = torch.linspace(-3,3, 10).reshape(5,2)
train_Y = (-(train_X)**2 + 1.).sum(dim=1, keepdim=True)
covar_module = ScaleKernel(
                RBFKernel(
                    ard_num_dims=train_X.shape[-1],
                    batch_shape=None,
                    lengthscale_prior=GammaPrior(3.0, 6.0),
                ),
                batch_shape=None,
                outputscale_prior=GammaPrior(2.0, 0.15),
            )

model = SingleTaskGP(train_X, train_Y, covar_module=covar_module)
mll = ExactMarginalLogLikelihood(model.likelihood, model)
fit_gpytorch_mll(mll)

quad = Quadrature(model=model,
            distribution=quad_distrib)



## Test gradient quadrature with vs. without manifold structure input space

In [78]:
### Manifold gradients
euclidean = geoopt.manifolds.Euclidean()
spd = geoopt.manifolds.SymmetricPositiveDefinite()
manifold = geoopt.manifolds.ProductManifold((euclidean, 2), (spd, (2,2)))

t = torch.zeros(6)
t[2], t[5] = .1, .1
t.requires_grad = True

mani_tensor = geoopt.ManifoldTensor(t, manifold=manifold)
param = geoopt.ManifoldParameter(mani_tensor)

In [79]:
m, v = quad._quadrature(manifold.take_submanifold_value(param, 0), manifold.take_submanifold_value(param, 1))

In [80]:
m.backward()

In [81]:
param.grad

tensor([-2.1763,  2.1762, -1.5402, -1.7121, -1.7121, -1.5402])

In [82]:
### Euclidean vanilla gradient
t = torch.zeros(6)
t[2], t[5] = .1, .1
t.requires_grad = True

mani_tensor = geoopt.ManifoldTensor(t)
param = geoopt.ManifoldParameter(mani_tensor)

m, v = quad._quadrature(manifold.take_submanifold_value(param, 0), manifold.take_submanifold_value(param, 1))
m.backward()
param.grad

param.grad.zero_()
m, v = quad._quadrature(manifold.take_submanifold_value(param, 0), manifold.take_submanifold_value(param, 1))
m.backward()
param.grad

tensor([-2.1763,  2.1762, -1.5402, -1.7121, -1.7121, -1.5402])

In [67]:
#### Comparison with implementation


quad.quadrature()
quad.gradient_direction()


In [68]:
quad.d_epsilon

tensor([[-0.1402,  0.0328],
        [ 0.0328, -0.1402]])

In [71]:
param.grad.zero_()

tensor([0., 0., 0., 0., 0., 0.])

In [72]:
param.grad

tensor([0., 0., 0., 0., 0., 0.])

## 1d

In [84]:

quad_distrib = MultivariateNormal(torch.tensor([0.]), torch.diag(torch.tensor([1.])))
train_X = torch.linspace(-3,3, 10).reshape(-1,1)
train_Y = (-(train_X)**2 + 1.).sum(dim=1, keepdim=True)
covar_module = ScaleKernel(
                RBFKernel(
                    ard_num_dims=train_X.shape[-1],
                    batch_shape=None,
                    lengthscale_prior=GammaPrior(3.0, 6.0),
                ),
                batch_shape=None,
                outputscale_prior=GammaPrior(2.0, 0.15),
            )

model = SingleTaskGP(train_X, train_Y, covar_module=covar_module)
mll = ExactMarginalLogLikelihood(model.likelihood, model)
fit_gpytorch_mll(mll)

quad = Quadrature(model=model,
            distribution=quad_distrib)



In [86]:
## Test gradient quadrature with vs. without manifold structure input space
### Manifold gradients
euclidean = geoopt.manifolds.Euclidean()
spd = geoopt.manifolds.SymmetricPositiveDefinite()
manifold = geoopt.manifolds.ProductManifold((euclidean, 1), (spd, (1,1)))

t = torch.zeros(2)
t[1] = .1
t.requires_grad = True

mani_tensor = geoopt.ManifoldTensor(t, manifold=manifold)
param = geoopt.ManifoldParameter(mani_tensor)
m, v = quad._quadrature(manifold.take_submanifold_value(param, 0), manifold.take_submanifold_value(param, 1))
m.backward()
param.grad



tensor([ 6.1035e-05, -1.0131e+00])

In [87]:

quad.quadrature()
quad.gradient_direction()
quad.d_mu, quad.d_epsilon

(tensor([0.]), tensor([[-0.9972]]))

In [None]:
### Euclidean vanilla gradient
t = torch.zeros(6)
t[2], t[5] = .1, .1
t.requires_grad = True

mani_tensor = geoopt.ManifoldTensor(t)
param = geoopt.ManifoldParameter(mani_tensor)

m, v = quad._quadrature(manifold.take_submanifold_value(param, 0), manifold.take_submanifold_value(param, 1))
m.backward()
param.grad

param.grad.zero_()
m, v = quad._quadrature(manifold.take_submanifold_value(param, 0), manifold.take_submanifold_value(param, 1))
m.backward()
param.grad