# Applying Trajectory Inference Pipeline

### Packages you need

In [None]:
import os, sys
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import plotly.express as px
plt.style.use('default')

from plotly import graph_objects as go
from sklearn.preprocessing import PolynomialFeatures
from sklearn.preprocessing import LabelEncoder
from sklearn.linear_model import LinearRegression

### Adjust your I/O configuration

In [None]:
root = os.path.abspath(os.path.join(os.getcwd(), "..")) # Adjust the path to your project root directory

if root not in sys.path:
    sys.path.insert(0, root)

data_io = os.path.join(root, "Trajectory_inference", "datasets")
data_io

### Introduction

Our dataset is derived from **single-cell immune profiling**, where **14 protein markers** were experimentally measured for individual immune cells. Using a **Variational Autoencoder (VAE)**, these high-dimensional measurements are compressed into a **3D latent state space** that captures the underlying structure of the immune system.

Guided by **biological domain knowledge**, we hypothesize that one latent dimension — or a derived indicator such as a **GFP**—can serve as a meaningful proxy for the **dynamical progression of immune cells**. Our objective is to perform **trajectory inference** within this latent space to reveal the **hidden dynamical patterns and developmental pathways** that govern immune cell differentiation and state transitions.

### Load the dataset

In [None]:
data = pd.read_csv(os.path.join(data_io, "latent_data_with_GFP.csv")) # data loading
data_sampled = data.sample(n=50000, random_state=42)

In [None]:
data.cell_type.value_counts()


### Estimate the potential function

In this step, we try to describe how cells "move" within the 3D latent space. Instead of observing motion directly, we estimate an **energy-like potential function** that explains where cells are more likely to appear and how they might flow between states.

Intuitively, you can think of this potential as a **landscape**: cells tend to "roll down" from high-potential regions (early states) toward low-potential regions (mature or stable states).  

By estimating this potential function, we can visualize and analyze the **direction and strength of cell state transitions**, helping us understand the hidden dynamics of the immune system in a simple, physical way.

To estimate the potential function $ P(x) $ from observed single-cell data, we can treat it as a **regression problem** that maps cell coordinates $ x \in \mathbb{R}^3 $ to a scalar GFP value $ y \in [0,1] $:

$$
y_i \;=\; P(x_i) \;+\; \varepsilon_i, \qquad i = 1, \dots, N,
$$

where  
- $x_i = (X_i, Y_i, Z_i)$ is the latent coordinate of the $i$-th cell,  
- $y_i$ is its observed or estimated GFP level,  
- $\varepsilon_i$ is a noise term capturing measurement or model uncertainty.

Our goal is to learn a smooth function $ \hat{P}(x) $ that minimizes the prediction error:

$$
\hat{P} \;=\; \arg\min_{f \in \mathcal{F}} 
\sum_{i=1}^{N} \big( y_i - f(x_i) \big)^2 ,
$$

where $ \mathcal{F} $ can be a family of regression models that are differentiable, such as Polynomial regression, kernel ridge regression, Gaussian process regression, or neural networks.

### Calculate the gradient-based vector field

Once we have estimated the smooth potential function $\hat{P}(x)$, we can derive a **vector field** to describe how cells move within the 3D latent space.

We interpret $\hat{P}(x)$ as an energy-like potential landscape. Cells tend to move “downhill” from regions of high potential (early or unstable states) to regions of low potential (mature or stable states). The corresponding velocity field is defined by the gradient:

$$
v(x) = \nabla \hat{P}(x)
$$

To stabilize and normalize the flow magnitude, we apply a smooth scaling factor:

$$
v(x) = \frac{\nabla \hat{P}(x)}{\lambda + \| \nabla \hat{P}(x) \| + \varepsilon}
$$

where  
- $\lambda > 0$ controls the decay rate (flow scale),  
- $\varepsilon > 0$ prevents division by zero,  
- and the direction of $v(x)$ always points toward decreasing potential.

This gradient-based vector field $v(x)$ encodes the **direction and speed** of cell state transitions.  
It allows us to visualize the inferred dynamics, identify attractors (low-potential basins), and trace developmental trajectories in the latent space.


#### Here is the pre-defined pipeline

In [None]:
class SklearnPolynomialScalarRegressor:
    """
    Sklearn-based scalar regressor g(z) with a torch-diffable forward.

    - Fit:    g(z) ≈ w^T * Poly(z) + b  using PolynomialFeatures + LinearRegression
    - Use:    g_np = reg(X)            # numpy, shape (N,)
              g_torch = reg.torch_forward(z)  # torch tensor, shape (N,)

    Designed to match an 'approximate_with_polynomial' style setup,
    and to be lightweight for use inside a single notebook.
    """
    def __init__(self, degree: int = 2, include_bias: bool = True):
        self.degree = int(degree)
        self.include_bias = bool(include_bias)

        self.poly = None          # type: PolynomialFeatures | None
        self.lin = None           # type: LinearRegression | None

        # Cached params for fast / torch-friendly evaluation
        self.powers_ = None       # (K, d)
        self.coef_ = None         # (K,)
        self.intercept_ = 0.0
        self.is_fitted = False

    def fit(self, X: np.ndarray, y: np.ndarray):
        X = np.asarray(X, dtype=np.float64)
        y = np.asarray(y, dtype=np.float64).reshape(-1)

        self.poly = PolynomialFeatures(
            degree=self.degree,
            include_bias=self.include_bias
        )
        X_poly = self.poly.fit_transform(X)

        self.lin = LinearRegression()
        self.lin.fit(X_poly, y)

        # Cache polynomial structure for manual / torch evaluation
        self.powers_ = self.poly.powers_.astype(np.int64)   # (K, d)
        self.coef_ = self.lin.coef_.astype(np.float64)      # (K,)
        self.intercept_ = float(self.lin.intercept_)
        self.is_fitted = True
        return self

    def predict(self, X: np.ndarray) -> np.ndarray:
        if not self.is_fitted:
            raise RuntimeError("Call fit() before predict().")
        X = np.asarray(X, dtype=np.float64)
        if X.ndim == 1:
            X = X[None, :]
        # Use transform + linear model in sklearn, stable and reliable
        X_poly = self.poly.transform(X)
        return self.lin.predict(X_poly).reshape(-1)

    def __call__(self, X: np.ndarray) -> np.ndarray:
        return self.predict(X)

    def torch_forward(self, z: torch.Tensor) -> torch.Tensor:
        """
        Differentiable forward: g(z) = sum_k coef[k] * prod_j z_j ** powers[k, j] + intercept

        z: (N, d) float tensor on any device
        return: (N,) float tensor
        """
        if not self.is_fitted:
            raise RuntimeError("Model not fitted.")
        if z.ndim == 1:
            z = z.unsqueeze(0)

        # Prepare parameters on same device/dtype
        powers = torch.as_tensor(self.powers_, dtype=z.dtype, device=z.device)   # (K, d)
        coef = torch.as_tensor(self.coef_, dtype=z.dtype, device=z.device)       # (K,)
        intercept = torch.as_tensor(self.intercept_, dtype=z.dtype, device=z.device)

        # z_expanded: (N, 1, d), powers: (1, K, d)
        # terms[n, k, j] = z[n, j] ** powers[k, j]
        z_expanded = z.unsqueeze(1)
        powers_expanded = powers.unsqueeze(0)
        terms = z_expanded ** powers_expanded      # (N, K, d)

        # product over variables j → (N, K)
        feats = terms.prod(dim=-1)

        # linear combination → (N,)
        g = (feats * coef.unsqueeze(0)).sum(dim=1) + intercept
        return g


class PotentialFieldFromScalar:
    """
    Turn any scalar regressor g(z) into a vector field via gradient-based construction.

    Default (GFP-inspired) mapping:

        base(z) = -∇g(z) / (lambda * g(z) + eps)

    If normalize=True:

        v(z) = base(z) / ( ||base(z)||^2 + eps )

    Otherwise:

        v(z) = base(z)

    Uses torch.autograd on the provided scalar_regressor.torch_forward.
    Suitable for use directly in a notebook.
    """
    def __init__(
        self,
        scalar_regressor: SklearnPolynomialScalarRegressor,
        decay_rate: float = np.log(2) / 3.0,
        eps: float = 1e-8,
        normalize: bool = True,
        device: str = "cpu",
    ):
        self.scalar = scalar_regressor
        self.lambda_decay = float(decay_rate)
        self.eps = float(eps)
        self.normalize = bool(normalize)
        self.device = torch.device(device)

    def fit(self, X: np.ndarray, V: np.ndarray | None = None):
        """
        Placeholder for API compatibility.
        Typically, fit the scalar regressor separately, then wrap it here.
        """
        return self

    def __call__(self, X: np.ndarray) -> np.ndarray:
        """
        Compute v(z) on a batch:
            X: (N, d) numpy array
            return: (N, d) numpy array
        """
        if not self.scalar.is_fitted:
            raise RuntimeError("Scalar regressor must be fitted before using PotentialFieldFromScalar.")

        X = np.asarray(X, dtype=np.float32)
        if X.ndim == 1:
            X = X[None, :]

        z = torch.tensor(X, dtype=torch.float32, device=self.device, requires_grad=True)

        # g(z): (N,)
        g = self.scalar.torch_forward(z)

        # ∇g(z): (N, d)
        grad = torch.autograd.grad(
            outputs=g,
            inputs=z,
            grad_outputs=torch.ones_like(g),
            create_graph=False,
            retain_graph=False,
            only_inputs=True,
        )[0]

        # base(z) = - ∇g / (lambda * g + eps)
        denom = self.lambda_decay * g.unsqueeze(1) + self.eps
        base = - grad / denom

        if self.normalize:
            # v(z) = base / (||base||^2 + eps)
            norm_sq = (base ** 2).sum(dim=1, keepdim=True) + self.eps
            v = base / norm_sq
        else:
            v = base

        # Convert to numpy - use np.array() to avoid torch-numpy ABI mismatch
        v_detached = v.detach().cpu()
        return np.array(v_detached)


#### Apply the pre-defined pipeline

In [None]:
approximate_with_poly = SklearnPolynomialScalarRegressor(degree=7).fit(data[["X", "Y", "Z"]].to_numpy(), data["GFP"].to_numpy())
vec_fn = PotentialFieldFromScalar(approximate_with_poly)

### Plotting the results

In [None]:
N_arrows = min(80000, len(data))

points_with_gfp = data.sample(n=N_arrows, random_state=0)[["X", "Y", "Z", "GFP"]].to_numpy()

# label encoding for cell types
label_encoder = LabelEncoder()
data['cell_type_encoded'] = label_encoder.fit_transform(data['cell_type'])
points_with_cluster = data.sample(n=N_arrows, random_state=0)[["X", "Y", "Z", "cell_type_encoded"]].to_numpy()
points = data.sample(n=N_arrows, random_state=0)[["X", "Y", "Z"]].to_numpy()

V = vec_fn(points)  # Expected shape: (N_arrows, 3)

v_norm = np.linalg.norm(V, axis=1)
# Clip extreme vectors to avoid a “spiky” appearance
clip_thr = np.percentile(v_norm, 95)
mask = v_norm > 0
V_clipped = np.where(
    (v_norm[:, None] > 0) & (v_norm[:, None] < clip_thr),
    V,
    V * (clip_thr / (v_norm[:, None] + 1e-9))
)
v_norm = np.linalg.norm(V_clipped, axis=1) + 1e-9  # avoid div by zero

# Normalize directions and scale arrow length
V_dir = V_clipped / v_norm[:, None]
length_scale = 0.2
V_plot = V_dir * length_scale

# Color arrows by speed magnitude
normed = (v_norm - v_norm.min()) / (v_norm.max() - v_norm.min() + 1e-12)
colors = plt.cm.viridis(normed)

In [None]:
# with cell type clustering
# Using V_plot and points from previous code cell
data_sampled = pd.DataFrame({
    'X': points[:, 0],
    'Y': points[:, 1],
    'Z': points[:, 2],
    'Vx': V_plot[:, 0],
    'Vy': V_plot[:, 1],
    'Vz': V_plot[:, 2],
    'speed': np.linalg.norm(V_clipped, axis=1),
    'cell_type': points_with_cluster[:, 3]
})

# Inverse transform to original cell type labels
data_sampled['cell_type'] = label_encoder.inverse_transform(data_sampled['cell_type'].astype(int).to_numpy())

# no FM B Cells go to MZ
# plot quiver plot using plotly
fig = px.scatter_3d(
    data_sampled,
    x='X',
    y='Y',
    z='Z',
    color='cell_type',
    color_discrete_map= {
        'T1': 'red',
        'T2': 'orange',
        'FM': 'green',
        'MZP': 'blue'
    },
    title='3D Velocity Field Derived from Polynomial Potential',
    range_x=[-0.5, 0.5],
    range_y=[-0.5, 0.5],
    range_z=[-0.5, 0.5],
    opacity=0.1,
    size='speed',
    size_max=10
)
fig.update_traces(marker=dict(size=2))

max_arrows = 1000  
if len(data_sampled) > max_arrows:
    data_arrow = data_sampled.sample(max_arrows, random_state=0)
else:
    data_arrow = data_sampled


fig.add_trace(
    go.Cone(
        x=data_arrow['X'],
        y=data_arrow['Y'],
        z=data_arrow['Z'],
        u=data_arrow['Vx'],
        v=data_arrow['Vy'],
        w=data_arrow['Vz'],
        anchor="tail",               
        sizemode="absolute",
        sizeref=0.5,             
        colorscale="blues",
        cmin=data_sampled['speed'].min(),
        cmax=data_sampled['speed'].max(),
        showscale=False               
    )
)

fig.update_layout(
    scene=dict(
        aspectmode="data"
    )
)

fig.show()

#save fig
fig.write_html(os.path.join(root, "Trajectory_inference", "results", "gfp_potential_trajectory.html"))