The code in this notebook replicates the CIFAR-10 vs CIFAR-10.1 experiment of
Liu et al.
(Learning Deep Kernels for Non-Parametric Two-Sample Tests, 
ICML 2020). 
We utilize their code which is under the MIT license:
https://github.com/fengliu90/DK-for-TST/blob/master/Deep_Baselines_CIFAR10.py

Not multiplying max{n_0, n_1} to denominator of the objective function 

# Environment mmdfuse-env

In [1]:
%load_ext autoreload 
%autoreload 2

In [2]:
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "7"

In [3]:
import numpy as np
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision import datasets
import torch
from jax import random
import jax.numpy as jnp
from tqdm.auto import tqdm
from pathlib import Path
from models import MyModel
Path("results").mkdir(exist_ok=True)

  warn(


In [4]:
from all_tests import mmdfuse_test
from all_tests import mmd_median_test, mmd_split_test
from all_tests import mmdagg_test, mmdagginc_test, deep_mmd_test
from all_tests import met_test, scf_test
from all_tests import ctt_test, actt_test
from tests import c2st_tst



In [5]:
# parameters
N1 = 50
img_size = 64
batch_size = 100
K = 3
N = 10

In [6]:
# Load the CIFAR 10 data and CIFAR 10.1

# Configure data loader
dataset_test = datasets.CIFAR10(root='./cifar_data/cifar10', download=True,train=False,
                           transform=transforms.Compose([
                               transforms.Resize(img_size),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                           ]))

dataloader_test = torch.utils.data.DataLoader(dataset_test, batch_size=10000, shuffle=True, num_workers=1)

# Obtain CIFAR10 images
for i, (imgs, Labels) in enumerate(dataloader_test):
    data_all = imgs
    label_all = Labels
Ind_all = np.arange(len(data_all))

# Obtain CIFAR10.1 images
data_new = np.load('./cifar_data/cifar10.1_v4_data.npy')
data_T = np.transpose(data_new, [0,3,1,2])
ind_M = np.random.choice(len(data_T), len(data_T), replace=False)
data_T = data_T[ind_M]
TT = transforms.Compose([transforms.Resize(img_size),transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trans = transforms.ToPILImage()
data_trans = torch.zeros([len(data_T),3,img_size,img_size])
data_T_tensor = torch.from_numpy(data_T)
for i in range(len(data_T)):
    d0 = trans(data_T_tensor[i])
    data_trans[i] = TT(d0)
Ind_v4_all = np.arange(len(data_T))

Files already downloaded and verified


In [7]:
# Setting for C2ST
device = 'cuda:0'
dtype = torch.float 
lr = 0.001 
n_epoch = 25
alpha = 0.05

, mmd_median_test, mmdagg_test, mmdagginc_test, ctt_test, actt_test, )

In [8]:
# Run experiment
save = True
seed = 0
key = random.PRNGKey(42)
imbalance_ratios = (1, 10, 30, 50, 70, 100)
tests = [mmdagg_test]

outputs = jnp.zeros((len(tests), len(imbalance_ratios), K, N))
outputs = outputs.tolist()

for ir in tqdm(range(len(imbalance_ratios)), desc="Imbalance Ratio Loop"):
    N2 = N1 * imbalance_ratios[ir]
    for kk in tqdm(range(K), desc="Trail Loop(K)", leave=False):
        torch.manual_seed(kk * 19 + N1 + N2)
        torch.cuda.manual_seed(kk * 19 + N1 + N2)
        np.random.seed(seed=1102 * (kk + 10) + N1 + N2)

        # Collect CIFAR10 images
        Ind_tr = np.random.choice(len(data_all), N2, replace=False)
        Ind_te = np.delete(Ind_all, Ind_tr)
        train_data = []
        for i in Ind_tr:
            train_data.append([data_all[i], label_all[i]])

        dataloader = torch.utils.data.DataLoader(
            train_data,
            batch_size=batch_size,
            shuffle=True,
        )

        # Collect CIFAR10.1 images
        np.random.seed(seed=819 * (kk + 9) + N1)
        Ind_tr_v4 = np.random.choice(len(data_T), N1, replace=False)
        Ind_te_v4 = np.delete(Ind_v4_all, Ind_tr_v4)
        New_CIFAR_tr = data_trans[Ind_tr_v4]
        New_CIFAR_te = data_trans[Ind_te_v4]
        
        # Run two-sample test on the training set
        # Fetch training data
        s1_tr = data_all[Ind_tr]
        s2_tr = data_trans[Ind_tr_v4]
        model = MyModel(in_channels=3, img_size = 64, device=device, dtype=dtype)
        
        for k in tqdm(range(N), desc="Test Loop(N)", leave=False):
            # Fetch test data
            np.random.seed(seed=1102 * (k + 1) + N1 + N2)
            data_all_te = data_all[Ind_te]
            # N_te = len(data_trans) - N1
            N_te = N2
            Ind_N_te = np.random.choice(len(Ind_te), N_te, replace=False)
            s1_te = data_all_te[Ind_N_te]
            s2_te = data_trans[Ind_te_v4]
  
            # concatenate the split data
            X = jnp.array(torch.cat((s1_tr, s1_te)))
            Y = jnp.array(torch.cat((s2_tr, s2_te)))
            
            seed += 1
            key, subkey = random.split(key)
            for t in range(len(tests)):
                test = tests[t]
                print(str(test))
                if test == c2st_tst:
                    outputs[t][ir][kk][k] = c2st_tst(s1_tr, s2_tr, s1_te, s2_te, model, alpha, lr, n_epoch, seed, 
                                    loss_fn=model.smooth_objective, device=device)
                    
                elif test == mmdagginc_test:
                    X = X.reshape(X.shape[0], -1)  # This will reshape X to have two dimensions
                    Y = Y.reshape(Y.shape[0], -1)  # Similarly for Y
                    outputs[t][ir][kk][k] = test(X, Y, subkey, seed)
                    
                else:
                    outputs[t][ir][kk][k] = test(X, Y, subkey, seed)


output = jnp.mean(jnp.array(outputs), -1)

if save:
    jnp.save(f"results/cifar_mmdagg{N1}.npy", output)

for t in range(len(tests)):
    print(" ")
    print(tests[t])
    print(output[t])

Imbalance Ratio Loop:   0%|          | 0/6 [00:00<?, ?it/s]

Trail Loop(K):   0%|          | 0/3 [00:00<?, ?it/s]

Test Loop(N):   0%|          | 0/10 [00:00<?, ?it/s]

<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>


Test Loop(N):   0%|          | 0/10 [00:00<?, ?it/s]

<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>


Test Loop(N):   0%|          | 0/10 [00:00<?, ?it/s]

<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>


Trail Loop(K):   0%|          | 0/3 [00:00<?, ?it/s]

Test Loop(N):   0%|          | 0/10 [00:00<?, ?it/s]

<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>


Test Loop(N):   0%|          | 0/10 [00:00<?, ?it/s]

<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>


Test Loop(N):   0%|          | 0/10 [00:00<?, ?it/s]

<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>


Trail Loop(K):   0%|          | 0/3 [00:00<?, ?it/s]

Test Loop(N):   0%|          | 0/10 [00:00<?, ?it/s]

<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>


Test Loop(N):   0%|          | 0/10 [00:00<?, ?it/s]

<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>


Test Loop(N):   0%|          | 0/10 [00:00<?, ?it/s]

<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>


Trail Loop(K):   0%|          | 0/3 [00:00<?, ?it/s]

Test Loop(N):   0%|          | 0/10 [00:00<?, ?it/s]

<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>


Test Loop(N):   0%|          | 0/10 [00:00<?, ?it/s]

<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>


Test Loop(N):   0%|          | 0/10 [00:00<?, ?it/s]

<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>


Trail Loop(K):   0%|          | 0/3 [00:00<?, ?it/s]

Test Loop(N):   0%|          | 0/10 [00:00<?, ?it/s]

<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>


Test Loop(N):   0%|          | 0/10 [00:00<?, ?it/s]

<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>


Test Loop(N):   0%|          | 0/10 [00:00<?, ?it/s]

<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>


Trail Loop(K):   0%|          | 0/3 [00:00<?, ?it/s]

Test Loop(N):   0%|          | 0/10 [00:00<?, ?it/s]

<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>


Test Loop(N):   0%|          | 0/10 [00:00<?, ?it/s]

<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>


Test Loop(N):   0%|          | 0/10 [00:00<?, ?it/s]

<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
<function mmdagg_test at 0x7f33306d45e0>
 
<function mmdagg_test at 0x7f33306d45e0>
[[0.1        0.2        0.        ]
 [0.90000004 0.3        0.3       ]
 [1.         1.         1.        ]
 [1.         1.         1.        ]
 [1.         1.         1.        ]
 [1.         1.         1.        ]]


# Environment autogluon-env

In [9]:
import numpy as np
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision import datasets
import torch
from jax import random
import jax.numpy as jnp
from tqdm.auto import tqdm
from pathlib import Path
Path("results").mkdir(exist_ok=True)

In [10]:
import autotst
from utils import HiddenPrints

def autotst_test(X, Y, key, seed, time=60):
    with HiddenPrints():
        tst = autotst.AutoTST(X, Y, split_ratio=0.5, model=autotst.model.AutoGluonTabularPredictor)
        tst.split_data()
        tst.fit_witness(time_limit=time)  # time limit adjustable to your needs (in seconds)
        p_value = tst.p_value_evaluate(permutations=10000)  # control number of permutations in the estimation
    return int(p_value <= 0.05)

ModuleNotFoundError: No module named 'autotst'

In [None]:
# parameters
N1 = 1000
img_size = 64
batch_size = 100
K = 3
N = 10

In [None]:
# Load the CIFAR 10 data and CIFAR 10.1

# Configure data loader
dataset_test = datasets.CIFAR10(root='./cifar_data/cifar10', download=True,train=False,
                           transform=transforms.Compose([
                               transforms.Resize(img_size),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                           ]))

dataloader_test = torch.utils.data.DataLoader(dataset_test, batch_size=10000, shuffle=True, num_workers=1)

# Obtain CIFAR10 images
for i, (imgs, Labels) in enumerate(dataloader_test):
    data_all = imgs
    label_all = Labels
Ind_all = np.arange(len(data_all))

# Obtain CIFAR10.1 images
data_new = np.load('./cifar_data/cifar10.1_v4_data.npy')
data_T = np.transpose(data_new, [0,3,1,2])
ind_M = np.random.choice(len(data_T), len(data_T), replace=False)
data_T = data_T[ind_M]
TT = transforms.Compose([transforms.Resize(img_size),transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trans = transforms.ToPILImage()
data_trans = torch.zeros([len(data_T),3,img_size,img_size])
data_T_tensor = torch.from_numpy(data_T)
for i in range(len(data_T)):
    d0 = trans(data_T_tensor[i])
    data_trans[i] = TT(d0)
Ind_v4_all = np.arange(len(data_T))

In [None]:
# Run experiment

seed = 0
key = random.PRNGKey(42)

tests = (autotst_test, )

outputs = [[] for _ in range(len(tests))]
for kk in tqdm(range(K)):
    torch.manual_seed(kk * 19 + N1)
    torch.cuda.manual_seed(kk * 19 + N1)
    np.random.seed(seed=1102 * (kk + 10) + N1)

    # Collect CIFAR10 images
    Ind_tr = np.random.choice(len(data_all), N1, replace=False)
    Ind_te = np.delete(Ind_all, Ind_tr)
    train_data = []
    for i in Ind_tr:
        train_data.append([data_all[i], label_all[i]])

    dataloader = torch.utils.data.DataLoader(
        train_data,
        batch_size=batch_size,
        shuffle=True,
    )

    # Collect CIFAR10.1 images
    np.random.seed(seed=819 * (kk + 9) + N1)
    Ind_tr_v4 = np.random.choice(len(data_T), N1, replace=False)
    Ind_te_v4 = np.delete(Ind_v4_all, Ind_tr_v4)
    New_CIFAR_tr = data_trans[Ind_tr_v4]
    New_CIFAR_te = data_trans[Ind_te_v4]
    
    # Run two-sample test on the training set
    # Fetch training data
    s1_tr = data_all[Ind_tr]
    s2_tr = data_trans[Ind_tr_v4]
    
    for k in tqdm(range(N)):
        # Fetch test data
        np.random.seed(seed=1102 * (k + 1) + N1)
        data_all_te = data_all[Ind_te]
        N_te = len(data_trans) - N1
        Ind_N_te = np.random.choice(len(Ind_te), N_te, replace=False)
        s1_te = data_all_te[Ind_N_te]
        s2_te = data_trans[Ind_te_v4]
        
        # MMD-FUSE & MMDAgg do not split the data
        s1_tr = jnp.array(s1_tr)
        s1_te = jnp.array(s1_te)
        s2_tr = jnp.array(s2_tr)
        s2_te = jnp.array(s2_te)
        
        # concatenate the split data
        X = jnp.concatenate((s1_tr, s1_te))
        Y = jnp.concatenate((s2_tr, s2_te))
        
        seed += 1
        key, subkey = random.split(key)
        for t in range(len(tests)):
            test = tests[t]
            outputs[t].append(client.submit(test, X, Y, subkey, seed))

output = jnp.mean(jnp.array(outputs), -1)

jnp.save("results/cifar_autotst.npy", output)

for t in range(len(tests)):
    print(" ")
    print(tests[t])
    print(output[t])