In [1]:
import sigpde.torch as sig
import torch
import math
device = torch.device('cuda:0')

In [2]:
def sample(n, rho=1, scale=8, T=0.5, start=0, stop=1, dt=0.005, randgrid = 30, thin=0, device=torch.device('cuda:0'), dtype=torch.float64):
        l = math.ceil((stop - start) / dt) + 1
        
        ts = torch.linspace(0, 1, l, device=device, dtype=dtype)
        
        b1 = math.sqrt(2 / T) * torch.sin(4 * torch.pi * ts / T)
        b2 = math.sqrt(2 / T) * torch.cos(6 * torch.pi * ts / T)
        c = torch.randn((n, 3), dtype=dtype, device=device)
       
        b1 = (b1 * c[:, 1].view(n, 1))
        b2 = (b2 * c[:, 2].view(n, 1))
        
        x = (b1 + b2 + c[:, 0].view(n, 1))
       
        c1 = (0.75 - 0.25) * torch.rand((n, 2), device=device, dtype=dtype) + 0.25
        
        i1 = (ts - c1[:,0].view(n, 1, 1)).view(n, -1)**2
        i2 = (ts - c1[:,1].view(n, 1, 1)).view(n, -1)**2
        
        zero_column = torch.zeros(n, 1, device=device, dtype=dtype)
        x1 = torch.cat((zero_column, (x[:,range(l-1)] * dt).cumsum(dim=1)), dim=1) * i2
        x2 = torch.cat((zero_column, (x[:,range(l-1)] * i1[:,range(l-1)] * dt).cumsum(dim=1)), dim=1)
        
        y = rho * (x2 - x1) * scale
        y = y + torch.randn(y.shape, device=device, dtype=dtype)
        x = x + torch.randn(x.shape, device=device, dtype=dtype)
        
        return x.view(n, l, 1)[:, range(0, l, 2**thin),:], y.view(n, l, 1)[:, range(0, l, 2**thin),:]

In [21]:
def generate(batch_size, length, dimension, device = torch.device('cpu')):
  random_walks = torch.randn(batch_size, length, dimension, dtype = torch.double, device = device) / math.sqrt(length)
  start = torch.zeros([batch_size, 1, dimension], device=device, dtype=torch.double)
  random_walks = torch.cat((start, random_walks), dim=1)
  random_walks = torch.cumsum(random_walks, dim=1)
  return random_walks

def add_time(x, start=0, stop=1):
    device = x.device
    dtype = x.dtype

    l = x.shape[1]

    t = torch.linspace(start, stop, l, device=device, dtype=dtype)
    t = t.unsqueeze(0).unsqueeze(-1)
    return torch.cat((x, t.expand(x.shape[0], x.shape[1], 1)), dim=-1)

def variation(x):
    diffs = x[:, 1:, :] - x[:, :-1, :]
    euclidean_norms = torch.norm(diffs, p=2, dim=-1)
    return euclidean_norms.sum(dim=-1)

def var_norm(x):
    return x / variation(x).view(x.shape[0], 1, 1).sqrt()

def norm(x):
    return x.pow(2).sum(dim=2).sqrt().max(dim=1).values

def std_norm(x):
    return 2 - 1 / (1 + x.log())

def time_norm(x, time=True):
    if time:
        x = add_time(x)
    return x / x.pow(2).sum(dim=2).sqrt().max(dim=1).values.view(x.shape[0], 1, 1)

In [4]:
kernel = sig.SigPDE(sig.kernels.LinearKernel(), 2)
kernel2 = sig.RobustSigPDE(sig.kernels.LinearKernel(), 2)

In [23]:
x = generate(30, 512, 50, device=device)
y = generate(1, 512, 50, device=device)
z, w = sample(30)

In [31]:
print(kernel.pairwise(time_norm(z)))

tensor([1.4604, 1.4138, 1.0275, 1.0753, 1.0407, 1.0538, 1.3269, 1.1372, 1.0367,
        1.1369, 1.1123, 1.3336, 1.3742, 1.0676, 1.2107, 1.0324, 1.0700, 1.0197,
        1.0843, 1.0520, 1.0714, 1.0196, 1.0922, 1.1803, 1.1995, 1.0496, 1.0226,
        1.0356, 1.0638, 1.0626], device='cuda:0', dtype=torch.float64)


In [33]:
kernel2.pairwise(add_time(z), tol=0.01)

tensor([1.0000, 1.0000, 1.9900, 1.0000, 1.9855, 1.0000, 1.9885, 1.0000, 1.9915,
        1.0000, 1.9482, 1.9750, 1.9892, 1.0000, 1.9848, 1.0000, 1.0000, 1.9907,
        1.0000, 1.9876, 1.9699, 1.0000, 1.0000, 1.0000, 1.9784, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000], device='cuda:0', dtype=torch.float64)

In [30]:
print(kernel2.pairwise(x))

tensor([1.9805, 1.9798, 1.9801, 1.9795, 1.9802, 1.9799, 1.9802, 1.9785, 1.9791,
        1.9803, 1.9788, 1.9798, 1.9798, 1.9789, 1.9795, 1.9805, 1.9786, 1.9802,
        1.9797, 1.9800, 1.9798, 1.9805, 1.9796, 1.9787, 1.9800, 1.9798, 1.9786,
        1.9788, 1.9799, 1.9801], device='cuda:0', dtype=torch.float64)
