In [None]:
def set_latex():
    for i in range(2):
        import matplotlib
        import matplotlib.pyplot as plt

        plt.rc('text', usetex=True)
        plt.rc('font', family='serif')

        plt.style.use("default")
        plt.rcParams["font.size"] = 15

        plt.rcParams['font.family'] = 'Times New Roman'
        plt.rcParams['mathtext.fontset'] = 'stix'

        try:
            del matplotlib.font_manager.weight_dict['roman']
            matplotlib.font_manager._rebuild()
        except:
            pass

In [None]:
set_latex()

In [None]:
import copy
import math
import os
import warnings

import matplotlib
import matplotlib.cm as cm
import matplotlib.colors as colors
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from scipy import special
from sklearn.metrics import mean_squared_error
from tqdm.notebook import tqdm

warnings.filterwarnings("ignore")
if os.environ.get("GPU"):
    device = os.environ.get("GPU") if torch.cuda.is_available() else "cpu"
else:
    device = "cuda" if torch.cuda.is_available() else "cpu"

torch.set_default_dtype(torch.float64)

In [None]:
class InnerNode:
    def __init__(self, config, depth, asym=False):
        self.config = config
        self.leaf = False
        self.fc = nn.Linear(
            self.config["input_dim"], self.config["n_tree"], bias=True
        ).to(device)
        nn.init.normal_(self.fc.weight, 0.0, 1.0)  # mean: 0.0, std: 1.0
        nn.init.normal_(self.fc.bias, 0.0, 1.0)  # mean: 0.0, std: 1.0
        self.prob = None
        self.path_prob = None
        self.left = None
        self.right = None
        self.leaf_accumulator = []
        self.asym = asym

        self.build_child(depth)

    def build_child(self, depth):
        if depth < self.config["max_depth"]:
            self.left = InnerNode(self.config, depth + 1, asym=self.asym)
            if self.asym:
                self.right = LeafNode(self.config)
            else:
                self.right = InnerNode(self.config, depth + 1, asym=self.asym)
        else:
            self.left = LeafNode(self.config)
            self.right = LeafNode(self.config)

    def forward(self, x):  # decision function
        return (
            0.5
            * torch.erf(
                self.config["scale"]
                * (
                    torch.matmul(x, self.fc.weight.t())
                    + self.config["bias_scale"] * self.fc.bias
                )
            )
            + 0.5
        )

    def calc_prob(self, x, path_prob):
        self.prob = self.forward(x)  # probability of selecting right node
        path_prob = path_prob.to(device)  # path_prob: [batch_size, n_tree]
        self.path_prob = path_prob
        left_leaf_accumulator = self.left.calc_prob(x, path_prob * (1 - self.prob))
        right_leaf_accumulator = self.right.calc_prob(x, path_prob * self.prob)
        self.leaf_accumulator.extend(left_leaf_accumulator)
        self.leaf_accumulator.extend(right_leaf_accumulator)
        return self.leaf_accumulator

    def reset(self):
        self.leaf_accumulator = []
        self.penalties = []
        self.left.reset()
        self.right.reset()


class SparseInnerNode(InnerNode):
    def __init__(self, config, depth, asym=False, feature_index=None):
        super().__init__(config, depth, asym)
        if feature_index is None:
            self.feature_index = np.random.randint(self.config["input_dim"])
        else:
            self.feature_index = feature_index

        self.fc = nn.Linear(1, self.config["n_tree"], bias=True).to(device)
        nn.init.normal_(self.fc.weight, 0.0, 1.0)  # mean: 0.0, std: 1.0
        nn.init.normal_(self.fc.bias, 0.0, 1.0)  # mean: 0.0, std: 1.0

    def build_child(self, depth):
        if depth < self.config["max_depth"]:
            self.left = SparseInnerNode(self.config, depth + 1, asym=self.asym)
            if self.asym:
                self.right = LeafNode(self.config)
            else:
                self.right = SparseInnerNode(self.config, depth + 1, asym=self.asym)
        else:
            self.left = LeafNode(self.config)
            self.right = LeafNode(self.config)

    def forward(self, x):  # decision function
        return (
            0.5
            * torch.erf(
                self.config["scale"]
                * (
                    torch.matmul(
                        x[:, self.feature_index].unsqueeze(dim=1), self.fc.weight.t()
                    )
                    + self.config["bias_scale"] * self.fc.bias
                )
            )
            + 0.5
        )  # -> [batch_size, n_tree]


class SparseFinetuneInnerNode(InnerNode):
    def __init__(self, config, depth, asym=False, feature_index=None):
        super().__init__(config, depth, asym)
        if feature_index is None:
            self.feature_index = np.random.randint(self.config["input_dim"])
        else:
            self.feature_index = feature_index

        self.fc = nn.Linear(
            self.config["input_dim"], self.config["n_tree"], bias=True
        ).to(device)
        nn.init.normal_(self.fc.weight, 0.0, 1.0)  # mean: 0.0, std: 1.0
        nn.init.normal_(self.fc.bias, 0.0, 1.0)  # mean: 0.0, std: 1.0

        with torch.no_grad():
            for i, w_per_tree in enumerate(self.fc.weight):
                for j, w in enumerate(w_per_tree):
                    if j != feature_index:
                        self.fc.weight[i][j] *= 0

    def build_child(self, depth):
        if depth < self.config["max_depth"]:
            self.left = SparseFinetuneInnerNode(self.config, depth + 1, asym=self.asym)
            if self.asym:
                self.right = LeafNode(self.config)
            else:
                self.right = SparseFinetuneInnerNode(
                    self.config, depth + 1, asym=self.asym
                )
        else:
            self.left = LeafNode(self.config)
            self.right = LeafNode(self.config)


class LeafNode:
    def __init__(self, config):
        self.config = config
        self.leaf = True
        self.param = nn.Parameter(
            torch.randn(self.config["output_dim"], self.config["n_tree"]).to(device)
        )  # [n_class, n_tree]

    def forward(self):
        return self.param

    def calc_prob(self, x, path_prob):
        path_prob = path_prob.to(device)  # [batch_size, n_tree]

        Q = self.forward()
        Q = Q.expand(
            (path_prob.size()[0], self.config["output_dim"], self.config["n_tree"])
        )  # -> [batch_size, n_class, n_tree]
        return [[path_prob, Q]]

    def reset(self):
        pass

In [None]:
class SoftTreeExp(nn.Module):
    def __init__(
        self,
        input_dim: int,
        output_dim: int,
        max_depth: int,
        scale: float,
        bias_scale: float,
        n_tree: int,
        arch: int,
        oblivious: bool,
        asym: bool = False,
        sparse: bool = True,
        finetune: bool = False,
    ):
        super(SoftTreeExp, self).__init__()
        config = {
            "input_dim": input_dim,
            "output_dim": output_dim,
            "scale": scale,
            "bias_scale": bias_scale,
            "n_tree": n_tree,
            "max_depth": max_depth,
        }
        self.config = config

        assert sparse  # only for sparse tree
        assert finetune <= sparse

        if finetune:  # AAI
            if arch == 0:
                # depth=1
                self.root = SparseFinetuneInnerNode(config, depth=1, feature_index=0)
                # depth=2
                self.root.left = LeafNode(config)
                self.root.right = SparseFinetuneInnerNode(
                    config, depth=2, feature_index=1
                )
                # depth=3
                self.root.right.left = LeafNode(config)
                self.root.right.right = LeafNode(config)
            elif arch == 1:
                # depth=1
                self.root = SparseFinetuneInnerNode(config, depth=1, feature_index=0)
                # depth=2
                self.root.left = LeafNode(config)
                self.root.right = LeafNode(config)
            elif arch == 2:
                # depth=1
                self.root = SparseFinetuneInnerNode(config, depth=1, feature_index=0)
                # depth=2
                self.root.left = SparseFinetuneInnerNode(
                    config, depth=2, feature_index=1
                )
                self.root.right = SparseFinetuneInnerNode(
                    config, depth=2, feature_index=1
                )
                if oblivious:
                    self.root.right.fc.weight = self.root.left.fc.weight
                    self.root.right.fc.bias = self.root.left.fc.bias
                # depth=3
                self.root.left.left = LeafNode(config)
                self.root.left.right = LeafNode(config)
                self.root.right.left = LeafNode(config)
                self.root.right.right = LeafNode(config)
            else:
                raise ValueError

        else:  # AAA
            if arch == 0:
                # depth=1
                self.root = SparseInnerNode(config, depth=1, feature_index=0)
                # depth=2
                self.root.left = LeafNode(config)
                self.root.right = SparseInnerNode(config, depth=2, feature_index=1)
                # depth=3
                self.root.right.left = LeafNode(config)
                self.root.right.right = LeafNode(config)
            elif arch == 1:
                # depth=1
                self.root = SparseInnerNode(config, depth=1, feature_index=0)
                # depth=2
                self.root.left = LeafNode(config)
                self.root.right = LeafNode(config)
            elif arch == 2:
                # depth=1
                self.root = SparseInnerNode(config, depth=1, feature_index=0)
                # depth=2
                self.root.left = SparseInnerNode(config, depth=2, feature_index=1)
                self.root.right = SparseInnerNode(config, depth=2, feature_index=1)
                if oblivious:
                    self.root.right.fc.weight = self.root.left.fc.weight
                    self.root.right.fc.bias = self.root.left.fc.bias
                # depth=3
                self.root.left.left = LeafNode(config)
                self.root.left.right = LeafNode(config)
                self.root.right.left = LeafNode(config)
                self.root.right.right = LeafNode(config)
            else:
                raise ValueError

        self.collect_parameters()

    def collect_parameters(self):
        nodes = [self.root]
        self.module_list = nn.ModuleList()
        self.param_list = nn.ParameterList()
        while nodes:
            node = nodes.pop(0)
            if node.leaf:
                param = node.param
                self.param_list.append(param)
            else:
                fc = node.fc
                nodes.append(node.right)
                nodes.append(node.left)
                self.module_list.append(fc)

    def forward(self, x):
        x = torch.squeeze(x, 1).reshape(x.shape[0], self.config["input_dim"])

        path_prob_init = torch.Tensor(torch.ones(x.shape[0], self.config["n_tree"]))

        leaf_accumulator = self.root.calc_prob(x, path_prob_init)
        pred = torch.zeros(x.shape[0], self.config["output_dim"])
        for i, (path_prob, Q) in enumerate(leaf_accumulator):  # 2**depth loop
            pred += torch.sum(path_prob.unsqueeze(1) * Q, dim=2)

        pred /= np.sqrt(self.config["n_tree"])  # NTK scaling

        self.root.reset()
        return pred

In [None]:
class SoftTreeMerge(nn.Module):
    def __init__(
        self,
        st1: nn.Module,
        st2: nn.Module
    ):
        super(SoftTreeMerge, self).__init__()
        self.st1 = st1
        self.st2 = st2

    def forward(self, x):
        x1 = self.st1.forward(x)
        x2 = self.st2.forward(x)
        return x1+x2

## Prepare dataset

In [None]:
import pandas as pd
from sklearn.datasets import load_diabetes
import torch

In [None]:
n_features = 2
n_dataset = 10
train_data = torch.Tensor([np.random.randn(n_features) for i in range(n_dataset)])
target_data = torch.tensor(np.random.randn(train_data.shape[0]))
test_data = torch.Tensor([np.random.randn(n_features) for i in range(10)])

In [None]:
train_data.shape

In [None]:
test_data.shape

## Tracking

In [None]:
def train_net(net, n_epochs, input_data, target, lr, initial_train):
    criterion = nn.MSELoss(reduction='mean')
    optimizer = optim.SGD(net.parameters(), lr=lr)
    for epoch in range(n_epochs):
        optimizer.zero_grad()
        outputs = net(input_data.double()) - initial_train.unsqueeze(1)
        loss = criterion(outputs.view(-1), target) / 2
        loss.backward()
        optimizer.step()

In [None]:
def plot_trajectory(st: nn.Module, linestyle: str, alpha: float, linewidth: float):
    ptrain_empirical1, ptest_empirical1 = [], []
    t_max = 1000
    t_step = 10
    lr = 0.1
    t_list = np.arange(t_step, t_max + t_step, t_step)

    initial_train1 = st.forward(train_data).reshape(-1)
    initial_test1 = st.forward(test_data).reshape(-1)

    ptrain_empirical1.append(torch.zeros_like(initial_train1).detach().numpy())
    ptest_empirical1.append(torch.zeros_like(initial_test1).detach().numpy())

    for t in tqdm(t_list):
        train_net(st, t_step, train_data, target_data, lr,
                  initial_train1.detach())
        ptrain_empirical1.append(
            st.forward(train_data).detach().cpu().numpy().reshape(-1) -
            initial_train1.detach().numpy())
        ptest_empirical1.append(
            st.forward(test_data).detach().cpu().numpy().reshape(-1) -
            initial_test1.detach().numpy())

    cmap = plt.cm.nipy_spectral
    t_list = np.arange(0, t_max + t_step, t_step)

    for i in range(len(ptest_empirical1[0])):
        plt.plot(t_list,
                 np.array(ptest_empirical1)[:, i],
                 color=cmap(i / len(ptest_empirical1[0])),
                 linestyle=linestyle,
                 alpha=alpha,
                 linewidth=linewidth)

In [None]:
alpha = 1.0
beta = 0.5
depth = -1

plt.figure(figsize=(9,7))

# ^^^^^^^^^^^^^^^^^^
finetune = False # AAA

plt.subplot(2,2,1)
n_tree = 8
st1 = SoftTreeExp(input_dim=train_data.shape[1], output_dim=1, scale=alpha, bias_scale=beta, n_tree=n_tree, arch=0, oblivious=None, max_depth=depth, finetune=finetune)
st2 = SoftTreeExp(input_dim=train_data.shape[1], output_dim=1, scale=alpha, bias_scale=beta, n_tree=n_tree, arch=0, oblivious=None, max_depth=depth, finetune=finetune)
st = SoftTreeMerge(st1, st2)
plot_trajectory(st, linestyle="solid", linewidth=1, alpha=1.0)

st1 = SoftTreeExp(input_dim=train_data.shape[1], output_dim=1, scale=alpha, bias_scale=beta, n_tree=n_tree, arch=1, oblivious=True, max_depth=depth, finetune=finetune)
st2 = SoftTreeExp(input_dim=train_data.shape[1], output_dim=1, scale=alpha, bias_scale=beta, n_tree=n_tree, arch=2, oblivious=True, max_depth=depth, finetune=finetune)
st = SoftTreeMerge(st1, st2)
plot_trajectory(st, linestyle="dashed", linewidth=3, alpha=0.5)

plt.xlabel("$\\tau$ (iteration)")
plt.ylabel("Model output")
plt.title("AAA $(M=16)$")
plt.ylim(-2.0, 2.0)
plt.grid(linestyle="dotted")

# ---------------------------

plt.subplot(2,2,2)
n_tree = 512
st1 = SoftTreeExp(input_dim=train_data.shape[1], output_dim=1, scale=alpha, bias_scale=beta, n_tree=n_tree, arch=0, oblivious=None, max_depth=depth, finetune=finetune)
st2 = SoftTreeExp(input_dim=train_data.shape[1], output_dim=1, scale=alpha, bias_scale=beta, n_tree=n_tree, arch=0, oblivious=None, max_depth=depth, finetune=finetune)
st = SoftTreeMerge(st1, st2)
plot_trajectory(st, linestyle="solid", linewidth=1, alpha=1.0)

st1 = SoftTreeExp(input_dim=train_data.shape[1], output_dim=1, scale=alpha, bias_scale=beta, n_tree=n_tree, arch=1, oblivious=True, max_depth=depth, finetune=finetune)
st2 = SoftTreeExp(input_dim=train_data.shape[1], output_dim=1, scale=alpha, bias_scale=beta, n_tree=n_tree, arch=2, oblivious=True, max_depth=depth, finetune=finetune)
st = SoftTreeMerge(st1, st2)
plot_trajectory(st, linestyle="dashed", linewidth=3, alpha=0.5)

plt.xlabel("$\\tau$ (iteration)")
plt.title("AAA $(M=1024)$")
plt.ylim(-2.0, 2.0)
plt.grid(linestyle="dotted")

# ^^^^^^^^^^^^^^^^^^
finetune = True # AAI

plt.subplot(2,2,3)
n_tree = 8
st1 = SoftTreeExp(input_dim=train_data.shape[1], output_dim=1, scale=alpha, bias_scale=beta, n_tree=n_tree, arch=0, oblivious=None, max_depth=depth, finetune=finetune)
st2 = SoftTreeExp(input_dim=train_data.shape[1], output_dim=1, scale=alpha, bias_scale=beta, n_tree=n_tree, arch=0, oblivious=None, max_depth=depth, finetune=finetune)
st = SoftTreeMerge(st1, st2)
plot_trajectory(st, linestyle="solid", linewidth=1, alpha=1.0)

st1 = SoftTreeExp(input_dim=train_data.shape[1], output_dim=1, scale=alpha, bias_scale=beta, n_tree=n_tree, arch=1, oblivious=True, max_depth=depth, finetune=finetune)
st2 = SoftTreeExp(input_dim=train_data.shape[1], output_dim=1, scale=alpha, bias_scale=beta, n_tree=n_tree, arch=2, oblivious=True, max_depth=depth, finetune=finetune)
st = SoftTreeMerge(st1, st2)
plot_trajectory(st, linestyle="dashed", linewidth=3, alpha=0.5)

plt.xlabel("$\\tau$ (iteration)")
plt.ylabel("Model output")
plt.title("AAI $(M=16)$")
plt.ylim(-2.0, 2.0)
plt.grid(linestyle="dotted")

# ---------------------------

plt.subplot(2,2,4)
n_tree = 512

st1 = SoftTreeExp(input_dim=train_data.shape[1], output_dim=1, scale=alpha, bias_scale=beta, n_tree=n_tree, arch=0, oblivious=None, max_depth=depth, finetune=finetune)
st2 = SoftTreeExp(input_dim=train_data.shape[1], output_dim=1, scale=alpha, bias_scale=beta, n_tree=n_tree, arch=0, oblivious=None, max_depth=depth, finetune=finetune)
st = SoftTreeMerge(st1, st2)
plot_trajectory(st, linestyle="solid", linewidth=1, alpha=1.0)

st1 = SoftTreeExp(input_dim=train_data.shape[1], output_dim=1, scale=alpha, bias_scale=beta, n_tree=n_tree, arch=1, oblivious=True, max_depth=depth, finetune=finetune)
st2 = SoftTreeExp(input_dim=train_data.shape[1], output_dim=1, scale=alpha, bias_scale=beta, n_tree=n_tree, arch=2, oblivious=True, max_depth=depth, finetune=finetune)
st = SoftTreeMerge(st1, st2)
plot_trajectory(st, linestyle="dashed", linewidth=3, alpha=0.5)

plt.xlabel("$\\tau$ (iteration)")
plt.title("AAI $(M=1024)$")
plt.ylim(-2.0, 2.0)
plt.grid(linestyle="dotted")

plt.tight_layout()

plt.subplot(2,2,4)
plt.plot([], [], color="black", label="Non-Oblivious", linestyle="dashed", alpha=0.5, linewidth=3)
plt.plot([], [], color="black", label="Oblivious", linestyle="solid", alpha=1.0,  linewidth=1)

plt.legend(ncol=3, bbox_to_anchor=(-0.05, -0.45), fontsize=15, loc="center", borderaxespad=0)

plt.savefig("./figures/trajectory_oblivious_conversion.pdf", bbox_inches="tight", pad_inches=0.10)