In [1]:
import torch
import sigkernel2
import sigkernel
import csv
import scipy
import timeit
import math
device = torch.cuda.device('cuda')

In [2]:
def generate(batch_size, length, dimension, device = torch.device('cpu')):
  random_walks = torch.randn(batch_size, length, dimension, dtype = torch.double, device = device) / math.sqrt(length)
  start = torch.zeros([batch_size, 1, dimension], device=device, dtype=torch.double)
  random_walks = torch.cat((start, random_walks), dim=1)
  random_walks = torch.cumsum(random_walks, dim=1)
  return random_walks

In [18]:
X = generate(512, 10, 7, device = torch.device('cuda:0'))
Y = generate(512, 10, 7, device = torch.device('cuda:0'))
Z = generate(512, 10, 7, device = torch.device('cuda:0'))

W = Z
W[:,:,0:1] = X[:,:,0:1]

In [37]:
def median_distance(X, Y):
    A = X.shape[0]
    M = X.shape[1]
    N = Y.shape[1]
    Xs = torch.sum(X**2, dim=2)
    Ys = torch.sum(Y**2, dim=2)
    dist = -2.*torch.bmm(X, Y.permute(0,2,1))
    dist += torch.reshape(Xs,(A,M,1)) + torch.reshape(Ys,(A,1,N))
    return dist.view(A, -1).median()

sig_x = sigkernel2.SigKernel2(sigkernel2.RBFKernel(median_distance(X, X).cpu().item()), 1)
sig_y = sigkernel2.SigKernel2(sigkernel2.RBFKernel(median_distance(Y, Y).cpu().item()), 1)
sig_z = sigkernel2.SigKernel2(sigkernel2.RBFKernel(median_distance(Z, Z).cpu().item()), 1)
sig_w = sigkernel2.SigKernel2(sigkernel2.RBFKernel(median_distance(Z, Z).cpu().item()), 1)

In [20]:
KX = sig_x.gram(X, max_batch=64)
KY = sig_y.gram(Y, max_batch=64)
KZ = sig_z.gram(Z, max_batch=64)
KW = sig_w.gram(W, max_batch=64)

In [6]:
class HSIC:
    def __init__(self, X, Y):
        self.n = X.shape[0]
        H = -torch.ones((self.n, self.n), device=X.device, dtype=X.dtype) / self.n
        H[range(self.n), range(self.n)] = 1 - 1 / self.n
        
        self.X = X @ H
        self.Y = Y @ H
        
    def test(self, alpha=0.05, perms=1000, max_batch=None):
        critical_value = torch.einsum('ij,ji->', self.X, self.Y) / self.n
        return {
            "critical_value": critical_value.cpu().item(),
            "gamma_p_value": self._gamma_approx(critical_value.cpu().item(), alpha),
            "perm_p_value": self._perm(critical_value, alpha, perms, max_batch).cpu().item(),
            "mc_p_value": self._montecarlo(critical_value, alpha, perms, max_batch).cpu().item()
        }
    
    def _gamma_approx(self, value=None, alpha=0.05):
        mu, var = self._empirical_moments()
        loc = mu**2 / var
        scale = var / mu
        if value is None:
            return scipy.stats.distributions.gamma.ppf(1 - alpha, a = loc.cpu(), scale = scale.cpu())
        return scipy.stats.distributions.gamma.sf(value, a = loc.cpu(), scale = scale.cpu())
    
    def _empirical_moments(self):
        mu_X = torch.trace(self.X)
        mu_Y = torch.trace(self.Y)
        
        var_X = torch.einsum('ij,ji->', self.X, self.X)
        var_Y = torch.einsum('ij,ji->', self.Y, self.Y)
        
        return mu_X * mu_Y / self.n**2, 2 * var_X * var_Y / self.n**4
    
    def _perm(self, value=None, alpha=0.05, m=100, max_batch=None):
        max_batch = m if max_batch is None else max_batch
        
        mb_size = min(m, max_batch)
        bm = -(-m // mb_size)
        
        stats = torch.zeros([m], device=self.X.device, dtype=self.X.dtype)
        
        for i in range(bm):
            mb_size_i = m - mb_size * (bm - 1) if i == bm - 1 else mb_size
            start = i * mb_size
            stop = i * mb_size + mb_size_i
            perms = torch.stack([torch.randperm(self.n, device=self.X.device) for _ in range(mb_size_i)])
            Y_perm = torch.stack([self.Y[p][:, p] for p in perms])
            tmp = torch.einsum('ij,kji->k', self.X, Y_perm) / self.n
            stats[start:stop] = tmp
            
        if value is None:
            return stats.abs().quantile(1 - alpha, 0)
            
        return (stats.abs() > value).double().mean()
    
    def _montecarlo(self, value=None, alpha=0.05, m=100, max_batch=None):
        max_batch = m if max_batch is None else max_batch
        
        mb_size = min(m, max_batch)
        bm = -(-m // mb_size)
        
        stats = torch.zeros([m], device=self.X.device, dtype=self.X.dtype)
        eig_X = torch.linalg.eigvalsh(self.X).view(self.n, -1)
        eig_Y = torch.linalg.eigvalsh(self.Y).view(-1, self.n)
        
        for i in range(bm):
            mb_size_i = m - mb_size * (bm - 1) if i == bm - 1 else mb_size
            start = i * mb_size
            stop = i * mb_size + mb_size_i
            z = torch.randn([mb_size_i, self.n, self.n], device=self.X.device, dtype=self.X.dtype).pow(2)
            z = z * eig_X * eig_Y           
            stats[start:stop] = z.mean(dim=[1,2])
            
        if value is None:
            return stats.abs().quantile(1 - alpha, 0)
            
        return (stats.abs() > value).double().mean()        

In [7]:
class CHSIC:
    def __init__(self, X, Y, Z):
        self.n = X.shape[0]
        H = -torch.ones((self.n, self.n), device=X.device, dtype=X.dtype) / self.n
        H[range(self.n), range(self.n)] = 1 - 1 / self.n
        self.X2 = H @ X @ H
        self.X = H @ (X * Z) @ H
        self.Y = H @ Y @ H
        self.Z = H @ Z @ H
    
    def test(self, eps=0.001, m=1000, max_batch=None, alpha=0.05):
        I = torch.eye(self.n, device=self.X.device, dtype=self.X.dtype)
        R = I - self.Z @ torch.inverse(self.Z + eps * I)
              
        XZ = torch.einsum('ik,kl,lj->ij', R, self.X, R.T)
        YY = torch.einsum('ik,kl,lj->ij', R, self.Y, R.T)
        
        critical_value = torch.einsum('ij,ji->', XZ, YY) / self.n
        
        return { 
            "critical_value": critical_value.cpu().item(),
            "mc_p_value": self._null_dist(XZ, YY, critical_value, alpha, m, max_batch).cpu().item()
        }
        
    
    def _null_dist(self, KXZ, KYZ, value=None, alpha=0.05, m=1000, max_batch=None):
        try:
            val_x, eig_x = torch.linalg.eigh(KXZ)
            if (val_x < 0).any():
                raise ValueError("Negative eigenvalues")
        except:
            val_x, eig_x = torch.linalg.eig(KXZ)
            
        try:
            val_y, eig_y = torch.linalg.eigh(KYZ)
            if (val_y < 0).any():
                raise ValueError("Negative eigenvalues")
        except:
              val_y, eig_y = torch.linalg.eig(KYZ)      
        
        vec_x = eig_x * val_x.sqrt()
        vec_y = eig_y * val_y.sqrt()

        ww = (vec_x.unsqueeze(2) * vec_y.unsqueeze(1)).reshape(self.n, -1)

        ww_prod = ww @ ww.T if ww.shape[1] > self.n else ww.T @ ww
        
        try:
            eig_vals = torch.linalg.eigvalsh(ww_prod)
        except:
            eig_vals = torch.linalg.eigvals(ww_prod)
        
        max_batch = m if max_batch is None else max_batch
        
        mb_size = min(m, max_batch)
        bm = -(-m // mb_size)
        
        stats = torch.zeros([m], device=self.X.device, dtype=self.X.dtype)
        
        for i in range(bm):
            mb_size_i = m - mb_size * (bm - 1) if i == bm - 1 else mb_size
            start = i * mb_size
            stop = i * mb_size + mb_size_i
            z = torch.randn([mb_size_i, self.n], device=self.X.device, dtype=self.X.dtype).pow(2)
            z = z * eig_vals
            stats[start:stop] = z.mean(dim=1)
            
        if value is None:
            return stats.abs().quantile(1 - alpha, 0)
            
        return (stats.abs() > value).double().mean()

In [29]:
test = CHSIC(KX, KX, KZ)

In [32]:
test.test(eps=0.001, m=1000)

{'critical_value': 1.8805678261424607e-13, 'mc_p_value': 0.357}

In [25]:
test = HSIC(KX, KY)

In [26]:
test.test(perms=2000)

{'critical_value': 1748.738515468685,
 'gamma_p_value': 0.21396013931400976,
 'perm_p_value': 0.4805,
 'mc_p_value': 0.234}

In [87]:
x = torch.randn([512, 10], device=torch.device('cuda:0'), dtype=torch.double).unsqueeze(0)
y = torch.randn([512, 10], device=torch.device('cuda:0'), dtype=torch.double).unsqueeze(0)
z = torch.randn([512, 10], device=torch.device('cuda:0'), dtype=torch.double).unsqueeze(0)
w = z

w[:,0:5] = x[:,0:5]

In [88]:
rbf_x = sigkernel2.RBFKernel(median_distance(x, x))
rbf_y = sigkernel2.RBFKernel(median_distance(y, y))
rbf_z = sigkernel2.RBFKernel(median_distance(z, z))
rbf_w = sigkernel2.RBFKernel(median_distance(w, w))

In [89]:
kx = rbf_x.batch_kernel(x, x).squeeze(0)
ky = rbf_y.batch_kernel(y, y).squeeze(0)
kz = rbf_z.batch_kernel(z, z).squeeze(0)
kw = rbf_w.batch_kernel(w, w).squeeze(0)

In [98]:
test = CHSIC(kx, kw, kz)

In [99]:
test.test(eps=10, m=1000)

{'critical_value': 0.21022597516140074, 'mc_p_value': 0.0}