In [None]:
!pip install -q scikit-learn matplotlib

import os
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random
from itertools import permutations
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import StratifiedKFold


def set_seed(seed=1, n_splits=5):
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    return StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=seed)

kf = set_seed(1) 

# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# Load dataset
data = load_iris()

# Normalization
X = data['data']
X = data['data'] / data['data'].max(axis=0)
X = 2 * X - 1

y = data['target']
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Constants
N_INTERNAL = 4
N_FEATURES = 4
N_LABELS = 3
EDGE_EPSILON = 0.05

# Node indexing
input_nodes = list(range(N_FEATURES))
label_nodes = list(range(N_FEATURES, N_FEATURES + N_LABELS))
internal_nodes = list(range(N_FEATURES + N_LABELS, N_FEATURES + N_LABELS + N_INTERNAL))
all_nodes = input_nodes + label_nodes + internal_nodes

# Helper
def batch_edge_from_vector(vectors):
    a = vectors[..., 0]
    b = vectors[..., 1]
    c = vectors[..., 2]
    d = vectors[..., 3]
    real = torch.stack([
        torch.stack([1 + a, b], dim=-1),
        torch.stack([b, 1 + d], dim=-1)
    ], dim=-2)
    imag = torch.stack([
        torch.stack([torch.zeros_like(c), c], dim=-1),
        torch.stack([-c, torch.zeros_like(c)], dim=-1)
    ], dim=-2)
    return (real + 1j * imag).to(device)

def approx_bures_curvature(H):
    I = torch.eye(2, dtype=torch.cfloat, device=device)
    diff = H - I
    return (diff.real ** 2 + diff.imag ** 2).sum(dim=[-2, -1]) / 8.0

# RQG model
class RQG(nn.Module):
    def __init__(self, all_nodes):
        super().__init__()
        self.nodes = all_nodes
        self.edge_keys = [f"{i}_{j}" for i in all_nodes for j in all_nodes if i != j]
        self.edges = nn.ParameterDict({
            key: nn.Parameter(EDGE_EPSILON * torch.randn(4)) for key in self.edge_keys
        })

    def edge_matrix(self, key):
        return batch_edge_from_vector(self.edges[key].unsqueeze(0))[0]

# Loop generation
def loops_with_input_label(input_nodes, label_node, internal_nodes):
    loops = []
    for i in input_nodes:
        for j in internal_nodes:
            loops.append([i, j, label_node])
    return loops

# Encode feature vector into edge parameters
def encode_input_edges(x_vec):
    edge_vectors = {}
    for i, xi in enumerate(x_vec):
        for j in internal_nodes:
            v = torch.full((4,), xi * EDGE_EPSILON, device=device)
            edge_vectors[f"{input_nodes[i]}_{j}"] = v
            edge_vectors[f"{j}_{input_nodes[i]}"] = v
    return edge_vectors

# Training function
def train_rqg(model, X_train, y_train, lr=0.01):
    model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)

    for i, (x_vec, y_label) in enumerate(zip(X_train, y_train)):
        optimizer.zero_grad()
        edge_vectors = encode_input_edges(torch.tensor(x_vec, dtype=torch.float32))

        loops = loops_with_input_label(input_nodes, label_nodes[y_label], internal_nodes)
        loop_matrices = []

        for loop in loops:
            H = torch.eye(2, dtype=torch.cfloat, device=device)
            for k in range(len(loop)):
                src, tgt = loop[k], loop[(k + 1) % len(loop)]
                key = f"{src}_{tgt}"
                if key in edge_vectors:
                    rho = batch_edge_from_vector(edge_vectors[key].unsqueeze(0))[0]
                else:
                    rho = model.edge_matrix(key)
                H = H @ rho
            loop_matrices.append(H)

        H_batch = torch.stack(loop_matrices, dim=0)
        loss = approx_bures_curvature(H_batch).mean()
        if i % 10 == 0:
            print(f"After sample {i+1}/{len(X_train)}: Curvature = {loss.item():.6f}")
        loss.backward()
        optimizer.step()

# Prediction
def predict(model, x_vec):
    edge_vectors = encode_input_edges(torch.tensor(x_vec, dtype=torch.float32))
    min_curv = float('inf')
    best_label = None

    for label_node in label_nodes:
        loops = loops_with_input_label(input_nodes, label_node, internal_nodes)
        loop_matrices = []

        for loop in loops:
            H = torch.eye(2, dtype=torch.cfloat, device=device)
            for k in range(len(loop)):
                src, tgt = loop[k], loop[(k + 1) % len(loop)]
                key = f"{src}_{tgt}"
                if key in edge_vectors:
                    rho = batch_edge_from_vector(edge_vectors[key].unsqueeze(0))[0]
                else:
                    rho = model.edge_matrix(key)
                H = H @ rho
            loop_matrices.append(H)

        H_batch = torch.stack(loop_matrices, dim=0)
        curv = approx_bures_curvature(H_batch).mean().item()
        if curv < min_curv:
            min_curv = curv
            best_label = label_nodes.index(label_node)

    return best_label

accuracies = []
for fold_idx, (train_idx, test_idx) in enumerate(kf.split(X, y)):
    print(f"\n--- Fold {fold_idx + 1} ---")
    
    X_train, X_test = X[train_idx], X[test_idx]
    y_train, y_test = y[train_idx], y[test_idx]

    # Reset model for each fold
    model = RQG(all_nodes)
    train_rqg(model, X_train, y_train)

    correct = 0
    for i in range(len(X_test)):
        pred = predict(model, X_test[i])
        if pred == y_test[i]:
            correct += 1
    acc = correct / len(X_test)
    accuracies.append(acc)
    print(f"Fold {fold_idx + 1} Accuracy: {acc:.2%}")

mean_acc = np.mean(accuracies)
std_acc = np.std(accuracies)
print(f"\nCross-validated Accuracy: {mean_acc:.2%} ± {std_acc:.2%}")