# MNIST and SST-2 Conformal Changepoint Localization

This notebook implements changepoint detection on MNIST digit images and SST-2 sentiment changes using conformal prediction.

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from scipy.stats import ks_2samp, ks_1samp, uniform
import torch
from torchvision.datasets import MNIST
from torchvision import transforms
from torchvision.models import ResNet18_Weights
from tqdm import tqdm

In [None]:
import matplotlib_inline.backend_inline

matplotlib_inline.backend_inline.set_matplotlib_formats(
    "svg"
)
plt.style.use("math.mplstyle")

## MNIST Dataset with Changepoint

Create a dataset with digit 3 before the changepoint and digit 7 after.

In [None]:
def get_pretrained_model(device="cpu"):
    model = torch.hub.load(
        "pytorch/vision:v0.10.0", "resnet18", weights=ResNet18_Weights.IMAGENET1K_V1
    )

    model.conv1 = torch.nn.Conv2d(
        1, 64, kernel_size=7, stride=2, padding=3, bias=False
    )

    model.fc = torch.nn.Linear(model.fc.in_features, 10)
    model.to(device)
    model.eval()
    return model

In [None]:
def generate_mnist_dataset(length, changepoint, digit1=3, digit2=7):
    """Generate MNIST dataset with a changepoint."""
    transform = transforms.ToTensor()
    mnist_data = MNIST(
        root="./data", train=True, download=True, transform=transform
    )
    data = mnist_data.data.numpy()
    targets = mnist_data.targets.numpy()
    
    images_digit1 = data[targets == digit1]
    images_digit2 = data[targets == digit2]
    np.random.shuffle(images_digit1)
    np.random.shuffle(images_digit2)
    
    n1 = changepoint + 1
    n2 = length - n1
    if n1 > len(images_digit1) or n2 > len(images_digit2):
        raise ValueError("Insufficient images for the specified digits and length.")
        
    data1 = images_digit1[:n1]
    data2 = images_digit2[:n2]
    x = np.concatenate([data1, data2], axis=0)
    
    x = x.reshape(length, -1).astype(np.float32) / 255.0
    return x

In [None]:
def predict_digit(model, image, device="cpu"):
    image = image.reshape(1, 1, 28, 28)
    image_tensor = torch.tensor(image, device=device)
    
    with torch.no_grad():
        outputs = torch.softmax(model(image_tensor), dim=1).cpu()
        predicted = outputs.argmax(dim=1).item()
    return (predicted, outputs)

In [None]:
def get_mnist_trained_model(device="cpu"):
    class MNISTModel(torch.nn.Module):
        def __init__(self):
            super(MNISTModel, self).__init__()
            self.conv1 = torch.nn.Conv2d(1, 32, 3, 1)
            self.conv2 = torch.nn.Conv2d(32, 64, 3, 1)
            self.dropout1 = torch.nn.Dropout(0.25)
            self.dropout2 = torch.nn.Dropout(0.5)
            self.fc1 = torch.nn.Linear(9216, 128)
            self.fc2 = torch.nn.Linear(128, 10)

        def forward(self, x):
            x = self.conv1(x)
            x = torch.nn.functional.relu(x)
            x = self.conv2(x)
            x = torch.nn.functional.relu(x)
            x = torch.nn.functional.max_pool2d(x, 2)
            x = self.dropout1(x)
            x = torch.flatten(x, 1)
            x = self.fc1(x)
            x = torch.nn.functional.relu(x)
            x = self.dropout2(x)
            x = self.fc2(x)
            return x

    model = MNISTModel().to(device)

    transform = transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
    )

    train_dataset = MNIST(root="./data", train=True, download=True, transform=transform)
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64)

    optimizer = torch.optim.Adam(model.parameters())
    criterion = torch.nn.CrossEntropyLoss()

    print("Training MNIST model...")
    model.train()
    for epoch in range(
        1
    ):
        for batch_idx, (data, target) in enumerate(tqdm(train_loader)):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()

            if batch_idx % 100 == 0:
                print(
                    f"Epoch: {epoch} [{batch_idx*len(data)}/{len(train_loader.dataset)} "
                    f"({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}"
                )

    test_dataset = MNIST(root="./data", train=False, download=True, transform=transform)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1000)

    model.eval()
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    accuracy = 100.0 * correct / len(test_loader.dataset)
    print(f"Test accuracy: {accuracy:.2f}%")

    model.eval()
    return model

In [None]:
model = get_mnist_trained_model()

In [None]:
length = 1000
changepoint = 400
digit1 = 3
digit2 = 7

x = generate_mnist_dataset(length, changepoint, digit1, digit2)

predicted_digits = [predict_digit(model, x[i]) for i in tqdm(range(length))]
probabilities = torch.vstack([prob for _, prob in predicted_digits])

## Compute Left and Right Scores

These scores are used for the conformal changepoint analysis.

In [None]:
left_score = np.zeros((length, length))
seen_digits = {}
for t, (predicted, _) in enumerate(predicted_digits):
    if predicted in seen_digits:
        seen_digits[predicted] += 1
    else:
        seen_digits[predicted] = 1
    curr_digit = max(seen_digits, key=seen_digits.get)
    left_score[t, : t + 1] = probabilities[: t + 1, curr_digit].cpu() / (
        1 - probabilities[: t + 1, curr_digit].cpu()
    )

In [None]:
right_score = np.zeros((length, length))
seen_digits = {}
for i, (predicted, _) in enumerate(reversed(predicted_digits)):
    t = length - i - 1
    if predicted in seen_digits:
        seen_digits[predicted] += 1
    else:
        seen_digits[predicted] = 1
    curr_digit = max(seen_digits, key=seen_digits.get)
    right_score[t, t:] = probabilities[t:, curr_digit].cpu() / (
        1 - probabilities[t:, curr_digit].cpu()
    )

## Calculate Discrepancy Scores

Using KS test to detect the changepoint.

In [None]:
def get_discrepancy_scores(x, scores_left, scores_right):
    discrepancy_scores = np.empty(len(x) - 1)
    statistics = []
    for t in tqdm(range(len(x) - 1)):
        p = np.empty(len(x))
        for r in range(t + 1):
            p[r] = (
                np.count_nonzero(scores_left[t, : r + 1] > scores_left[t, r])
                + np.random.uniform(0, 1) * np.count_nonzero(scores_left[t, : r + 1] == scores_left[t, r])
            ) / (r + 1)
        for r in range(len(x) - 1, t, -1):
            p[r] = (
                np.count_nonzero(scores_right[t, r:] > scores_right[t, r])
                + np.random.uniform(0, 1) * np.count_nonzero(scores_right[t, r:] == scores_right[t, r])
            ) / (len(x) - r)
        statistics.append((ks_1samp(p[: t + 1], uniform.cdf), ks_1samp(p[t+1:], uniform.cdf)))
        discrepancy_scores[t] = (
            statistics[-1][0].statistic * np.sqrt(t + 1)
            + statistics[-1][1].statistic * np.sqrt(len(x) - t - 1)
        )
    return discrepancy_scores, statistics

In [None]:
discrepancy_scores, statistics = get_discrepancy_scores(x, left_score, right_score)

In [None]:
plt.figure(figsize=(10, 6))
plt.plot(discrepancy_scores)
plt.axvline(x=changepoint, color='red', linestyle='--', label='True Changepoint')
plt.xlabel('Position t')
plt.ylabel('Discrepancy Score')
plt.title('MNIST Changepoint Detection')
plt.legend()
plt.show()

## Calculate p-values

Using the Bonferroni correction to combine p-values.

In [None]:
from scipy.stats import chi2

p_values_left = np.array([s[0].pvalue for s in statistics])
p_values_right = np.array([s[1].pvalue for s in statistics])

min_method = chi2.cdf(np.minimum(2 * p_values_left, 2 * p_values_right, np.ones_like(p_values_left)), 4)
threshold = 0.05

plt.plot(np.arange(1, length), min_method)
plt.axvline(
    changepoint, color="red", linestyle="--", label="Changepoint ($\\xi = 400$)"
)
plt.axhline(threshold, color='green', linestyle=':', label='Threshold ($\\alpha = 0.05$)')
plt.xlabel("$t$")
plt.ylabel("p-value ($p_t$)")
plt.title("p-values for MNIST digit change (pre-trained classifier)")
plt.legend()
plt.savefig("images/mnist-pvalues.pdf")
plt.show()

In [None]:
confidence_set = np.argwhere(min_method > threshold).flatten()
confidence_interval = (confidence_set[0], confidence_set[-1]) if len(confidence_set) > 0 else None

print(f"True changepoint: {changepoint}")
print(f"Confidence interval: {confidence_interval}")
print(f"Minimum of Fisher's statistic at t={np.argmax(min_method)+1}")

In [None]:
def evaluate_confidence_intervals(
    n_trials=50,
    length=1000,
    changepoint=400,
    alpha=0.05,
    digit1=3,
    digit2=7,
    calibration_size=100,
    method="basic",
):
    coverages = []
    widths = []

    for trial in tqdm(range(n_trials), desc=f"Running {method} method trials"):
        if method == "basic":
            x = generate_mnist_dataset(length, changepoint, digit1, digit2)
            predicted_digits = [predict_digit(model, x[i]) for i in range(length)]
            probabilities = torch.vstack([prob for _, prob in predicted_digits])

            left_score = np.zeros((length, length))
            seen_digits = {}
            for t, (predicted, _) in enumerate(predicted_digits):
                if predicted in seen_digits:
                    seen_digits[predicted] += 1
                else:
                    seen_digits[predicted] = 1
                curr_digit = max(seen_digits, key=seen_digits.get)
                left_score[t, : t + 1] = probabilities[: t + 1, curr_digit].cpu() / (
                    1 - probabilities[: t + 1, curr_digit].cpu()
                )

            right_score = np.zeros((length, length))
            seen_digits = {}
            for i, (predicted, _) in enumerate(reversed(predicted_digits)):
                t = length - i - 1
                if predicted in seen_digits:
                    seen_digits[predicted] += 1
                else:
                    seen_digits[predicted] = 1
                curr_digit = max(seen_digits, key=seen_digits.get)
                right_score[t, t:] = probabilities[t:, curr_digit].cpu() / (
                    1 - probabilities[t:, curr_digit].cpu()
                )

            discrepancy_scores, statistics = get_discrepancy_scores(
                x, left_score, right_score
            )

        elif method == "both_cal":
            x_main, x_cal_pre, x_cal_post = generate_extended_mnist_dataset(
                length, changepoint, calibration_size, digit1, digit2
            )

            predicted_digits = [predict_digit(model, x_main[i]) for i in range(length)]
            probabilities = torch.vstack([prob for _, prob in predicted_digits])

            predicted_cal_pre = [
                predict_digit(model, x_cal_pre[i]) for i in range(calibration_size)
            ]
            probabilities_cal_pre = torch.vstack(
                [prob for _, prob in predicted_cal_pre]
            )

            predicted_cal_post = [
                predict_digit(model, x_cal_post[i]) for i in range(calibration_size)
            ]
            probabilities_cal_post = torch.vstack(
                [prob for _, prob in predicted_cal_post]
            )

            left_scores, left_scores_cal = compute_left_scores_with_calibration(
                probabilities,
                predicted_digits,
                probabilities_cal_pre,
                predicted_cal_pre,
                length,
            )

            right_scores, right_scores_cal = compute_right_scores_with_calibration(
                probabilities,
                predicted_digits,
                probabilities_cal_post,
                predicted_cal_post,
                length,
            )

            # Calculate discrepancy scores with calibration
            discrepancy_scores, statistics = get_discrepancy_scores_with_calibration(
                x_main, left_scores, right_scores, left_scores_cal, right_scores_cal
            )

        elif method == "left_cal":
            x_main, x_cal_pre = generate_left_calibration_mnist_dataset(
                length, changepoint, calibration_size, digit1, digit2
            )

            # Get predictions
            predicted_digits = [predict_digit(model, x_main[i]) for i in range(length)]
            probabilities = torch.vstack([prob for _, prob in predicted_digits])

            predicted_cal_pre = [
                predict_digit(model, x_cal_pre[i]) for i in range(calibration_size)
            ]
            probabilities_cal_pre = torch.vstack(
                [prob for _, prob in predicted_cal_pre]
            )

            # Compute scores
            left_scores, left_scores_cal = compute_left_scores_with_left_calibration(
                probabilities,
                predicted_digits,
                probabilities_cal_pre,
                predicted_cal_pre,
                length,
            )

            right_scores = compute_right_scores_without_calibration(
                probabilities, predicted_digits, length
            )

            # Calculate discrepancy scores
            discrepancy_scores, statistics = (
                get_discrepancy_scores_with_left_calibration(
                    x_main, left_scores, right_scores, left_scores_cal
                )
            )

        # Extract p-values
        p_values_left = np.array([s[0].pvalue for s in statistics])
        p_values_right = np.array([s[1].pvalue for s in statistics])

        # Fisher's method to combine p-values
        min_method = chi2.cdf(
            np.minimum(
                2 * p_values_left, 2 * p_values_right, np.ones_like(p_values_left)
            ),
            4,
        )

        # Get confidence set
        confidence_set = np.argwhere(min_method <= alpha).flatten()

        if len(confidence_set) > 0:
            # Calculate interval width
            width = confidence_set[-1] - confidence_set[0]
            widths.append(width)

            # Check coverage
            coverage = changepoint + 1 in confidence_set
            coverages.append(coverage)
        else:
            # Empty confidence set - consider as no coverage
            coverages.append(False)
            widths.append(0)  # Width is 0 for empty set

    # Compute overall statistics
    coverage_rate = np.mean(coverages)
    avg_width = np.mean(widths)

    return coverage_rate, avg_width


# Function to compare all methods
def compare_methods(
    n_trials=20,
    length=200,
    changepoint=80,
    alpha=0.05,
    digit1=3,
    digit2=7,
    calibration_size=100,
):
    """
    Compare different changepoint detection methods.
    """
    methods = ["basic"]
    results = {}

    for method in methods:
        print(f"Evaluating {method} method...")
        coverage, avg_width = evaluate_confidence_intervals(
            n_trials=n_trials,
            length=length,
            changepoint=changepoint,
            alpha=alpha,
            digit1=digit1,
            digit2=digit2,
            calibration_size=calibration_size,
            method=method,
        )
        results[method] = {"coverage": coverage, "avg_width": avg_width}
        print(f"{method} - Coverage: {coverage:.4f}, Avg Width: {avg_width:.2f}")

    return results

In [None]:
if "model" not in locals():
    model = get_mnist_trained_model()

results = compare_methods(n_trials=200)

methods = list(results.keys())
coverages = [results[m]["coverage"] for m in methods]
widths = [results[m]["avg_width"] for m in methods]

fig, ax1 = plt.subplots(figsize=(10, 6))

x = np.arange(len(methods))
width = 0.35

ax1.bar(x - width / 2, coverages, width, label="Coverage", color="blue", alpha=0.7)
ax1.set_ylabel("Coverage", color="blue")
ax1.set_ylim([0, 1.1])
ax1.tick_params(axis="y", labelcolor="blue")

ax2 = ax1.twinx()
ax2.bar(x + width / 2, widths, width, label="Average Width", color="red", alpha=0.7)
ax2.set_ylabel("Average Width", color="red")
ax2.tick_params(axis="y", labelcolor="red")

plt.xticks(x, methods)
plt.title("Coverage and Average Width Comparison")

lines1, labels1 = ax1.get_legend_handles_labels()
lines2, labels2 = ax2.get_legend_handles_labels()
ax1.legend(lines1 + lines2, labels1 + labels2, loc="upper right")

plt.tight_layout()
plt.show()

## Visualize the Data

Display some example images before and after the changepoint.

In [None]:
x = generate_mnist_dataset(length, changepoint, digit1, digit2)

fig, axes = plt.subplots(1, 5, figsize=(6, 5))

for i, ax in enumerate(axes):
    idx = changepoint - 2 + i
    img = x[idx].reshape(28, 28)
    ax.imshow(img, cmap="gray")
    if i != 2:
        ax.set_title(f"$t = {idx}$")
    else:
        ax.set_title(f"$t = \\xi = {idx}$")
    ax.axis("off")

plt.tight_layout()
plt.savefig("images/mnist-sample.pdf")
plt.show()

# Augment on both sides

In [None]:
def generate_extended_mnist_dataset(
    length, changepoint, calibration_size=100, digit1=3, digit2=7
):
    """Generate MNIST dataset with a changepoint and calibration data."""
    transform = transforms.ToTensor()
    mnist_data = MNIST(root="./data", train=True, download=True, transform=transform)
    data = mnist_data.data.numpy()
    targets = mnist_data.targets.numpy()

    images_digit1 = data[targets == digit1]
    images_digit2 = data[targets == digit2]
    np.random.shuffle(images_digit1)
    np.random.shuffle(images_digit2)

    n1 = changepoint + 1 + calibration_size
    n2 = (length - changepoint - 1) + calibration_size

    if n1 > len(images_digit1) or n2 > len(images_digit2):
        raise ValueError("Insufficient images for the specified digits and length.")

    data1 = images_digit1[:n1]
    data2 = images_digit2[:n2]

    calibration_pre = data1[:calibration_size]
    main_pre = data1[calibration_size:n1]
    calibration_post = data2[:calibration_size]
    main_post = data2[calibration_size:n2]

    x_main = np.concatenate([main_pre, main_post], axis=0)
    x_calibration_pre = calibration_pre
    x_calibration_post = calibration_post

    x_main = x_main.reshape(length, -1).astype(np.float32) / 255.0
    x_calibration_pre = (
        x_calibration_pre.reshape(calibration_size, -1).astype(np.float32) / 255.0
    )
    x_calibration_post = (
        x_calibration_post.reshape(calibration_size, -1).astype(np.float32) / 255.0
    )

    return x_main, x_calibration_pre, x_calibration_post

In [None]:
length = 1000
changepoint = 400
calibration_size = 100
digit1 = 3
digit2 = 7

x_main, x_cal_pre, x_cal_post = generate_extended_mnist_dataset(
    length, changepoint, calibration_size, digit1, digit2
)

predicted_digits = [predict_digit(model, x_main[i]) for i in tqdm(range(length))]
probabilities = torch.vstack([prob for _, prob in predicted_digits])

predicted_cal_pre = [
    predict_digit(model, x_cal_pre[i]) for i in tqdm(range(calibration_size))
]
probabilities_cal_pre = torch.vstack([prob for _, prob in predicted_cal_pre])

predicted_cal_post = [
    predict_digit(model, x_cal_post[i]) for i in tqdm(range(calibration_size))
]
probabilities_cal_post = torch.vstack([prob for _, prob in predicted_cal_post])

In [None]:
def compute_left_scores_with_calibration(
    probabilities, predicted_digits, probabilities_cal_pre, predicted_cal_pre, length
):
    left_scores = np.zeros((length, length))
    left_scores_cal = np.zeros((length, calibration_size))

    seen_digits = {}

    for i, (predicted, _) in enumerate(predicted_cal_pre):
        if predicted in seen_digits:
            seen_digits[predicted] += 1
        else:
            seen_digits[predicted] = 1

    for t, (predicted, _) in enumerate(predicted_digits):
        if predicted in seen_digits:
            seen_digits[predicted] += 1
        else:
            seen_digits[predicted] = 1
        curr_digit = max(seen_digits, key=seen_digits.get)

        left_scores[t, : t + 1] = probabilities[: t + 1, curr_digit].cpu() / (
            1 - probabilities[: t + 1, curr_digit].cpu()
        )

        left_scores_cal[t, :] = probabilities_cal_pre[:, curr_digit].cpu() / (
            1 - probabilities_cal_pre[:, curr_digit].cpu()
        )

    return left_scores, left_scores_cal


def compute_right_scores_with_calibration(
    probabilities, predicted_digits, probabilities_cal_post, predicted_cal_post, length
):
    calibration_size = len(
        predicted_cal_post
    )
    right_scores = np.zeros((length, length))
    right_scores_cal = np.zeros((length, calibration_size))

    seen_digits = {}

    for i, (predicted, _) in enumerate(predicted_cal_post):
        if predicted in seen_digits:
            seen_digits[predicted] += 1
        else:
            seen_digits[predicted] = 1

    for i, (predicted, _) in enumerate(reversed(predicted_digits)):
        t = length - i - 1
        if predicted in seen_digits:
            seen_digits[predicted] += 1
        else:
            seen_digits[predicted] = 1
        curr_digit = max(seen_digits, key=seen_digits.get)

        right_scores[t, t:] = probabilities[t:, curr_digit].cpu() / (
            1 - probabilities[t:, curr_digit].cpu()
        )

        right_scores_cal[t, :] = probabilities_cal_post[:, curr_digit].cpu() / (
            1 - probabilities_cal_post[:, curr_digit].cpu()
        )

    return right_scores, right_scores_cal

In [None]:
def get_discrepancy_scores_with_calibration(
    x, scores_left, scores_right, scores_left_cal, scores_right_cal
):
    n = len(x)
    calibration_size_pre = scores_left_cal.shape[1]
    calibration_size_post = scores_right_cal.shape[1]
    discrepancy_scores = np.empty(n - 1)
    statistics = []

    for t in tqdm(range(n - 1)):
        p = np.empty(n)

        for r in range(t + 1):
            score_r = scores_left[t, r]

            main_counts = np.sum(scores_left[t, : r + 1] < score_r) + np.random.uniform(
                0, 1
            ) * np.sum(scores_left[t, : r + 1] == score_r)

            cal_counts = np.sum(scores_left_cal[t, :] < score_r) + np.random.uniform(
                0, 1
            ) * np.sum(scores_left_cal[t, :] == score_r)

            p[r] = (main_counts + cal_counts) / (r + 1 + calibration_size_pre)

        for r in range(n - 1, t, -1):
            score_r = scores_right[t, r]

            main_counts = np.sum(scores_right[t, r:] < score_r) + np.random.uniform(
                0, 1
            ) * np.sum(scores_right[t, r:] == score_r)

            cal_counts = np.sum(scores_right_cal[t, :] < score_r) + np.random.uniform(
                0, 1
            ) * np.sum(scores_right_cal[t, :] == score_r)

            p[r] = (main_counts + cal_counts) / (n - r + calibration_size_post)

        statistics.append(
            (ks_1samp(p[: t + 1], uniform.cdf), ks_1samp(p[t + 1 :], uniform.cdf))
        )
        discrepancy_scores[t] = statistics[-1][0].statistic * np.sqrt(
            t + 1
        ) + statistics[-1][1].statistic * np.sqrt(n - t - 1)

    return discrepancy_scores, statistics

In [None]:
left_scores, left_scores_cal = compute_left_scores_with_calibration(
    probabilities, predicted_digits, probabilities_cal_pre, predicted_cal_pre, length
)

right_scores, right_scores_cal = compute_right_scores_with_calibration(
    probabilities, predicted_digits, probabilities_cal_post, predicted_cal_post, length
)

discrepancy_scores_cal, statistics_cal = get_discrepancy_scores_with_calibration(
    x_main, left_scores, right_scores, left_scores_cal, right_scores_cal
)

plt.figure(figsize=(10, 6))
plt.plot(discrepancy_scores_cal)
plt.axvline(x=changepoint, color="red", linestyle="--", label="True Changepoint")
plt.xlabel("Position t")
plt.ylabel("Discrepancy Score")
plt.title("MNIST Changepoint Detection with Calibration")
plt.legend()
plt.show()

In [None]:
from scipy.stats import chi2

p_values_left = np.array([s[0].pvalue for s in statistics_cal])
p_values_right = np.array([s[1].pvalue for s in statistics_cal])

min_method = chi2.cdf(
    np.minimum(2 * p_values_left, 2 * p_values_right, np.ones_like(p_values_left)), 4
)
threshold = 0.05

plt.plot(np.arange(1, length), min_method)
plt.axvline(
    changepoint, color="red", linestyle="--", label="Changepoint ($\\xi = 400$)"
)
plt.axhline(
    threshold, color="green", linestyle=":", label="Threshold ($\\alpha = 0.05$)"
)
plt.xlabel("$t$")
plt.ylabel("p-value ($p_t$)")
plt.title("p-values for MNIST digit change (two-sided calibration)")
plt.legend()
plt.savefig("images/mnist-pvalues-calibrated.pdf")
plt.show()

In [None]:
ci = np.argwhere(min_method > threshold).flatten()
confidence_interval = (ci[0], ci[-1]) if len(ci) > 0 else None
print(f"True changepoint: {changepoint}")
print(f"Confidence interval: {confidence_interval}")
print(f"Minimum of Fisher's statistic at t={np.argmax(min_method)+1}")

# Left side only

In [None]:
def generate_left_calibration_mnist_dataset(
    length, changepoint, calibration_size=100, digit1=3, digit2=7
):
    transform = transforms.ToTensor()
    mnist_data = MNIST(root="./data", train=True, download=True, transform=transform)
    data = mnist_data.data.numpy()
    targets = mnist_data.targets.numpy()

    images_digit1 = data[targets == digit1]
    images_digit2 = data[targets == digit2]
    np.random.shuffle(images_digit1)
    np.random.shuffle(images_digit2)

    n1 = changepoint + 1 + calibration_size
    n2 = length - changepoint - 1

    if n1 > len(images_digit1) or n2 > len(images_digit2):
        raise ValueError("Insufficient images for the specified digits and length.")

    data1 = images_digit1[:n1]
    data2 = images_digit2[:n2]

    calibration_pre = data1[:calibration_size]
    main_pre = data1[calibration_size:n1]
    main_post = data2

    x_main = np.concatenate([main_pre, main_post], axis=0)
    x_calibration_pre = calibration_pre

    x_main = x_main.reshape(length, -1).astype(np.float32) / 255.0
    x_calibration_pre = (
        x_calibration_pre.reshape(calibration_size, -1).astype(np.float32) / 255.0
    )

    return x_main, x_calibration_pre

In [None]:
def compute_left_scores_with_left_calibration(
    probabilities, predicted_digits, probabilities_cal_pre, predicted_cal_pre, length
):
    left_scores = np.zeros((length, length))
    left_scores_cal = np.zeros((length, len(predicted_cal_pre)))

    seen_digits = {}

    for i, (predicted, _) in enumerate(predicted_cal_pre):
        if predicted in seen_digits:
            seen_digits[predicted] += 1
        else:
            seen_digits[predicted] = 1

    for t, (predicted, _) in enumerate(predicted_digits):
        if predicted in seen_digits:
            seen_digits[predicted] += 1
        else:
            seen_digits[predicted] = 1
        curr_digit = max(seen_digits, key=seen_digits.get)

        left_scores[t, : t + 1] = probabilities[: t + 1, curr_digit].cpu() / (
            1 - probabilities[: t + 1, curr_digit].cpu()
        )

        left_scores_cal[t, :] = probabilities_cal_pre[:, curr_digit].cpu() / (
            1 - probabilities_cal_pre[:, curr_digit].cpu()
        )

    return left_scores, left_scores_cal


def compute_right_scores_without_calibration(probabilities, predicted_digits, length):
    right_scores = np.zeros((length, length))

    seen_digits = {}

    for i, (predicted, _) in enumerate(reversed(predicted_digits)):
        t = length - i - 1
        if predicted in seen_digits:
            seen_digits[predicted] += 1
        else:
            seen_digits[predicted] = 1
        curr_digit = max(seen_digits, key=seen_digits.get)

        right_scores[t, t:] = probabilities[t:, curr_digit].cpu() / (
            1 - probabilities[t:, curr_digit].cpu()
        )

    return right_scores

In [None]:
def get_discrepancy_scores_with_left_calibration(
    x, scores_left, scores_right, scores_left_cal
):
    n = len(x)
    calibration_size_pre = scores_left_cal.shape[1]
    discrepancy_scores = np.empty(n - 1)
    statistics = []

    for t in tqdm(range(n - 1)):
        p = np.empty(n)

        for r in range(t + 1):
            score_r = scores_left[t, r]

            main_counts = np.sum(scores_left[t, : r + 1] < score_r) + np.random.uniform(
                0, 1
            ) * np.sum(scores_left[t, : r + 1] == score_r)

            cal_counts = np.sum(scores_left_cal[t, :] < score_r) + np.random.uniform(
                0, 1
            ) * np.sum(scores_left_cal[t, :] == score_r)

            p[r] = (main_counts + cal_counts) / (r + 1 + calibration_size_pre)

        for r in range(n - 1, t, -1):
            p[r] = (
                np.count_nonzero(scores_right[t, r:] > scores_right[t, r])
                + np.random.uniform(0, 1)
                * np.count_nonzero(scores_right[t, r:] == scores_right[t, r])
            ) / (n - r)

        statistics.append(
            (ks_1samp(p[: t + 1], uniform.cdf), ks_1samp(p[t + 1 :], uniform.cdf))
        )
        discrepancy_scores[t] = statistics[-1][0].statistic * np.sqrt(
            t + 1
        ) + statistics[-1][1].statistic * np.sqrt(n - t - 1)

    return discrepancy_scores, statistics

In [None]:
length = 1000
changepoint = 400
calibration_size = 100
digit1 = 3
digit2 = 7

x_main, x_cal_pre = generate_left_calibration_mnist_dataset(
    length, changepoint, calibration_size, digit1, digit2
)

predicted_digits = [predict_digit(model, x_main[i]) for i in tqdm(range(length))]
probabilities = torch.vstack([prob for _, prob in predicted_digits])

predicted_cal_pre = [
    predict_digit(model, x_cal_pre[i]) for i in tqdm(range(calibration_size))
]
probabilities_cal_pre = torch.vstack([prob for _, prob in predicted_cal_pre])

left_scores, left_scores_cal = compute_left_scores_with_left_calibration(
    probabilities, predicted_digits, probabilities_cal_pre, predicted_cal_pre, length
)

right_scores = compute_right_scores_without_calibration(
    probabilities, predicted_digits, length
)

discrepancy_scores_left_cal, statistics_left_cal = (
    get_discrepancy_scores_with_left_calibration(
        x_main, left_scores, right_scores, left_scores_cal
    )
)

plt.figure(figsize=(10, 6))
plt.plot(discrepancy_scores_left_cal)
plt.axvline(x=changepoint, color="red", linestyle="--", label="True Changepoint")
plt.xlabel("Position t")
plt.ylabel("Discrepancy Score")
plt.title("MNIST Changepoint Detection with Left Calibration Only")
plt.legend()
plt.show()

p_values_left = np.array([s[0].pvalue for s in statistics_left_cal])
p_values_right = np.array([s[1].pvalue for s in statistics_left_cal])

min_method = chi2.cdf(
    np.minimum(2 * p_values_left, 2 * p_values_right, np.ones_like(p_values_left)), 4
)
threshold = 0.05

plt.plot(np.arange(1, length), min_method)
plt.axhline(
    threshold, color="green", linestyle=":", label="Threshold ($\\alpha = 0.05$)"
)
plt.axvline(
    changepoint, color="red", linestyle="--", label="Changepoint ($\\xi = 400$)"
)
plt.xlabel("$t$")
plt.ylabel("p-value ($p_t$)")
plt.title("p-values for MNIST digit change (left-side calibration)")
plt.legend()
plt.savefig("images/mnist-pvalues-left-calibrated.pdf")
plt.show()

confidence_set = np.argwhere(min_method > threshold).flatten()
confidence_interval = (
    (confidence_set[0], confidence_set[-1]) if len(confidence_set) > 0 else None
)

print(f"True changepoint: {changepoint}")
print(f"Confidence interval: {confidence_interval}")
print(f"Minimum of Fisher's statistic at t={np.argmax(min_method)+1}")

# LLM simulation

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch
import random
from scipy.stats import ks_2samp, ks_1samp, uniform
from tqdm import tqdm
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from datasets import load_dataset

import matplotlib_inline.backend_inline

matplotlib_inline.backend_inline.set_matplotlib_formats("svg")
plt.style.use("math.mplstyle")

In [None]:
import matplotlib_inline.backend_inline

matplotlib_inline.backend_inline.set_matplotlib_formats("svg")
plt.style.use("math.mplstyle")

In [None]:
def get_pretrained_sentiment_model(device="cpu"):
    model_name = "distilbert-base-uncased-finetuned-sst-2-english"
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForSequenceClassification.from_pretrained(model_name)

    model.to(device)
    model.eval()
    return model, tokenizer

In [None]:
def generate_sentiment_dataset(length, changepoint, dataset_name="sst2"):
    dataset = load_dataset(dataset_name)

    train_data = dataset["train"]
    positive_texts = [item["sentence"] for item in train_data if item["label"] == 1]
    negative_texts = [item["sentence"] for item in train_data if item["label"] == 0]

    random.shuffle(positive_texts)
    random.shuffle(negative_texts)

    n1 = changepoint + 1
    n2 = length - n1

    if n1 > len(positive_texts) or n2 > len(negative_texts):
        raise ValueError("Insufficient texts for the specified length and changepoint.")

    texts_before = positive_texts[:n1]
    texts_after = negative_texts[:n2]

    texts = texts_before + texts_after
    true_labels = [1] * n1 + [0] * n2

    return texts, true_labels

In [None]:
def predict_sentiment(model, tokenizer, text, device="cpu"):
    inputs = tokenizer(
        text, return_tensors="pt", padding=True, truncation=True, max_length=512
    ).to(device)

    with torch.no_grad():
        outputs = model(**inputs)
        probs = torch.softmax(outputs.logits, dim=1).cpu()
        predicted = probs.argmax(dim=1).item()

    return (predicted, probs.squeeze())

In [None]:
length = 500
changepoint = 200
device = "cuda" if torch.cuda.is_available() else "cpu"

model, tokenizer = get_pretrained_sentiment_model(device)

print("Generating sentiment dataset...")
texts, true_labels = generate_sentiment_dataset(length, changepoint)

print("Getting predictions...")
predictions = []
probabilities = []

for i, text in enumerate(tqdm(texts)):
    pred, prob = predict_sentiment(model, tokenizer, text, device)
    predictions.append(pred)
    probabilities.append(prob)

probabilities = torch.stack(
    probabilities
)

In [None]:
left_score = np.zeros((length, length))
seen_sentiments = {0: 0, 1: 0}

for t in range(length):
    seen_sentiments[predictions[t]] += 1
    curr_sentiment = max(seen_sentiments, key=seen_sentiments.get)

    for r in range(t + 1):
        left_score[t, r] = probabilities[r, curr_sentiment] / (
            1 - probabilities[r, curr_sentiment]
        )

right_score = np.zeros((length, length))
seen_sentiments = {0: 0, 1: 0}

for i in range(length - 1, -1, -1):
    t = length - i - 1
    seen_sentiments[predictions[i]] += 1
    curr_sentiment = max(seen_sentiments, key=seen_sentiments.get)

    right_score[t, t:] = probabilities[t:, curr_sentiment] / (
        1 - probabilities[t:, curr_sentiment]
    )

In [None]:
def get_discrepancy_scores(scores_left, scores_right, length):
    discrepancy_scores = np.empty(length - 1)
    statistics = []

    for t in tqdm(range(length - 1)):
        p = np.empty(length)

        for r in range(t + 1):
            p[r] = (
                np.count_nonzero(scores_left[t, : r + 1] > scores_left[t, r])
                + np.random.uniform(0, 1)
                * np.count_nonzero(scores_left[t, : r + 1] == scores_left[t, r])
            ) / (r + 1)

        for r in range(length - 1, t, -1):
            p[r] = (
                np.count_nonzero(scores_right[t, r:] > scores_right[t, r])
                + np.random.uniform(0, 1)
                * np.count_nonzero(scores_right[t, r:] == scores_right[t, r])
            ) / (length - r)

        statistics.append(
            (ks_1samp(p[: t + 1], uniform.cdf), ks_1samp(p[t + 1 :], uniform.cdf))
        )
        discrepancy_scores[t] = statistics[-1][0].statistic * np.sqrt(
            t + 1
        ) + statistics[-1][1].statistic * np.sqrt(length - t - 1)

    return discrepancy_scores, statistics

In [None]:
print("Calculating discrepancy scores...")
discrepancy_scores, statistics = get_discrepancy_scores(left_score, right_score, length)

plt.figure(figsize=(10, 6))
plt.plot(discrepancy_scores)
plt.axvline(x=changepoint, color="red", linestyle="--", label="True Changepoint")
plt.xlabel("Position t")
plt.ylabel("Discrepancy Score")
plt.title("Sentiment Analysis Changepoint Detection")
plt.legend()
plt.show()

In [None]:
from scipy.stats import chi2

p_values_left = np.array([s[0].pvalue for s in statistics])
p_values_right = np.array([s[1].pvalue for s in statistics])

min_method = chi2.cdf(
    np.minimum(2 * p_values_left, 2 * p_values_right, np.ones_like(p_values_left)), 4
)
threshold = 0.05

plt.plot(np.arange(1, length), min_method)
plt.axvline(
    changepoint,
    color="red",
    linestyle="--",
    label=f"Changepoint ($\\xi = {changepoint}$)",
)
plt.axhline(
    threshold,
    color="green",
    linestyle=":",
    label=f"Threshold ($\\alpha = {threshold}$)",
)
plt.xlabel("$t$")
plt.ylabel("p-value ($p_t$)")
plt.title("p-values for SST-2 sentiment change")
plt.legend()
plt.savefig("images/sentiment-pvalues.pdf")
plt.show()

confidence_set = np.argwhere(min_method > threshold).flatten()
confidence_interval = (
    (confidence_set[0], confidence_set[-1]) if len(confidence_set) > 0 else None
)

print(f"True changepoint: {changepoint}")
print(f"Confidence interval: {confidence_interval}")
print(f"Minimum of Fisher's statistic at t={np.argmax(min_method)+1}")

In [None]:
print("\nExamples before changepoint (positive):")
for i in range(3):
    idx = np.random.randint(0, changepoint)
    print(f'Text {i+1}: "{texts[idx]}"')
    print(
        f"True label: Positive, Predicted: {'Positive' if predictions[idx] == 1 else 'Negative'}"
    )
    print(f"Confidence: {probabilities[idx][predictions[idx]]:.4f}\n")

print("\nExamples after changepoint (negative):")
for i in range(3):
    idx = np.random.randint(changepoint + 1, length)
    print(f'Text {i+1}: "{texts[idx]}"')
    print(
        f"True label: Negative, Predicted: {'Positive' if predictions[idx] == 1 else 'Negative'}"
    )
    print(f"Confidence: {probabilities[idx][predictions[idx]]:.4f}\n")

In [None]:
def generate_mixed_sentiment_dataset(length, changepoint, dataset_name="sst2"):
    dataset = load_dataset(dataset_name)

    train_data = dataset["train"]
    positive_texts = [item["sentence"] for item in train_data if item["label"] == 1]
    negative_texts = [item["sentence"] for item in train_data if item["label"] == 0]

    random.shuffle(positive_texts)
    random.shuffle(negative_texts)

    n_pre = changepoint + 1
    n_post = length - n_pre

    n_pos_pre = int(n_pre * 0.6)
    n_neg_pre = n_pre - n_pos_pre

    n_pos_post = int(n_post * 0.4)
    n_neg_post = n_post - n_pos_post

    if n_pos_pre + n_pos_post > len(positive_texts) or n_neg_pre + n_neg_post > len(
        negative_texts
    ):
        raise ValueError(
            "Insufficient texts for the specified distribution and length."
        )

    pre_pos_texts = positive_texts[:n_pos_pre]
    pre_neg_texts = negative_texts[:n_neg_pre]
    pre_texts = pre_pos_texts + pre_neg_texts
    pre_labels = [1] * n_pos_pre + [0] * n_neg_pre

    pre_combined = list(zip(pre_texts, pre_labels))
    random.shuffle(pre_combined)
    pre_texts, pre_labels = zip(*pre_combined)

    post_pos_texts = positive_texts[n_pos_pre : n_pos_pre + n_pos_post]
    post_neg_texts = negative_texts[n_neg_pre : n_neg_pre + n_neg_post]
    post_texts = post_pos_texts + post_neg_texts
    post_labels = [1] * n_pos_post + [0] * n_neg_post

    post_combined = list(zip(post_texts, post_labels))
    random.shuffle(post_combined)
    post_texts, post_labels = zip(*post_combined)

    texts = list(pre_texts) + list(post_texts)
    true_labels = list(pre_labels) + list(post_labels)

    return texts, true_labels

In [None]:
length = 1000
changepoint = 400
device = "cuda" if torch.cuda.is_available() else "cpu"

if "model" not in locals() or "tokenizer" not in locals():
    model, tokenizer = get_pretrained_sentiment_model(device)

print("Generating mixed sentiment dataset (60%/40% to 40%/60%)...")
texts, true_labels = generate_mixed_sentiment_dataset(length, changepoint)

print("Getting predictions...")
predictions = []
probabilities = []

for i, text in enumerate(tqdm(texts)):
    pred, prob = predict_sentiment(model, tokenizer, text, device)
    predictions.append(pred)
    probabilities.append(prob)

probabilities = torch.stack(
    probabilities
)

left_score = np.zeros((length, length))
seen_sentiments = {0: 0, 1: 0}

for t in range(length):
    seen_sentiments[predictions[t]] += 1
    curr_sentiment = max(seen_sentiments, key=seen_sentiments.get)

    for r in range(t + 1):
        left_score[t, r] = probabilities[r, curr_sentiment] / (
            1 - probabilities[r, curr_sentiment]
        )

right_score = np.zeros((length, length))
seen_sentiments = {0: 0, 1: 0}

for i in range(length - 1, -1, -1):
    t = length - i - 1
    seen_sentiments[predictions[i]] += 1
    curr_sentiment = max(seen_sentiments, key=seen_sentiments.get)

    right_score[t, t:] = probabilities[t:, curr_sentiment] / (
        1 - probabilities[t:, curr_sentiment]
    )

print("Calculating discrepancy scores...")
discrepancy_scores, statistics = get_discrepancy_scores(left_score, right_score, length)

plt.figure(figsize=(10, 6))
plt.plot(discrepancy_scores)
plt.axvline(x=changepoint, color="red", linestyle="--", label="True Changepoint")
plt.xlabel("Position t")
plt.ylabel("Discrepancy Score")
plt.title(
    "Mixed Sentiment Analysis Changepoint Detection\n(Pre: 60% pos/40% neg, Post: 40% pos/60% neg)"
)
plt.legend()
plt.show()

p_values_left = np.array([s[0].pvalue for s in statistics])
p_values_right = np.array([s[1].pvalue for s in statistics])

min_method = chi2.cdf(
    np.minimum(2 * p_values_left, 2 * p_values_right, np.ones_like(p_values_left)), 4
)
threshold = 0.05

plt.plot(np.arange(1, length), min_method)
plt.axvline(
    changepoint,
    color="red",
    linestyle="--",
    label=f"Changepoint ($\\xi = {changepoint}$)",
)
plt.axhline(
    threshold,
    color="green",
    linestyle=":",
    label=f"Threshold ($\\alpha = {threshold}$)",
)
plt.xlabel("$t$")
plt.ylabel("p-value ($p_t$)")
plt.title(
    "p-values for SST-2 Mixed Sentiment Change\n(Pre: 60% pos/40% neg, Post: 40% pos/60% neg)"
)
plt.legend()
plt.show()

confidence_set = np.argwhere(min_method > threshold).flatten()
confidence_interval = (
    (confidence_set[0], confidence_set[-1]) if len(confidence_set) > 0 else None
)

print(f"True changepoint: {changepoint}")
print(f"Confidence interval: {confidence_interval}")
print(f"Minimum of Fisher's statistic at t={np.argmax(min_method)+1}")

pre_positives = sum(1 for i in range(changepoint + 1) if true_labels[i] == 1)
pre_negatives = (changepoint + 1) - pre_positives
post_positives = sum(1 for i in range(changepoint + 1, length) if true_labels[i] == 1)
post_negatives = (length - changepoint - 1) - post_positives

print("\nSentiment distribution in data:")
print(
    f"Pre-change: {pre_positives/(changepoint+1)*100:.1f}% positive, {pre_negatives/(changepoint+1)*100:.1f}% negative"
)
print(
    f"Post-change: {post_positives/(length-changepoint-1)*100:.1f}% positive, {post_negatives/(length-changepoint-1)*100:.1f}% negative"
)

print("\nExamples before changepoint (60% positive, 40% negative):")
for i in range(3):
    idx = np.random.randint(0, changepoint)
    print(f'Text {i+1}: "{texts[idx]}"')
    print(
        f"True label: {'Positive' if true_labels[idx] == 1 else 'Negative'}, Predicted: {'Positive' if predictions[idx] == 1 else 'Negative'}"
    )
    print(f"Confidence: {probabilities[idx][predictions[idx]]:.4f}\n")

In [None]:
np.argwhere(min_method > threshold).flatten()

In [None]:
plt.plot(np.arange(1, length), min_method)
plt.axvline(
    changepoint,
    color="red",
    linestyle="--",
    label=f"Changepoint ($\\xi = {changepoint}$)",
)
plt.axhline(
    threshold,
    color="green",
    linestyle=":",
    label=f"Threshold ($\\alpha = {threshold}$)",
)
plt.xlabel("$t$")
plt.ylabel("p-value ($p_t$)")
plt.title(
    "p-values for SST-2 mixed sentiment change"
)
plt.legend()
plt.savefig("images/sentiment-pvalues-mixed.pdf")
plt.show()