In [None]:
%load_ext autoreload
%autoreload 2
import numpy as np
import pandas as pd
from scipy.stats import multivariate_normal
from tabrel.benchmark.nw_regr import run_training
from tabrel.benchmark.nw_regr import generate_toy_regr_data


def make_symmetric_r(n: int, num_ones: int, seed: int) -> np.ndarray:
    np.random.seed(seed)

    r = np.zeros((n, n), dtype=float)

    # Get all lower-triangular indices (excluding diagonal)
    lower_indices = np.tril_indices(n, k=-1)
    num_lower = len(lower_indices[0])

    # Randomly select positions to set to 1
    chosen = np.random.choice(num_lower, size=min(num_ones, num_lower), replace=False)
    r[lower_indices[0][chosen], lower_indices[1][chosen]] = 1

    r += r.T
    return r

n_samples = 300
n_ones = 70
r2_list = []
mse_list = []
for y_func in ("square", "sign"):
    for seed in range(20):
        x, y, c = generate_toy_regr_data(n_samples=n_samples, n_clusters=3, y_func=y_func, seed=seed)
        r = np.eye(n_samples)
        for i in range(n_samples):
            for j in range(i):
                if c[i] == c[j] and np.random.choice((True, False)):
                    r[i, j] = 1
                    r[j, i] = 1
        x = x.numpy()
        y = y.numpy()

        # r = make_symmetric_r(n_samples, num_ones=n_ones, seed=seed)
        # sigma = r.copy()
        # while min(np.linalg.eigvalsh(sigma)) < 0:
        #     sigma += .5 * np.eye(n_samples)
        # # print(np.linalg.eigvalsh(sigma))

        # # sampling from multivariate normal distribution
        # mean = np.zeros(n_samples)
        # mvn_dist = multivariate_normal(mean, sigma, allow_singular=True)
        # x = np.random.uniform(-1, 1, (n_samples, 1))
        # y = mvn_dist.rvs(size=1) + 5 * x.T[0]

        n_query = n_validate = n_samples // 3
        samples_indices = np.array(list(range(n_samples)))
        validate_indices = samples_indices[:n_validate]
        query_indices = samples_indices[n_validate: n_query]
        back_indices = samples_indices[n_query:]

        results = run_training(
            x=x,
            y=y,
            r=r,
            backgnd_indices=back_indices,
            query_indices=query_indices,
            val_indices=validate_indices,
            lr=1e-3,
            n_epochs=100,
            rel_as_feats=r,
        )
        mse_list.append({k:v[0] for k,v in results.items()})
        r2_list.append({k:v[1] for k, v in results.items()})
    
    print(f"=== {y_func} ===")
    for label, lst in (("mse", mse_list), ("r2", r2_list)):
        results_df = pd.DataFrame(lst)
        print("\n")
        print(label)
        print(pd.DataFrame({"mean": results_df.mean().round(3), "std": results_df.std().round(3)}))