# Plant skeletonization

This Jupyter notebook replicates the plant skeletonization experiment, specifically focusing on a tomato plant at various growth stages (day 5, day 8, day 13). The notebook calculates Branched Central Spanning Trees (BCST) for different $\alpha$ (alpha) values.

The BCST computation is conducted in two scenarios:

- **With Prior Knowledge:** In this scenario, prior knowledge about the tree's root is utilized.
Instances of terminals corresponding to the root are duplicated, creating a simulated higher density in this specific location.
- **Without Prior Knowledge:** In this case, BCST is computed without any prior information about the tree's root.

In [None]:
get_ipython().run_line_magic('load_ext', 'autoreload')
get_ipython().run_line_magic('autoreload', '2')


In [None]:
import os
import sys
if 'Plant_skeletonization' in os.getcwd():
	os.chdir('../..')
	sys.path.insert(0,os.getcwd())
print(os.getcwd())
print(sys.path[0])

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from lib.CST.T_datacls import T_data
from scipy.sparse import coo_matrix
from lib.CST.utils.utils import factor_number
import tempfile
import os
from PIL import Image
# import scipy.sparse as sp

from Experiments.Plant_skeletonization.visualization3D import visualize_data_3d,plot_graph_3d,generate_rotation_gif
from Experiments.Plant_skeletonization.load_plantdata import load_plantdata

#### Functions to generate name of folders and filenames of the figures

In [None]:
def define_filename(key,alpha,n,prior=False):
	if prior:
		prior_txt='_prior'
	else:
		prior_txt=''
	return '%s_n=%i_BCST_%0.2f%s'%(key,n,alpha,prior_txt)

def define_foldername(plant_type,plant_num,n):
	return 'Figures/plant_skeleton/%s_plant%i/n=%i/'%(plant_type,plant_num,n)


#### parameters

In [None]:
tdata_dict={}


In [None]:
plant_type='tomato'
plant_num=2 # plant instance

seed=12

#number samples
n=5000

#alpha values of the BCST
alpha_ls=[0,0.5,0.7,1]

#parameters BCST
maxfreq_mSTreg=3
maxiter_mSTreg=20

# The power to which the values of the edge widths will be raised to improve visualization
power_width=0.7

#max edge-width
max_width=15

factor_rep=5

# if true figures are saved
save=False

#if True gif will be generated
generate_gif=False
num_frames_gif=20


offset_base_root_prior=10

## DAY5

### Load data

In [None]:
np.random.seed(seed)
#set day plant
day=5

# set visualization angles
elevation_angle=30
azimuthal_angle=30

# load data
data=load_plantdata(plant_type, plant_num, day)
key='%s_plant%i_day%i'%(plant_type,plant_num,day)

#subsample n points
data_sub=data[np.random.choice(np.arange(data.shape[0]),n,replace=False)]

visualize_data_3d(data_sub,title='Subsampled data with %i points'%n,show_plot=False,elevation_angle=elevation_angle,
				  azimuthal_angle=azimuthal_angle)

#### Compute BCST

In [None]:


tdata_dict[key]=T_data(data_sub)
tdata=tdata_dict[key]

for alpha in alpha_ls:
	tdata.compute_BCST(alpha=alpha,return_topo_CST=False,maxfreq_mSTreg=maxfreq_mSTreg,maxiter_mSTreg=maxiter_mSTreg)

In [None]:
tdata_prior=tdata_dict[key]

k1,k2=factor_number(len(alpha_ls))


separated = False

folder_figures=define_foldername(plant_type,plant_num,n)

if save:
	separated=True
	os.makedirs(folder_figures,True)
for counter_plot,alpha in enumerate(alpha_ls):

	T_flows=tdata.trees['BCST_%0.2f'%alpha].get_T_weighted_by_flows()
	if power_width!=1:
		T_flows.data**=power_width
	T_flows*=max_width/T_flows.max()
	cost=tdata.trees['BCST_%0.2f'%alpha].cost

	coords=tdata.trees['BCST_%0.2f'%alpha].coords

	if separated:
		k1=k2=1
		counter_plot=0

	if save:
		title=''
	else:
		title='BCST %0.2f'%(alpha)
	fig, ax=plot_graph_3d(T_flows,coords,node_size=1,node_colors=None,linewidth_multiplier=1,
				  title=title,show_plot=False,k1=k1,k2=k2,counter_plot=counter_plot,
						  figsize=(15,12),elevation_angle=elevation_angle,
				  azimuthal_angle=azimuthal_angle)
	if save:
		plt.tight_layout()
		filename=define_filename(key,alpha,n,prior=False)
		plt.savefig(folder_figures+filename+'.png')

	if generate_gif:
		generate_rotation_gif(T_flows, coords, node_colors=None, node_size=1,
                       title='BCST %0.2f'%alpha, edge_color='red',
                       figsize=(15,12), num_frames=num_frames_gif, output_filename=folder_figures+filename+'.gif')
plt.tight_layout()
plt.show()

#### Compute BCST with prior root

In [None]:
# set number of times root is repeated
root_reps=factor_rep*len(data_sub)

# identify root as the one with lowest z-coordinate
root=np.argmin(data_sub[:,2])

# consider also as roots the terminals are above a certain threshold from the lowest terminal
pseudo_roots=np.where(data_sub[:,2]<=data_sub[root,2]+offset_base_root_prior)[0]

#set mass of roots
demands=-np.ones(data_sub.shape[0])/(data_sub.shape[0]+root_reps-1)
for r in pseudo_roots:
	demands[r]=1/(data_sub.shape[0]+root_reps-1)
	if r==root:
		demands[root]=root_reps/len(pseudo_roots)/(data_sub.shape[0]+root_reps-1)
	else:
		demands[r] = -root_reps / len(pseudo_roots) / (data_sub.shape[0] + root_reps - 1)

tdata_dict[key+'_prior']=T_data(data_sub)
tdata_prior=tdata_dict[key+'_prior']


for alpha in alpha_ls:
	tdata_prior.compute_BCST(alpha=alpha,return_topo_CST=False,maxfreq_mSTreg=maxfreq_mSTreg,maxiter_mSTreg=maxiter_mSTreg,
							 demands=demands)

#### Plot trees

In [None]:
tdata_prior=tdata_dict[key+'_prior']

k1,k2=factor_number(len(alpha_ls))


separated = False


folder_figures=define_foldername(plant_type,plant_num,n)

if save:
	separated=True
	os.makedirs(folder_figures,True)

for counter_plot,alpha in enumerate(alpha_ls):

	T_flows=tdata_prior.trees['BCST_%0.2f'%alpha].get_T_weighted_by_flows()
	if power_width!=1:
		T_flows.data**=power_width
	T_flows*=max_width/T_flows.max()
	cost=tdata_prior.trees['BCST_%0.2f'%alpha].cost

	coords=tdata_prior.trees['BCST_%0.2f'%alpha].coords

	if separated:
		k1=k2=1
		counter_plot=0

	if save:
		title=''
	else:
		title='BCST %0.2f'%(alpha)
	fig, ax=plot_graph_3d(T_flows,coords,node_size=1,node_colors=None,linewidth_multiplier=1,
				  title=title,show_plot=False,k1=k1,k2=k2,counter_plot=counter_plot,
						  figsize=(15,12),elevation_angle=elevation_angle,
				  azimuthal_angle=azimuthal_angle)
	if save:
		plt.tight_layout()
		filename=define_filename(key,alpha,n,prior=True)
		plt.savefig(folder_figures+filename+'.png')

	if generate_gif:
		generate_rotation_gif(T_flows, coords, node_colors=None, node_size=1,
                       title='BCST %0.2f'%alpha, edge_color='red',
                       figsize=(15,12), num_frames=num_frames_gif, output_filename=folder_figures+filename+'.gif')

plt.tight_layout()
plt.show()

## DAY8

### Load data

In [None]:
np.random.seed(seed)
#set day plant
day=8

# set visualization angles
elevation_angle=30
azimuthal_angle=-90

# load data
data=load_plantdata(plant_type, plant_num, day)
key='%s_plant%i_day%i'%(plant_type,plant_num,day)

#subsample n points
data_sub=data[np.random.choice(np.arange(data.shape[0]),n,replace=False)]

visualize_data_3d(data_sub,title='Subsampled data with %i points'%n,show_plot=False,elevation_angle=elevation_angle,
				  azimuthal_angle=azimuthal_angle)

#### Compute BCST

In [None]:


tdata_dict[key]=T_data(data_sub)
tdata=tdata_dict[key]

for alpha in alpha_ls:
	tdata.compute_BCST(alpha=alpha,return_topo_CST=False,maxfreq_mSTreg=maxfreq_mSTreg,maxiter_mSTreg=maxiter_mSTreg)

#### Plot trees

In [None]:
tdata=tdata_dict[key]

k1,k2=factor_number(len(alpha_ls))


separated = False


folder_figures=define_foldername(plant_type,plant_num,n)

if save:
	separated=True
	os.makedirs(folder_figures,True)

for counter_plot,alpha in enumerate(alpha_ls):

	T_flows=tdata.trees['BCST_%0.2f'%alpha].get_T_weighted_by_flows()
	if power_width!=1:
		T_flows.data**=power_width
	T_flows*=max_width/T_flows.max()
	cost=tdata.trees['BCST_%0.2f'%alpha].cost

	coords=tdata.trees['BCST_%0.2f'%alpha].coords

	if separated:
		k1=k2=1
		counter_plot=0

	if save:
		title=''
	else:
		title='BCST %0.2f'%(alpha)
	fig, ax=plot_graph_3d(T_flows,coords,node_size=1,node_colors=None,linewidth_multiplier=1,
				  title=title,show_plot=False,k1=k1,k2=k2,counter_plot=counter_plot,
						  figsize=(15,12),elevation_angle=elevation_angle,
				  azimuthal_angle=azimuthal_angle)
	if save:
		plt.tight_layout()
		filename=define_filename(key,alpha,n,prior=False)
		plt.savefig(folder_figures+filename+'.png')

	if generate_gif:
		generate_rotation_gif(T_flows, coords, node_colors=None, node_size=1,
                       title='BCST %0.2f'%alpha, edge_color='red',
                       figsize=(15,12), num_frames=num_frames_gif, output_filename=folder_figures+filename+'.gif')
plt.tight_layout()
plt.show()

#### Compute BCST with prior root

In [None]:
# set number of times root is repeated
root_reps=factor_rep*len(data_sub)

# identify root as the one with lowest z-coordinate
root=np.argmin(data_sub[:,2])

# consider also as roots the terminals are above a certain threshold from the lowest terminal
pseudo_roots=np.where(data_sub[:,2]<=data_sub[root,2]+offset_base_root_prior)[0]

#set mass of roots
demands=-np.ones(data_sub.shape[0])/(data_sub.shape[0]+root_reps-1)
for r in pseudo_roots:
	demands[r]=1/(data_sub.shape[0]+root_reps-1)
	if r==root:
		demands[root]=root_reps/len(pseudo_roots)/(data_sub.shape[0]+root_reps-1)
	else:
		demands[r] = -root_reps / len(pseudo_roots) / (data_sub.shape[0] + root_reps - 1)

tdata_dict[key+'_prior']=T_data(data_sub)
tdata_prior=tdata_dict[key+'_prior']


for alpha in alpha_ls:
	tdata_prior.compute_BCST(alpha=alpha,return_topo_CST=False,maxfreq_mSTreg=maxfreq_mSTreg,maxiter_mSTreg=maxiter_mSTreg,
							 demands=demands)

In [None]:
tdata_prior=tdata_dict[key+'_prior']

k1,k2=factor_number(len(alpha_ls))


separated = False


folder_figures=define_foldername(plant_type,plant_num,n)

if save:
	separated=True
	os.makedirs(folder_figures,True)

for counter_plot,alpha in enumerate(alpha_ls):

	T_flows=tdata_prior.trees['BCST_%0.2f'%alpha].get_T_weighted_by_flows()
	if power_width!=1:
		T_flows.data**=power_width
	T_flows*=max_width/T_flows.max()
	cost=tdata_prior.trees['BCST_%0.2f'%alpha].cost

	coords=tdata_prior.trees['BCST_%0.2f'%alpha].coords

	if separated:
		k1=k2=1
		counter_plot=0

	if save:
		title=''
	else:
		title='BCST %0.2f'%(alpha)
	fig, ax=plot_graph_3d(T_flows,coords,node_size=1,node_colors=None,linewidth_multiplier=1,
				  title=title,show_plot=False,k1=k1,k2=k2,counter_plot=counter_plot,
						  figsize=(15,12),elevation_angle=elevation_angle,
				  azimuthal_angle=azimuthal_angle)
	if save:
		plt.tight_layout()
		filename=define_filename(key,alpha,n,prior=True)
		plt.savefig(folder_figures+filename+'.png')

	if generate_gif:
		generate_rotation_gif(T_flows, coords, node_colors=None, node_size=1,
                       title='BCST %0.2f'%alpha, edge_color='red',
                       figsize=(15,12), num_frames=num_frames_gif, output_filename=folder_figures+filename+'.gif')
plt.tight_layout()
plt.show()

## DAY13

### Load data

In [None]:
np.random.seed(seed)
#set day plant
day=13

# set visualization angles
elevation_angle=None
azimuthal_angle=None

# load data
data=load_plantdata(plant_type, plant_num, day)
key='%s_plant%i_day%i'%(plant_type,plant_num,day)

#subsample n points
data_sub=data[np.random.choice(np.arange(data.shape[0]),n,replace=False)]

visualize_data_3d(data_sub,title='Subsampled data with %i points'%n,show_plot=False,elevation_angle=elevation_angle,
				  azimuthal_angle=azimuthal_angle)

#### Compute BCST

In [None]:
tdata_dict[key]=T_data(data_sub)
tdata=tdata_dict[key]

for alpha in alpha_ls:
	tdata.compute_BCST(alpha=alpha,return_topo_CST=False,maxfreq_mSTreg=maxfreq_mSTreg,maxiter_mSTreg=maxiter_mSTreg)

#### Plot trees

In [None]:
tdata_prior=tdata_dict[key]

k1,k2=factor_number(len(alpha_ls))


separated = False


folder_figures=define_foldername(plant_type,plant_num,n)

if save:
	separated=True
	os.makedirs(folder_figures,True)

for counter_plot,alpha in enumerate(alpha_ls):

	T_flows=tdata.trees['BCST_%0.2f'%alpha].get_T_weighted_by_flows()
	if power_width!=1:
		T_flows.data**=power_width
	T_flows*=max_width/T_flows.max()
	cost=tdata.trees['BCST_%0.2f'%alpha].cost

	coords=tdata.trees['BCST_%0.2f'%alpha].coords

	if separated:
		k1=k2=1
		counter_plot=0

	if save:
		title=''
	else:
		title='BCST %0.2f'%(alpha)
	fig, ax=plot_graph_3d(T_flows,coords,node_size=1,node_colors=None,linewidth_multiplier=1,
				  title=title,show_plot=False,k1=k1,k2=k2,counter_plot=counter_plot,
						  figsize=(15,12),elevation_angle=elevation_angle,
				  azimuthal_angle=azimuthal_angle)
	if save:
		plt.tight_layout()
		filename=define_filename(key,alpha,n,prior=False)
		plt.savefig(folder_figures+filename+'.png')

	if generate_gif:
		generate_rotation_gif(T_flows, coords, node_colors=None, node_size=1,
                       title='BCST %0.2f'%alpha, edge_color='red',
                       figsize=(15,12), num_frames=num_frames_gif, output_filename=folder_figures+filename+'.gif')

plt.tight_layout()
plt.show()

#### Compute BCST with prior root

In [None]:
# set number of times root is repeated
root_reps=factor_rep*len(data_sub)

# identify root as the one with lowest z-coordinate
root=np.argmin(data_sub[:,2])

# consider also as roots the terminals are above a certain threshold from the lowest terminal
pseudo_roots=np.where(data_sub[:,2]<=data_sub[root,2]+offset_base_root_prior)[0]

#set mass of roots
demands=-np.ones(data_sub.shape[0])/(data_sub.shape[0]+root_reps-1)
for r in pseudo_roots:
	demands[r]=1/(data_sub.shape[0]+root_reps-1)
	if r==root:
		demands[root]=root_reps/len(pseudo_roots)/(data_sub.shape[0]+root_reps-1)
	else:
		demands[r] = -root_reps / len(pseudo_roots) / (data_sub.shape[0] + root_reps - 1)

tdata_dict[key+'_prior']=T_data(data_sub)
tdata_prior=tdata_dict[key+'_prior']


for alpha in alpha_ls:
	tdata_prior.compute_BCST(alpha=alpha,return_topo_CST=False,maxfreq_mSTreg=maxfreq_mSTreg,maxiter_mSTreg=maxiter_mSTreg,
							 demands=demands)

#### Plot trees

In [None]:

k1,k2=factor_number(len(alpha_ls))

separated = False
tdata_prior=tdata_dict[key+'_prior']


folder_figures=define_foldername(plant_type,plant_num,n)
if save:
	separated=True
	os.makedirs(folder_figures,True)


for counter_plot,alpha in enumerate(alpha_ls):

	T_flows=tdata_prior.trees['BCST_%0.2f'%alpha].get_T_weighted_by_flows()
	if power_width!=1:
		T_flows.data**=power_width
	T_flows*=max_width/T_flows.max()
	cost=tdata_prior.trees['BCST_%0.2f'%alpha].cost

	coords=tdata_prior.trees['BCST_%0.2f'%alpha].coords

	if separated:
		k1=k2=1
		counter_plot=0

	if save:
		title=''
	else:
		title='BCST %0.2f'%(alpha)
	fig, ax=plot_graph_3d(T_flows,coords,node_size=1,node_colors=None,linewidth_multiplier=1,
				  title=title,show_plot=False,k1=k1,k2=k2,counter_plot=counter_plot,
						  figsize=(15,12),elevation_angle=elevation_angle,
				  azimuthal_angle=azimuthal_angle)
	if save:
		plt.tight_layout()
		filename=define_filename(key,alpha,n,prior=True)
		plt.savefig(folder_figures+filename+'.png')

	if generate_gif:
		generate_rotation_gif(T_flows, coords, node_colors=None, node_size=1,
                       title='BCST %0.2f'%alpha, edge_color='red',
                       figsize=(15,12), num_frames=num_frames_gif, output_filename=folder_figures+filename+'.gif')
plt.tight_layout()
plt.show()