## Demo for tst in the text data

In [None]:
from adaptesting import tst # Load the main library to conduct tst

# Load HC3 data as example, the input text data could be either list of strings ["text1", "text2", ...] 
# or embeddings of the text data in the form of Pytorch Tensor
import torch
import random
import time
from datasets import load_dataset
from transformers import logging

logging.set_verbosity_error()  # This will hide the warning

start = time.time()
torch.manual_seed(0)
random.seed(0)

if torch.cuda.is_available():
    device = torch.device("cuda:0")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

dataset = load_dataset("Hello-SimpleAI/HC3", "all")

filtered_dataset = dataset['train'].filter(
    lambda x: x['source'] == 'medicine')  # 1248 samples

def flatten(list_of_lists):
    # Flatten a list of lists where each inner list contains one string
    return [item for sublist in list_of_lists for item in sublist]

# Text data from huggingface has a form of [["text1"], ["text2"], ...], flatten it
Z1 = flatten(filtered_dataset["chatgpt_answers"])
Z2 = flatten(filtered_dataset["human_answers"])  # Test power
# Z2 = Z1  # Type-I error

counter = 0
n_trial = 100
n_samples = 100

# Conduct Experiments for n_trial times,
# remove the for loop if only want to get a result of reject or not
for _ in range(n_trial):

    # Sample X and Y from Z using the selected indices
    random.shuffle(Z1)
    X = Z1[:n_samples]

    random.shuffle(Z2)
    Y = Z2[:n_samples]

    # Five kinds of SOTA TST methods to choose
    # !!! Must input data type as "text" for text data !!!
    h, _, _ = tst(X, Y, device=device, data_type="text")  # default method is median heuristic
    # h, _, _ = tst(X, Y, device=device, method="fuse", data_type="text", kernel="laplace_gaussian", n_perm=2000)
    # h, _, _ = tst(X, Y, device=device, method="agg", data_type="text", n_perm=3000)
    # h, _, _ = tst(X, Y, device=device, method="clf", data_type="text", patience=150, n_perm=200)
    # h, _, _ = tst(X, Y, device=device, method="deep", data_type="text", patience=150, n_perm=200)
    counter += h

print(f"Power: {counter}/{n_trial}")
end = time.time()
print(f"Time taken: {end - start:.4f} seconds")

Fail to reject the null hypothesis with p-value: 0.29, the MMD value is 0.0018093585968017578.
Reject the null hypothesis with p-value: 0.0, the MMD value is 0.052556753158569336.
Reject the null hypothesis with p-value: 0.0, the MMD value is 0.019411563873291016.
Reject the null hypothesis with p-value: 0.01, the MMD value is 0.018266797065734863.
Reject the null hypothesis with p-value: 0.0, the MMD value is 0.00862109661102295.
Reject the null hypothesis with p-value: 0.0, the MMD value is 0.05629551410675049.
Fail to reject the null hypothesis with p-value: 0.5, the MMD value is -0.0006477832794189453.
Fail to reject the null hypothesis with p-value: 0.06, the MMD value is 0.00481569766998291.
Fail to reject the null hypothesis with p-value: 0.62, the MMD value is -0.002126932144165039.
Reject the null hypothesis with p-value: 0.0, the MMD value is 0.018361926078796387.
Reject the null hypothesis with p-value: 0.0, the MMD value is 0.03613781929016113.
Fail to reject the null hypot