In [1]:
import torch
import sigkernel2
import sigkernel
import csv
import scipy
import torch.distributions as tdist
import math
from HSIC import HSIC
from dCov import dCov
from samplers import LinearSDE
from dCovMod import dCovMod

device = torch.cuda.device('cuda')

In [105]:
class sampler2():
    def __init__(self, l=201, p=25):
        self.dtype = torch.float64
        self.device = torch.device('cuda:0')
        
        self.l = l
        self.p = p
        
        self.ts = torch.linspace(0, 1, l, device=self.device, dtype=self.dtype)
        self.coef = torch.arange(1, p + 1, 1, dtype=self.dtype, device=self.device).view(-1, 1)
        
        self.cos = torch.cos(2 * torch.pi * self.coef * self.ts)
        self.sin = torch.sin(2 * torch.pi * self.coef * self.ts)
        
    def get_z(self, n):
        alf = torch.randn((n, self.p, 1), device=self.device, dtype=self.dtype)
        beta = torch.randn((n, self.p, 1), device=self.device, dtype=self.dtype)
        
        alf = (alf * self.cos).sum(dim=1)
        beta = (beta * self.sin).sum(dim=1)
        
        return (alf + beta) / (2 * n**2)
    
    def get_x(self, z1, z2, func1, func2, sigma):
        n = z1.shape[0]
        dt = 1 / (self.l - 1)
        
        x = torch.zeros((n, self.l), device=self.device, dtype=self.dtype)
        x[:,0] = 0
        
        for i in range(self.l - 1):
            x[:, i + 1] = x[:, i] + func1(x[:, i], i * dt) * dt + sigma * z1[:,i]    
            
        z1 = torch.cumsum(z1, dim=1)
        print((func2(x) * dt).sum(dim=1).shape)
            
        return x, (func2(x) * dt).sum(dim=1) + sigma * z2.cumsum(dim=1), z1.cumsum(dim=1)
    
    def sample(self, func1, func2, n=30, sigma=0.5):
        z1 = self.get_z(n)
        z2 = self.get_z(n)
        return self.get_x(z1, z2, func1, func2, sigma)

In [109]:
class sampler():
    def __init__(self, xi_dist, gam_dist, func, p=50, l=201):
        self.xi_dist = xi_dist
        self.gam_dist = gam_dist
        
        test_sample = xi_dist.sample()
        self.dtype = test_sample.dtype
        self.device = test_sample.device
        
        self.func = func
        self.p = p
        self.l = l
        
        self.ts = torch.linspace(0, 1, l, device=self.device, dtype=self.dtype)
        self.coef = torch.arange(1, p + 1, 1, dtype=self.dtype, device=self.device).view(-1, 1)
        
        self.basis = math.sqrt(2) * torch.cos(torch.pi * self.coef * self.ts)
        
    def sample(self, n=30, m=0, garbage=0):
        m = self.p if m >= self.p else m
        
        xis = self.xi_dist.sample((n, self.p))
        gams = self.gam_dist.sample((n, self.p - m))
        
        if m == 0:
            gams = gams
        else:
            gams = torch.concat((self.func(xis[:,0:m]), gams), dim=1)
        
        X = (xis.view(n, self.p, 1) * self.basis).sum(dim=1).view(n, -1, 1)
        Y = (gams.view(n, self.p, 1) * self.basis).sum(dim=1).view(n, -1, 1)
                   
        return X, Y

In [112]:
class SigHSIC():
    def __init__(self, x, y, normalize=False):
        self.x = self.normalize(x) if normalize else x
        self.y = self.normalize(y) if normalize else y
        
        self.x_dist = self.median_dist(self.x).cpu().item()
        self.y_dist = self.median_dist(self.y).cpu().item()
        self.x_static = sigkernel2.RBFKernel(self.x_dist)
        self.y_static = sigkernel2.RBFKernel(self.y_dist)
        
    def normalize(self, x):
        return x / x.abs().max(dim=1).values.unsqueeze(1)
    
    def median_dist(self, x):
        x = x.view(-1, x.shape[2])
        
        x_norm = (x ** 2).sum(dim=1).unsqueeze(1)
        squared_dist_matrix = x_norm + x_norm.T - 2 * torch.mm(x, x.T)

        squared_dist_matrix = squared_dist_matrix.clamp(min=0)

        n = squared_dist_matrix.size(0)
        triu_indices = torch.triu_indices(n, n, offset=1, device=x.device)
        pairwise_squared_distances = squared_dist_matrix[triu_indices[0], triu_indices[1]]

        return pairwise_squared_distances.sqrt().median()**2
    
    def gram(self, x, dyadic_order=0, normalize=True):
        x = self.normalize(x) if normalize else x
        
        x_dist = self.median_dist(x).cpu().item()
        x_static = sigkernel2.RBFKernel(x_dist)
        
        sig_x = sigkernel2.SigKernel2(x_static, dyadic_order=dyadic_order)
        return sig_x.gram(x)
    
    def test(self, m=1000, dyadic_order=0, sig_max_batch=100, perm_max_batch=1000):
        self.sig_x = sigkernel2.SigKernel2(self.x_static, dyadic_order=dyadic_order)
        self.sig_y = sigkernel2.SigKernel2(self.y_static, dyadic_order=dyadic_order)
        
        KX = self.sig_x.gram(self.x, max_batch=sig_max_batch)
        KY = self.sig_y.gram(self.y, max_batch=sig_max_batch)
        
        self.HSIC = HSIC(KX, KY)
        return self.HSIC.test(alpha=0.05, perms=m, max_batch=perm_max_batch)

In [113]:
class RBFHSIC():
    def __init__(self, x, y):
        self.x = x.view(x.shape[0], -1)
        self.y = y.view(y.shape[0], -1)
        
        self.x_dist = self.median_dist(self.x)
        self.y_dist = self.median_dist(self.y)
        
    def test(self, m=1000, perm_max_batch=1000):
        KX = self.rbf_kernel(self.x, self.x_dist)
        KY = self.rbf_kernel(self.y, self.y_dist)
        self.HSIC = HSIC(KX, KY)
        return self.HSIC.test(alpha=0.05, perms=m, max_batch=perm_max_batch)
    
    def median_dist(self, x):
        x_norm = (x ** 2).sum(dim=1).unsqueeze(1)
        squared_dist_matrix = x_norm + x_norm.T - 2 * torch.mm(x, x.T)

        squared_dist_matrix = squared_dist_matrix.clamp(min=0)

        n = squared_dist_matrix.size(0)
        triu_indices = torch.triu_indices(n, n, offset=1, device=x.device)
        pairwise_squared_distances = squared_dist_matrix[triu_indices[0], triu_indices[1]]

        return pairwise_squared_distances.sqrt().median()**2
    
    def rbf_kernel(self, x, sigma):
        x_norm = (x**2).sum(dim=1).unsqueeze(1)  # Shape (batch_size, 1)
        squared_distances = x_norm + x_norm.T - 2 * torch.mm(x, x.T)
        kernel = torch.exp(-squared_distances / sigma)
        return kernel

In [114]:
device = torch.device('cuda:0')
dtype = torch.float64
dyadic_order = 2
n = 30
perms = 500
thin = 1
reps = 500
rho = [0, 0.25, 0.5, 0.75, 1, 1.25, 1.5, 1.75, 2]
normalize = False

torch.manual_seed(42)

if torch.cuda.is_available():
    torch.cuda.manual_seed(42)
    torch.cuda.manual_seed_all(42)

tests = {
    "dCovHD": lambda x, y : dCovMod(x, y).test(),
    "dCov": lambda x, y : dCov(x, y).test(),
    "SigHSIC": lambda x, y : SigHSIC(x, y).test(dyadic_order=dyadic_order),
    "SigHSICAddT": lambda x, y : SigHSIC(LinearSDE.add_time(x), LinearSDE.add_time(y)).test(dyadic_order=dyadic_order),
    "RBF HSIC": lambda x, y : RBFHSIC(x, y).test()
}

with open("../speciale/rplots/hsic_linear_sde2.csv", "w", newline="") as file:
    writer = csv.writer(file)
    writer.writerow(["test", "rho", "run", "result"])
    
    for r in rho:
        print(f"rho: {r}")
        torch.cuda.empty_cache()
        for i in range(reps):
                if i % 100 == 0:
                    print(f"rep: {i}")
                X, Y = LinearSDE.sample(n=n, rho=r, thin=thin)
                for t_name, test in tests.items():
                    p_value = test(X, Y) 
                    writer.writerow([t_name, r, i, p_value])
            

rho: 0
rep: 0
rep: 100
rep: 200
rep: 300
rep: 400
rho: 0.25
rep: 0
rep: 100
rep: 200
rep: 300
rep: 400
rho: 0.5
rep: 0
rep: 100
rep: 200
rep: 300
rep: 400
rho: 0.75
rep: 0
rep: 100
rep: 200
rep: 300
rep: 400
rho: 1
rep: 0
rep: 100
rep: 200
rep: 300
rep: 400
rho: 1.25
rep: 0
rep: 100
rep: 200
rep: 300
rep: 400
rho: 1.5
rep: 0
rep: 100
rep: 200
rep: 300
rep: 400
rho: 1.75
rep: 0
rep: 100
rep: 200
rep: 300
rep: 400
rho: 2
rep: 0
rep: 100
rep: 200
rep: 300
rep: 400


In [111]:
device = torch.device('cuda:0')
dtype = torch.float64
dyadic_order = 3
n = 30
p = 50
perms = 500
l = 101
reps = 500
ms = [0, 2, 6, 8, 10]
normalize = True

torch.manual_seed(42)

if torch.cuda.is_available():
    torch.cuda.manual_seed(42)
    torch.cuda.manual_seed_all(42)
    
xi_dists = {
    "Normal": tdist.Normal(torch.tensor([0], dtype=dtype, device=device), torch.tensor([1], dtype=dtype, device=device))
}

gam_dists = {
    "Normal": tdist.Normal(torch.tensor([0], dtype=dtype, device=device), torch.tensor([1], dtype=dtype, device=device))
}

tests = {
    "dCovHD": lambda x, y : dCovMod(x, y).test(),
    "dCov": lambda x, y : dCov(x, y).test(),
    "SigHSIC": lambda x, y : SigHSIC(x, y).test(dyadic_order=dyadic_order),
    "SigHSICAddT": lambda x, y : SigHSIC(LinearSDE.add_time(x), LinearSDE.add_time(y)).test(dyadic_order=dyadic_order),
    "RBF HSIC": lambda x, y : RBFHSIC(x, y).test()
}

funcs = {
    "f(x)=x^3": lambda x : x.pow(3)
    #"f(x)=x^2": lambda x : x.pow(2)
    #,"f(x)=\sin(x)": lambda x : x.sin()
    #,"f(x)=\cos(x)": lambda x : x.cos()
}

with open("../speciale/rplots/hsic_harmonics2.csv", "w", newline="") as file:
    writer = csv.writer(file)
    writer.writerow(["test", "m", "xi_dist", "gam_dist", "f_name", "run", "result"])
    
    for m in ms:
        print(f"m: {m}")
        for xi_name, xi_d in xi_dists.items():
            for gam_name, gam_d in gam_dists.items():
                print(f"xi: {xi_name}, gam: {gam_name}")
                for f_name, func in funcs.items():
                    print(f"func: {f_name}")
                    new_sampler = sampler(xi_d, gam_d, func, p=p)
                    torch.cuda.empty_cache()
                    for i in range(reps):
                        if i % 100 == 0:
                            print(f"rep: {i}")
                        X, Y = new_sampler.sample(n, m, garbage=0)
                        
                        for t_name, test in tests.items():
                            p_value = test(X, Y) 
                            writer.writerow([t_name, m, xi_name, gam_name, f_name, i, p_value])

m: 0
xi: Normal, gam: Normal
func: f(x)=x^3
rep: 0
rep: 100
rep: 200
rep: 300
rep: 400
m: 2
xi: Normal, gam: Normal
func: f(x)=x^3
rep: 0
rep: 100
rep: 200
rep: 300
rep: 400
m: 6
xi: Normal, gam: Normal
func: f(x)=x^3
rep: 0
rep: 100
rep: 200
rep: 300
rep: 400
m: 8
xi: Normal, gam: Normal
func: f(x)=x^3
rep: 0
rep: 100
rep: 200
rep: 300
rep: 400
m: 10
xi: Normal, gam: Normal
func: f(x)=x^3
rep: 0
rep: 100
rep: 200
rep: 300
rep: 400
