In [1]:
import torch
import sigkernel2
import sigkernel
import csv
import scipy
import timeit
import math
device = torch.cuda.device('cuda')

In [161]:
def generate(batch_size, length, dimension, device = torch.device('cpu')):
  random_walks = torch.randn(batch_size, length, dimension, dtype = torch.double, device = device) / math.sqrt(length)
  start = torch.zeros([batch_size, 1, dimension], device=device, dtype=torch.double)
  random_walks = torch.cat((start, random_walks), dim=1)
  random_walks = torch.cumsum(random_walks, dim=1)
  return random_walks

def median_distance(X, Y):
    A = X.shape[0]
    M = X.shape[1]
    N = Y.shape[1]
    Xs = torch.sum(X**2, dim=2)
    Ys = torch.sum(Y**2, dim=2)
    dist = -2.*torch.bmm(X, Y.permute(0,2,1))
    dist += torch.reshape(Xs,(A,M,1)) + torch.reshape(Ys,(A,1,N))
    return dist.view(A, -1).median()
  
def tensor_norm(x, C=2, alpha=1):
  return C - (C - 1) / (x.sqrt())**alpha

def guess(norm, value):
  if norm > value:
    return math.sqrt((-1.0 - math.sqrt(1 - 2 * (1 - value))) / (-1.0 - math.sqrt(1 - 2 * (1 - norm))))
  return math.sqrt((-1.0 + math.sqrt(1 - 2 * (1 - value))) / (-1.0 + math.sqrt(1 - 2 * (1 - norm))))
def guess2(norm, value):
  return math.sqrt((-1.0 + math.sqrt(1 - 2 * (1 - value))) / (-1.0 + math.sqrt(1 - 2 * (1 - norm))))

def guess3(norm, value):
  return (value - 1) / (norm - 1)

In [162]:
X = generate(1, 50, 10, device = torch.device('cuda:0'))

In [170]:
sig_x = sigkernel2.SigKernel2(sigkernel2.RBFKernel(median_distance(X, X).cpu().item()), 5)
#sig_x = sigkernel2.SigKernel2(sigkernel2.LinearKernel(0.3), 3)

In [171]:
print(tensor_norm(sig_x.kernel(X, X)))
print(sig_x.kernel(X, X))

tensor([1.8166], device='cuda:0', dtype=torch.float64)
tensor([29.7160], device='cuda:0', dtype=torch.float64)


In [172]:
norms = sig_x.kernel(X, X)
norms
print(guess(norms.cpu().item(), tensor_norm(norms).cpu().item()))
print(guess2(norms.cpu().item(), tensor_norm(norms).cpu().item()))
print(guess3(norms.cpu().item(), tensor_norm(norms).cpu().item()))

0.5508252332134318
0.3061380222985923
0.02843558704144373


In [173]:
lambdas = torch.tensor([guess(norms.cpu().item(), tensor_norm(norms).cpu().item())], dtype=torch.float64, device=X.device)
sig_x.robust_kernel(X, tensor_norm(norms, alpha=1), lambdas, n=10)

(tensor([1.8168], device='cuda:0', dtype=torch.float64),
 tensor([9.0088], device='cuda:0', dtype=torch.float64),
 tensor([0.4966], device='cuda:0', dtype=torch.float64))

In [78]:
tensor_norm(norms)

tensor([1.8572], device='cuda:0', dtype=torch.float64)

In [24]:
sig_x.kernel(X * 0.3586, X * 0.3586)

tensor([1.9302], device='cuda:0', dtype=torch.float64)