In [1]:
import torch
import torch.distributions as tdist
import sigkernel2
import csv
import math
from independence_test import HSIC, dCov, dCovMod
from samplers import LinearSDE, Fourier, sampler, NonLinearSDE

device = torch.cuda.device('cuda')

In [2]:
def tensor_norm(x, C=2, alpha=1):
  return C - (C - 1) / (x.sqrt())**alpha

class SigHSIC():
    def __init__(self, x, y, normalize=False, robust=False, alpha=0.5, linear=False):
        self.x = self.normalize(x) if normalize else x
        self.y = self.normalize(y) if normalize else y
        self.robust = robust
        self.normalizer = lambda x : tensor_norm(x, C=2, alpha=alpha)
        if linear:
            self.x_static = sigkernel2.LinearKernel()
            self.y_static = sigkernel2.LinearKernel()
        else:   
            self.x_dist = self.median_dist(self.x).cpu().item()
            self.y_dist = self.median_dist(self.y).cpu().item()
            self.x_static = sigkernel2.RBFKernel(self.x_dist)
            self.y_static = sigkernel2.RBFKernel(self.y_dist)
        
    def normalize(self, x):
        return x / x.abs().max(dim=1).values.unsqueeze(1)
    
    def median_dist(self, x):
        x = x.view(-1, x.shape[2])
        
        x_norm = (x ** 2).sum(dim=1).unsqueeze(1)
        squared_dist_matrix = x_norm + x_norm.T - 2 * torch.mm(x, x.T)

        squared_dist_matrix = squared_dist_matrix.clamp(min=0)

        n = squared_dist_matrix.size(0)
        triu_indices = torch.triu_indices(n, n, offset=1, device=x.device)
        pairwise_squared_distances = squared_dist_matrix[triu_indices[0], triu_indices[1]]

        return pairwise_squared_distances.median()
       
    def gram(self, x, dyadic_order=0, normalize=True):
        x = self.normalize(x) if normalize else x
        
        x_dist = self.median_dist(x).cpu().item()
        x_static = sigkernel2.RBFKernel(x_dist)
        if self.robust:
            sig_x = sigkernel2.RobustSigKernel(x_static, dyadic_order=dyadic_order, normalizer=self.normalizer)
        else:
            sig_x = sigkernel2.SigKernel2(x_static, dyadic_order=dyadic_order)
        
        return sig_x.gram(x)
    
    def test(self, m=1000, dyadic_order=0, sig_max_batch=100, perm_max_batch=1000):
        if self.robust:
            self.sig_x = sigkernel2.RobustSigKernel(self.x_static, dyadic_order=dyadic_order, normalizer=self.normalizer)
            self.sig_y = sigkernel2.RobustSigKernel(self.y_static, dyadic_order=dyadic_order, normalizer=self.normalizer)
        else:
            self.sig_x = sigkernel2.SigKernel2(self.x_static, dyadic_order=dyadic_order)
            self.sig_y = sigkernel2.SigKernel2(self.y_static, dyadic_order=dyadic_order)
            
        KX = self.sig_x.gram(self.x, max_batch=sig_max_batch)
        KY = self.sig_y.gram(self.y, max_batch=sig_max_batch)
        
        self.HSIC = HSIC(KX, KY)
        return self.HSIC.test(alpha=0.05, perms=m, max_batch=perm_max_batch)

In [3]:
class SigHSICLinear():
    def __init__(self, x, y, scale=1):
        self.x = x
        self.y = y
        
        self.scale = scale
        self.static = sigkernel2.LinearKernel(scale)
       
    def gram(self, x, scale=1, dyadic_order=0):
        x_static = sigkernel2.LinearKernel(scale)
        
        sig_x = sigkernel2.SigKernel2(x_static, dyadic_order=dyadic_order)
        return sig_x.gram(x)
    
    def test(self, m=1000, dyadic_order=0, sig_max_batch=100, perm_max_batch=1000):
        self.sig_x = sigkernel2.SigKernel2(self.static, dyadic_order=dyadic_order)
        self.sig_y = sigkernel2.SigKernel2(self.static, dyadic_order=dyadic_order)
        
        KX = self.sig_x.gram(self.x, max_batch=sig_max_batch)
        KY = self.sig_y.gram(self.y, max_batch=sig_max_batch)
        
        self.HSIC = HSIC(KX, KY)
        return self.HSIC.test(alpha=0.05, perms=m, max_batch=perm_max_batch)

In [4]:
class SigHSICPlus():
    def __init__(self, x, y):
        self.x = x.view(x.shape[0], -1)
        self.y = y
               
        self.x_dist = self.median_dist2(self.x).cpu().item()
        self.y_dist = self.median_dist(self.y).cpu().item()
        self.y_static = sigkernel2.RBFKernel(self.y_dist)
    
    def median_dist(self, x):
        x = x.view(-1, x.shape[2])
        
        x_norm = (x ** 2).sum(dim=1).unsqueeze(1)
        squared_dist_matrix = x_norm + x_norm.T - 2 * torch.mm(x, x.T)

        squared_dist_matrix = squared_dist_matrix.clamp(min=0)

        n = squared_dist_matrix.size(0)
        triu_indices = torch.triu_indices(n, n, offset=1, device=x.device)
        pairwise_squared_distances = squared_dist_matrix[triu_indices[0], triu_indices[1]]

        return pairwise_squared_distances.median()
    
    def median_dist2(self, x):
        x_norm = (x ** 2).sum(dim=1).unsqueeze(1)
        squared_dist_matrix = x_norm + x_norm.T - 2 * torch.mm(x, x.T)

        squared_dist_matrix = squared_dist_matrix.clamp(min=0)

        n = squared_dist_matrix.size(0)
        triu_indices = torch.triu_indices(n, n, offset=1, device=x.device)
        pairwise_squared_distances = squared_dist_matrix[triu_indices[0], triu_indices[1]]

        return pairwise_squared_distances.median()
    
    def rbf_kernel(self, x, sigma):
        x_norm = (x**2).sum(dim=1).unsqueeze(1)  # Shape (batch_size, 1)
        squared_distances = x_norm + x_norm.T - 2 * torch.mm(x, x.T)
        kernel = torch.exp(-squared_distances / sigma)
        return kernel
       
    def gram(self, x, dyadic_order=0, normalize=True):
        x = self.normalize(x) if normalize else x
        
        x_dist = self.median_dist(x).cpu().item()
        x_static = sigkernel2.RBFKernel(x_dist)
        
        sig_x = sigkernel2.SigKernel2(x_static, dyadic_order=dyadic_order)
        return sig_x.gram(x)
    
    def test(self, m=1000, dyadic_order=0, sig_max_batch=100, perm_max_batch=1000):
        self.sig_y = sigkernel2.SigKernel2(self.y_static, dyadic_order=dyadic_order)
        
        KX = self.rbf_kernel(self.x, self.x_dist)
        KY = self.sig_y.gram(self.y, max_batch=sig_max_batch)
        
        self.HSIC = HSIC(KX, KY)
        return self.HSIC.test(alpha=0.05, perms=m, max_batch=perm_max_batch)

In [5]:
class RBFHSIC():
    def __init__(self, x, y):
        self.x = x.view(x.shape[0], -1)
        self.y = y.view(y.shape[0], -1)
        
        self.x_dist = self.median_dist(self.x)
        self.y_dist = self.median_dist(self.y)
        
    def test(self, m=1000, perm_max_batch=1000):
        KX = self.rbf_kernel(self.x, self.x_dist)
        KY = self.rbf_kernel(self.y, self.y_dist)
        self.HSIC = HSIC(KX, KY)
        return self.HSIC.test(alpha=0.05, perms=m, max_batch=perm_max_batch)
    
    def median_dist(self, x):
        x_norm = (x ** 2).sum(dim=1).unsqueeze(1)
        squared_dist_matrix = x_norm + x_norm.T - 2 * torch.mm(x, x.T)

        squared_dist_matrix = squared_dist_matrix.clamp(min=0)

        n = squared_dist_matrix.size(0)
        triu_indices = torch.triu_indices(n, n, offset=1, device=x.device)
        pairwise_squared_distances = squared_dist_matrix[triu_indices[0], triu_indices[1]]

        return pairwise_squared_distances.median()
    
    def rbf_kernel(self, x, sigma):
        x_norm = (x**2).sum(dim=1).unsqueeze(1)  # Shape (batch_size, 1)
        squared_distances = x_norm + x_norm.T - 2 * torch.mm(x, x.T)
        kernel = torch.exp(-squared_distances / sigma)
        return kernel

In [5]:
device = torch.device('cuda:0')
dtype = torch.float64
dyadic_order = 2
n = 30
perms = 1000
thin = 1
reps = 1000
alfas = [0, 0.25, 0.5, 0.75, 1.0]
normalize = False

torch.manual_seed(42)

if torch.cuda.is_available():
    torch.cuda.manual_seed(42)
    torch.cuda.manual_seed_all(42)

tests = {
    "dCovHD": lambda x, y : dCovMod(x, y).test(),
    "dCov": lambda x, y : dCov(x, y).test(),
    "SigHSIC": lambda x, y : SigHSIC(x, y).test(dyadic_order=dyadic_order),
    "SigHSICAddT": lambda x, y : SigHSIC(LinearSDE.add_time(x), LinearSDE.add_time(y)).test(dyadic_order=dyadic_order),
    "RBF HSIC": lambda x, y : RBFHSIC(x, y).test()
}

with open("../speciale/rplots/fourier_basis_.csv", "w", newline="") as file:
    writer = csv.writer(file)
    writer.writerow(["test", "alfa", "run", "result"])
    
    for alf in alfas:
        print(f"rho: {alf}")
        torch.cuda.empty_cache()
        for i in range(reps):
                if i % 100 == 0:
                    print(f"rep: {i}")
                X, Y = Fourier.sample(n, rho=alf, T=0.5, thin=thin)
                for t_name, test in tests.items():
                    p_value = test(X, Y) 
                    writer.writerow([t_name, alf, i, p_value])
            

rho: 0
rep: 0
rep: 100
rep: 200
rep: 300
rep: 400
rep: 500
rep: 600
rep: 700
rep: 800
rep: 900
rho: 0.25
rep: 0
rep: 100
rep: 200
rep: 300
rep: 400
rep: 500
rep: 600
rep: 700
rep: 800
rep: 900
rho: 0.5
rep: 0
rep: 100
rep: 200
rep: 300
rep: 400
rep: 500
rep: 600
rep: 700
rep: 800
rep: 900
rho: 0.75
rep: 0
rep: 100
rep: 200
rep: 300
rep: 400
rep: 500
rep: 600
rep: 700
rep: 800
rep: 900
rho: 1.0
rep: 0
rep: 100
rep: 200
rep: 300
rep: 400
rep: 500
rep: 600
rep: 700
rep: 800
rep: 900


In [7]:
device = torch.device('cuda:0')
dtype = torch.float64
dyadic_order = 2
n = 30
perms = 1000
thin = 1
reps = 1000
rho = [0, 0.25, 0.5, 0.75, 1, 1.25, 1.5, 1.75, 2]
normalize = False

torch.manual_seed(42)

if torch.cuda.is_available():
    torch.cuda.manual_seed(42)
    torch.cuda.manual_seed_all(42)

tests = {
    "dCovHD": lambda x, y : dCovMod(x, y).test(),
    "dCov": lambda x, y : dCov(x, y).test(),
    "SigHSIC": lambda x, y : SigHSIC(x, y).test(dyadic_order=dyadic_order),
    "SigHSICAddT": lambda x, y : SigHSIC(LinearSDE.add_time(x), LinearSDE.add_time(y)).test(dyadic_order=dyadic_order),
    "RBF HSIC": lambda x, y : RBFHSIC(x, y).test()
}

with open("../speciale/rplots/hsic_linear_sde_.csv", "w", newline="") as file:
    writer = csv.writer(file)
    writer.writerow(["test", "rho", "run", "result"])
    
    for r in rho:
        print(f"rho: {r}")
        torch.cuda.empty_cache()
        for i in range(reps):
                if i % 200 == 0:
                    print(f"rep: {i}")
                X, Y = LinearSDE.sample(n=n, rho=r, thin=thin)
                for t_name, test in tests.items():
                    p_value = test(X, Y) 
                    writer.writerow([t_name, r, i, p_value])
            

rho: 0
rep: 0
rep: 200
rep: 400
rep: 600
rep: 800
rho: 0.25
rep: 0
rep: 200
rep: 400
rep: 600
rep: 800
rho: 0.5
rep: 0
rep: 200
rep: 400
rep: 600
rep: 800
rho: 0.75
rep: 0
rep: 200
rep: 400
rep: 600
rep: 800
rho: 1
rep: 0
rep: 200
rep: 400
rep: 600
rep: 800
rho: 1.25
rep: 0
rep: 200
rep: 400
rep: 600
rep: 800
rho: 1.5
rep: 0
rep: 200
rep: 400
rep: 600
rep: 800
rho: 1.75
rep: 0
rep: 200
rep: 400
rep: 600
rep: 800
rho: 2
rep: 0
rep: 200
rep: 400
rep: 600
rep: 800


In [9]:
device = torch.device('cuda:0')
dtype = torch.float64
dyadic_order = 3
n = 30
p = 50
perms = 1000
l = 101
reps = 1000
ms = [0, 2, 6, 8, 10]
normalize = True

torch.manual_seed(42)

if torch.cuda.is_available():
    torch.cuda.manual_seed(42)
    torch.cuda.manual_seed_all(42)
    
xi_dists = {
    "Normal": tdist.Normal(torch.tensor([0], dtype=dtype, device=device), torch.tensor([1], dtype=dtype, device=device))
}

gam_dists = {
    "Normal": tdist.Normal(torch.tensor([0], dtype=dtype, device=device), torch.tensor([1], dtype=dtype, device=device))
}

tests = {
    "dCovHD": lambda x, y : dCovMod(x, y).test(),
    "dCov": lambda x, y : dCov(x, y).test(),
    "SigHSIC": lambda x, y : SigHSIC(x, y).test(dyadic_order=dyadic_order),
    "SigHSICAddT": lambda x, y : SigHSIC(LinearSDE.add_time(x), LinearSDE.add_time(y)).test(dyadic_order=dyadic_order),
    "RBF HSIC": lambda x, y : RBFHSIC(x, y).test()
}

funcs = {
    "f(x)=x^3": lambda x : x.pow(3)
    #"f(x)=x^2": lambda x : x.pow(2)
    #,"f(x)=\sin(x)": lambda x : x.sin()
    #,"f(x)=\cos(x)": lambda x : x.cos()
}

with open("../speciale/rplots/hsic_harmonics-.csv", "w", newline="") as file:
    writer = csv.writer(file)
    writer.writerow(["test", "m", "xi_dist", "gam_dist", "f_name", "run", "result"])
    
    for m in ms:
        print(f"m: {m}")
        for xi_name, xi_d in xi_dists.items():
            for gam_name, gam_d in gam_dists.items():
                print(f"xi: {xi_name}, gam: {gam_name}")
                for f_name, func in funcs.items():
                    print(f"func: {f_name}")
                    new_sampler = sampler(xi_d, gam_d, func, p=p)
                    torch.cuda.empty_cache()
                    for i in range(reps):
                        if i % 100 == 0:
                            print(f"rep: {i}")
                        X, Y = new_sampler.sample(n, m, garbage=0)
                        
                        for t_name, test in tests.items():
                            p_value = test(X, Y) 
                            writer.writerow([t_name, m, xi_name, gam_name, f_name, i, p_value])

m: 0
xi: Normal, gam: Normal
func: f(x)=x^3
rep: 0
rep: 100
rep: 200
rep: 300
rep: 400
rep: 500
rep: 600
rep: 700
rep: 800
rep: 900
m: 2
xi: Normal, gam: Normal
func: f(x)=x^3
rep: 0
rep: 100
rep: 200
rep: 300
rep: 400
rep: 500
rep: 600
rep: 700
rep: 800
rep: 900
m: 6
xi: Normal, gam: Normal
func: f(x)=x^3
rep: 0
rep: 100
rep: 200
rep: 300
rep: 400
rep: 500
rep: 600
rep: 700
rep: 800
rep: 900
m: 8
xi: Normal, gam: Normal
func: f(x)=x^3
rep: 0
rep: 100
rep: 200
rep: 300
rep: 400
rep: 500
rep: 600
rep: 700
rep: 800
rep: 900
m: 10
xi: Normal, gam: Normal
func: f(x)=x^3
rep: 0
rep: 100
rep: 200
rep: 300
rep: 400
rep: 500
rep: 600
rep: 700
rep: 800
rep: 900


In [4]:
device = torch.device('cuda:0')
dtype = torch.float64
dyadic_order = 2
n = 30
perms = 1000
thin = 1
reps = 1000
rho = [0, 0.25, 0.5, 0.75, 1, 1.25, 1.5, 1.75, 2]
normalize = False

torch.manual_seed(42)

if torch.cuda.is_available():
    torch.cuda.manual_seed(42)
    torch.cuda.manual_seed_all(42)

tests = {
    "dCovHD": lambda x, y : dCovMod(x, y).test(),
    "dCov": lambda x, y : dCov(x, y).test(),
    "SigHSIC": lambda x, y : SigHSIC(x, y).test(dyadic_order=dyadic_order),
    "SigHSICAddT": lambda x, y : SigHSIC(LinearSDE.add_time(x), LinearSDE.add_time(y)).test(dyadic_order=dyadic_order),
    "RBF HSIC": lambda x, y : RBFHSIC(x, y).test()
}

with open("../speciale/rplots/hsic_linear_sde2-.csv", "w", newline="") as file:
    writer = csv.writer(file)
    writer.writerow(["test", "rho", "run", "result"])
    
    for r in rho:
        print(f"rho: {r}")
        torch.cuda.empty_cache()
        for i in range(reps):
                if i % 200 == 0:
                    print(f"rep: {i}")
                X, Y = LinearSDE.sample_sig(n=n, rho=r, thin=thin)
                for t_name, test in tests.items():
                    p_value = test(X, Y) 
                    writer.writerow([t_name, r, i, p_value])
            

rho: 0
rep: 0
rep: 100
rep: 200
rep: 300
rep: 400
rep: 500
rep: 600
rep: 700
rep: 800
rep: 900
rho: 0.25
rep: 0
rep: 100
rep: 200
rep: 300
rep: 400
rep: 500
rep: 600
rep: 700
rep: 800
rep: 900
rho: 0.5
rep: 0
rep: 100
rep: 200
rep: 300
rep: 400
rep: 500
rep: 600
rep: 700
rep: 800
rep: 900
rho: 0.75
rep: 0
rep: 100
rep: 200
rep: 300
rep: 400
rep: 500
rep: 600
rep: 700
rep: 800
rep: 900
rho: 1
rep: 0
rep: 100
rep: 200
rep: 300
rep: 400
rep: 500
rep: 600
rep: 700
rep: 800
rep: 900
rho: 1.25
rep: 0
rep: 100
rep: 200
rep: 300
rep: 400
rep: 500
rep: 600
rep: 700
rep: 800
rep: 900
rho: 1.5
rep: 0
rep: 100
rep: 200
rep: 300
rep: 400
rep: 500
rep: 600
rep: 700
rep: 800
rep: 900
rho: 1.75
rep: 0
rep: 100
rep: 200
rep: 300
rep: 400
rep: 500
rep: 600
rep: 700
rep: 800
rep: 900
rho: 2
rep: 0
rep: 100
rep: 200
rep: 300
rep: 400
rep: 500
rep: 600
rep: 700
rep: 800
rep: 900


In [54]:
device = torch.device('cuda:0')
dtype = torch.float64
dyadic_order = 1
n = 30
perms = 500
thin = 1
reps = 500
thetas = [0, 1.0, 2.0, 3.0, 4.0]
normalize = False

torch.manual_seed(42)

if torch.cuda.is_available():
    torch.cuda.manual_seed(42)
    torch.cuda.manual_seed_all(42)

tests = {
    "dCovHD": lambda x, y : dCovMod(x, y).test(),
    "dCov": lambda x, y : dCov(x, y).test(),
    "SigHSICAddT": lambda x, y : SigHSIC(LinearSDE.add_time(x), LinearSDE.add_time(y)).test(dyadic_order=dyadic_order),
    #"SigHSICLinearAddT": lambda x, y : SigHSICLinear(LinearSDE.add_time(x), LinearSDE.add_time(y)).test(dyadic_order=dyadic_order),
    "RBF HSIC": lambda x, y : RBFHSIC(x, y).test()
}

with open("../speciale/rplots/hsic_nonlinear_sde3.csv", "w", newline="") as file:
    writer = csv.writer(file)
    writer.writerow(["test", "theta", "run", "result"])
    
    for theta in thetas:
        print(f"theta: {theta}")
        torch.cuda.empty_cache()
        for i in range(reps):
                if i % 200 == 0:
                    print(f"rep: {i}")
                X, Y = NonLinearSDE.sample(n=n, theta=theta, dt=0.005, thin=1, device=device, dtype=dtype)
                for t_name, test in tests.items():
                    p_value = test(X, Y) 
                    writer.writerow([t_name, theta, i, p_value])

theta: 0
rep: 0
rep: 200
rep: 400
theta: 1.0
rep: 0
rep: 200
rep: 400
theta: 2.0
rep: 0
rep: 200
rep: 400
theta: 3.0
rep: 0
rep: 200
rep: 400
theta: 4.0
rep: 0
rep: 200
rep: 400


In [6]:
device = torch.device('cuda:0')
dtype = torch.float64
dyadic_order = 2
n = 30
perms = 1000
thin = 1
reps = 1000
alfas = [0, 0.25, 0.5, 0.75, 1.0]
normalize = False

torch.manual_seed(42)

if torch.cuda.is_available():
    torch.cuda.manual_seed(42)
    torch.cuda.manual_seed_all(42)

tests = {
    "RBF HSIC + SigHSICAddT": lambda x, y : SigHSICPlus(x, LinearSDE.add_time(y)).test(dyadic_order=dyadic_order),
}

with open("../speciale/rplots/fourier_basis_extra_2.csv", "w", newline="") as file:
    writer = csv.writer(file)
    writer.writerow(["test", "alfa", "run", "result"])
    
    for alf in alfas:
        print(f"rho: {alf}")
        torch.cuda.empty_cache()
        for i in range(reps):
                if i % 100 == 0:
                    print(f"rep: {i}")
                X, Y = Fourier.sample(n, rho=alf, T=0.5, thin=thin)
                for t_name, test in tests.items():
                    p_value = test(X, Y) 
                    writer.writerow([t_name, alf, i, p_value])
            

rho: 0
rep: 0
rep: 100
rep: 200
rep: 300
rep: 400
rep: 500
rep: 600
rep: 700
rep: 800
rep: 900
rho: 0.25
rep: 0
rep: 100
rep: 200
rep: 300
rep: 400
rep: 500
rep: 600
rep: 700
rep: 800
rep: 900
rho: 0.5
rep: 0
rep: 100
rep: 200
rep: 300
rep: 400
rep: 500
rep: 600
rep: 700
rep: 800
rep: 900
rho: 0.75
rep: 0
rep: 100
rep: 200
rep: 300
rep: 400
rep: 500
rep: 600
rep: 700
rep: 800
rep: 900
rho: 1.0
rep: 0
rep: 100
rep: 200
rep: 300
rep: 400
rep: 500
rep: 600
rep: 700
rep: 800
rep: 900


In [6]:
device = torch.device('cuda:0')
dtype = torch.float64
dyadic_order = 1
n = 30
perms = 500
thin = 1
reps = 500
thetas = [0, 1.0, 2.0, 3.0, 4.0]
normalize = False

torch.manual_seed(42)

if torch.cuda.is_available():
    torch.cuda.manual_seed(42)
    torch.cuda.manual_seed_all(42)

tests = {
    #"dCovHD": lambda x, y : dCovMod(x, y).test(),
    #"dCov": lambda x, y : dCov(x, y).test(),
    #"SigHSICAddT": lambda x, y : SigHSIC(LinearSDE.add_time(x), LinearSDE.add_time(y)).test(dyadic_order=dyadic_order),
    #"SigHSICLinearAddT": lambda x, y : SigHSICLinear(LinearSDE.add_time(x), LinearSDE.add_time(y)).test(dyadic_order=dyadic_order),
    #"RBF HSIC": lambda x, y : RBFHSIC(x, y).test()
    "RobustSigRBFAddT": lambda x, y : SigHSIC(LinearSDE.add_time(x), LinearSDE.add_time(y), robust=True).test(dyadic_order=dyadic_order),
    "RobustSigLinAddT": lambda x, y : SigHSIC(LinearSDE.add_time(x), LinearSDE.add_time(y), robust=True, linear=True).test(dyadic_order=dyadic_order),
}

with open("../speciale/rplots/hsic_nonlinear_sde_robust.csv", "w", newline="") as file:
    writer = csv.writer(file)
    writer.writerow(["test", "theta", "run", "result"])
    
    for theta in thetas:
        print(f"theta: {theta}")
        torch.cuda.empty_cache()
        for i in range(reps):
                if i % 200 == 0:
                    print(f"rep: {i}")
                X, Y = NonLinearSDE.sample(n=n, theta=theta, dt=0.005, thin=1, device=device, dtype=dtype)
                for t_name, test in tests.items():
                    p_value = test(X, Y) 
                    writer.writerow([t_name, theta, i, p_value])

theta: 0
rep: 0




rep: 200
rep: 400
theta: 1.0
rep: 0
rep: 200
rep: 400
theta: 2.0
rep: 0
rep: 200
rep: 400
theta: 3.0
rep: 0
rep: 200
rep: 400
theta: 4.0
rep: 0
rep: 200
rep: 400
