In [None]:
%load_ext autoreload
%autoreload 2

from src.distributions import VonMisesFisher, SphereUniform
import matplotlib.pyplot as plt
import torch
import plotly.express as px
import seaborn as sns
import math

# The uniform distribution
The source code for the distributions can be seen in the `src/distributions.py` file. Below, the uniform distribution is tested. This distribution will be used in the vMF sampling.

In [None]:
su = SphereUniform(2)
x, y, z = su.sample(torch.Size([100])).T
px.scatter_3d(x=x, y=y, z=z)

# Von Mises Fisher

Below, the sampling procedure is laid out for vMF. We are sampling from 4 different mean vectors, with different concentration parameters

In [None]:
mu = torch.tensor([
    [0, 1, 0],
    [-1, 0, 0],
    [0, 1/math.sqrt(2), -1/math.sqrt(2)],
    [1/math.sqrt(2), 0, 1/math.sqrt(2)],
])
k = torch.tensor([10, 50, 25, 10])
m = torch.tensor( mu.shape[-1] )

sample_shape = torch.Size([100]) # Lets sample a bunch of samples for each parameter 
batch_shape = torch.Size(mu.shape[:-1])
event_shape = torch.Size([m])

In [None]:
vmf = VonMisesFisher(mu, k)
z = vmf.rsample(sample_shape)

In [None]:
from plotly import graph_objects as go

fig = go.Figure()
for i in range(z.shape[-2]):
    x, y, z_ = z[:, i, :].T
    fig.add_trace((go.Scatter3d(
        x=x, 
        y=y, 
        z=z_, 
        mode="markers"
    )))
fig