## Demo for tst in the tabular data

In [1]:
from adaptesting import tst # Load the main library to conduct tst

# Generate data as example, make sure the input data should be Pytorch Tensor
import torch
import random
import time
from torch.distributions import MultivariateNormal

start = time.time()
torch.manual_seed(0)
random.seed(0)

if torch.cuda.is_available():
    device = torch.device("cuda:0")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

mean = torch.tensor([0.5, 0.5])
cov1 = torch.tensor([[1.0, 0.5], [0.5, 1.0]])
cov2 = torch.tensor([[1.0, 0], [0, 1.0]])
mvn1 = MultivariateNormal(mean, cov1)
mvn2 = MultivariateNormal(mean, cov2)

counter = 0
n_trial = 100
n_samples = 250

# Conduct Experiments for n_trial times, 
# remove the for loop if only want to get a result of reject or not
for _ in range(n_trial):

    # Uncomment Z2 with same distribution to test Type-I error
    Z1 = mvn1.sample((1000,))
    Z2 = mvn2.sample((1000,))  # Test power
    # Z2 = mvn1.sample((1000,))  # Type-I error

    # Create a list of indices from 0 to 1000
    indices = list(range(1000))

    # Shuffle the indices
    random.shuffle(indices)

    # Select the n_samples for X
    X_indices = indices[:n_samples]

    # Select the n_samples for Y
    # Y_indices = indices[:n_samples]

    # Sample X and Y from Z using the selected indices, 
    # X.size() = (n_samples, 2), Y.size() = (n_samples, 2)
    X = Z1[X_indices]
    Y = Z2[X_indices]

    # Five kinds of SOTA TST methods to choose
    h, _, _ = tst(X, Y, device=device) # default method is median heuristic
    # h, _, _ = tst(X, Y, device=device, method="fuse", kernel="laplace_gaussian", n_perm=2000)
    # h, _, _ = tst(X, Y, device=device, method="agg", n_perm=3000)
    # h, _, _ = tst(X, Y, device=device, method="clf", data_type="tabular", patience=150, n_perm=200)
    # h, _, _ = tst(X, Y, device=device, method="deep", data_type="tabular", patience=150, n_perm=200)
    counter += h

print(f"Power: {counter}/{n_trial}")
end = time.time()
print(f"Time taken: {end - start:.4f} seconds")

Reject the null hypothesis with p-value: 0.02, the MMD value is 0.0060651302337646484.
Reject the null hypothesis with p-value: 0.0, the MMD value is 0.008074045181274414.
Reject the null hypothesis with p-value: 0.01, the MMD value is 0.0069904327392578125.
Fail to reject the null hypothesis with p-value: 0.1, the MMD value is 0.0027581453323364258.
Fail to reject the null hypothesis with p-value: 0.05, the MMD value is 0.00699925422668457.
Reject the null hypothesis with p-value: 0.0, the MMD value is 0.009712934494018555.
Fail to reject the null hypothesis with p-value: 0.11, the MMD value is 0.004103302955627441.
Fail to reject the null hypothesis with p-value: 0.07, the MMD value is 0.0034340620040893555.
Reject the null hypothesis with p-value: 0.02, the MMD value is 0.006949901580810547.
Reject the null hypothesis with p-value: 0.04, the MMD value is 0.005875945091247559.
Reject the null hypothesis with p-value: 0.0, the MMD value is 0.013964414596557617.
Reject the null hypothe

### Performance display after running 100 trials of different input samples

| Method       | Median | MMD-FUSE | MMD-Agg | MMD-Deep | C2ST-MMD |
| ------------ | ------ | -------- | ------- | -------- | -------- |
| Test Power   | 0.69   | 0.56     | 0.72    | 0.72     | 0.71     |
| Type-I Error | 0.01   | 0.03     | 0.04    | 0.05     | 0.06     |
| Runtime (s)  | 4.49   | 10.14    | 3.48    | 486.81   | 570.94   |