In [None]:
def set_latex():
    for i in range(2):
        import matplotlib
        import matplotlib.pyplot as plt

        plt.rc('text', usetex=True)
        plt.rc('font', family='serif')

        plt.style.use("default")
        plt.rcParams["font.size"] = 15

        plt.rcParams['font.family'] = 'Times New Roman'
        plt.rcParams['mathtext.fontset'] = 'stix'

        try:
            del matplotlib.font_manager.weight_dict['roman']
            matplotlib.font_manager._rebuild()
        except:
            pass

In [None]:
set_latex()

In [None]:
import copy
import math
import os
import warnings

import matplotlib
import matplotlib.cm as cm
import matplotlib.colors as colors
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from scipy import special
from sklearn.metrics import mean_squared_error
from tqdm.notebook import tqdm

warnings.filterwarnings("ignore")
if os.environ.get("GPU"):
    device = os.environ.get("GPU") if torch.cuda.is_available() else "cpu"
else:
    device = "cuda" if torch.cuda.is_available() else "cpu"

torch.set_default_dtype(torch.float64)

In [None]:
def calc_tau(alpha: float, S: np.array, diag_i: np.array, diag_j: np.array) -> np.array:
    tau = 1 / 4 + 1 / (2 * math.pi) * np.arcsin(
        ((alpha**2) * S)
        / (np.sqrt(((alpha**2) * diag_i + 0.5) * ((alpha**2) * diag_j + 0.5)))
    )
    return tau


def calc_tau_dot(
    alpha: float, S: np.array, diag_i: np.array, diag_j: np.array
) -> np.array:
    tau_dot = (
        (alpha**2)
        / (math.pi)
        * 1
        / np.sqrt(
            (2 * (alpha**2) * diag_i + 1) * (2 * (alpha**2) * diag_j + 1)
            - (4 * (alpha**4) * (S**2))
        )
    )
    return tau_dot

In [None]:
class InnerNode:
    def __init__(self, config, depth, asym=False):
        self.config = config
        self.leaf = False
        self.fc = nn.Linear(
            self.config["input_dim"], self.config["n_tree"], bias=True
        ).to(device)
        nn.init.normal_(self.fc.weight, 0.0, 1.0)  # mean: 0.0, std: 1.0
        nn.init.normal_(self.fc.bias, 0.0, 1.0)  # mean: 0.0, std: 1.0
        self.prob = None
        self.path_prob = None
        self.left = None
        self.right = None
        self.leaf_accumulator = []
        self.asym = asym

        self.build_child(depth)

    def build_child(self, depth):
        if depth < self.config["max_depth"]:
            self.left = InnerNode(self.config, depth + 1, asym=self.asym)
            if self.asym:
                self.right = LeafNode(self.config)
            else:
                self.right = InnerNode(self.config, depth + 1, asym=self.asym)
        else:
            self.left = LeafNode(self.config)
            self.right = LeafNode(self.config)

    def forward(self, x):  # decision function
        return (
            0.5
            * torch.erf(
                self.config["scale"]
                * (
                    torch.matmul(x, self.fc.weight.t())
                    + self.config["bias_scale"] * self.fc.bias
                )
            )
            + 0.5
        )

    def calc_prob(self, x, path_prob):
        self.prob = self.forward(x)  # probability of selecting right node
        path_prob = path_prob.to(device)  # path_prob: [batch_size, n_tree]
        self.path_prob = path_prob
        left_leaf_accumulator = self.left.calc_prob(x, path_prob * (1 - self.prob))
        right_leaf_accumulator = self.right.calc_prob(x, path_prob * self.prob)
        self.leaf_accumulator.extend(left_leaf_accumulator)
        self.leaf_accumulator.extend(right_leaf_accumulator)
        return self.leaf_accumulator

    def reset(self):
        self.leaf_accumulator = []
        self.penalties = []
        self.left.reset()
        self.right.reset()


class SparseInnerNode(InnerNode):
    def __init__(self, config, depth, asym=False, feature_index=None):
        super().__init__(config, depth, asym)
        if feature_index is None:
            self.feature_index = np.random.randint(self.config["input_dim"])
        else:
            self.feature_index = feature_index

        self.fc = nn.Linear(1, self.config["n_tree"], bias=True).to(device)
        nn.init.normal_(self.fc.weight, 0.0, 1.0)  # mean: 0.0, std: 1.0
        nn.init.normal_(self.fc.bias, 0.0, 1.0)  # mean: 0.0, std: 1.0

    def build_child(self, depth):
        if depth < self.config["max_depth"]:
            self.left = SparseInnerNode(self.config, depth + 1, asym=self.asym)
            if self.asym:
                self.right = LeafNode(self.config)
            else:
                self.right = SparseInnerNode(self.config, depth + 1, asym=self.asym)
        else:
            self.left = LeafNode(self.config)
            self.right = LeafNode(self.config)

    def forward(self, x):  # decision function
        return (
            0.5
            * torch.erf(
                self.config["scale"]
                * (
                    torch.matmul(
                        x[:, self.feature_index].unsqueeze(dim=1), self.fc.weight.t()
                    )
                    + self.config["bias_scale"] * self.fc.bias
                )
            )
            + 0.5
        )  # -> [batch_size, n_tree]


class SparseFinetuneInnerNode(InnerNode):
    def __init__(self, config, depth, asym=False, feature_index=None):
        super().__init__(config, depth, asym)
        if feature_index is None:
            self.feature_index = np.random.randint(self.config["input_dim"])
        else:
            self.feature_index = feature_index

        self.fc = nn.Linear(
            self.config["input_dim"], self.config["n_tree"], bias=True
        ).to(device)
        nn.init.normal_(self.fc.weight, 0.0, 1.0)  # mean: 0.0, std: 1.0
        nn.init.normal_(self.fc.bias, 0.0, 1.0)  # mean: 0.0, std: 1.0

        with torch.no_grad():
            for i, w_per_tree in enumerate(self.fc.weight):
                for j, w in enumerate(w_per_tree):
                    if j != feature_index:
                        self.fc.weight[i][j] *= 0

    def build_child(self, depth):
        if depth < self.config["max_depth"]:
            self.left = SparseFinetuneInnerNode(self.config, depth + 1, asym=self.asym)
            if self.asym:
                self.right = LeafNode(self.config)
            else:
                self.right = SparseFinetuneInnerNode(
                    self.config, depth + 1, asym=self.asym
                )
        else:
            self.left = LeafNode(self.config)
            self.right = LeafNode(self.config)


class LeafNode:
    def __init__(self, config):
        self.config = config
        self.leaf = True
        self.param = nn.Parameter(
            torch.randn(self.config["output_dim"], self.config["n_tree"]).to(device)
        )  # [n_class, n_tree]

    def forward(self):
        return self.param

    def calc_prob(self, x, path_prob):
        path_prob = path_prob.to(device)  # [batch_size, n_tree]

        Q = self.forward()
        Q = Q.expand(
            (path_prob.size()[0], self.config["output_dim"], self.config["n_tree"])
        )  # -> [batch_size, n_class, n_tree]
        return [[path_prob, Q]]

    def reset(self):
        pass


class SoftTree(nn.Module):
    def __init__(
        self,
        input_dim: int,
        output_dim: int,
        max_depth: int,
        scale: float,
        bias_scale: float,
        n_tree: int,
        asym: bool = False,
        sparse: bool = False,
    ):
        super(SoftTree, self).__init__()
        config = {
            "input_dim": input_dim,
            "output_dim": output_dim,
            "max_depth": max_depth,
            "scale": scale,
            "bias_scale": bias_scale,
            "n_tree": n_tree,
        }
        self.config = config
        if sparse:
            self.root = SparseInnerNode(config, depth=1, asym=asym)
        else:
            self.root = InnerNode(config, depth=1, asym=asym)

        self.collect_parameters()

    def collect_parameters(self):
        nodes = [self.root]
        self.module_list = nn.ModuleList()
        self.param_list = nn.ParameterList()
        while nodes:
            node = nodes.pop(0)
            if node.leaf:
                param = node.param
                self.param_list.append(param)
            else:
                fc = node.fc
                nodes.append(node.right)
                nodes.append(node.left)
                self.module_list.append(fc)

    def forward(self, x):
        x = torch.squeeze(x, 1).reshape(x.shape[0], self.config["input_dim"])

        path_prob_init = torch.Tensor(torch.ones(x.shape[0], self.config["n_tree"]))

        leaf_accumulator = self.root.calc_prob(x, path_prob_init)
        pred = torch.zeros(x.shape[0], self.config["output_dim"]).to(device)
        for i, (path_prob, Q) in enumerate(leaf_accumulator):  # 2**depth loop
            pred += torch.sum(path_prob.unsqueeze(1) * Q, dim=2)

        pred /= np.sqrt(self.config["n_tree"])  # NTK scaling

        self.root.reset()
        return pred

In [None]:
class SoftTreeExp(nn.Module):
    def __init__(
        self,
        input_dim: int,
        output_dim: int,
        max_depth: int,
        scale: float,
        bias_scale: float,
        n_tree: int,
        asym: bool=False,
        sparse: bool=True,
        finetune: bool=False
    ):
        super(SoftTreeExp, self).__init__()
        config = {
            "input_dim": input_dim,
            "output_dim": output_dim,
            "scale": scale,
            "bias_scale": bias_scale,
            "n_tree": n_tree,
            "max_depth": max_depth
        }
        self.config = config
        
        assert sparse # only for sparse tree
        assert finetune <= sparse
        
        if finetune: # AAI
           #depth=1
            self.root = SparseFinetuneInnerNode(config, depth=1, feature_index=0)
            #depth=2
            self.root.left = SparseFinetuneInnerNode(config, depth=2, feature_index=1)
            self.root.right = SparseFinetuneInnerNode(config, depth=2, feature_index=1)
        else: # AAA
            # depth=1
            self.root = SparseInnerNode(config, depth=1, feature_index=0)
            #depth=2
            self.root.left = SparseInnerNode(config, depth=2, feature_index=1)
            self.root.right = SparseInnerNode(config, depth=2, feature_index=1)
 
        #depth=3
        self.root.left.left = LeafNode(config)
        self.root.left.right = LeafNode(config)
        self.root.right.left = LeafNode(config)
        self.root.right.right = LeafNode(config)

        self.collect_parameters()

    def collect_parameters(self):
        nodes = [self.root]
        self.module_list = nn.ModuleList()
        self.param_list = nn.ParameterList()
        while nodes:
            node = nodes.pop(0)
            if node.leaf:
                param = node.param
                self.param_list.append(param)
            else:
                fc = node.fc
                nodes.append(node.right)
                nodes.append(node.left)
                self.module_list.append(fc)

    def forward(self, x):
        x = torch.squeeze(x, 1).reshape(x.shape[0], self.config["input_dim"])

        path_prob_init = torch.Tensor(torch.ones(x.shape[0], self.config["n_tree"]))

        leaf_accumulator = self.root.calc_prob(x, path_prob_init)
        pred = torch.zeros(x.shape[0], self.config["output_dim"])
        for i, (path_prob, Q) in enumerate(leaf_accumulator):  # 2**depth loop
            pred += torch.sum(path_prob.unsqueeze(1) * Q, dim=2)

        pred /= np.sqrt(self.config["n_tree"])  # NTK scaling

        self.root.reset()
        return pred

## Prepare dataset

In [None]:
import pandas as pd
from sklearn.datasets import load_diabetes
import torch

In [None]:
n_features = 2
n_dataset = 10
train_data = torch.Tensor(
    [np.random.randn(n_features) for i in range(n_dataset)])
target_data = torch.tensor(np.random.randn(train_data.shape[0]))
test_data = torch.Tensor([np.random.randn(n_features) for i in range(10)])

In [None]:
train_data.shape

In [None]:
test_data.shape

## Tracking

In [None]:
def train_net(net, n_epochs, input_data, target, lr, initial_train):
    criterion = nn.MSELoss(reduction='mean')
    optimizer = optim.SGD(net.parameters(), lr=lr)
    for epoch in range(n_epochs):
        optimizer.zero_grad()
        outputs = net(input_data.double()) - initial_train.unsqueeze(1)
        loss = criterion(outputs.view(-1), target) / 2
        loss.backward()
        optimizer.step()

In [None]:
def analytical_evolution_MSE(t, lr, H_train, H_test, initial_train, initial_test, target_data):
    n_train = len(initial_train)

    # first compute the exponential of the matrix (using eigendecomposition):
    lam, P = np.linalg.eig(H_train)  # eig decomposition
    lam = lam.astype(dtype='float64')

    H_train_inv = np.dot(P, np.dot(np.diag(lam**(-1)), P.transpose()))

    # note that you need to rescale the time by n_train, as the 2 paper use different convention for the loss function
    # I am using np arrays, not torch tensors
    exp_matrix = np.dot(
        P, np.dot(np.diag(np.exp(-lr * t * lam / n_train)), P.transpose()))

    # compute the prediction on train set
    pred_train = target_data.cpu().numpy() + np.dot(
        exp_matrix, (initial_train - target_data).cpu().detach().numpy())

    # compute the intermediate matrix used both in prediction on test set and weights evolution
    tmp = np.dot(
        np.eye(lam.size) - exp_matrix,
        (initial_train - target_data).cpu().detach().numpy())
    tmp = np.dot(H_train_inv, tmp)

    # compute prediction on test set
    pred_test = np.dot(H_test, tmp)
    pred_test = initial_test.detach().cpu().numpy().reshape(-1) - pred_test

    return pred_train, pred_test

In [None]:
def hardtree_viz(X: np.array, alpha: float, beta: float, finetune: bool):
    S_list = []
    tau_list = []
    tau_dot_list = []

    for feature_index in range(len(X[0])):
        S = np.outer(X[:, feature_index], X[:, feature_index].T) + beta**2
        S_all = np.matmul(X, X.T) + beta**2
        if finetune:
            S_list.append(S_all)
        else:
            S_list.append(S)

        _diag = [S[i, i] for i in range(len(S))]
        diag_i = np.array(_diag * len(_diag)).reshape(len(_diag), len(_diag))
        diag_j = diag_i.transpose()
        tau_list.append(calc_tau(alpha, S, diag_i, diag_j))
        tau_dot_list.append(calc_tau_dot(alpha, S, diag_i, diag_j))

    K = np.zeros((X.shape[0], X.shape[0]))

    rulelist = [[0, 1], [0, 1], [0, 1], [0, 1]]
    H = np.zeros_like(S_list[0])
    for rules in rulelist:

        # Internal nodes
        for i, s in enumerate(rules):
            ts = rules[0:i] + rules[i + 1:]
            _H_nodes = S_list[s] * tau_dot_list[s]
            for t in ts:
                _H_nodes *= tau_list[t]
            K += _H_nodes
        _H_leaves = np.ones_like(K)

        # Leaves
        for tau in [tau_list[i] for i in rules]:
            _H_leaves *= tau
        K += _H_leaves

    return K

In [None]:
def plot_trajectory(finetune: bool):
    alpha = 2.0
    beta = 0.5
    depth = -1

    H_analytical_train = hardtree_viz(
        train_data.numpy(), alpha=alpha, beta=beta, finetune=finetune
    )
    H_analytical_test = hardtree_viz(
        torch.cat([train_data, test_data]).numpy(),
        alpha=alpha,
        beta=beta,
        finetune=finetune,
    )[len(train_data) :, 0 : len(train_data)]

    ptrain_empiricals1, ptest_empiricals1 = [], []
    ptrain_empiricals2, ptest_empiricals2 = [], []

    for n_tree in (16, 1024):
        ptrain_empirical1, ptest_empirical1 = [], []
        ptrain_empirical2, ptest_empirical2 = [], []
        ptrain_analytical1, ptest_analytical1 = [], []
        ptrain_analytical2, ptest_analytical2 = [], []

        t_max = 1000
        t_step = 10
        lr = 0.1
        t_list = np.arange(t_step, t_max + t_step, t_step)

        st1 = SoftTreeExp(
            input_dim=train_data.shape[1],
            output_dim=1,
            scale=alpha,
            bias_scale=beta,
            n_tree=n_tree,
            max_depth=depth,
            finetune=finetune,
        )

        initial_train1 = st1.forward(train_data).reshape(-1)
        initial_test1 = st1.forward(test_data).reshape(-1)

        ptrain_analytical1.append(torch.zeros_like(initial_train1).detach().numpy())

        ptrain_empirical1.append(torch.zeros_like(initial_train1).detach().numpy())

        ptest_analytical1.append(torch.zeros_like(initial_test1).detach().numpy())
        ptest_empirical1.append(torch.zeros_like(initial_test1).detach().numpy())

        for t in tqdm(t_list):
            train_net(st1, t_step, train_data, target_data, lr, initial_train1.detach())

            ptrain_empirical1.append(
                st1.forward(train_data).detach().cpu().numpy().reshape(-1)
                - initial_train1.detach().numpy()
            )
            ptest_empirical1.append(
                st1.forward(test_data).detach().cpu().numpy().reshape(-1)
                - initial_test1.detach().numpy()
            )

            pred_train, pred_test = analytical_evolution_MSE(
                t=t,
                lr=lr,
                H_train=H_analytical_train,
                H_test=H_analytical_test,
                initial_train=torch.zeros_like(initial_train1),
                initial_test=torch.zeros_like(initial_test1),
                target_data=target_data,
            )
            ptrain_analytical1.append(pred_train)
            ptest_analytical1.append(pred_test)

        ptrain_empiricals1.append(ptrain_empirical1)
        ptest_empiricals1.append(ptest_empirical1)

    cmap = plt.cm.nipy_spectral
    t_list = np.arange(0, t_max + t_step, t_step)

    for i in range(len(ptest_analytical1[0])):
        plt.plot(
            t_list,
            np.array(ptest_analytical1)[:, i],
            color=cmap(i / len(ptest_analytical1[0])),
            alpha=0.3,
            linewidth=5,
        )
        plt.plot(
            t_list,
            np.array(ptest_empiricals1[0])[:, i],
            color=cmap(i / len(ptest_analytical1[0])),
            linestyle="dotted",
        )
        plt.plot(
            t_list,
            np.array(ptest_empiricals1[1])[:, i],
            color=cmap(i / len(ptest_analytical1[0])),
            linestyle="dashed",
        )
    if finetune:
        plt.xlabel("$\\tau$ (iteration)")
    plt.ylabel("Model output")
    plt.title(f"{'AAI' if finetune else 'AAA'}")
    plt.ylim(-2.0, 2.0)
    plt.grid(linestyle="dotted")

In [None]:
plt.figure(figsize=(7, 7))
plt.subplot(2, 1, 1)
plot_trajectory(finetune=False)
plt.subplot(2, 1, 2)
plot_trajectory(finetune=True)

plt.subplot(2, 1, 2)
plt.plot([], [], color="black", label="Analytical", linewidth=5, alpha=0.3)
plt.plot([], [], color="black", label="$M=16$", linestyle="dotted")
plt.plot([], [], color="black", label="$M=1024$", linestyle="dashed")
plt.legend(ncol=3, bbox_to_anchor=(0.5, -0.4), fontsize=15, loc="center", borderaxespad=0)

plt.tight_layout()

plt.savefig("./figures/trajectory.pdf", bbox_inches="tight", pad_inches=0.10)