This notebook walks through some basic analysis of the SEMBAS results. Specifically,
it is used for determining which model's output can be trusted, and therefore improve
reducing the ensemble's error.

In [None]:
import numpy as np
import torch.nn as nn
import torch
import json
import os

from rtree.index import Index, Property

from numpy import ndarray

from network import *
from data import FutData, f as fut
from main import classify_validity

# Setup
Specify the paths you chose for your models. Note: this will be relative to where
this notebook is.

In [None]:
BOUNDARY_PATH = "../../../.data/boundaries"
MODEL_PATH = "../../../.models/bnn_expl" # Same loc as --model-path arg


NETWORK_PATH = f"{MODEL_PATH}/ensemble/"

In [None]:
def load_boundary(i: int) -> ndarray:
    with open(f"{BOUNDARY_PATH}/boundary_{i}.json") as f: # TODO
        data = json.load(f)
    return np.array(data["boundary_points"]), np.array(data["boundary_surface"])


def load_boundary_into_rtree(bpoints: ndarray, surface) -> Index:
    p = Property()
    p.set_dimension(bpoints.shape[1])
    
    index = Index(properties=p)
    for i, (b, n) in enumerate(zip(bpoints, surface)):
        index.insert(i, b, (b, n))
        
    return index
    

In [None]:
def pred_perf(p: ndarray, index: Index) -> tuple[bool, float]:
    "Predicts the performance mode of @p given RTree @index"
    b, n = next(index.nearest(p, 1, 'raw'))
    
    s = p - b
    dist = np.linalg.norm(s)
    v: ndarray = s / dist
    
    return v.dot(n) < 0.0, dist
        

# Specifying SEMBAS Selected Models
The main.rs generates boundary data, but also provides a list of indices to the
models that are redundant (i.e. the "skip-list"). This skip-list can be used to trim
down the number of necessary models to produce a reasonable ensemble.

These tools are both early in development and mostly act as a proof-of-concept, so
optimization is necessary to get it performing well enough to be viable.

Below is the skip_list for the notebook, paste in the indices that were skipped,
which can be found in the main.rs standard output at the end of exploring the models.
Alternatively, leave it empty or specify which models you wish to skip manually (or
edit the model_indices, which specifies which models to include in the ensemble
directly).

In [None]:
skip_list = []

In [None]:
total_models = len(os.listdir(NETWORK_PATH))
print(f"found {total_models} number of models")

model_indices = [x for x in range(total_models) if x not in skip_list]
print("Number of selected models:", len(model_indices))

def load(i: int) -> nn.Module:
    network = nn.Sequential(ConcreteLinear(2, 50), nn.ReLU(), ConcreteLinear(50, 1))
    state = torch.load(f"{NETWORK_PATH}/network_{i}.model")
    network.load_state_dict(state)
    return network

dataset = FutData(2**14)

boundary_rtrees = [load_boundary_into_rtree(*load_boundary(i)) for i in model_indices]
networks = [load(i) for i in model_indices]
all_networks = [load(i) for i in range(1000)]

# Doesn't use the boundary data, but only uses the mean of the model results
# (traditional solution)
ensemble_mean_model = lambda x: np.array([model(x).detach() for model in networks]).mean()

# Similar to ensemble_mean, but instead of using the SEMBAS selected models it uses
# all of them.
full_ensemble_mean_model = lambda x: np.array([model(x).detach() for model in all_networks]).mean()



In [None]:
def ensemble_sembas_model(x):
    """
    Ensemble model that applies SEMBAS boundary data for determining which model's
    output can be trusted the most.
    """
    result = np.zeros((x.shape[0], 1))
    
    for i, xi in enumerate(x):
        goodboys = []
        min_boy = None
        k = 0
        for tree, model in zip(boundary_rtrees, networks):
            sembas_p = dataset.inverse_transform_request(xi).detach().numpy()
            cls, dist = pred_perf(sembas_p, tree)
            if cls:
                goodboys.append(model)
            
            if min_boy is None or dist < min_boy[1]:
                min_boy = (model, dist)
            
            k += 1
        
        xi = xi.reshape(1, -1)
        if len(goodboys) == 0:
            result[i] = min_boy[0](xi).detach()
        else:
            y_hat = np.array([gb(xi).detach().item() for gb in goodboys])
            result[i] = y_hat.mean()
            
        
    return result

In [None]:
def create_random_ensemble(n: int):
    """
    Generates an ensemble model from a random sub-population of models.
    @n is the number of models to include.
    """
    rng = np.random.default_rng()
    rand_net = []
    picks = rng.choice(np.arange(total_models), n, replace=False)#np.random.randint(0, 100, len(skip_list))
    for i in [x for x in picks]:
        rand_net.append(load(i))
    return lambda x: np.median([model(x).detach() for model in rand_net])

In [None]:
def test(model, dataset: FutData):
    "Returns the MSE of the model over the @dataset"
    x, y = dataset
    pred = model(x).squeeze()
    
    err:ndarray = y.squeeze() - pred
    return (np.power(err, 2.0)).mean()


In [None]:
x, y = dataset


# Evaluating
First, we show the average performance of a randomly selected ensemble, followed by
the sembas-selected ensemble performance using the mean of the outputs, ending with 
the same ensemble with sembas boundary data for selecting trusted outputs.

In [None]:
np.array([test(create_random_ensemble(len(model_indices)), dataset).item() for i in range(100)]).mean()

In [None]:
test(ensemble_mean_model, dataset)

In [None]:
# testing performance

test(ensemble_sembas_model, dataset)

In [None]:
np.array([test(full_ensemble_mean_model, dataset).item() for i in range(100)]).mean()

In [None]:
import matplotlib.pyplot as plt
from data import f as fut

In [None]:
i = 0
model_i = model_indices[i]
ensemble = create_random_ensemble(1000)
model = lambda x: torch.tensor(ensemble(x), dtype=torch.float64)
bpoints, surface = load_boundary(model_i)
index = boundary_rtrees[i]

fig, axes = plt.subplots(ncols=2)
axl, axr = axes
n = int(dataset.data_size**0.5)

x, y = dataset

pred = ensemble_sembas_model(x)

err: ndarray = y.squeeze() - pred.squeeze()
y_cls = np.power(err, 2.0) < 0.5

pred_rand = create_random_ensemble(len(model_indices))(x)
err_rand: ndarray = y.squeeze() - pred_rand.squeeze()
y_cls_rand = np.power(err_rand, 2.0) < 0.5


axl.imshow(y_cls.reshape(n, n))
axr.imshow(y_cls_rand.reshape(n, n))