In [3]:
%load_ext autoreload
%autoreload 2

In [4]:
import os
from os.path import join

import numpy as np

import plotly.graph_objects as go

import sys
plantbox_path = os.environ["PLANTBOX_PATH"]
sys.path.append(plantbox_path)

from plantbox import RootSystem, SegmentAnalyser

from root import Root

ModuleNotFoundError: No module named 'root'

In [4]:
param_dir = join(plantbox_path, "modelparameter", "rootsystem")
# os.listdir(param_dir)

In [5]:
name = "Anagallis_femina_Leitner_2010"
param_path = join(param_dir,  f"{name}.xml")

In [97]:
rs = RootSystem()
rs.readParameters(param_path)
rs.initialize()
rs.simulate(14)
pyroot = Root(rs)
# fig = pyroot.plot(draw_nodes=False)
# fig

In [98]:
fig.write_image(join(os.environ["HOME"], "Images", "root.png"))

# Optimal Transport

## Points clouds

In [7]:
import matplotlib.pyplot as plt
from balanced_ot import wasserstein, reg_wasserstein
from unbalanced_ot import unb_reg_wasserstein

In [8]:
anagallis_file = "Anagallis_femina_Leitner_2010"
anagallis_path = join(param_dir,  f"{anagallis_file}.xml")

brassica_file = "Brassica_oleracea_Vansteenkiste_2014"
brassica_path = join(param_dir,  f"{brassica_file}.xml")

anagallis = Root.from_file(anagallis_path, age=4)
print(anagallis.n_nodes, "nodes")

anagallis_2 = Root.from_file(anagallis_path, age=6)
print(anagallis_2.n_nodes)

brassica = Root.from_file(brassica_path, age=6)
print(brassica.n_nodes, "nodes")

172 nodes
407
162 nodes


In [9]:
wasserstein_dist, coupling = wasserstein(anagallis, anagallis_2)
print(wasserstein_dist)
# plt.imshow(coupling)

3.2164835710586166


## Sinkhorn 

In [10]:
wasserstein_dist, coupling = reg_wasserstein(anagallis, anagallis_2)
wasserstein_dist
# plt.imshow(coupling)

array([3.26595633])

In [11]:
wasserstein_dist, coupling = unb_reg_wasserstein(anagallis, anagallis_2)
print(wasserstein_dist)
# plt.imshow(coupling)

[3.93499218]


## Layerwise

### Distances

In [12]:
from layerwise.distances import lw_wasserstein

In [13]:
anagallis.true_scale()
anagallis_2.true_scale()
lw_wasserstein(anagallis, anagallis_2, 500)


Empty layer



2.2035927859716353

In [14]:
anagallis.rescale()
anagallis_2.rescale()
lw_wasserstein(anagallis, brassica, 500)

5.212406474511068

In [16]:
anagallis.true_scale()
# anagallis.plot()

## Barycenters

In [85]:
from barycenters import point_cloud_barycenter
from layerwise.barycenters import lw_barycenter, plot_barycenter_nodes

### Points cloud barycenter

In [88]:
bar = point_cloud_barycenter(anagallis, anagallis_2)

In [2]:
# fig = plot_barycenter_nodes(bar)
# fig

In [94]:
fig.write_image(join(os.environ["HOME"], "Images", "point_cloud_barycenter.png"))

### Layerwise barycenter

In [75]:
bar_mu_v, bar_lw, bar_nodes = lw_barycenter([anagallis, anagallis_2], 150, 1e-1)

In [1]:
# fig = plot_barycenter_nodes(bar_nodes)
# fig

In [96]:
fig.write_image(join(os.environ["HOME"], "Images", "barycenter.png"))