In [None]:
def set_latex():
    for i in range(2):
        import matplotlib
        import matplotlib.pyplot as plt

        plt.rc('text', usetex=True)
        plt.rc('font', family='serif')

        plt.style.use("default")
        plt.rcParams["font.size"] = 15

        plt.rcParams['font.family'] = 'Times New Roman'
        plt.rcParams['mathtext.fontset'] = 'stix'

        try:
            del matplotlib.font_manager.weight_dict['roman']
            matplotlib.font_manager._rebuild()
        except:
            pass

In [None]:
set_latex()

In [None]:
import math
import warnings

import matplotlib
import matplotlib.cm as cm
import matplotlib.colors as colors
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
from scipy import special
from sklearn.metrics import mean_squared_error
from tqdm.notebook import tqdm

warnings.filterwarnings("ignore")

In [None]:
def rotation_o(u, t, deg=False):
    if deg == True:
        t = np.deg2rad(t)

    R = np.array([[np.cos(t), -np.sin(t)], [np.sin(t), np.cos(t)]])
    return np.dot(R, u)

In [None]:
def calc_tau(alpha: float, S: np.array, diag_i: np.array, diag_j: np.array) -> np.array:
    tau = 1 / 4 + 1 / (2 * math.pi) * np.arcsin(
        ((alpha**2) * S)
        / (np.sqrt(((alpha**2) * diag_i + 0.5) * ((alpha**2) * diag_j + 0.5)))
    )
    return tau


def calc_tau_dot(
    alpha: float, S: np.array, diag_i: np.array, diag_j: np.array
) -> np.array:
    tau_dot = (
        (alpha**2)
        / (math.pi)
        * 1
        / np.sqrt(
            (2 * (alpha**2) * diag_i + 1) * (2 * (alpha**2) * diag_j + 1)
            - (4 * (alpha**4) * (S**2))
        )
    )
    return tau_dot

In [None]:
def hardtree_viz(X: np.array, alpha: float, beta: float, finetune: bool, arch: int):
    assert arch in {0, 1}
    S_list = []
    tau_list = []
    tau_dot_list = []

    for feature_index in range(len(X[0])):
        S = np.outer(X[:, feature_index], X[:, feature_index].T) + beta**2
        S_all = np.matmul(X, X.T) + beta**2
        if finetune:
            S_list.append(S_all)
        else:
            S_list.append(S)

        _diag = [S[i, i] for i in range(len(S))]
        diag_i = np.array(_diag * len(_diag)).reshape(len(_diag), len(_diag))
        diag_j = diag_i.transpose()
        tau_list.append(calc_tau(alpha, S, diag_i, diag_j))
        tau_dot_list.append(calc_tau_dot(alpha, S, diag_i, diag_j))

    K = np.zeros((X.shape[0], X.shape[0]))
    if arch == 0:
        rulelist = [[0, 0], [0, 0], [0, 0], [0, 0]]
    elif arch == 1:
        rulelist = [[0, 1], [0, 1], [0, 1], [0, 1]]
    H = np.zeros_like(S_list[0])
    for rules in rulelist:
        # Internal nodes
        for i, s in enumerate(rules):
            ts = rules[0:i] + rules[i + 1:]
            _H_nodes = S_list[s] * tau_dot_list[s]
            for t in ts:
                _H_nodes *= tau_list[t]
            K += _H_nodes
        _H_leaves = np.ones_like(K)

        # Leaves
        for tau in [tau_list[i] for i in rules]:
            _H_leaves *= tau
        K += _H_leaves

    return K

In [None]:
def softtree_viz(X: np.array, alpha: float, beta: float, max_depth: int):
    K = np.zeros((max_depth, X.shape[0], X.shape[0]))
    S = np.matmul(X, X.T) + beta**2
    _diag = [S[i, i] for i in range(len(S))]
    diag_i = np.array(_diag * len(_diag)).reshape(len(_diag), len(_diag))
    diag_j = diag_i.transpose()

    tau = calc_tau(alpha, S, diag_i, diag_j)
    tau_dot = calc_tau_dot(alpha, S, diag_i, diag_j)

    for i, depth in enumerate((range(1, max_depth + 1, 1))):
        H = (2 * S * (2**(depth - 1)) * depth * tau_dot * tau**(depth - 1)) + (
            (2**depth) * (tau**depth))
        K[depth - 1] = H

    return K[max_depth - 1]

In [None]:
def get_kernel_hard(alpha, beta, u, finetune, arch):
    kernel = []
    taus = []
    tau_dots = []
    inner_product = []
    for i in range(360):
        Ru = rotation_o(u, i * np.pi / 180)
        H = hardtree_viz(np.vstack([u, Ru]), alpha=alpha, beta=beta, finetune=finetune, arch=arch)
        kernel.append(H[1, 0])
        inner_product.append(np.dot(u, Ru))
    return kernel, inner_product

In [None]:
def get_kernel_soft(alpha, beta, u, depth=2):
    kernel = []
    taus = []
    tau_dots = []
    inner_product = []
    for i in range(360):
        Ru = rotation_o(u, i * np.pi / 180)
        H = softtree_viz(np.vstack([u, Ru]), alpha=alpha, beta=beta, max_depth=depth)
        kernel.append(H[1, 0])
        inner_product.append(np.dot(u, Ru))
    return kernel

In [None]:
def save_pdf(alpha, beta, ylim_min, ylim_max):
    plt.figure(figsize=(15, 7))
    plt.subplot(2, 2, 1)
    for i in tqdm(range(15), leave=False):
        Ru = rotation_o((1, 0), 6 * i * np.pi / 180)
        kernel_list, inner_product_list = get_kernel_hard(
            alpha=alpha, beta=beta, u=Ru, finetune=False, arch=0
        )
        plt.plot(
            inner_product_list[0:180],
            kernel_list[0:180],
            linewidth=1,
            color=cm.jet(i / 15),
        )

    kernel_list = get_kernel_soft(alpha=alpha, beta=beta, u=Ru, depth=2)
    plt.plot(
        inner_product_list[0:180],
        kernel_list[0:180],
        linewidth=3,
        color="black",
        linestyle="dotted",
        label="Oblique",
    )
    plt.grid(linestyle="dotted")
    plt.title("AAA, Tree architecture=(A)")
    plt.ylim(ylim_min, ylim_max)
    plt.xlim(-1.0, 1.0)
    plt.ylabel("Kernel value")
    plt.legend(loc="upper left")

    plt.subplot(2, 2, 2)
    for i in tqdm(range(15), leave=False):
        Ru = rotation_o((1, 0), 6 * i * np.pi / 180)
        kernel_list, inner_product_list = get_kernel_hard(
            alpha=alpha, beta=beta, u=Ru, finetune=False, arch=1
        )
        plt.plot(
            inner_product_list[0:180],
            kernel_list[0:180],
            linewidth=1,
            color=cm.jet(i / 15),
        )
    kernel_list = get_kernel_soft(alpha=alpha, beta=beta, u=Ru, depth=2)
    plt.plot(
        inner_product_list[0:180],
        kernel_list[0:180],
        linewidth=3,
        color="black",
        linestyle="dotted",
        label="Oblique",
    )
    plt.grid(linestyle="dotted")
    plt.title("AAA, Tree architecture=(B)")
    plt.ylim(ylim_min, ylim_max)
    plt.xlim(-1.0, 1.0)
    plt.legend(loc="upper left")

    plt.subplot(2, 2, 3)
    for i in tqdm(range(15), leave=False):
        Ru = rotation_o((1, 0), 6 * i * np.pi / 180)
        kernel_list, inner_product_list = get_kernel_hard(
            alpha=alpha, beta=beta, u=Ru, finetune=True, arch=0
        )
        plt.plot(
            inner_product_list[0:180],
            kernel_list[0:180],
            linewidth=1,
            color=cm.jet(i / 15),
        )
    kernel_list = get_kernel_soft(alpha=alpha, beta=beta, u=Ru, depth=2)
    plt.plot(
        inner_product_list[0:180],
        kernel_list[0:180],
        linewidth=3,
        color="black",
        linestyle="dotted",
        label="Oblique",
    )
    plt.grid(linestyle="dotted")
    plt.title("AAI, Tree architecture=(A)")
    plt.ylim(ylim_min, ylim_max)
    plt.xlim(-1.0, 1.0)
    plt.ylabel("Kernel value")
    plt.xlabel("Inner product of the inputs")
    plt.legend(loc="upper left")

    plt.subplot(2, 2, 4)
    for i in tqdm(range(15), leave=False):
        Ru = rotation_o((1, 0), 6 * i * np.pi / 180)
        kernel_list, inner_product_list = get_kernel_hard(
            alpha=alpha, beta=beta, u=Ru, finetune=True, arch=1
        )
        plt.plot(
            inner_product_list[0:180],
            kernel_list[0:180],
            linewidth=1,
            color=cm.jet(i / 15),
        )
    kernel_list = get_kernel_soft(alpha=alpha, beta=beta, u=Ru, depth=2)
    plt.plot(
        inner_product_list[0:180],
        kernel_list[0:180],
        linewidth=3,
        color="black",
        linestyle="dotted",
        label="Oblique",
    )
    plt.grid(linestyle="dotted")
    plt.title("AAI, Tree architecture=(B)")
    plt.ylim(ylim_min, ylim_max)
    plt.xlim(-1.0, 1.0)
    plt.xlabel("Inner product of the inputs")
    plt.legend(loc="upper left")

    cax = plt.axes([0.145, -0.02, 0.75, 0.02])
    norm = matplotlib.colors.Normalize(vmin=0, vmax=90)
    plt.colorbar(
        matplotlib.cm.ScalarMappable(norm=norm, cmap=cm.jet),
        cax=cax,
        orientation="horizontal",
        label="Rotation angle (degree)",
        ticks=[15, 30, 45, 60, 75],
    )

    plt.tight_layout()
    plt.suptitle(f"$\\alpha$={alpha}, $\\beta$={beta}", y=1.0)
    plt.savefig(
        f"./figures/kernels_{alpha}_{beta}.pdf", bbox_inches="tight", pad_inches=0.10
    )

In [None]:
save_pdf(alpha=2.0, beta=0.5, ylim_max=2.5, ylim_min=-0.5)

In [None]:
save_pdf(alpha=1.0, beta=0.5, ylim_max=2.5, ylim_min=-0.5)

In [None]:
save_pdf(alpha=4.0, beta=0.5, ylim_max=2.5, ylim_min=-0.5)

In [None]:
save_pdf(alpha=2.0, beta=0.1, ylim_max=2.5, ylim_min=-0.5)

In [None]:
save_pdf(alpha=2.0, beta=1.0, ylim_max=2.5, ylim_min=-0.5)