In [None]:
def set_latex():
    for i in range(2):
        import matplotlib
        import matplotlib.pyplot as plt

        plt.rc('text', usetex=True)
        plt.rc('font', family='serif')

        plt.style.use("default")
        plt.rcParams["font.size"] = 15

        plt.rcParams['font.family'] = 'Times New Roman'
        plt.rcParams['mathtext.fontset'] = 'stix'

        try:
            del matplotlib.font_manager.weight_dict['roman']
            matplotlib.font_manager._rebuild()
        except:
            pass

In [None]:
set_latex()

In [None]:
import math
import os
import warnings

import matplotlib
import matplotlib.cm as cm
import matplotlib.colors as colors
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
from scipy import special
from sklearn.metrics import mean_squared_error
from tqdm.notebook import tqdm

warnings.filterwarnings("ignore")

In [None]:
def gradient(outputs,
             inputs,
             grad_outputs=None,
             retain_graph=None,
             create_graph=False):
    if torch.is_tensor(inputs):
        inputs = [inputs]
    else:
        inputs = list(inputs)
    grads = torch.autograd.grad(
        outputs,
        inputs,
        grad_outputs,
        allow_unused=True,
        retain_graph=retain_graph,
        create_graph=create_graph,
    )
    grads = [
        x if x is not None else torch.zeros_like(y)
        for x, y in zip(grads, inputs)
    ]
    return torch.cat([x.contiguous().view(-1) for x in grads])


def compute_kernels(f, xtr, parameters=None):
    if parameters is None:
        parameters = list(f.parameters())

    ktrtr = xtr.new_zeros(len(xtr), len(xtr))

    params = []
    current = []
    for p in sorted(parameters, key=lambda p: p.numel(), reverse=True):
        current.append(p)
        if sum(p.numel() for p in current) > 2e9 // (8 * (len(xtr))):
            if len(current) > 1:
                params.append(current[:-1])
                current = current[-1:]
            else:
                params.append(current)
                current = []
    if len(current) > 0:
        params.append(current)

    for i, p in enumerate(params):
        jtr = xtr.new_empty(len(xtr), sum(u.numel() for u in p))  # (P, N~)

        for j, x in enumerate(xtr):
            jtr[j] = gradient(f(x[None]), p)  # (N~)

        ktrtr.add_(jtr @ jtr.t())
        del jtr

    return ktrtr


def rotation_o(u, t, deg=False):
    if deg == True:
        t = np.deg2rad(t)
    R = np.array([[np.cos(t), -np.sin(t)], [np.sin(t), np.cos(t)]])
    return np.dot(R, u)

In [None]:
def get_kernel(alpha, beta, u, finetune, arch):
    kernel = []
    taus = []
    tau_dots = []
    inner_product = []
    for i in range(360):
        Ru = rotation_o(u, i * np.pi / 180)
        H = hardtree_viz(np.vstack([u, Ru]),
                         alpha=alpha,
                         beta=beta,
                         finetune=finetune,
                         arch=arch)
        kernel.append(H[1, 0])
        inner_product.append(np.dot(u, Ru))
    return kernel, inner_product

In [None]:
if os.environ.get("GPU"):
    device = os.environ.get("GPU") if torch.cuda.is_available() else "cpu"
else:
    device = "cuda" if torch.cuda.is_available() else "cpu"

torch.set_default_dtype(torch.float64)


class InnerNode:
    def __init__(self, config, depth, asym=False):
        self.config = config
        self.leaf = False
        self.fc = nn.Linear(
            self.config["input_dim"], self.config["n_tree"], bias=True
        ).to(device)
        nn.init.normal_(self.fc.weight, 0.0, 1.0)  # mean: 0.0, std: 1.0
        nn.init.normal_(self.fc.bias, 0.0, 1.0)  # mean: 0.0, std: 1.0
        self.prob = None
        self.path_prob = None
        self.left = None
        self.right = None
        self.leaf_accumulator = []
        self.asym = asym

        self.build_child(depth)

    def build_child(self, depth):
        if depth < self.config["max_depth"]:
            self.left = InnerNode(self.config, depth + 1, asym=self.asym)
            if self.asym:
                self.right = LeafNode(self.config)
            else:
                self.right = InnerNode(self.config, depth + 1, asym=self.asym)
        else:
            self.left = LeafNode(self.config)
            self.right = LeafNode(self.config)

    def forward(self, x):  # decision function
        return (
            0.5
            * torch.erf(
                self.config["scale"]
                * (
                    torch.matmul(x, self.fc.weight.t())
                    + self.config["bias_scale"] * self.fc.bias
                )
            )
            + 0.5
        )

    def calc_prob(self, x, path_prob):
        self.prob = self.forward(x)  # probability of selecting right node
        path_prob = path_prob.to(device)  # path_prob: [batch_size, n_tree]
        self.path_prob = path_prob
        left_leaf_accumulator = self.left.calc_prob(x, path_prob * (1 - self.prob))
        right_leaf_accumulator = self.right.calc_prob(x, path_prob * self.prob)
        self.leaf_accumulator.extend(left_leaf_accumulator)
        self.leaf_accumulator.extend(right_leaf_accumulator)
        return self.leaf_accumulator

    def reset(self):
        self.leaf_accumulator = []
        self.penalties = []
        self.left.reset()
        self.right.reset()


class SparseInnerNode(InnerNode):
    def __init__(self, config, depth, asym=False, feature_index=None):
        super().__init__(config, depth, asym)
        if feature_index is None:
            self.feature_index = np.random.randint(self.config["input_dim"])
        else:
            self.feature_index = feature_index

        self.fc = nn.Linear(1, self.config["n_tree"], bias=True).to(device)
        nn.init.normal_(self.fc.weight, 0.0, 1.0)  # mean: 0.0, std: 1.0
        nn.init.normal_(self.fc.bias, 0.0, 1.0)  # mean: 0.0, std: 1.0

    def build_child(self, depth):
        if depth < self.config["max_depth"]:
            self.left = SparseInnerNode(self.config, depth + 1, asym=self.asym)
            if self.asym:
                self.right = LeafNode(self.config)
            else:
                self.right = SparseInnerNode(self.config, depth + 1, asym=self.asym)
        else:
            self.left = LeafNode(self.config)
            self.right = LeafNode(self.config)

    def forward(self, x):  # decision function
        return (
            0.5
            * torch.erf(
                self.config["scale"]
                * (
                    torch.matmul(
                        x[:, self.feature_index].unsqueeze(dim=1), self.fc.weight.t()
                    )
                    + self.config["bias_scale"] * self.fc.bias
                )
            )
            + 0.5
        )  # -> [batch_size, n_tree]


class SparseFinetuneInnerNode(InnerNode):
    def __init__(self, config, depth, asym=False, feature_index=None):
        super().__init__(config, depth, asym)
        if feature_index is None:
            self.feature_index = np.random.randint(self.config["input_dim"])
        else:
            self.feature_index = feature_index

        self.fc = nn.Linear(
            self.config["input_dim"], self.config["n_tree"], bias=True
        ).to(device)
        nn.init.normal_(self.fc.weight, 0.0, 1.0)  # mean: 0.0, std: 1.0
        nn.init.normal_(self.fc.bias, 0.0, 1.0)  # mean: 0.0, std: 1.0

        with torch.no_grad():
            for i, w_per_tree in enumerate(self.fc.weight):
                for j, w in enumerate(w_per_tree):
                    if j != feature_index:
                        self.fc.weight[i][j] *= 0

    def build_child(self, depth):
        if depth < self.config["max_depth"]:
            self.left = SparseFinetuneInnerNode(self.config, depth + 1, asym=self.asym)
            if self.asym:
                self.right = LeafNode(self.config)
            else:
                self.right = SparseFinetuneInnerNode(
                    self.config, depth + 1, asym=self.asym
                )
        else:
            self.left = LeafNode(self.config)
            self.right = LeafNode(self.config)


class LeafNode:
    def __init__(self, config):
        self.config = config
        self.leaf = True
        self.param = nn.Parameter(
            torch.randn(self.config["output_dim"], self.config["n_tree"]).to(device)
        )  # [n_class, n_tree]

    def forward(self):
        return self.param

    def calc_prob(self, x, path_prob):
        path_prob = path_prob.to(device)  # [batch_size, n_tree]

        Q = self.forward()
        Q = Q.expand(
            (path_prob.size()[0], self.config["output_dim"], self.config["n_tree"])
        )  # -> [batch_size, n_class, n_tree]
        return [[path_prob, Q]]

    def reset(self):
        pass


class SoftTree(nn.Module):
    def __init__(
        self,
        input_dim: int,
        output_dim: int,
        max_depth: int,
        scale: float,
        bias_scale: float,
        n_tree: int,
        asym: bool = False,
        sparse: bool = False,
    ):
        super(SoftTree, self).__init__()
        config = {
            "input_dim": input_dim,
            "output_dim": output_dim,
            "max_depth": max_depth,
            "scale": scale,
            "bias_scale": bias_scale,
            "n_tree": n_tree,
        }
        self.config = config
        if sparse:
            self.root = SparseInnerNode(config, depth=1, asym=asym)
        else:
            self.root = InnerNode(config, depth=1, asym=asym)

        self.collect_parameters()

    def collect_parameters(self):
        nodes = [self.root]
        self.module_list = nn.ModuleList()
        self.param_list = nn.ParameterList()
        while nodes:
            node = nodes.pop(0)
            if node.leaf:
                param = node.param
                self.param_list.append(param)
            else:
                fc = node.fc
                nodes.append(node.right)
                nodes.append(node.left)
                self.module_list.append(fc)

    def forward(self, x):
        x = torch.squeeze(x, 1).reshape(x.shape[0], self.config["input_dim"])

        path_prob_init = torch.Tensor(torch.ones(x.shape[0], self.config["n_tree"]))

        leaf_accumulator = self.root.calc_prob(x, path_prob_init)
        pred = torch.zeros(x.shape[0], self.config["output_dim"]).to(device)
        for i, (path_prob, Q) in enumerate(leaf_accumulator):  # 2**depth loop
            pred += torch.sum(path_prob.unsqueeze(1) * Q, dim=2)

        pred /= np.sqrt(self.config["n_tree"])  # NTK scaling

        self.root.reset()
        return pred

In [None]:
class SoftTreeExp(nn.Module):

    def __init__(
        self,
        input_dim: int,
        output_dim: int,
        max_depth: int,
        scale: float,
        bias_scale: float,
        n_tree: int,
        asym: bool = False,
        sparse: bool = True,
        finetune: bool = False,
        arch: int = 0,
    ):
        super(SoftTreeExp, self).__init__()
        config = {
            "input_dim": input_dim,
            "output_dim": output_dim,
            "scale": scale,
            "bias_scale": bias_scale,
            "n_tree": n_tree,
            "max_depth": max_depth
        }
        self.config = config

        assert sparse  # only for sparse tree
        assert finetune <= sparse

        if finetune:  # AAI
            #depth=1
            self.root = SparseFinetuneInnerNode(config,
                                                depth=1,
                                                feature_index=0)
            #depth=2
            self.root.left = SparseFinetuneInnerNode(
                config, depth=2, feature_index=0 if arch == 0 else 1)
            self.root.right = SparseFinetuneInnerNode(
                config, depth=2, feature_index=0 if arch == 0 else 1)
        else:  # AAA
            # depth=1
            self.root = SparseInnerNode(config, depth=1, feature_index=0)
            #depth=2
            self.root.left = SparseInnerNode(
                config, depth=2, feature_index=0 if arch == 0 else 1)
            self.root.right = SparseInnerNode(
                config, depth=2, feature_index=0 if arch == 0 else 1)

        #depth=3
        self.root.left.left = LeafNode(config)
        self.root.left.right = LeafNode(config)
        self.root.right.left = LeafNode(config)
        self.root.right.right = LeafNode(config)

        self.collect_parameters()

    def collect_parameters(self):
        nodes = [self.root]
        self.module_list = nn.ModuleList()
        self.param_list = nn.ParameterList()
        while nodes:
            node = nodes.pop(0)
            if node.leaf:
                param = node.param
                self.param_list.append(param)
            else:
                fc = node.fc
                nodes.append(node.right)
                nodes.append(node.left)
                self.module_list.append(fc)

    def forward(self, x):
        x = torch.squeeze(x, 1).reshape(x.shape[0], self.config["input_dim"])

        path_prob_init = torch.Tensor(
            torch.ones(x.shape[0], self.config["n_tree"]))

        leaf_accumulator = self.root.calc_prob(x, path_prob_init)
        pred = torch.zeros(x.shape[0], self.config["output_dim"])
        for i, (path_prob, Q) in enumerate(leaf_accumulator):  # 2**depth loop
            pred += torch.sum(path_prob.unsqueeze(1) * Q, dim=2)

        pred /= np.sqrt(self.config["n_tree"])  # NTK scaling

        self.root.reset()
        return pred

In [None]:
def plot_kernel(
    alpha, beta, n_tree, color, u: tuple = (1, 0), finetune: bool = False, arch: int = 0
):
    st = SoftTreeExp(
        input_dim=2,
        output_dim=1,
        max_depth=2,
        scale=alpha,
        bias_scale=beta,
        n_tree=n_tree,
        sparse=True,
        finetune=finetune,
        arch=arch,
    )
    res = []
    inner_product = []
    for i in tqdm(range(180), leave=False):
        Ru = rotation_o(u, i * np.pi / 180)
        x = torch.Tensor([u, Ru]).to("cuda" if torch.cuda.is_available() else "cpu")
        K = compute_kernels(st, x)
        res.append(K[1, 0].tolist())
        inner_product.append(np.dot(u, Ru))

    plt.plot(inner_product, res, color=color, linewidth=1, zorder=1)
    return inner_product, res

In [None]:
def plot_kernel_finite(
    alpha: float,
    beta: float,
    u: list,
    n_seeds: int = 10,
    finetune: bool = False,
    colormap=cm.bwr,
    arch: int = 0,
) -> None:
    n_tree_combinations = [16, 64, 256, 1024, 4096]
    for j, n_tree in enumerate(tqdm(n_tree_combinations, leave=False)):
        res_all = []
        kernel_list, inner_product_list = get_kernel(alpha, beta, u, finetune, arch)
        for i in tqdm(range(n_seeds), leave=False):
            inner_product, res = plot_kernel(
                alpha,
                beta,
                n_tree,
                color=colormap(j / len(n_tree_combinations)),
                u=u,
                finetune=finetune,
                arch=arch,
            )
            res_all.append(res)
    plt.plot(
        inner_product_list[0:180],
        kernel_list[0:180],
        color="black",
        linewidth=3,
        linestyle="dotted",
        zorder=2,
    )
    if u == (1, 0):
        plt.title(
            "$x_i =(1, 0)$" + f", Tree architecture={'(A)' if arch==0 else '(B)'}"
        )
    else:
        plt.title(
            "$x_i =({1}/{\\sqrt{2}}, {1}/{\\sqrt{2}})$"
            + f", Architecture={'(A)' if arch==0 else '(B)'}"
        )
    plt.xlabel("Inner product of the inputs")

    ylim_min = -0.5
    ylim_max = 2.5

    plt.ylim(ylim_min, ylim_max)
    plt.grid(linestyle="dotted")

    plt.tight_layout()

In [None]:
def calc_tau(alpha: float, S: np.array, diag_i: np.array, diag_j: np.array) -> np.array:
    tau = 1 / 4 + 1 / (2 * math.pi) * np.arcsin(
        ((alpha**2) * S)
        / (np.sqrt(((alpha**2) * diag_i + 0.5) * ((alpha**2) * diag_j + 0.5)))
    )
    return tau


def calc_tau_dot(
    alpha: float, S: np.array, diag_i: np.array, diag_j: np.array
) -> np.array:
    tau_dot = (
        (alpha**2)
        / (math.pi)
        * 1
        / np.sqrt(
            (2 * (alpha**2) * diag_i + 1) * (2 * (alpha**2) * diag_j + 1)
            - (4 * (alpha**4) * (S**2))
        )
    )
    return tau_dot

In [None]:
def hardtree_viz(X: np.array, alpha: float, beta: float, finetune: bool, arch: int):

    assert arch in {0, 1}

    S_list = []
    tau_list = []
    tau_dot_list = []

    for feature_index in range(len(X[0])):
        S = np.outer(X[:, feature_index], X[:, feature_index].T) + beta**2
        S_all = np.matmul(X, X.T) + beta**2
        if finetune:
            S_list.append(S_all)
        else:
            S_list.append(S)

        _diag = [S[i, i] for i in range(len(S))]
        diag_i = np.array(_diag * len(_diag)).reshape(len(_diag), len(_diag))
        diag_j = diag_i.transpose()
        tau_list.append(calc_tau(alpha, S, diag_i, diag_j))
        tau_dot_list.append(calc_tau_dot(alpha, S, diag_i, diag_j))

    K = np.zeros((X.shape[0], X.shape[0]))
    if arch == 0:
        rulelist = [[0, 0], [0, 0], [0, 0], [0, 0]]
    elif arch == 1:
        rulelist = [[0, 1], [0, 1], [0, 1], [0, 1]]
    H = np.zeros_like(S_list[0])
    for rules in rulelist:

        # Internal nodes
        for i, s in enumerate(rules):
            ts = rules[0:i] + rules[i + 1:]
            _H_nodes = S_list[s] * tau_dot_list[s]
            for t in ts:
                _H_nodes *= tau_list[t]  # nodes
            K += _H_nodes
        _H_leaves = np.ones_like(K)

        # Leaves
        for tau in [tau_list[i] for i in rules]:
            _H_leaves *= tau
        K += _H_leaves

    return K

In [None]:
n_seeds = 10
alpha = 2.0
beta = 0.5

arch=0
plt.figure(figsize=(18, 10))
plt.subplot(2,4,1)
plot_kernel_finite(alpha=alpha, beta=beta, u=(1,0), n_seeds=n_seeds, finetune=False, colormap=cm.PRGn, arch=arch)
plt.ylabel("Kernel value (AAA)")

plt.subplot(2,4,2)
plot_kernel_finite(alpha=alpha, beta=beta, u=(1/np.sqrt(2), 1/np.sqrt(2)), n_seeds=n_seeds, finetune=False, colormap=cm.PRGn, arch=arch)

arch=1
plt.subplot(2,4,3)
plot_kernel_finite(alpha=alpha, beta=beta, u=(1,0), n_seeds=n_seeds, finetune=False, colormap=cm.PRGn, arch=arch)

plt.subplot(2,4,4)
plot_kernel_finite(alpha=alpha, beta=beta, u=(1/np.sqrt(2), 1/np.sqrt(2)), n_seeds=n_seeds, finetune=False, colormap=cm.PRGn, arch=arch)

arch=0
plt.subplot(2,4,5)
plot_kernel_finite(alpha=alpha, beta=beta, u=(1,0), n_seeds=n_seeds, finetune=True, colormap=cm.bwr, arch=arch)
plt.ylabel("Kernel value (AAI)")

plt.subplot(2,4,6)
plot_kernel_finite(alpha=alpha, beta=beta, u=(1/np.sqrt(2), 1/np.sqrt(2)), n_seeds=n_seeds, finetune=True, colormap=cm.bwr, arch=arch)

arch=1
plt.subplot(2,4,7)
plot_kernel_finite(alpha=alpha, beta=beta, u=(1,0), n_seeds=n_seeds, finetune=True, colormap=cm.bwr, arch=arch)

plt.subplot(2,4,8)
plot_kernel_finite(alpha=alpha, beta=beta, u=(1/np.sqrt(2), 1/np.sqrt(2)), n_seeds=n_seeds, finetune=True, colormap=cm.bwr, arch=arch)

# Legend
plt.subplot(2,4,6)
patterns = [16, 64, 256, 1024, 4096]
for i in range(5):
    plt.plot([], [], color=cm.PRGn(i/5), linewidth=1, label=f"$M={patterns[i]}$")
plt.plot([], [], color="black", linewidth=3, linestyle="dotted", label="$M={\\infty}$")
plt.legend(ncol=3, bbox_to_anchor=(0.5, -0.35), fontsize=12, title="AAA", loc="center", borderaxespad=0)

plt.subplot(2,4,7)
for i in range(5):
    plt.plot([], [], color=cm.bwr(i/5), linewidth=1, label=f"$M={patterns[i]}$")
plt.plot([], [], color="black", linewidth=3, linestyle="dotted", label="$M={\\infty}$")
plt.legend(ncol=3, bbox_to_anchor=(0.5, -0.35), fontsize=12, title="AAI", loc="center", borderaxespad=0)

plt.savefig("./figures/kernels_asymptotic.pdf", bbox_inches="tight", pad_inches=0.10)