In [1]:
import embedders
import torch

In [2]:
SIGS = [
    [(1, 2)],
    [(0, 2)],
    [(-1, 2)]
]

embeddings = []
labels = []

for i, sig in enumerate(SIGS):
    pm = embedders.manifolds.ProductManifold(sig)
    X, y = embedders.gaussian_mixture.gaussian_mixture(pm=pm, seed=i)
    embeddings.append(X)
    labels.append(y.unsqueeze(1))

X = torch.hstack(embeddings)
y = torch.hstack(labels)

print(X.shape)
print(y.shape)

torch.Size([1000, 8])
torch.Size([1000, 3])




In [9]:
# Get correlation over dim 1 of y

correlation = torch.corrcoef(y.T)
print(correlation)

tensor([[ 1.0000e+00, -4.7419e-02,  2.7763e-02],
        [-4.7419e-02,  1.0000e+00, -9.2908e-04],
        [ 2.7763e-02, -9.2908e-04,  1.0000e+00]])


In [73]:
# Try classification using each set of labels as a target
from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
pm = embedders.manifolds.ProductManifold(signature=[sig for sublist in SIGS for sig in sublist])
trees = []
for i in range(len(SIGS)):
    pdt = embedders.tree_new.ProductSpaceDT(pm=pm, max_depth=3)

    pdt.fit(X_train, y_train[:, i])

    print(f"Accuracy for {i}: {pdt.score(X_test, y_test[:, i]).float().mean():.3f}")
    trees.append(pdt)


Accuracy for 0: 0.925
Accuracy for 1: 0.675
Accuracy for 2: 0.890


In [75]:
feats = []
for tree in trees:
    feats.append([x.feature for x in tree.nodes if x.feature])

# Which ones fall in the right dim:
for feat_set, allowed in zip(feats, [[0, 1], [2, 3], [4, 5]]):
    total = len(feat_set)
    correct = sum([f in allowed for f in feat_set])
    print(f"{correct/total} correct")

1.0 correct
0.8333333333333334 correct
0.8571428571428571 correct


In [15]:
# Try classification using each set of labels as a target
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
# pm = embedders.manifolds.ProductManifold(signature=[sig for sublist in SIGS for sig in sublist])
trees = []
for i in range(len(SIGS)):
    # pdt = embedders.tree_new.ProductSpaceDT(pm=pm, max_depth=3)
    dt = DecisionTreeClassifier(max_depth=3)

    # pdt.fit(X_train, y_train[:, i])
    dt.fit(X_train, y_train[:, i])

    print(f"Accuracy for {i}: {dt.score(X_test, y_test[:, i]):.3f}")
    trees.append(dt)


Accuracy for 0: 0.995
Accuracy for 1: 0.680
Accuracy for 2: 0.940


In [16]:
feats = []
for tree in trees:
    feats.append([x.feature for x in tree.nodes if x.feature])

# Which ones fall in the right dim:
for feat_set, allowed in zip(feats, [[0, 1], [2, 3], [4, 5]]):
    total = len(feat_set)
    correct = sum([f in allowed for f in feat_set])
    print(f"{correct/total} correct")

AttributeError: 'DecisionTreeClassifier' object has no attribute 'nodes'

In [27]:
# Get features for each split in trained sklearn decision tree

# for tree in trees:
#     print(tree.tree_.feature)

feats = []
for tree in trees:
    feats.append([x for x in tree.tree_.feature if x != -2])

# Which ones fall in the right dim:
for feat_set, allowed in zip(feats, [[0, 1, 2], [3, 4, 5], [6, 7]]):
    total = len(feat_set)
    correct = sum([f in allowed for f in feat_set])
    print(f"{correct/total} correct")

1.0 correct
0.8333333333333334 correct
0.6666666666666666 correct
