In [15]:
import torch
import sigkernel2
import sigkernel
import csv
import scipy
import timeit
import math
from samplers import Fourier, NonLinearSDE, sampler
device = torch.cuda.device('cuda')

In [2]:
def generate(batch_size, length, dimension, device = torch.device('cpu')):
  random_walks = torch.randn(batch_size, length, dimension, dtype = torch.double, device = device) / math.sqrt(length)
  start = torch.zeros([batch_size, 1, dimension], device=device, dtype=torch.double)
  random_walks = torch.cat((start, random_walks), dim=1)
  random_walks = torch.cumsum(random_walks, dim=1)
  return random_walks

def median_distance(X, Y):
    A = X.shape[0]
    M = X.shape[1]
    N = Y.shape[1]
    Xs = torch.sum(X**2, dim=2)
    Ys = torch.sum(Y**2, dim=2)
    dist = -2.*torch.bmm(X, Y.permute(0,2,1))
    dist += torch.reshape(Xs,(A,M,1)) + torch.reshape(Ys,(A,1,N))
    return dist.view(A, -1).median()
  
def tensor_norm(x, C=2, alpha=1):
  return C - (C - 1) / (x.sqrt())**alpha

def guess(norm, value):
  a = math.sqrt((-1.0 - math.sqrt(1 - 2 * (1 - value))) / (-1.0 - math.sqrt(1 - 2 * (1 - norm))))
  b = math.sqrt((-1.0 + math.sqrt(1 - 2 * (1 - value))) / (-1.0 + math.sqrt(1 - 2 * (1 - norm))))
  return a, b
def guess2(norm, value):
  return math.sqrt((-1.0 + math.sqrt(1 - 2 * (1 - value))) / (-1.0 + math.sqrt(1 - 2 * (1 - norm))))

def guess3(norm, value):
  return (value - 1) / (norm - 1)

In [11]:
#X = generate(4, 50, 10, device = torch.device('cuda:0'))
X, Y = Fourier.sample(30, rho=0.5, thin=1)
#X, Y = NonLinearSDE.sample(n=30, theta=0.5, thin=1)

In [12]:
sig_x2 = sigkernel2.RobustSigKernel(sigkernel2.RBFKernel(median_distance(X, X).cpu().item()), 3, normalizer=lambda x : tensor_norm(x, C=2, alpha=0.5))
#sig_x2 = sigkernel2.RobustSigKernel(sigkernel2.LinearKernel(1), 2, normalizer=lambda x : tensor_norm(x, C=2, alpha=1))
sig_x = sigkernel2.SigKernel2(sigkernel2.RBFKernel(median_distance(X, X).cpu().item()), 2)
#sig_x = sigkernel2.SigKernel2(sigkernel2.LinearKernel(1), 2)

In [14]:
sig_x2.gram(X)

tensor([[1.6870, 0.9143, 0.9461, 1.0229, 0.9921, 1.0196, 1.1577, 1.0251, 0.9160,
         1.2770, 0.9697, 0.9622, 0.9422, 1.0825, 1.0387, 0.9317, 0.9627, 1.1025,
         1.0036, 1.1547, 1.2503, 1.1027, 0.8294, 1.0513, 1.1658, 0.9340, 1.0546,
         1.0145, 1.0192, 0.9745],
        [0.9143, 1.5258, 1.1014, 0.8890, 1.4236, 1.1896, 1.1983, 0.8205, 0.8092,
         0.7558, 1.0717, 0.7500, 1.1467, 0.7539, 0.9602, 1.1865, 1.1645, 0.8802,
         1.2199, 1.0186, 0.8510, 0.8985, 0.9028, 1.1018, 0.8775, 0.9626, 0.9522,
         0.9552, 1.1262, 1.1101],
        [0.9461, 1.1014, 1.7021, 0.9642, 0.9763, 0.9612, 1.2126, 0.9068, 1.0297,
         1.1430, 0.9930, 0.9087, 1.0354, 0.8680, 0.9103, 0.9895, 1.0455, 0.9147,
         0.9748, 1.1465, 0.9970, 1.0130, 1.1019, 1.0311, 0.9055, 1.0344, 0.9626,
         1.0141, 1.1380, 1.0164],
        [1.0229, 0.8890, 0.9642, 1.9393, 0.7758, 0.8594, 0.8661, 1.0629, 1.0480,
         1.1528, 0.9847, 1.1841, 0.8318, 1.1877, 1.1654, 0.9109, 1.0241, 1.1925,
       