In [None]:
# notebooks/learnable/01_optimize_single_kernel.py

# Optimize Single Learnable Kernel (RBF)
- Dataset: Breast Cancer
- Model: LearnableRBF
- Goal: Maximize KTA over time

In [None]:

import matplotlib.pyplot as plt
import torch
from sklearn import datasets, model_selection, preprocessing

try:
    from kta import LearnableRBF, kta_torch
except ModuleNotFoundError:
    import subprocess
    import sys

    subprocess.check_call(
        [
            sys.executable,
            "-m",
            "pip",
            "install",
            "--quiet",
            "git+https://github.com/whitham-powell/kernel-target-alignment.git",
        ],
    )
    from kta import LearnableRBF, kta_torch

## 1. Load and preprocess data

In [None]:
X, y = datasets.load_breast_cancer(return_X_y=True)
y = (y * 2 - 1).astype(float)  # convert to {-1, 1}

X_tr, X_te, y_tr, y_te = model_selection.train_test_split(
    X,
    y,
    test_size=0.3,
    random_state=0,
    stratify=y,
)

scaler = preprocessing.StandardScaler().fit(X_tr)
X_tr = scaler.transform(X_tr)
X_te = scaler.transform(X_te)

X = torch.tensor(X_tr, dtype=torch.float32)
y = torch.tensor(y_tr, dtype=torch.float32)

## 2. Initialize model and optimizer

In [None]:
model = LearnableRBF(gamma_init=1.0)
optimizer = torch.optim.Adam(model.parameters(), lr=0.05)

## 3. Train loop: maximize KTA

In [None]:
alignments = []
gammas = []

for epoch in range(100):
    K = model(X)
    loss = -kta_torch(K, y)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    alignments.append(-loss.item())
    gammas.append(model.gamma.item())

## 4. Plot results

In [None]:
fig, ax1 = plt.subplots()
ax2 = ax1.twinx()
ax1.plot(alignments, label="KTA", color="tab:blue")
ax2.plot(gammas, label="gamma", color="tab:orange")

ax1.set_xlabel("Epoch")
ax1.set_ylabel("Alignment", color="tab:blue")
ax2.set_ylabel("Gamma", color="tab:orange")
fig.suptitle("Learnable RBF: KTA vs Gamma")
fig.tight_layout()
plt.show()

## 5. Final gamma and alignment

In [None]:
print(f"Final gamma: {model.gamma.item():.4f}")
print(f"Final alignment: {alignments[-1]:.4f}")