In [1]:
%load_ext autoreload
%autoreload 2

In [43]:
import torch
import numpy as np
import embedders
import anndata

In [236]:
# Load the scRNA embeddings

data = torch.tensor(np.load("/teamspace/studios/this_studio/embedders/data/blood_cell_scrna/embeddings_s2_e2_h2_3.npy"))
data.shape
idx = np.random.choice(data.shape[0], 10_000, replace=False)
data = data[idx]  # Take it easy

# Also, let's add that dummy dimension for E2
data = torch.hstack([data[:, :3], torch.ones(data.shape[0], 1), data[:, 3:]])
data[0]

tensor([-0.9942,  0.0777,  0.0749,  1.0000, -0.0303, -0.0044,  1.0001,  0.0133,
         0.0052,  1.0000,  0.0060,  0.0045,  1.0000, -0.0023, -0.0032])

In [237]:
classes = torch.tensor(
    [
        int(x)
        for x in anndata.read_h5ad("/teamspace/studios/this_studio/embedders/data/blood_cell_scrna/adata.h5ad").obs[
            "cell_type"
        ]
    ]
)
classes = classes[idx]  # Take it easy
classes.shape

torch.Size([10000])

In [238]:
# Initialize the appropriate product manifold, which we'll use for indexing

pm = embedders.manifolds.ProductManifold(signature=[(1, 2), (0, 2), (-1, 2), (-1, 2), (-1, 2)])

In [239]:
# What if we expressed everything in terms of angles in a 2-d projection to (timelike_dim, some_spatial_dim)?

timelike_dims = [pm.man2dim[i][0] for i in range(pm.n_manifolds)]
spacelike_dims = [pm.man2dim[i][1:] for i in range(pm.n_manifolds)]

for i in range(pm.n_manifolds):
    print(f"M{i}\tspecial {timelike_dims[i]}\t spacelike {spacelike_dims[i]}")

M0	special 0	 spacelike [1, 2]
M1	special 3	 spacelike [4]
M2	special 5	 spacelike [6, 7]
M3	special 8	 spacelike [9, 10]
M4	special 11	 spacelike [12, 13]


In [240]:
# First, we can compute the angles of all 2-d projections

angle_vals = torch.zeros(data.shape[0], pm.dim)

for i, M in enumerate(pm.P):
    dims = pm.man2dim[i]
    dims_target = pm.man2intrinsic[i]
    if M.type in ["H", "S"]:
        angle_vals[:, dims_target] = torch.atan2(data[:, dims[0]].view(-1, 1), data[:, dims[1:]])
    elif M.type == "E":
        angle_vals[:, dims_target] = torch.atan2(torch.tensor(1), data[:, dims])

angle_vals.shape  # Note that we have gone from (1000, 14) to (1000, 10), the number of intrinsic dimensions

torch.Size([10000, 10])

In [242]:
def circular_greater(angles, threshold):
    """
    Check if angles are in the half-circle (threshold, threshold + pi)
    """
    return ((angles - threshold + torch.pi) % (2 * torch.pi)) - torch.pi > 0


def calculate_info_gain(values, labels):
    batch_size, n_dim = values.shape
    n_classes = labels.max().item() + 1

    # Calculate total Gini impurity
    class_counts = torch.bincount(labels, minlength=n_classes).float()
    total_gini = 1 - ((class_counts / batch_size) ** 2).sum()

    # Initialize arrays for left and right counts
    left_counts = torch.zeros((batch_size, n_dim, n_classes), device=values.device)
    right_counts = torch.zeros((batch_size, n_dim, n_classes), device=values.device)

    # Calculate left and right counts for each potential split
    for i in range(batch_size):
        mask = circular_greater(values, values[i].unsqueeze(0))
        for j in range(n_dim):
            left_counts[i, j] = torch.nn.functional.one_hot(labels[~mask[:, j]], n_classes).sum(dim=0)
            right_counts[i, j] = torch.nn.functional.one_hot(labels[mask[:, j]], n_classes).sum(dim=0)

    # Calculate Gini impurities for left and right partitions
    left_gini = 1 - ((left_counts / left_counts.sum(dim=-1, keepdim=True).clamp(min=1)) ** 2).sum(dim=-1)
    right_gini = 1 - ((right_counts / right_counts.sum(dim=-1, keepdim=True).clamp(min=1)) ** 2).sum(dim=-1)

    # Calculate weighted Gini impurity
    weighted_gini = (left_counts.sum(dim=-1) * left_gini + right_counts.sum(dim=-1) * right_gini) / batch_size

    # Calculate information gain
    info_gain = total_gini - weighted_gini

    return info_gain


ig = calculate_info_gain(angle_vals, classes)

# What's the index?
best_idx = torch.argmax(ig)

KeyboardInterrupt: 

In [269]:
from hyperdt.torch.product_space_DT import ProductSpaceDT
from hyperdt.torch.tree import DecisionNode
from hyperdt.torch.hyperbolic_trig import _hyperbolic_midpoint


class TorchProductSpaceDT(ProductSpaceDT):
    def __init__(self, signature):
        sig_r = [(x[1], x[0]) for x in signature]
        super().__init__(signature=sig_r)
        self.pm = embedders.manifolds.ProductManifold(signature=signature)

    def _get_angle_vals(self, X):
        angle_vals = torch.zeros(data.shape[0], pm.dim)

        for i, M in enumerate(pm.P):
            dims = pm.man2dim[i]
            dims_target = pm.man2intrinsic[i]
            if M.type in ["H", "S"]:
                angle_vals[:, dims_target] = torch.atan2(data[:, dims[0]].view(-1, 1), data[:, dims[1:]])
            elif M.type == "E":
                angle_vals[:, dims_target] = torch.atan2(torch.tensor(1), data[:, dims])

        return angle_vals

    def _fit(self, X, y):
        """Fit a decision tree to the data. Modified from HyperbolicDecisionTreeClassifier
        to remove multiple timelike dimensions in product space."""
        # Find all dimensions in product space (including timelike dimensions)
        self.all_dims = list(range(sum([space[0] + 1 for space in self.signature])))

        # Find indices of timelike dimensions in product space
        self.timelike_dims = [0]
        for i in range(len(self.signature) - 1):
            self.timelike_dims.append(sum([space[0] + 1 for space in self.signature[: i + 1]]))

        # Remove timelike dimensions from list of dimensions
        # self.dims_ex_time = list(np.delete(np.array(self.all_dims), self.timelike_dims))
        self.dims_ex_time = [dim for dim in self.all_dims if dim not in self.timelike_dims]

        # Get array of classes
        self.classes_ = torch.unique(y)

        # Call recursive fitting function
        timelike_dims = [self.pm.man2dim[i][0] for i in range(pm.n_manifolds)]
        spacelike_dims = [self.pm.man2dim[i][1:] for i in range(pm.n_manifolds)]

        # First, we can compute the angles of all 2-d projections
        angle_vals = self._get_angle_vals(X)
        self.tree = self._fit_node(X=angle_vals, y=y, depth=0)

    def _fit_node(self, X, y, depth):
        print(f"Depth {depth} with {len(X)} samples")
        # Base case
        if depth == self.max_depth or len(X) < self.min_samples_split or len(torch.unique(y)) == 1:
            value, probs = self._leaf_values(y)
            return hyperdt.torch.tree.DecisionNode(value=value, probs=probs)

        # Recursively find the best split:
        ig = calculate_info_gain(X, y)
        best_idx = torch.argmax(ig)
        best_row, best_dim = best_idx // X.shape[1], best_idx % X.shape[1]
        best_ig = ig[best_row, best_dim]

        # Since we're evaluating greater than, we need to also find the next-largest value and take the midpoint
        next_largest = torch.max(X[~circular_greater(X[:, best_dim], X[best_row, best_dim]), best_dim])

        # Midpoint computation will depend on manifold; TODO: actually do this
        # best_theta = (X[best_row, best_dim] + next_largest) / 2
        best_manifold = self.pm.P[self.pm.intrinsic2man[best_dim]]
        if best_manifold.type == "H":
            best_theta = _hyperbolic_midpoint(X[best_row, best_dim], next_largest)
        elif best_manifold.type == "S":
            best_theta = (X[best_row, best_dim] + next_largest) / 2
        else:
            

        # Fallback case:
        if best_ig <= 0:
            print(f"Fallback triggered at depth {depth}")
            value, probs = self._leaf_values(y)
            return DecisionNode(value=value, probs=probs)

        # Populate:
        node = DecisionNode(feature=best_dim, theta=best_theta)
        node.score = best_ig
        left, right = circular_greater(X[:, best_dim], best_theta), ~circular_greater(X[:, best_dim], best_theta)
        node.left = self._fit_node(X=X[left], y=y[left], depth=depth + 1)
        node.right = self._fit_node(X=X[right], y=y[right], depth=depth + 1)
        return node

    def predict(self, X):
        angle_vals = self._get_angle_vals(X)
        return torch.Tensor([self._traverse(x).value for x in X])

    def _left(self, x, node):
        """Boolean: Go left?"""
        return circular_greater(x[node.feature], node.theta)

In [273]:
# Let's test it out

tpsdt = TorchProductSpaceDT(signature=[(1, 2), (0, 2), (-1, 2), (-1, 2), (-1, 2)])
tpsdt.fit(data, classes)

Depth 0 with 10000 samples
Depth 1 with 2246 samples
Depth 2 with 1130 samples
Depth 3 with 420 samples
tensor(3) tensor([0.0571, 0.1190, 0.0190, 0.2357, 0.1095, 0.0500, 0.0952, 0.0476, 0.0714,
        0.1952])
Depth 3 with 710 samples
tensor(4) tensor([0.0577, 0.0549, 0.0507, 0.1676, 0.2944, 0.0366, 0.0507, 0.0352, 0.0493,
        0.2028])
Depth 2 with 1116 samples
Depth 3 with 110 samples
tensor(2) tensor([0.0273, 0.0091, 0.6364, 0.0909, 0.1273, 0.0091, 0.0000, 0.0273, 0.0000,
        0.0727])
Depth 3 with 1006 samples
tensor(4) tensor([0.0358, 0.0139, 0.0765, 0.1889, 0.5318, 0.0189, 0.0149, 0.0169, 0.0119,
        0.0905])
Depth 1 with 7754 samples
Depth 2 with 2953 samples
Depth 3 with 1554 samples
tensor(9) tensor([0.1094, 0.1023, 0.0225, 0.1351, 0.0611, 0.1004, 0.0920, 0.1042, 0.1075,
        0.1654])
Depth 3 with 1399 samples
tensor(3) tensor([0.1165, 0.1551, 0.0122, 0.1565, 0.0279, 0.1072, 0.1565, 0.0443, 0.1408,
        0.0829])
Depth 2 with 4801 samples
Depth 3 with 1599 samp

In [274]:
tpsdt.score(data, classes).sum() / data.shape[0]

tensor(0.2435)