In [1]:
#|default_exp geodesics

%load_ext autoreload
%autoreload 2

# Geodesics
> A JAX Implementation

This is a literate-programming adaptation of Edward de Brouwer's `fm_geodesics` library, programmed in jax for speed. 

# Implementation

## Jax Training Utilities

A nice feature of jax flax: it infers the input dimensions on initialization. They needn't be hard coded.

In [2]:
from flax import linen as nn 

class MLP(nn.Module):
  """
  A general MLP in jax
  """
  num_hidden_layers: int
  hidden_dim: int
  output_dim: int
  
  def setup(self):
    if self.num_hidden_layers == 0:
       self.layers = [nn.Dense(self.output_dim)]
    else:
      input_layer = [nn.Dense(self.hidden_dim)]
      hidden_layers = [nn.Dense(self.hidden_dim) for _ in range(self.num_hidden_layers -1)]
      output_layer = [nn.Dense(self.output_dim)]
      self.layers = input_layer + hidden_layers + output_layer

  def __call__(self,x):
      for i, layer in enumerate(self.layers):
        x = layer(x)
        if i < len(self.layers) - 1: # no activation for the last layer
          x = nn.relu(x)
      return x

Logging utils

In [3]:
import logging
import wandb

from pytorch_lightning.utilities import rank_zero_only

def compute_and_log_metrics(state, metrics_history, step_num, prefix = "train_", logger = True, commit = True):
    """
    Compute metrics and log them with wandb

    commit: wether this will be the last logged value for this step or not.
    """
    metric_dict = state.metrics.compute().items()
    for i_, (metric,value) in enumerate(metric_dict): # compute metrics
        metrics_history[prefix + metric].append(value) # record metrics

        # Log Metrics to Weights & Biases
        if logger:
            if commit:
                if i_ == len(metric_dict)-1:  
                    wandb.log({prefix + metric:value.item()}, step = step_num, commit = True)
                    continue
            wandb.log({prefix + metric:value.item()}, step = step_num, commit = False)
    
    return metrics_history


def get_pylogger(name=__name__) -> logging.Logger:
    """Initializes multi-GPU-friendly python command line logger."""

    logger = logging.getLogger(name)

    # this ensures all logging levels get marked with the rank zero decorator
    # otherwise logs would get multiplied for each GPU process in multi-GPU setup
    logging_levels = ("debug", "info", "warning", "error", "exception", "fatal", "critical")
    for level in logging_levels:
        setattr(logger, level, rank_zero_only(getattr(logger, level)))

    return logger

In [4]:
import numpy as np
import jax.numpy as jnp

def mse_geodesic(x0,x1, t, ground_truth, state = None, pred = None):
    """
    ground_truth: function that returns the ground truth geodesic 

    Either state of pred have to be provided
    state: state for the model
    pred : prediction of the geodesic for the model
    """

    if pred is None:
        pred = state.apply_fn({'params': state.params}, x0, x1, t)
    oracle = jnp.stack([ground_truth(x0[i],x1[i],t)[1] for i in range(len(x0))],1)

    preds = []
    oracles = []
    for i in range(len(x0)):
        oracle_interp = jnp.interp(jnp.linspace(x0[i,0],x1[i,0]), oracle[:,i,0], oracle[:,i,1] )
        out_interp = jnp.interp(jnp.linspace(x0[i,0],x1[i,0]), pred[:,i,0], pred[:,i,1] ) 
        preds.append(out_interp)
        oracles.append(oracle_interp)
    preds = jnp.stack(preds,0)
    oracles = jnp.stack(oracles,0)
    return jnp.power(preds-oracles,2).mean()

A Pytorch Lightning like trainer

In [5]:

import tqdm
import jax.numpy as jnp
import jax
import wandb

from jax import grad, jit, vmap

from functools import partial

class JaxTrainer:
    """
    Functions to implement:
    - train_step

    Attributes to give:
    - metrics_names : the list of names of the different metrics that are recorded.
    - max_epochs : the number of epochs
    """
    def __init__(self, datamodule):
        self.train_dl = datamodule.train_dataloader()
        self.val_dl = datamodule.val_dataloader()
        self.test_dl = datamodule.test_dataloader()

    def fit(self):
        step = 0
        self.metrics_history = { **{"train_" + m : [] for m in self.metrics_names}, **{"val_" + m : [] for m in self.metrics_names}, **{"test_" + m : [] for m in self.metrics_names}} 
        
        for epoch in range(self.max_epochs):
            print(f"Epoch {epoch}")
            self.epoch = epoch

            for train_step,batch in tqdm.tqdm(enumerate(self.train_dl)):
                self.train_step(step, batch)
                self.train_step_end(step)
                step+=1
            self.train_epoch_end(step)

            self.val_state = self.state
            for val_step, val_batch in tqdm.tqdm(enumerate(self.val_dl)):
                self.val_step(step, val_batch)
                self.val_step_end(step)
            self.val_epoch_end(step)
            
            if self.logger:
                wandb.log({"epoch": epoch}, step = step)
            step = step + 1

            for callbacks in self.callbacks:
                callbacks.on_epoch_end(self)
            
        self.test_state = self.state
        for test_step, test_batch in tqdm.tqdm(enumerate(self.test_dl)):
            self.test_step(step, test_batch)
            self.test_step_end(step)
        self.test_epoch_end(step)

        for callbacks in self.callbacks:
            callbacks.on_training_end(self)
        
        if self.logger:
            wandb.log({"epoch": epoch + 1}, step = step)
    
    def train_step_end(self,step):

        if ((step % self.log_every_n_steps) == 0):
            self.metrics_history = compute_and_log_metrics(self.state, metrics_history= self.metrics_history, step_num = step, prefix = "train_", logger = self.logger, commit = True)

    def train_epoch_end(self,step):
        self.metrics_history = compute_and_log_metrics(self.state, metrics_history= self.metrics_history, step_num = step, prefix = "train_", logger = self.logger, commit = False)
        self.state = self.state.replace(metrics=self.state.metrics.empty()) # reset train_metrics for next training epoch

    def val_step_end(self,step):
        return
    
    def val_epoch_end(self,step):
        metrics_history = compute_and_log_metrics(self.val_state, metrics_history= self.metrics_history, step_num = step , prefix = "val_", logger = self.logger, commit = False)
        self.val_state = self.val_state.replace(metrics=self.val_state.metrics.empty()) # reset metrics

    def test_step_end(self,step):
        return
    
    def test_epoch_end(self,step):
        self.metrics_history = compute_and_log_metrics(self.test_state, metrics_history= self.metrics_history, step_num = step , prefix = "test_", logger = self.logger, commit = False)





A trainable curve, which takes a linear interpolation between the input points $x_0$ and $x_1$, and uses the MLP `mod_x0_x1` to send a vector containing a stacked representation of `[x0, x1, t]` to a point in latent space, for each time point. These points computed from teh stacked representation are then scaled by a window function $s(q - (2t - 1)^{2e})$ and added to the linear interpolation. The window guarantees that the endpoints aren't perturbed, by scaling the addition in this part of the time window to 0.

In [6]:
import jax
import jax.numpy as jnp
from jax import grad, jit, vmap

from clu import metrics 
from flax import struct  
from flax.training import train_state
import optax

from functools import partial

from flax import linen as nn 
import numpy as np

  

class CondCurve(nn.Module):
  """
  Conditional Grounded Curve
  """
  input_dim: int
  hidden_dim: int
  scale_factor: int
  symmetric: bool
  num_layers: int
  envelope_exponent: int = 1
  exponent_as_parameter: bool = False

  def setup(self):
    self.mod_x0_x1 = MLP( hidden_dim = self.hidden_dim, output_dim = self.input_dim, num_hidden_layers = self.num_layers)
      
    self.x0_emb = MLP( hidden_dim=self.hidden_dim, output_dim = self.hidden_dim, num_hidden_layers = self.num_layers)
    self.x1_emb = MLP( hidden_dim=self.hidden_dim, output_dim = self.hidden_dim, num_hidden_layers = self.num_layers)

    if self.exponent_as_parameter:
      exponent_init = nn.initializers.constant(self.envelope_exponent)
      self.var_envelope_exponent = self.param('var_envelope_exponent',
                          exponent_init, # Initialization function
                          (1,))  # shape info.
  
  def x0_emb_(self,x):
    ### Just for debugging
    return self.x0_emb(x)
  
  def x1_emb_(self,x):
    ### Just for debugging
    return self.x1_emb(x)
  
  def mod_x0_x1_(self,x):
    ### Just for debugging
    return self.mod_x0_x1(x)
  
  def __call__(self,x0, x1, t):

    x0_ = jnp.tile(x0, (t.shape[0],1))
    x1_ = jnp.tile(x1, (t.shape[0],1))
    t_ = jnp.repeat(t, x0.shape[0])[:,None]
  
    emb_x0 = self.x0_emb(x0_)
    emb_x1 = self.x1_emb(x1_)

    avg = t_ * x1_ + (1-t_) * x0_

    if self.exponent_as_parameter:
      env_exp = self.var_envelope_exponent
    else:
      env_exp = self.envelope_exponent
    
    envelope = self.scale_factor * (1- (t_*2-1)** (2 * env_exp)) 
    aug_state = jnp.concatenate([emb_x0, emb_x1,t_], axis = -1)
    outs =  self.mod_x0_x1(aug_state) * envelope + avg 

    return outs.reshape(t.shape[0],x0.shape[0], self.input_dim)


@struct.dataclass
class Metrics(metrics.Collection):
  #accuracy: metrics.Accuracy
  loss: metrics.Average.from_output('loss') # overall loss
  true_length: metrics.Average.from_output('true_length') # keeps track of the true length of each geodesic
  mse_geodesic: metrics.Average.from_output('mse_geodesic') # compute MSE between true and predicted geodesics
  loss_geo: metrics.Average.from_output('loss_geo') # loss for the geodesic network
  loss_density: metrics.Average.from_output('loss_density') # loss for the density network
  

class TrainState(train_state.TrainState):
  metrics: Metrics

def create_geodesic_train_state(module, rng, learning_rate):
  """Creates an initial `TrainState`."""
  params = module.init(rng, jnp.ones([1, module.input_dim]), jnp.ones([1, module.input_dim]), jnp.ones([3]))['params'] # initialize parameters by passing a template image
  tx = optax.adam(learning_rate)
  return TrainState.create(
      apply_fn=module.apply, params=params, tx=tx,
      metrics=Metrics.empty())

def jacdiag(f ):
    def _jacdiag(theta, x0,x1, t):
        def partial_grad_f_index(i):
            return jax.jacrev(f,argnums = 3)(theta, x0,x1,t[i][None])[0,...,0]
        return jax.vmap(partial_grad_f_index)(jax.numpy.arange(t.shape[0]))
    return _jacdiag

Geodesic Trainer, which takes a datamodule, parameterizable curve, and metric function.

It also takes an 'oracle', which returns the true lengths of geodesics between points in the given space - but this is used solely for validation.

In [7]:
import jax.numpy as jnp
import jax

from jax import grad, jit, vmap

from functools import partial


class GeodesicTrainer(JaxTrainer):

    def __init__(self, 
                 datamodule, # a class with functions train_dataloader, val_dataloader, test_dataloader which return those. Each batch should return two points, x_o, and x_1
                 cond_curve, # the curve to be learned
                 metric, # a function that, given (n, d) input, returns the metric for each of n points.
                 oracle, # 
                 max_epochs, 
                 seed, 
                 lr,
                 log_every_n_steps,
                 n_interp_times, 
                 density_lambda, k_density, logger, callbacks,  **kwargs):

        super().__init__(datamodule)
        self.model = cond_curve

        self.geo_metric = metric
        self.geo_oracle = oracle
        
        self.max_epochs = max_epochs

        self.density_lambda = density_lambda
        self.k_density = k_density

        if self.density_lambda != 0:
            x_train = datamodule.ds.tensors[0][datamodule.train_idx].double().numpy().copy() 
            self.X_train = jnp.array(x_train)
        else:
            self.X_train = None

        self.t = jnp.linspace(0,1, n_interp_times)

        init_rng = jax.random.key(seed)

        self.state = create_geodesic_train_state(self.model, init_rng, lr)

        self.log_every_n_steps = log_every_n_steps

        self.N_train = len(self.train_dl)

        self.logger = logger # wether to log (wandb ) or not.

        self.callbacks = callbacks
        self.metrics_names = [m for m in list(self.state.metrics.__dict__.keys()) if '_reduction_counter' not in m]

    def val_step(self,step, batch):
        x0 = batch[0].numpy() #x0_fixed
        x1 = batch[1].numpy() #x1_fixed
        val_batch = x0, x1, self.t
        

        self.val_state = self.compute_metrics(state=self.val_state, metric = self.geo_metric, oracle = self.geo_oracle, batch= val_batch, val = True)

    def train_step(self, step, batch):

        x0 = batch[0].numpy() #x0_fixed
        x1 = batch[1].numpy() #x1_fixed
        batch = x0, x1, self.t
        
        
        self.state = self.train_step_(self.state, self.geo_metric, batch) # get updated train state (which contains the updated parameters)
        self.state = self.compute_metrics(state=self.state, metric = self.geo_metric, oracle = self.geo_oracle, batch=batch) # aggregate batch metrics
    
    def test_step(self, step, batch):
    
        x0 = batch[0].numpy() #x0_fixed
        x1 = batch[1].numpy() #x1_fixed
        test_batch = x0, x1, self.t
        self.test_state = self.compute_metrics(state=self.test_state, metric = self.geo_metric, oracle = self.geo_oracle, batch= test_batch, val = True)

    def density_loss_fn(self, out):
        
        key = jax.random.key(421)
        out_ = out[1:-1].reshape(-1, out.shape[-1]) # remove the first and last points which are fixed.
        #breakpoint()
        #X_train = jax.random.normal(key,shape= (3000,out_.shape[-1]))
        #X_train = np.random.randn(3000, out_.shape[-1])
        
        def euclidean_norm(a,b):
            return jnp.sqrt(jnp.sum((a-b)**2, axis = -1))

        mv = vmap(euclidean_norm, in_axes = (0,None))   
        mm = vmap(mv, in_axes = (None,0), out_axes = 1)     
        #cdist = vmap(vmap(euclidean_norm, in_axes=(None,0)), in_axes=(0,None))
        dists = mm(out_, self.X_train)

        val, idx = jax.lax.top_k(-dists, self.k_density) 
        return -jnp.mean(val)
        dists = -jnp.transpose(dists,(1,2,0))
        val, idx = jax.lax.top_k(dists, self.k_density)

        return -jnp.mean(val)

    def loss_fn(self,state, params, metric, x0, x1, t):
        """
        Loss function for the geodesic network
        """

        jac = jacdiag(state.apply_fn)({'params':params},x0,x1,t)

        out = state.apply_fn({'params':params}, x0, x1, t)
        mu = metric(out)

        pre_prod = jnp.einsum('tbd,tbdj->tbj',jac,mu)
        prod = jnp.einsum('tbd,tbd->tb', pre_prod, jac)

        loss_geo = jnp.sqrt(prod).mean(0).sum()
        
        loss_geo = jnp.nan_to_num(loss_geo)

        if self.density_lambda != 0:
            loss_density =  self.density_loss_fn(out)
        else:
            loss_density = 0

        loss = loss_geo + self.density_lambda * loss_density

        return loss, (loss_geo, loss_density), out
    
    # @partial(jit, static_argnums=(0,2))
    def train_step_(self,state, metric, batch):
        """Train for a single step."""
        x0, x1, t = batch
        
        def loss_fn_(params):

            loss, _, out = self.loss_fn(state, params, metric, x0, x1, t)
            #if self.density_lambda != 0:
            #    loss_density =  self.density_loss_fn(out, X_train)
            #else:
            #    loss_density = 0
            
            #loss = loss_geo + loss_density
            return loss

        grad_fn = jax.grad(loss_fn_)
        grads = grad_fn(state.params)
        state = state.apply_gradients(grads=grads)
        return state
    
    # @partial(jit, static_argnums=(0,2,3,5,))
    def compute_metrics(self,state, metric, oracle, batch, val = False):
        """
        val : if true, computes some extra metrics which are only computed in val step.
        """

        x0, x1, t = batch

        loss, (loss_geo, loss_density), preds = self.loss_fn(state, state.params, metric, x0, x1, t)

        geo_length = oracle.geo_length(x0,x1).sum() #true length

        if val:
            mse_geodesic_ = oracle.mse_geodesic(x0,x1,t, preds = preds)
        else:
            mse_geodesic_ = 0

        metric_updates = state.metrics.single_from_model_output(loss=loss, true_length = geo_length, mse_geodesic = mse_geodesic_, loss_geo = loss_geo, loss_density = loss_density)
        metrics = state.metrics.merge(metric_updates)
        state = state.replace(metrics=metrics)
        return state

## Data Utils

In [8]:
import lightning as pl
from torch.utils.data import Dataset
import torch

class pairwise_geodesics_dataset(Dataset):
    def __init__(self, X, Y) -> None:
        self.X = X
        self.Y = Y
    def __len__(self):
        return self.X.shape[0]

    def __getitem__(self, index):
        target = self.Y[index, :]
        sample = self.X[index, :]
        return sample, target

class GeodesicDataModule(pl.LightningDataModule):
    def __init__(self, X, Y, batch_size=32, split = [0.8,0.1,0.1]):
        self.X = X
        self.Y = Y
        self.n_obs = X.shape[0]
        self.dim = X.shape[1]
        self.batch_size = batch_size
        self.split = split
        self.prepare_data()
        super().__init__()

    def prepare_data(self) -> None:
        last_train_idx = int(len(self.X) * self.split[0])
        last_val_idx = int(len(self.X) * (self.split[0] + self.split[1]))
        self.train_data = pairwise_geodesics_dataset(self.X[:last_train_idx], self.Y[:last_train_idx])
        self.val_data = pairwise_geodesics_dataset(self.X[last_train_idx:last_val_idx], self.Y[last_train_idx:last_val_idx])
        self.test_data = pairwise_geodesics_dataset(self.X[-last_val_idx:], self.Y[-last_val_idx:])
        
    def setup(self, stage: str):
        return
        #self.train_loader = train_dataloader(self.name, self.n_obs, self.dim, self.emb_dim, self.batch_size, self.knn, self.PATH, self.indx)

    def train_dataloader(self):
        return torch.utils.data.DataLoader(dataset = self.train_data, batch_size = self.batch_size, shuffle = True,
                                           num_workers = 0, pin_memory = True)
    
    def val_dataloader(self):
        return torch.utils.data.DataLoader(dataset = self.val_data, batch_size = self.batch_size, shuffle = True,
                                           num_workers = 0, pin_memory = True)
    def test_dataloader(self):
        return torch.utils.data.DataLoader(dataset = self.test_data, batch_size = self.batch_size, shuffle = True,
                                           num_workers = 0, pin_memory = True)


In [9]:
#|export
class DummyOracle:
    def __init__(self):
        return
    
    def geo_length(self,x0,x1):
        """
        Computes the length between x0 and x1 on the sphere
        """
        return jnp.ones(len(x0))
    
    def mse_geodesic(self, x0,x1,t, preds):
        return 1

In [16]:
#|export
def wrap_torch_metric(x, metric):
    print(x)
    print(x.shape)
    x_np = np.array(x)
    x = torch.from_numpy(x)
    return jax.lax.stop_gradient(metric(x).detach().numpy())

class GeodesicQuicktrainer(GeodesicTrainer):
    def __init__(self, 
                 X:np.ndarray, # data coordinates, in ambient dimension
                 intrinsic_dim:int, # dimension (usually intrinsic dimension of manifold)
                 metric_fn, # a function that, given (n, d) input, returns the metric for each of n points.
                 max_epochs=1000,
                 batch_size = 32,
                 seed = 42,
                 layers_in_curve = 3,
                 hidden_dimension=16,
                 use_autometric_metric = True,
                 **kwargs
                 ):
        #split X in two
        X = np.array(X)
        np.random.seed(seed)
        idx = np.random.permutation(len(X))
        X = X[idx]
        X1 = X[:len(X)//2]
        X2 = X[len(X)//2:]
        datamodule = GeodesicDataModule(X1, X2, batch_size = batch_size)
        cond_curve = CondCurve(
            input_dim = intrinsic_dim,
            hidden_dim = hidden_dimension,
            scale_factor = 5,
            symmetric = False,
            num_layers=layers_in_curve,
        )
        oracle = DummyOracle()
        if use_autometric_metric:
            metric_fn = partial(wrap_torch_metric, metric = metric_fn)

        super().__init__(
            datamodule, # a class with functions train_dataloader, val_dataloader, test_dataloader which return those. Each batch should return two points, x_o, and x_1
            cond_curve, # the curve to be learned
            metric_fn, # a function that, given (n, d) input, returns the metric for each of n points.
            oracle, # 
            max_epochs = max_epochs,
            lr = 0.001,
            density_lambda=0,
            seed = 421,
            log_every_n_steps=10,
            n_interp_times=20,
            k_density=5, # default parameters in EB's hydra configs
            logger = False,
            callbacks = [],
            )
    

# Example Usage of Quicktrainer

In [17]:
from autometric.datasets import *

In [18]:
S = Sphere(2000)

In [19]:
GT = GeodesicQuicktrainer(
    X = S.intrinsic_coords,
    intrinsic_dim = 2,
    metric_fn = S.metric.metric_matrix,
)

In [21]:
GT.fit()

Epoch 0


0it [00:00, ?it/s]

Traced<ConcreteArray([[[0.8217955  0.7612241 ]
  [0.92124283 0.09495129]
  [0.16644812 0.917934  ]
  ...
  [0.9451909  0.9616518 ]
  [0.09142207 0.16165861]
  [0.35804468 0.2778272 ]]

 [[0.88024426 0.7410093 ]
  [0.9893618  0.11400148]
  [0.25630146 0.88964325]
  ...
  [0.99862874 0.9432589 ]
  [0.09985784 0.18859282]
  [0.40638933 0.26414135]]

 [[0.9342167  0.7225772 ]
  [1.0532163  0.13010012]
  [0.34755296 0.86168635]
  ...
  [1.0458964  0.9267856 ]
  [0.1136967  0.22144407]
  [0.4559548  0.25839856]]

 ...

 [[0.9237305  0.57578963]
  [1.1509829  0.5748307 ]
  [0.9322858  0.8955642 ]
  ...
  [0.72653055 0.85308266]
  [0.13518839 0.76102686]
  [0.65867406 0.21666998]]

 [[0.8292381  0.5159184 ]
  [1.063587   0.5683033 ]
  [0.9063536  0.841078  ]
  ...
  [0.586697   0.79223406]
  [0.10660225 0.7316562 ]
  [0.6335392  0.14631966]]

 [[0.71918875 0.44336426]
  [0.9646362  0.5503747 ]
  [0.87257534 0.7672638 ]
  ...
  [0.42501265 0.7158112 ]
  [0.07227828 0.685501  ]
  [0.60394645 0.0




TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on traced array with shape float32[20,32,2].
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError

In [326]:
jax_metric = partial(wrap_torch_metric, S.metric.metric_matrix)

In [325]:
np.array(jax.device_get(jnp.ones((4,2))))

array([[1., 1.],
       [1., 1.],
       [1., 1.],
       [1., 1.]], dtype=float32)

In [327]:
jax_metric(jnp.ones((4,2)))

TypeError: can't convert np.ndarray of type numpy.object_. The only supported types are: float64, float32, float16, complex64, complex128, int64, int32, int16, int8, uint8, and bool.

# Tests

Let's start by testing this on our old friend, the Sphere.

In [22]:
def geodesic_sphere_jax(x0,x1, t):
    """
    x0 is the starting point
    x1 is the ending point
    the times at which to evaluate the geodesic.
    """
    vx = jnp.cos(x0[0]) * jnp.sin(x0[1])
    vy = jnp.sin(x0[0]) * jnp.sin(x0[1])
    vz = jnp.cos(x0[1])

    wx = jnp.cos(x1[0]) * jnp.sin(x1[1])
    wy = jnp.sin(x1[0]) * jnp.sin(x1[1])
    wz = jnp.cos(x1[1])

    v = jnp.array([vx,vy,vz])
    w = jnp.array([wx,wy,wz])

    u = w - jnp.dot(v,w) * v
    u = u / jnp.linalg.norm(u)

    c = jnp.arccos(jnp.dot(v,w))
    #t = np.linspace(0,c,100)
    #assert(t[-1] == 1) # geodesic goes from 0 to 1
    t = t * c

    alpha = jnp.cos(t)[:,None] * v + jnp.sin(t)[:,None] * u

    v = jnp.arccos(alpha[...,2])
    cosu = alpha[...,0] / jnp.sin(v)
    sinu = alpha[...,1] / jnp.sin(v)
    u = jnp.arctan2(sinu,cosu)

    return alpha, jnp.stack([u,v], -1)

In [23]:
class SphereMetric:
    """
    Jax implementation of the metric on the sphere
    """
    def __init__(self):
        return
    
    def __call__(self,x):
        m = jnp.tile(jnp.eye(2)[None,...],x.shape[:-1]+(1,1))
        m = m.at[...,0,0].set(jnp.sin( x[...,-1])**(2))
        return m
    
    def __hash__(self):
        return 0
    

class OracleSphere:
    def __init__(self):
        return
    
    def geo_length(self,x0,x1):
        """
        Computes the length between x0 and x1 on the sphere
        """
        delta_l = jnp.abs(x0[:,0]-x1[:,0])
        d = jnp.arccos(jnp.sin(x0[:,1]) * jnp.sin(x1[:,1]) + jnp.cos(x0[:,1])*jnp.cos(x1[:,1]) * jnp.cos(delta_l))
        return d
    
    def mse_geodesic(self, x0,x1,t, preds):
        return mse_geodesic(x0,x1,t, ground_truth = geodesic_sphere_jax, pred = preds)

In [24]:
SM = SphereMetric()
SM(jnp.zeros((3,2)))

Array([[[0., 0.],
        [0., 1.]],

       [[0., 0.],
        [0., 1.]],

       [[0., 0.],
        [0., 1.]]], dtype=float32)

Indeed, the sphere metric here is in latent space, not ambient space. Does this mean that the NeuralFIM is likewise?

In [25]:
from autometric.datasets import Sphere

In [26]:
sphere1 = Sphere(2000).intrinsic_coords.numpy()
sphere2 = Sphere(2000).intrinsic_coords.numpy()

In [27]:
datamodule = GeodesicDataModule(sphere1, sphere2)

In [28]:
line_module = CondCurve(
    input_dim = 2,
    hidden_dim = 16,
    scale_factor = 5,
    symmetric = False,
    num_layers=3,
)

In [32]:
trainer = GeodesicTrainer(
    datamodule,
    cond_curve = line_module,
    metric = SM,
    oracle = OracleSphere(),
    max_epochs = 100,
    lr = 0.001,
    density_lambda=0,
    seed = 421,
    log_every_n_steps=10,
    n_interp_times=20,
    k_density=5, # default parameters in EB's hydra configs
    logger = False,
    callbacks = [],
)

In [33]:
trainer.geo_metric(jax.numpy.array([0,1]))

Array([[[0.7080734, 0.       ],
        [0.       , 1.       ]]], dtype=float32)

In [34]:
trainer.fit()

Epoch 0


50it [00:11,  4.38it/s]
7it [00:00,  8.49it/s]


Epoch 1


50it [00:11,  4.42it/s]
7it [00:00,  8.25it/s]


Epoch 2


50it [00:11,  4.28it/s]
7it [00:00,  8.23it/s]


Epoch 3


50it [00:11,  4.36it/s]
7it [00:00,  8.38it/s]


Epoch 4


50it [00:11,  4.40it/s]
7it [00:00,  8.49it/s]


Epoch 5


50it [00:11,  4.49it/s]
7it [00:00,  8.33it/s]


Epoch 6


50it [00:11,  4.29it/s]
7it [00:00,  8.26it/s]


Epoch 7


50it [00:11,  4.30it/s]
7it [00:00,  8.44it/s]


Epoch 8


50it [00:11,  4.33it/s]
7it [00:00,  8.37it/s]


Epoch 9


50it [00:11,  4.40it/s]
7it [00:00,  8.43it/s]


Epoch 10


50it [00:11,  4.33it/s]
7it [00:00,  8.43it/s]


Epoch 11


50it [00:11,  4.43it/s]
7it [00:00,  8.22it/s]


Epoch 12


50it [00:11,  4.46it/s]
7it [00:00,  8.51it/s]


Epoch 13


50it [00:11,  4.54it/s]
7it [00:00,  8.33it/s]


Epoch 14


50it [00:11,  4.38it/s]
7it [00:00,  8.30it/s]


Epoch 15


50it [00:11,  4.42it/s]
7it [00:00,  8.48it/s]


Epoch 16


50it [00:11,  4.43it/s]
7it [00:00,  8.40it/s]


Epoch 17


50it [00:11,  4.48it/s]
7it [00:00,  8.43it/s]


Epoch 18


50it [00:11,  4.40it/s]
7it [00:00,  8.44it/s]


Epoch 19


50it [00:11,  4.41it/s]
7it [00:00,  8.48it/s]


Epoch 20


50it [00:11,  4.32it/s]
7it [00:00,  8.42it/s]


Epoch 21


50it [00:11,  4.36it/s]
7it [00:00,  8.21it/s]


Epoch 22


50it [00:11,  4.44it/s]
7it [00:00,  8.43it/s]


Epoch 23


50it [00:11,  4.29it/s]
7it [00:00,  8.17it/s]


Epoch 24


50it [00:11,  4.30it/s]
7it [00:00,  8.28it/s]


Epoch 25


50it [00:11,  4.35it/s]
7it [00:00,  8.07it/s]


Epoch 26


50it [00:11,  4.34it/s]
7it [00:00,  8.46it/s]


Epoch 27


50it [00:11,  4.44it/s]
7it [00:00,  8.33it/s]


Epoch 28


50it [00:11,  4.30it/s]
7it [00:00,  8.40it/s]


Epoch 29


50it [00:11,  4.34it/s]
7it [00:00,  8.39it/s]


Epoch 30


50it [00:11,  4.38it/s]
7it [00:00,  8.38it/s]


Epoch 31


50it [00:11,  4.36it/s]
7it [00:00,  8.41it/s]


Epoch 32


50it [00:11,  4.51it/s]
7it [00:00,  8.41it/s]


Epoch 33


50it [00:11,  4.41it/s]
7it [00:00,  8.34it/s]


Epoch 34


50it [00:11,  4.36it/s]
7it [00:00,  8.28it/s]


Epoch 35


50it [00:11,  4.31it/s]
7it [00:00,  8.22it/s]


Epoch 36


50it [00:11,  4.30it/s]
7it [00:00,  8.18it/s]


Epoch 37


50it [00:11,  4.42it/s]
7it [00:00,  8.22it/s]


Epoch 38


50it [00:11,  4.32it/s]
7it [00:00,  8.27it/s]


Epoch 39


50it [00:11,  4.34it/s]
7it [00:00,  8.29it/s]


Epoch 40


50it [00:11,  4.41it/s]
7it [00:00,  8.38it/s]


Epoch 41


50it [00:11,  4.39it/s]
7it [00:00,  8.37it/s]


Epoch 42


50it [00:11,  4.51it/s]
7it [00:00,  8.45it/s]


Epoch 43


50it [00:11,  4.38it/s]
7it [00:00,  8.25it/s]


Epoch 44


50it [00:11,  4.37it/s]
7it [00:00,  8.41it/s]


Epoch 45


50it [00:11,  4.43it/s]
7it [00:00,  8.60it/s]


Epoch 46


50it [00:11,  4.48it/s]
7it [00:00,  8.44it/s]


Epoch 47


50it [00:11,  4.51it/s]
7it [00:00,  8.49it/s]


Epoch 48


50it [00:11,  4.44it/s]
7it [00:00,  8.53it/s]


Epoch 49


50it [00:11,  4.40it/s]
7it [00:00,  8.22it/s]


Epoch 50


50it [00:11,  4.35it/s]
7it [00:00,  8.52it/s]


Epoch 51


50it [00:11,  4.45it/s]
2it [00:00,  8.10it/s]

# Evaluating Geodesics Visually

Geodesics must always be generated in the latent space, with the intrinsic coordinates. Of course, every model will learn its own intrinsic coordinates. Hence, to easily compare the performance of GD6, we need to cast them back into ambient space.

To automate this, here's a function that takes as input a pair of geodesic functions, combined with a pair of encoder and decoder functions, plus a set of ambient coordinates.

In [276]:
#|export
from autometric.utils import plot_3d

def sample_along_geodesic(
    start, end, 
    geodesic_func, encoder, decoder,
    num_times = 50,
):
    start_latent = encoder(start).detach().cpu().numpy()[0]
    end_latent = encoder(end).detach().cpu().numpy()[0]
    ts = jnp.linspace(0,1,num_times)[:,None]
    # partial_geodesic = partial(geodesic_func, x0=start_latent, x1=end_latent)
    # print(geodesic_func(start_latent, end_latent, 0.5))
    samples = geodesic_func(start_latent, end_latent, ts)
    samples = np.squeeze(np.array(jax.device_get(samples))) # convert back to numpy
    samples_decoded = decoder(samples)
    return samples_decoded

def plot_3d_with_geodesics(X, geodesics):
    # if geodesics is not a list, wrap it in one
    if isinstance(geodesics, np.ndarray):
        geodesics = [geodesics]
    combined_geodesics = np.concatenate(geodesics, axis=0)
    all_points = np.concatenate([X, combined_geodesics], axis=0)
    plot_colors = np.zeros(len(X) + len(combined_geodesics))
    running_length = len(X)
    for i, g in enumerate(geodesics):
        plot_colors[running_length:running_length + len(g)] = i + 1
        running_length += len(g)
    plot_3d(all_points, plot_colors, use_plotly=True)

def visualize_geodesics(
    X_ambient:np.ndarray, # ambient coordinates
    geodesic_func_1, # first geodesic function. Takes input x1, x2, t and returns the geodesic at time t.
    geodesic_func_2, 
    model1,
    model2,
    num_geodesics_to_sample:int = 1,
):
    encoder_1 = model1.encode
    decoder_1 = model1.decode
    encoder_2 = model2.encode
    decoder_2 = model2.decode
    endpoints_idx = np.random.choice(np.arange(len(X_ambient)), size=(num_geodesics_to_sample,2), replace=False)
    geodesics_1 = []
    geodesics_2 = []
    for i in range(num_geodesics_to_sample):
        start = np.array(X_ambient[endpoints_idx[i][0]][None,:])
        end = np.array(X_ambient[endpoints_idx[i][1]][None,:])
        geodesics_1.append(
            sample_along_geodesic(start, end, geodesic_func_1, encoder_1, decoder_1)
        )
        geodesics_2.append(
            sample_along_geodesic(start, end, geodesic_func_2, encoder_2, decoder_2)
        )
        
    combined_g1s = np.concatenate(geodesics_1, axis=0)
    combined_g2s = np.concatenate(geodesics_2, axis=0)
    plot_3d_with_geodesics(X_ambient, [combined_g1s, combined_g2s])
    return geodesics_1, geodesics_2
    

In [277]:
def get_geodesic_from_trainer(start, end, ts, trainer):
    start = jnp.array(start)
    if len(start.shape) == 1:
        start = start[None,:]
    end = jnp.array(end)
    if len(end.shape) == 1:
        end = end[None,:]
    return trainer.state.apply_fn({'params':trainer.state.params}, start, end, ts)

In [278]:
S1 = Sphere(2000)

In [285]:
g1, g2 = visualize_geodesics(
    S1.X,
    S1.latent_geodesic,
    partial(get_geodesic_from_trainer, trainer=trainer),
    S1, 
    S1
)

torch.Size([1, 3]) torch.Size([1, 3])



To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).



In [245]:
A = jnp.zeros((10,2))

In [250]:
np.array(jax.device_get(A))

array([[0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.]], dtype=float32)

In [161]:
# sync changes to the library
from IPython.display import display, Javascript
import time
display(Javascript('IPython.notebook.save_checkpoint();'))
time.sleep(2)
!pixi run nbsync

<IPython.core.display.Javascript object>

  pid, fd = os.forkpty()


[33m WARN[0m [2mpixi::project::manifest[0m[2m:[0m BETA feature `[pypi-dependencies]` enabled!

Please report any and all issues here:

	https://github.com/prefix-dev/pixi.


✨ [1mPixi task ([0m[35m[1mdefault[0m[1m): [0m[34m[1mnbdev_export[0m
[2K[32m⠁[0m activating environment                                                                 