In [3]:
import torch
import sigkernel2
import sigkernel
import csv
import scipy
import timeit
import math
device = torch.cuda.device('cuda')

In [265]:
def generate(batch_size, length, dimension, device = torch.device('cpu')):
  random_walks = torch.randn(batch_size, length, dimension, dtype = torch.double, device = device) / math.sqrt(length)
  start = torch.zeros([batch_size, 1, dimension], device=device, dtype=torch.double)
  random_walks = torch.cat((start, random_walks), dim=1)
  random_walks = torch.cumsum(random_walks, dim=1)
  return random_walks

def median_distance(X, Y):
    A = X.shape[0]
    M = X.shape[1]
    N = Y.shape[1]
    Xs = torch.sum(X**2, dim=2)
    Ys = torch.sum(Y**2, dim=2)
    dist = -2.*torch.bmm(X, Y.permute(0,2,1))
    dist += torch.reshape(Xs,(A,M,1)) + torch.reshape(Ys,(A,1,N))
    return dist.view(A, -1).median()

In [354]:
X = generate(512, 50, 10, device = torch.device('cuda:0'))
Y = generate(512, 50, 10, device = torch.device('cuda:0'))
Z = generate(512, 50, 10, device = torch.device('cuda:0'))

X = 1 * torch.sin(0.5 * Y) + X

In [355]:
sig_x = sigkernel2.SigKernel2(sigkernel2.RBFKernel(median_distance(X, X).cpu().item()), 1)
sig_y = sigkernel2.SigKernel2(sigkernel2.RBFKernel(median_distance(Y, Y).cpu().item()), 1)
sig_z = sigkernel2.SigKernel2(sigkernel2.RBFKernel(median_distance(Z, Z).cpu().item()), 1)
sig_w = sigkernel2.SigKernel2(sigkernel2.RBFKernel(median_distance(Z, Z).cpu().item()), 1)

In [356]:
KX = sig_x.gram(X, max_batch=64)
KY = sig_y.gram(Y, max_batch=64)
KZ = sig_z.gram(Z, max_batch=64)

In [357]:
test = HSIC(KX, KY)
test.test(perms=3000, max_batch=1000)


{'critical_value': 6456.121145056668,
 'gamma_p_value': 5.692797108986429e-109,
 'perm_p_value': 0.0,
 'mc_p_value': 0.0}

In [311]:
class HSIC:
    def __init__(self, X, Y):
        self.n = X.shape[0]
        H = -torch.ones((self.n, self.n), device=X.device, dtype=X.dtype) / self.n
        H[range(self.n), range(self.n)] = 1 - 1 / self.n
        
        self.X = X @ H
        self.Y = Y @ H
        
    def test(self, alpha=0.05, perms=1000, max_batch=None):
        critical_value = torch.einsum('ij,ji->', self.X, self.Y) / self.n
        return {
            "critical_value": critical_value.cpu().item(),
            "gamma_p_value": self._gamma_approx(critical_value.cpu().item(), alpha),
            "perm_p_value": self._perm(critical_value, alpha, perms, max_batch).cpu().item(),
            "mc_p_value": self._montecarlo(critical_value, alpha, perms, max_batch).cpu().item()
        }
    
    def _gamma_approx(self, value=None, alpha=0.05):
        mu, var = self._empirical_moments()
        loc = mu**2 / var
        scale = var / mu
        if value is None:
            return scipy.stats.distributions.gamma.ppf(1 - alpha, a = loc.cpu(), scale = scale.cpu())
        return scipy.stats.distributions.gamma.sf(value, a = loc.cpu(), scale = scale.cpu())
    
    def _empirical_moments(self):
        mu_X = torch.trace(self.X)
        mu_Y = torch.trace(self.Y)
        
        var_X = torch.einsum('ij,ji->', self.X, self.X)
        var_Y = torch.einsum('ij,ji->', self.Y, self.Y)
        
        return mu_X * mu_Y / self.n**2, 2 * var_X * var_Y / self.n**4
    
    def _perm(self, value=None, alpha=0.05, m=100, max_batch=None):
        max_batch = m if max_batch is None else max_batch
        
        mb_size = min(m, max_batch)
        bm = -(-m // mb_size)
        
        stats = torch.zeros([m], device=self.X.device, dtype=self.X.dtype)
        
        for i in range(bm):
            mb_size_i = m - mb_size * (bm - 1) if i == bm - 1 else mb_size
            start = i * mb_size
            stop = i * mb_size + mb_size_i
            perms = torch.stack([torch.randperm(self.n, device=self.X.device) for _ in range(mb_size_i)])
            Y_perm = torch.stack([self.Y[p][:, p] for p in perms])
            tmp = torch.einsum('ij,kji->k', self.X, Y_perm) / self.n
            stats[start:stop] = tmp
            
        if value is None:
            return stats.abs().quantile(1 - alpha, 0)
            
        return (stats.abs() > value).double().mean()
    
    def _montecarlo(self, value=None, alpha=0.05, m=100, max_batch=None):
        max_batch = m if max_batch is None else max_batch
        
        mb_size = min(m, max_batch)
        bm = -(-m // mb_size)
        
        stats = torch.zeros([m], device=self.X.device, dtype=self.X.dtype)
        eig_X = torch.linalg.eigvalsh(self.X).view(self.n, -1)
        eig_Y = torch.linalg.eigvalsh(self.Y).view(-1, self.n)
        
        for i in range(bm):
            mb_size_i = m - mb_size * (bm - 1) if i == bm - 1 else mb_size
            start = i * mb_size
            stop = i * mb_size + mb_size_i
            z = torch.randn([mb_size_i, self.n, self.n], device=self.X.device, dtype=self.X.dtype).pow(2)
            z = z * eig_X * eig_Y           
            stats[start:stop] = z.mean(dim=[1,2])
            
        if value is None:
            return stats.abs().quantile(1 - alpha, 0)
            
        return (stats.abs() > value).double().mean()        

In [296]:
class CHSIC:
    def __init__(self, X, Y, Z):
        self.n = X.shape[0]
        self.X = self.sym(self.center_gram(X * Z))
        self.Y = self.sym(self.center_gram(Y))
        self.Z = self.sym(self.center_gram(Z))
        
    def sym(self, x):
        return 0.5 * (x + x.T)
    
    def center_gram(self, x):
        n = x.shape[0]
        csums = x.sum(dim=0)
        tsums = csums.sum()  
        return x - (csums[None, :] + csums[:, None]) / n + (tsums / n ** 2)
    
    def eig_trunc(self, x):
        vals, vecs = torch.linalg.eigh(x)
        return vals, vecs

    def eigvals_trunc(self, x):
        vals = torch.linalg.eigvalsh(x)
        return vals
    
    def fit_eps(self, Kx, Kz, maxit=10000, reltol=1e-3, lr=0.1):
        vals, vecs = self.eig_trunc(Kx)
        vals[vals < 0] = 0
        vals = vals.sqrt()
        vecs = 2 * math.sqrt(self.n) * vecs * vals / vals.max()
        sig = torch.randn(1, requires_grad=True, device=Kx.device, dtype=Kx.dtype)
        scale = torch.randn(1, requires_grad=True, device=Kx.device, dtype=Kx.dtype)
        sig = torch.nn.Parameter(sig)
        scale = torch.nn.Parameter(scale)
        
        optimizer = torch.optim.Adam([sig, scale], lr=lr)
        
        prev_sig = sig.detach().exp()
        prev_scale = sig.detach().exp()
        
        for _ in range(maxit):
            optimizer.zero_grad()
            M = torch.exp(scale) * Kz + torch.exp(sig) * torch.eye(self.n, device=Kz.device, dtype=Kz.dtype)
            dist = torch.distributions.MultivariateNormal(loc=0 * vecs.mean(dim=0), covariance_matrix=M)
            loss = -dist.log_prob(vecs).mean()
            loss.backward()
            optimizer.step()
            
            with torch.no_grad():               
                if torch.abs(sig.exp() - prev_sig) / sig.exp() < reltol and torch.abs(scale.exp() - prev_scale) / scale.exp() < reltol:
                    break
                prev_sig = sig.detach().exp()
                prev_scale = scale.detach().exp()
                
        return sig.detach().exp(), scale.detach().exp()
        
    
    def test(self, m=1000, max_batch=None, alpha=0.05):
        I = torch.eye(self.n, device=self.X.device, dtype=self.X.dtype)
        x_eps, x_scale = self.fit_eps(self.X, self.Z, lr=2)
        y_eps, y_scale = self.fit_eps(self.Y, self.Z, lr=2)
        
        Rx = x_eps * torch.linalg.pinv((x_scale * self.Z + x_eps * I), hermitian=True)
        Ry = y_eps * torch.linalg.pinv((y_scale * self.Z + y_eps * I), hermitian=True)
        
        #R = I - self.Z @ torch.inverse(self.Z + eps * I)
              
        XZ = torch.einsum('ik,kl,jl->ij', Rx, self.X, Rx)
        YY = torch.einsum('ik,kl,jl->ij', Ry, self.Y, Ry)
        
        critical_value = (XZ * YY).sum() / self.n
        
        return { 
            "critical_value": critical_value.cpu().item(),
            "mc_p_value": self._null_dist(XZ, YY, critical_value, alpha, m, max_batch).cpu().item()
        }
        
    
    def _null_dist(self, KXZ, KYZ, value=None, alpha=0.05, m=1000, max_batch=None):
        try:
            val_x, eig_x = torch.linalg.eigh(KXZ)
            val_x[val_x < 0] = 0
        except:
            val_x, eig_x = torch.linalg.eig(KXZ)
            
        try:
            val_y, eig_y = torch.linalg.eigh(KYZ)
            val_y[val_y < 0] = 0
        except:
              val_y, eig_y = torch.linalg.eig(KYZ)      
        
        vec_x = eig_x * val_x.sqrt()
        vec_y = eig_y * val_y.sqrt()

        ww = (vec_x.unsqueeze(2) * vec_y.unsqueeze(1)).reshape(self.n, -1)

        ww_prod = ww @ ww.T if ww.shape[1] > self.n else ww.T @ ww
        
        try:
            eig_vals = torch.linalg.eigvalsh(ww_prod)
            eig_vals[eig_vals < 0] = 0
        except:
            eig_vals = torch.linalg.eigvals(ww_prod)
        
        max_batch = m if max_batch is None else max_batch
        
        mb_size = min(m, max_batch)
        bm = -(-m // mb_size)
        
        stats = torch.zeros([m], device=self.X.device, dtype=self.X.dtype)
        
        for i in range(bm):
            mb_size_i = m - mb_size * (bm - 1) if i == bm - 1 else mb_size
            start = i * mb_size
            stop = i * mb_size + mb_size_i
            z = torch.randn([mb_size_i, self.n], device=self.X.device, dtype=self.X.dtype).pow(2)
            z = z * eig_vals
            stats[start:stop] = z.mean(dim=1)
            
        if value is None:
            return stats.abs().quantile(1 - alpha, 0)
            
        return (stats.abs() > value).double().mean()

In [258]:
test = CHSIC(KX, KY, KZ)

In [259]:
test.test(m=2000)

{'critical_value': 2402.2684856673623, 'mc_p_value': 0.428}

In [154]:
CX = test.X
CY = test.Y
CZ = test.Z

In [57]:
class epsfinder():
    def __init__(self, x, z):
        self.eig_z = self.eigs(z)
        self.eig_x = self.eigs(x)
        self.K = x
        self.device = x.device
        self.dtype = x.dtype
        self.n = x.shape[0]
        self.sig = torch.nn.Parameter(torch.randn(1, requires_grad=True, device=self.device, dtype=self.dtype))
        
    def forward(self):
        M = self.K + torch.exp(self.sig) * torch.eye(self.n, device=self.device, dtype=self.dtype)
        dist = torch.distributions.MultivariateNormal(loc=self.eig_z.mean(dim=0), covariance_matrix=M)
        return -dist.log_prob(self.eig_z.T).mean()
        
    def eigs(self, x):
        val, vec = torch.linalg.eig(0.5 * (x + x.T))
        return vec.real * val.real

In [55]:
eps = epsfinder(CZ, CY)
optimizer = torch.optim.Adam([eps.sig], lr=0.1)

In [56]:
num_iterations = 100000

with torch.no_grad():
    sig_prev = eps.sig.clone()

for i in range(num_iterations):
    optimizer.zero_grad()
    loss = eps.forward()
    loss.backward()
    optimizer.step()

    if i % 25:
        print(f"Loss: {loss.item()}, Sig: {eps.sig.exp().cpu().item()}")

Loss: 505.2293297857577, Sig: 0.8823157194190312
Loss: 482.0843100448312, Sig: 0.7983830489429063
Loss: 459.07723777904215, Sig: 0.7224606618141663
Loss: 436.2179810352863, Sig: 0.6537897531367733
Loss: 413.5170949995638, Sig: 0.5916827461642138
Loss: 390.9858776784415, Sig: 0.5355166762080917
Loss: 368.6364299002704, Sig: 0.4847271842145996
Loss: 346.4817196305238, Sig: 0.43880306401581815
Loss: 324.53565049978613, Sig: 0.3972813124189331
Loss: 302.81313431394403, Sig: 0.35974263597765466
Loss: 281.3301671468446, Sig: 0.32580737253922
Loss: 260.1039083986085, Sig: 0.2951317895167745
Loss: 239.15276192983066, Sig: 0.2674047243367253
Loss: 218.49645804398995, Sig: 0.24234453568674283
Loss: 198.1561346770613, Sig: 0.2196963370723864
Loss: 178.1544156524829, Sig: 0.19922948680600455
Loss: 158.51548325761405, Sig: 0.18073531092519532
Loss: 139.26514167978212, Sig: 0.16402503769200746
Loss: 120.43086699101961, Sig: 0.14892792427840845
Loss: 102.041838377277, Sig: 0.1352895580166842
Loss: 84

KeyboardInterrupt: 

In [291]:
x = torch.randn([512, 4], device=torch.device('cuda:0'), dtype=torch.double).unsqueeze(0)
y = torch.randn([512, 4], device=torch.device('cuda:0'), dtype=torch.double).unsqueeze(0)
z = torch.randn([512, 4], device=torch.device('cuda:0'), dtype=torch.double).unsqueeze(0)

x = z**2 + x
y = z**3 + y


In [292]:
rbf_x = sigkernel2.RBFKernel(median_distance(x, x))
rbf_y = sigkernel2.RBFKernel(median_distance(y, y))
rbf_z = sigkernel2.RBFKernel(median_distance(z, z))

In [293]:
kx = rbf_x.batch_kernel(x, x).squeeze(0)
ky = rbf_y.batch_kernel(y, y).squeeze(0)
kz = rbf_z.batch_kernel(z, z).squeeze(0)

In [295]:
test = CHSIC(kx, kx, kz)
test.test(m=2000)

{'critical_value': 6.769860165464699, 'mc_p_value': 0.0}