# Minkowski dot
> Does using Minkowski dot-products for classification do better?

In [1]:
%load_ext autoreload
%autoreload 2

In [11]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from src.hyperdt.toy_data import generate_gaussian_mixture_hyperboloid
from src.hyperdt.tree import HyperbolicDecisionTreeClassifier

from sklearn.model_selection import train_test_split


In [45]:
# Regular tree versus minkowski version

X, y = generate_gaussian_mixture_hyperboloid(
    num_points=1000, num_classes=2, n_dim=2
)
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42
)
tree1 = HyperbolicDecisionTreeClassifier(max_depth=2)
tree1.fit(X_train, y_train)
print(tree1.score(X_test, y_test))


class HDT_Minkowski(HyperbolicDecisionTreeClassifier):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.metric = "minkowski"

    def _dot(self, X, dim, theta):
        if self.sparse_dot_product and X.ndim == 1:
            return (
                -np.cos(theta) * X[self.timelike_dim] + np.sin(theta) * X[dim]
            )
        elif self.sparse_dot_product:
            return (
                -np.cos(theta) * X[:, self.timelike_dim]
                + np.sin(theta) * X[:, dim]
            )


tree2 = HDT_Minkowski(max_depth=2)
tree2.fit(X_train, y_train)
print(tree2.score(X_test, y_test))


0.61
0.615


In [46]:
# Wow, OK! That worked better... let's try many times

results = pd.DataFrame(columns=["dims", "trial", "minkowski", "dot"])

for dims in range(2, 16):
    for _ in range(100):
        try:
            X, y = generate_gaussian_mixture_hyperboloid(
                num_points=1000, num_classes=2, n_dim=dims
            )
            X_train, X_test, y_train, y_test = train_test_split(
                X, y, test_size=0.2, random_state=42
            )
            tree1 = HyperbolicDecisionTreeClassifier(
                max_depth=2, timelike_dim=0
            )
            tree1.fit(X_train, y_train)
            tree2 = HDT_Minkowski(max_depth=2, timelike_dim=0)
            tree2.fit(X_train, y_train)
            results.loc[len(results)] = [
                dims,
                1,
                tree1.score(X_test, y_test),
                tree2.score(X_test, y_test),
            ]
        except Exception as e:
            print(str(e))
            continue

results


Points must lie on a hyperboloid
Points must lie on a hyperboloid
Points must lie on a hyperboloid
Points must lie on a hyperboloid
Points must lie on a hyperboloid
Cannot project a vector of norm 0. in the Minkowski space to the hyperboloid
Points must lie on a hyperboloid
Points must lie on a hyperboloid
Points must lie on a hyperboloid
Points must lie on a hyperboloid
Points must lie on a hyperboloid
Points must lie on a hyperboloid
Points must lie on a hyperboloid
Points must lie on a hyperboloid
Points must lie on a hyperboloid
Points must lie on a hyperboloid
Points must lie on a hyperboloid
Points must lie on a hyperboloid
Points must lie on a hyperboloid
Points must lie on a hyperboloid
Points must lie on a hyperboloid
Points must lie on a hyperboloid
Points must lie on a hyperboloid
Points must lie on a hyperboloid
Points must lie on a hyperboloid
Points must lie on a hyperboloid
Points must lie on a hyperboloid
Points must lie on a hyperboloid
Points must lie on a hyperboloid

Unnamed: 0,dims,trial,minkowski,dot
0,2.0,1.0,0.800000,0.800000
1,2.0,1.0,0.720000,0.750000
2,2.0,1.0,0.770000,0.775000
3,2.0,1.0,0.940000,0.940000
4,2.0,1.0,0.710000,0.865000
...,...,...,...,...
389,6.0,1.0,0.805000,0.810000
390,7.0,1.0,0.710000,0.715000
391,7.0,1.0,0.874372,0.874372
392,7.0,1.0,0.681818,0.681818


In [47]:
# So here we can see that we always do better with regular dot product. This
# makes sense since the regular Euclidean dot product has the interpretation of
# "to one side of the hyperplane" and "to the other side of the hyperplane",
# and the Minkowski dot product does not have this interpretation.

results.groupby(["dims"]).mean()


Unnamed: 0_level_0,trial,minkowski,dot
dims,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
2.0,1.0,0.8203,0.8203
3.0,1.0,0.848776,0.852602
4.0,1.0,0.829872,0.830594
5.0,1.0,0.829375,0.829238
6.0,1.0,0.801022,0.801696
7.0,1.0,0.745298,0.746548
