In [1]:
import jax
import jax.numpy as jnp
import jax.random as random
import os
from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)

os.environ["CUDA_VISIBLE_DEVICES"] = "4"

gpu


In [2]:
Kxx = jnp.array([[4,1,6],[1,4,5,],[6,5,4]])
Kxx

Array([[4, 1, 6],
       [1, 4, 5],
       [6, 5, 4]], dtype=int32)

In [3]:
Kyy = jnp.array([[1,2],[2,1]])
Kyy

Array([[1, 2],
       [2, 1]], dtype=int32)

In [4]:
Kxy = jnp.array([[1,2],[3,4],[5,6]])
Kxy

Array([[1, 2],
       [3, 4],
       [5, 6]], dtype=int32)

In [5]:
tKxx = Kxx -jnp.diag(jnp.diag(Kxx))
tKxx

Array([[0, 1, 6],
       [1, 0, 5],
       [6, 5, 0]], dtype=int32)

In [6]:
tKyy = Kyy -jnp.diag(jnp.diag(Kyy))
tKyy

Array([[0, 2],
       [2, 0]], dtype=int32)

In [7]:
m = Kxx.shape[0]
n = Kyy.shape[0]

In [8]:
def compute_mmd_sq(Kxx, Kyy, Kxy, m, n):
    term1 = jnp.sum(Kxx - jnp.diag(jnp.diag(Kxx))) / (m * (m - 1))
    term2 = jnp.sum(Kyy - jnp.diag(jnp.diag(Kyy))) / (n * (n - 1))
    term3 = -2 * jnp.sum(Kxy) / (m * n)

    return term1 + term2 + term3

In [9]:
compute_mmd_sq(Kxx, Kyy, Kxy, m, n)

Array(-1., dtype=float32)

In [10]:
def H(i,j,k,l) :
    return Kxx[i,j] + Kyy[k,l] - (Kxy[i,k]+Kxy[i,l]+Kxy[j,k]+Kxy[j,l])/2

In [11]:
def compute_moments(Kxx, Kyy, Kxy):
    m = Kxx.shape[0]
    n = Kyy.shape[0]
    one_m = jnp.ones(m)
    one_n = jnp.ones(n)
    
    tKxx = Kxx - jnp.diag(jnp.diag(Kxx))
    tKyy = Kyy - jnp.diag(jnp.diag(Kyy))
    return [
        0,
        jnp.trace(tKxx.T @ tKxx), # C1
        jnp.sum(tKxx.T @ tKxx), # C2
        jnp.sum(tKxx) * jnp.sum(tKxx), # C3 
        jnp.sum(tKxx) * jnp.sum(tKyy), # C4
        jnp.sum(tKxx @ Kxy), # C5
        (jnp.sum(Kxx) * jnp.sum(Kxy))
        -jnp.sum(jnp.diag(Kxx) * jnp.sum(Kxy))-jnp.sum(Kxx@Kxy)
        +jnp.sum(jnp.diag(Kxx)@Kxy@one_n), # C6
        jnp.sum(Kxy @ tKyy), # C7
        jnp.sum(Kxy @ Kyy), # C8 
        jnp.trace(Kxy.T @ Kxy), # C9
        jnp.sum((Kxy.T @ Kxy) -jnp.diag(jnp.diag((Kxy.T @ Kxy)))), # C10
        jnp.sum((Kxy @ Kxy.T) -jnp.diag(jnp.diag((Kxy @ Kxy.T)))), # C11
        (jnp.sum(Kxy) * jnp.sum(Kxy)) - jnp.sum((one_m.T @ Kxy)**2) 
        - jnp.sum((Kxy @  one_n)**2) + jnp.sum(Kxy ** 2), # C12
        jnp.trace(tKyy.T @ tKyy), # C13
        jnp.sum(tKyy @ tKyy), # C14
        jnp.sum(tKyy) * jnp.sum(tKyy) # C15 
    ]

In [12]:
C = compute_moments(Kxx, Kyy, Kxy)
C

[0,
 Array(124, dtype=int32),
 Array(206, dtype=int32),
 Array(576, dtype=int32),
 Array(96, dtype=int32),
 Array(184, dtype=int32),
 Array(320., dtype=float32),
 Array(42, dtype=int32),
 Array(63, dtype=int32),
 Array(91, dtype=int32),
 Array(88, dtype=int32),
 Array(134, dtype=int32),
 Array(128., dtype=float32),
 Array(8, dtype=int32),
 Array(8, dtype=int32),
 Array(16, dtype=int32)]

In [13]:
# 
result = 0.0

for i in range(m):
    for j in range(m):
        if i != j:
            for v in range(m):
                if i != v:
                    for k in range(n):
                        for l in range(n):
                            if k != l:
                                for u in range(n):
                                    for t in range(n):
                                        if u != t:
                                            result += Kxx[i, j] * Kxx[i, v]

print("Result:", result)
print(C[2] * (n**2) * ((n-1)**2))

Result: 824.0
824


In [16]:
result = 0.0
for i in range(m):
    for j in range(m):
        if i != j:
            for v in range(m):
                if i != v:
                    for k in range(n):
                        for l in range(n):
                            if k != l:
                                for u in range(n):
                                    for t in range(n):
                                        if u != t:
                                            result += Kxx[i, j] * Kxy[v, k]

print(result)
print(C[6] * n * (n-1) ** 2)

640.0
640.0


In [20]:
result = 0.0
for i in range(m):
    for j in range(m):
        if i != j:
            for v in range(m):
                for q in range(m):
                    if v != q:
                        for k in range(n):
                            for l in range(n):
                                if k != l:
                                    for u in range(n):
                                        if u != k:
                                            result += Kxx[i, j] * Kxx[v, q]

print("Result:", result)
print(C[3] * n * (n-1)**2)

Result: 1152.0
1152


In [21]:
result = 0.0
for i in range(m):
    for j in range(m):
        if i != j:
            for v in range(m):
                for q in range(m):
                    if v != q:
                        for k in range(n):
                            for l in range(n):
                                if k != l:
                                    for u in range(n):
                                        if u != k:
                                            result += Kxx[i, j] * Kyy[k, l]

print("Result:", result)
print(C[4] * m * (m-1))

Result: 576.0
576


In [43]:
result = 0.0
for i in range(m):
    for j in range(m):
        for v in range(m):
            for q in range(m):
                if i != j and v != q:
                    for k in range(n):
                        for l in range(n):
                            if (k != l) & (k != u):
                                for u in range(n):
                                     result += Kxx[v, q] * Kxy[j, k]

print("Result:", result)
print(C[6] * (m-1) * (n-1)**2)

Result: 864.0
640.0


In [25]:
result = 0.0
for i in range(m):
    for j in range(m):
        if i != j:
            for v in range(m):
                for q in range(m):
                    if v != q:
                        for k in range(n):
                            for l in range(n):
                                if k != l:
                                    for u in range(n):
                                        if u != k:
                                            result += Kxy[i, l] * Kyy[k, l]

print("Result:", result)
print(C[7] * m * (m-1) **2 * (n-1))

Result: 504.0
504


In [34]:
result = 0.0
for i in range(m):
    for j in range(m):
        for v in range(m):
            if i != j and i != v:
                for k in range(n):
                    for l in range(n):
                        for u in range(n):
                            if k != l and k != u:
                                result += Kxx[i, j] * Kxy[v, k]

print("Result:", result)
print(C[6] * (n-1)**2)

Result: 320.0
320.0


In [35]:
result = 0.0
for i in range(m):
    for j in range(m):
        for v in range(m):
            for q in range(m):
                if i != j and v != q:
                    for k in range(n):  
                        for l in range(n):
                            if k != l:
                                result += Kxx[i, j] * Kxy[v, k]

print("Result:", result)
print(C[6] * (m-1) * (n-1))

Result: 1008.0
640.0


In [37]:
1008 / 640

1.575

In [36]:
result = 0.0
for i in range(m):
    for j in range(m):
        for v in range(m):
            if i != j and i != v:
                for k in range(n):  
                    for l in range(n):
                        if k != l:
                            result += Kxx[i, j] * Kxy[v, k]

print("Result:", result)
print(C[6] * (n-1))

Result: 320.0
320.0


In [None]:
# C4 = 96
result = 0.0
for i in range(m) :
    for j in range(m) :
        if i != j :
            for k in range(n) :
                for l in range(n) :
                    if k != l :
                        result += Kxx[i,j] * Kyy[k,l]
result

In [None]:
# c5 184
result = 0.0
for i in range(m) :
    for j in range(m) :
        if i != j :
            for k in range(n) :
                        result += Kxx[i,j] * Kxy[i,k]
result

In [None]:
# c6 504
result = 0.0
for i in range(m) :
    for v in range(m) :
        for j in range(m) :
            if i != j :
                for k in range(n) :
                            result += Kxx[i,j] * Kxy[v,k]
result

In [None]:
# c7 42
result = 0.0
for i in range(m) :
    for k in range(n) :
        for l in range(n) :
            if k != l:
                result += Kxy[i,l] * Kyy[k,l]
result

In [None]:
# C8 63
result = 0.0
for i in range(m) :
    for k in range(n) :
        for u in range(n) :
            for l in range(n) :
                if k != l :
                    result += Kxy[i,k] * Kyy[l,u]
result

In [None]:
# c9 91
result = 0.0
for i in range(m) :
    for k in range(n) :
        result += Kxy[i,k] * Kxy[i,k]
result

In [None]:
# c10 88
result = 0.0
for i in range(m) :
    for k in range(n) :
        for l in range(n) :
            if k != l :
                result += Kxy[i,k] * Kxy[i,l]
result

In [None]:
# c11 134
result = 0.0
for i in range(m) :
    for j in range(m) :
        if i != j :
            for k in range(n) :
                result += Kxy[i,k] * Kxy[j,k]
result

In [None]:
# c12 128
result = 0.0
for i in range(m) :
    for j in range(m) :
        if i != j :
            for k in range(n) :
                for l in range(n) :
                    if k != l :
                        result += Kxy[i,k] * Kxy[j,l]
result

In [None]:
# c13 8
result = 0.0
for k in range(n) :
    for l in range(n) :
        if k != l :
            result += Kyy[k,l] * Kyy[k,l]
result

In [None]:
# c14 8
result = 0.0
for k in range(n) :
    for u in range(n) :
        for l in range(n) :
            if k != l and k != u :
                result += Kyy[k,l] * Kyy[k,u]
result

In [None]:
# c15 16
result = 0.0
for k in range(n) :
    for l in range(n) :
        if k != l :
            for u in range(n) :
                for t in range(n) :
                    if u != t  :
                        result += Kyy[k,l] * Kyy[u,t]
result

In [None]:
def calc_xi(coefficients, mmd2):
    xi_value = sum(coefficients)
    return xi_value - mmd2 

In [None]:
mm = m * (m-1) 
nn = n * (n-1)
mn = m * (n-1)
nm = n * (m-1)

In [None]:
[C[2]/(mm*(m-1)), 2*C[4]/(mm * nn), -2*C[5]/(mm * n), -2*C[6]/(mm*(m-1)*n), 
        -4*C[8]/(m*n*nn), C[10]/(m*nn), 3*C[12]/(mm * nn), C[15]/(nn * nn)]

In [None]:
result = 0.0
for i in range(m) :
    for j in range(m) :
        for v in range(m) :
            if i != j and i != v :
                for k in range(n) :
                    for u in range(n) :
                        for l in range(n) :
                            for t in range(n) :
                                if k != l and u != t :
                                    result += Kyy[k,l] * Kyy[u,t]
result/(mm*(m-1)*nn*nn)

In [None]:
C[3] * nn * nn

In [None]:
def compute_Xi_values(C, m, n, mmd_sq, complete=True):
    mmd2 = mmd_sq ** 2 
    
    mm = m * (m-1) 
    nn = n * (n-1)
    mn = m * (n-1)
    nm = n * (m-1)

    Xi = [
        # Xi_01
        calc_xi([C[3]/(mm**2), 2*C[4]/(mm * nn * (n-1)), -4*C[6]/(m*mm*n), -2*C[7]/(m * nn), 
                -2*C[8]/(m*nn*(n-1)), C[11]/(mm * n), 3*C[12]/(mm * nn), C[14]/(nn * (n-1))], mmd2),
        
        # Xi_02
        calc_xi([C[3]/(mm**2), 2*C[4]/(mm*nn), -4*C[6]/(m * mm * nn), -4*C[7]/(m * nn), 
                2*C[11]/(mm * n), 2*C[12]/(mm * nn), C[13]/nn], mmd2),
        
        # Xi_10 
        calc_xi([C[2]/(mm*(m-1)), 2*C[4]/(mm * nn), -2*C[5]/(mm * n), -2*C[6]/(mm*(m-1)*n), 
                -4*C[8]/(m*n*nn), C[10]/(m*nn), 3*C[12]/(mm * nn), C[15]/(nn * nn)], mmd2), 
        
        # Xi_11 
        calc_xi([C[2]/(mm*(m-1)), 2*C[4]/(mm * nn), -2*C[5]/(mm * n), -2*C[6]/(mm * nm), 
                 -2*C[7]/(m * nn), -2*C[8]/(mn * nn), 0.25*C[9]/mn, 0.75*C[10]/(m*nn), 0.75*C[11]/(mm * n), 
                 2.25*C[12]/(mm * nn), C[14] / (nn * (n-1))], mmd2), 
        
        # Xi_12 
        calc_xi([C[1]/(mm * (m-1)), 2*C[4]/(mm*nn), -2*C[5]/(mm * n), -2*C[6]/(mm * nm), 
                 -4*C[7]/(m*nn), 0.5*C[9]/(mn), 0.5*C[10]/(m * nn), 1.5*C[11]/(mm*n), 1.5*C[12]/(mm * nn), 
                 C[13] / (nn)], mmd2), 
        
        # Xi_20 
        calc_xi([C[1]/mm , 2*C[4]/(mm * nn), -4*C[5]/(mm * n), -4*C[8]/(m*n*n), 2*C[10]/(m * nn),
                 2*C[12]/(mm * nn), C[15]/(nn * nn)], mmd2), 
        
        # Xi_21 
        calc_xi([C[1]/mm, 2*C[4]/(mm * nn), -4*C[5]/(mm * n), -2*C[7]/(nn * (n-1)), -2*C[8]/(mn * nn), 
                 0.5*C[9]/(mn), 1.5*C[10]/(m * nn), 0.5*C[11]/(mm * n), 1.5*C[12]/(mm * nn), C[14]/(nn * (n-1))], mmd2), 
        
        # Xi_22 
        calc_xi([C[1]/mm, 2*C[4]/(mm * nn), -4*C[5]/(mm * n), -4*C[7]/(m * nn), C[9]/(mn), C[10]/(m * nn),
                 C[11]/(mm * n), C[12]/(mm * nn), C[13]/nn], mmd2)
        ]

    if complete == False:
        Xi = [Xi[0], Xi[2]]


    return Xi

In [None]:
calc_xi([C[3]/(mm**2), 2*C[4]/(mm * nn * (n-1)), -4*C[6]/(m*mm*n), -2*C[7]/(m * nn), 
                -2*C[8]/(m*nn*(n-1)), C[11]/(mm * n), 3*C[12]/(mm * nn), C[14]/(nn * (n-1))], -1)

In [None]:
calc_xi([C[2]/(mm*(m-1)), 2*C[4]/(mm * nn), -2*C[5]/(mm * n), -2*C[6]/(mm*(m-1)*n), 
        -4*C[8]/(m*n*nn), C[10]/(m*nn), 3*C[12]/(mm * nn), C[15]/(nn * nn)], -1), 


In [None]:
C[3]/(mm**2)

In [None]:
result = 0.0
for i in range(m) :
    for j in range(m) :
        for v in range(m) :
            if i != j and i != v :
                for k in range(n) :
                    for u in range(n) :
                        for l in range(n) :
                            for t in range(n) :
                                if k != l and u != t :
                                    result += Kyy[k,l] * Kyy[u,t]
result/(mm*(m-1)*nn*nn)

1. C1

In [None]:
result = 0.0
for i in range(m) :
    for j in range(m) :
        if j != i :
            result += Kxx[i,j] * Kxx[i,j]
result

In [None]:
jnp.trace(tKxx.T @ tKxx)

2. C2

In [None]:
result = 0.0
for i in range(m) :
    for j in range(m) :
        for v in range(m) :
            if j != i and v!= i :
                print("i,j,v : ",i,j,v)
                result += Kxx[i,j] * Kxx[i,v]
result

In [None]:
jnp.sum(tKxx.T @ tKxx)

In [None]:
one_3 = jnp.ones(3)

In [None]:
one_3.T @ tKxx.T @ tKxx @ one_3