## MULTI DAPI VAE in PYRO

### IMPORT NECESSARY MODULES

In [3]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [4]:
# FROM CLI: conda install -c nvidia -c rapidsai -c numba -c conda-forge -c defaults numba=0.48.0 cugraph cudatoolkit=10.1

In [5]:
%matplotlib inline
#%matplotlib notebook
import matplotlib.pyplot as plt
from IPython.core.display import display, HTML, Image
from mpl_toolkits.mplot3d import Axes3D

from MODULES.utilities import *
from MODULES.vae_model import *

In [6]:


import torch
import pyro
#from pyro.infer import SVI, Trace_ELBO #,TraceEnum_ELBO, TraceGraph_ELBO, config_enumerate, JitTraceEnum_ELBO 

# Set up pyro environment
pyro.clear_param_store()
pyro.set_rng_seed(0)

# Check versions
from platform import python_version
print(python_version())
print("pyro.__version__  --> ",pyro.__version__)
print("torch.__version__ --> ",torch.__version__)

3.8.2
pyro.__version__  -->  1.3.0
torch.__version__ -->  1.4.0


In [7]:
mask_file = "/Users/ldalessi/DAPI_unsupervised/spacetx-research/merfish_june22_v2/BIG_edges_tiling.pkl"
tmp = torch.load(mask_file, map_location=torch.device('cpu'))
edges_tiling, img_to_segment = tmp
print(mask_file)

RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU.

In [None]:
plt.imshow(edges_tiling[12].cpu())

In [None]:
plot_grid(edges_tiling.cpu()[:,1000:1300,1000:1300], figsize=(20,20))

# Cut the graph

In [None]:
from platform import python_version
print("python_version -->",python_version())
print("cuda_version -->")
!nvcc --version


In [None]:
import leidenalg as la

la.__version__

In [None]:
import functools 
import leidenalg as la
import igraph as ig
from typing import NamedTuple

def plot_segmentation(mask, raw_img, figsize=None):
    figure, axes = plt.subplots(ncols=3, figsize=figsize)
    axes[0].imshow(skimage.color.label2rgb(mask, np.zeros_like(mask), alpha=1.0, bg_label=0))
    axes[1].imshow(skimage.color.label2rgb(mask, raw_img, alpha=0.25, bg_label=0))
    axes[2].imshow(raw_img, cmap='gray')
    
def plot_resolution(label_list):
    N = len(label_list)
    MAX_row = N//4
    
    figure, axes = plt.subplots(ncols=4,nrows=MAX_row, figsize=(30, 30))
    for n in range(4*MAX_row):
        row = n//4
        col = n % 4
        axes[row,col].imshow(skimage.color.label2rgb(label_list[n], np.zeros_like(label_list[n]), alpha=1.0, bg_label=0))

class adjacency(NamedTuple):
    v: list
    i: list
    j: list
        
class GraphSegmentation():
    """ Takes many integer segmentation masks and produce a consensus segmentation mask.
        It does so by producing a graph in which each node is a foreground pixel and each edge
        is the number of times two pixel are segmented in the same object. 
        
        Typical usage:
        consensus = ConsensusSegmentation(integer_segmentation_masks)
        mask = consensus.mask()
    """
    
    def __init__(self, edges):
        super().__init__()
        
        # size = (2*r+1)*(2*r+1), w, h. Each channel contains the edges between pixel_i and pixel_j
        assert len(edges.shape) == 3 
        
        self.edges = edges
        self.device = self.edges.device
        N, self.nx, self.ny = self.edges.shape
        self.radius_nn = int((np.sqrt(N) -1) // 2)
        print("radius_nn ->",self.radius_nn)
        
        self.ch_edge_ii = (N -1)//2
        print("ch_e_ii -->",self.ch_edge_ii)
        
        self.fg_mask = self.edges[self.ch_edge_ii] > 0.5
        
        ix_matrix, iy_matrix = torch.meshgrid([torch.arange(self.nx, dtype=int, device=self.device), 
                                               torch.arange(self.ny, dtype=int, device=self.device)])
        self.x_coordinate_fg_pixel = ix_matrix[self.fg_mask]
        self.y_coordinate_fg_pixel = iy_matrix[self.fg_mask]
        self.n_fg_pixel = self.x_coordinate_fg_pixel.shape[0]
        self.index_array = torch.arange(self.n_fg_pixel, dtype=int, device=self.device)
        self.index_matrix = -1*torch.ones_like(ix_matrix)
        self.index_matrix[self.x_coordinate_fg_pixel, 
                          self.y_coordinate_fg_pixel] = self.index_array
        
        print("n_fg_pixel -->",self.n_fg_pixel)
        
        self.adj = self._build_adjacency()
        
        self.graph = self._build_graph(self.adj)
        
        
    def _build_adjacency(self):
        ch = -1
        w_list, i_list, j_list = [],[],[]
        for dx in range(-self.radius_nn, self.radius_nn + 1):
            index_matrix_tmp = torch.roll(self.index_matrix, dx, dims=-2)
            for dy in range(-self.radius_nn, self.radius_nn + 1):
                index_matrix_shifted = torch.roll(index_matrix_tmp, dy, dims=-1)
                
                w = self.edges[ch][self.fg_mask]
                i = self.index_matrix[self.fg_mask]
                j = index_matrix_shifted[self.fg_mask]
                
                
                w_tmp = w[w>0.01] 
                i_tmp = i[w>0.01] 
                j_tmp = j[w>0.01] 
                
                w_list += w_tmp[j_tmp>=0].cpu().numpy().tolist()
                i_list += i_tmp[j_tmp>=0].cpu().numpy().tolist()
                j_list += j_tmp[j_tmp>=0].cpu().numpy().tolist()
                
        return adjacency(v=w_list, i=i_list, j=j_list)
        
    
    def _build_graph(self, adj):
        
        vertex_list = [n for n in range(self.n_fg_pixel)]
        edgelist = list(zip(adj.i, adj.j))
        
        graph = ig.Graph(vertex_attrs={"label":vertex_list}, edges=edgelist, directed=False)
        graph.es['weight'] = adj.v
        return graph
    
    @functools.lru_cache(maxsize=10)
    def find_profile(self, resolution_range=(0.01,0.5)):
        optimiser = la.Optimiser()
        profile = optimiser.resolution_profile(self.graph, la.CPMVertexPartition,
                                               resolution_range=resolution_range,
                                               weights=self.graph.es['weight'])
        return profile

    @functools.lru_cache(maxsize=10)
    def find_partition(self, resolution):
        partition = la.find_partition(self.graph, la.CPMVertexPartition, 
                                      resolution_parameter = resolution,
                                      weights=self.graph.es['weight'])
        return partition

    
    def profile_to_list_of_masks(self,profile):
        mask_list = []
        for n,partition in enumerate(profile):
            mask = self.partition_to_mask(partition)
            mask_list.append(mask)
        return mask_list

    def partition_to_mask(self, partition, size_threshold=10):
        
        instace_IDs = torch.tensor(partition.membership, device=self.device) + 1 # +1 b/c label_bg=0, label_fg=1,2,...
        
        for n,size in enumerate(partition.sizes()):
            if size < size_threshold:
                tmp = (instace_IDs == n+1)
                instace_IDs[tmp] = 0   # small community are set to bg value
                
        mask = torch.zeros_like(self.index_matrix)
        mask[self.x_coordinate_fg_pixel, self.y_coordinate_fg_pixel] = instace_IDs
        return mask.cpu().numpy()

In [None]:
# import cugraph
# import cudf
# 
# #Create dataframe on GPU
# df = cudf.DataFrame()
# df['v'] = g.adj.v
# df['i'] = g.adj.i
# df['j'] = g.adj.j
# 
# # Crete graph on GPU
# G = cugraph.Graph()
# G.from_cudf_edgelist(df, source='i', destination='j', edge_attr='v')
# 
# # Run Louvain on the graph
# df_partition, mod = cugraph.louvain(G)
# print('Modularity was {}'.format(mod))
# df_partition.head()
# 
# # TAkes the partition
# ainstance_IDs = np.array(df_partition["partition"]) # use torch.tensor so that things stay on GPU

In [None]:
plt.imshow(img_to_segment[0,1000:1250,2200:2450].cpu())

In [None]:
g_small = GraphSegmentation(edges_tiling[:,1000:1250,2200:2450])
g_large = GraphSegmentation(edges_tiling)
plt.imshow(img_to_segment[0].cpu())

In [None]:
partition_015 = g_small.find_partition(resolution=0.015)
partition_02 = g_small.find_partition(resolution=0.02)
partition_01 = g_small.find_partition(resolution=0.01)
#print(partition)
mask_015 = g_small.partition_to_mask(partition_03, size_threshold=10)
mask_02 = g_small.partition_to_mask(partition_02, size_threshold=10)
mask_01 = g_small.partition_to_mask(partition_01, size_threshold=10)

In [None]:
plot_segmentation(mask_01, img_to_segment[0, 1000:1250,2200:2450].cpu(), figsize=(20,20))

In [None]:
plot_segmentation(mask_02, img_to_segment[0, 1000:1250,2200:2450].cpu(), figsize=(20,20))

In [None]:
plot_segmentation(mask_015, img_to_segment[0, 1000:1250,2200:2450].cpu(), figsize=(20,20))

In [None]:
plot_resolution(labels_list[:16])

In [None]:
plot_grid(labels)

# Check the segmentation results

In [None]:
#seg_mask = vae.segment_with_tiling(train_loader.x[...,2000:2400,2000:2400], 
#                                   crop_w=80, crop_h=80, 
#                                   stride_w=60, stride_h=60, n_objects_max_per_patch=10)

x,y,index = test_loader.load(batch_size=8)
seg_mask = vae.segment(x)

vae.eval()
output_test = vae.forward(x,
                          draw_image=True,
                          draw_boxes=True,
                          verbose=False)

print(x.shape, seg_mask.shape, output_test.imgs.shape)

In [None]:
output_test.inference.prob[...,0]

In [None]:
show_batch(output_test.imgs)

In [None]:
chosen=6
figure, axes = plt.subplots(ncols=3, figsize=(24, 24))
axes[0].imshow(x[chosen,0].cpu(), cmap='gray')
axes[1].imshow(skimage.color.label2rgb(skimage.img_as_ubyte(seg_mask[chosen,0].cpu()), x[chosen,0].cpu(), alpha=0.25, bg_label=0))
axes[2].imshow(output_test.imgs[chosen,0].cpu(), cmap='hot')

# Check the results

In [None]:
train_metrics

In [None]:
for k,v in history_dict.items():
    if k.startswith("train"):
        print(k," -->", history_dict[k][-3:])

In [None]:
#plt.yscale('log')
y_shift=0
x_shift=0
sign=1
plt.plot(np.arange(x_shift, x_shift+len(history_dict["train_loss"])), 
         sign*np.array(history_dict["train_loss"])+y_shift,'-')
plt.plot(np.arange(x_shift, x_shift+len(history_dict["test_loss"])*TEST_FREQUENCY,TEST_FREQUENCY), 
         sign*np.array(history_dict["test_loss"])+y_shift, '.--')
plt.xlabel('epoch')
plt.ylabel('LOSS = - ELBO')
plt.title('Training procedure')
#plt.ylim(ymax=4, ymin=0)
plt.grid(True)
plt.legend(['train', 'test'])
#plt.show()

fig_file = os.path.join(dir_output, "train.png")
plt.savefig(fig_file)

In [None]:
plt.plot(np.arange(0,len(history_dict["train_length_GP"])), history_dict["train_length_GP"], '-', label="train")
plt.plot(np.arange(0,len(history_dict["test_length_GP"])*TEST_FREQUENCY,TEST_FREQUENCY), history_dict["test_length_GP"], 'x', label="test")
plt.title('LENGTH GP')
plt.xlabel('epoch')
plt.ylabel('lenght_GP')
plt.legend()
plt.grid(True)

fig_file = os.path.join(dir_output, "lenght_GP.png")
plt.savefig(fig_file)

In [None]:
fig, ax1 = plt.subplots()

color = 'tab:red'
ax1.set_xlabel('epochs')
ax1.set_ylabel('fg fraction av', color=color)
ax1.plot(np.arange(0, len(history_dict["train_fg_fraction"])),
         history_dict["train_fg_fraction"], 'o', color=color, label="train")
ax1.plot(np.arange(0, len(history_dict["test_fg_fraction"])*TEST_FREQUENCY, TEST_FREQUENCY),
         history_dict["test_fg_fraction"], 'x-', color=color, label="test")

ymin=min(params["GECO"]["target_fg_fraction"])
ymax=max(params["GECO"]["target_fg_fraction"])
ax1.plot(ymin*np.ones(len(history_dict["train_fg_fraction"])), '-', color='black', label="y_min")
ax1.plot(ymax*np.ones(len(history_dict["train_fg_fraction"])), '-', color='black', label="y_max")

ax1.tick_params(axis='y', labelcolor=color)
ax1.grid()
#ax1.set_ylim([1000,1870])
plt.legend()
ax2 = ax1.twinx()  # instantiate a second axes that shares the same x-axis

color = 'tab:blue'
ax2.set_ylabel('accuracy', color=color)  # we already handled the x-label with ax1
ax2.plot(np.arange(0, len(history_dict["train_accuracy"])),
         history_dict["train_accuracy"],'x', color=color)
ax2.plot(np.arange(0, len(history_dict["test_accuracy"])*TEST_FREQUENCY, TEST_FREQUENCY),
         history_dict["test_accuracy"],'-', color=color)
ax2.tick_params(axis='y', labelcolor=color)
ax2.grid()
#ax2.set_ylim([0.97,1.0])

fig.tight_layout()  # otherwise the right y-label is slightly clipped
fig_file = os.path.join(dir_output, "accuracy.png")
plt.savefig(fig_file)

In [None]:
f = plt.figure(figsize=(20,20))
ax = f.add_subplot(321)
ax2 = f.add_subplot(322)
ax3 = f.add_subplot(323)
ax4 = f.add_subplot(324)
ax5 = f.add_subplot(325)
ax6 = f.add_subplot(326)
epoch_min, epoch_max = 200, None


loss = np.array(history_dict["train_loss"])
kl_instance = np.array(history_dict["train_kl_instance"])
kl_where = np.array(history_dict["train_kl_where"])
kl_logit = np.array(history_dict["train_kl_logit"])
kl_raw = np.array(history_dict["train_kl_tot"])
nll_raw = np.array(history_dict["train_nll"])
reg_raw = np.array(history_dict["train_reg"])
sparsity_raw = np.array(history_dict["train_sparsity"])
overlap_raw = np.array(history_dict["train_cost_overlap"])
f_geco_sparsity = np.array(history_dict["train_geco_sparsity"])
f_geco_balance = np.array(history_dict["train_geco_balance"])


ax.plot(sparsity_raw,'-',label='sparsity_raw')
ax.plot(overlap_raw,'-',label='cost overlap raw')
ax.set_xlim([epoch_min, epoch_max])
ax.set_ylim([None, 1.01*max(max(sparsity_raw[epoch_min:epoch_max]), max(overlap_raw[epoch_min:epoch_max]))])
ax.grid()
ax.legend()

ax2.plot(kl_instance,'-',label='kl instance raw')
ax2.plot(kl_where,'-',label='kl zwhere raw')
ax2.set_xlim([epoch_min, epoch_max])
ax2.set_ylim([0, 1.01*max(max(kl_instance[epoch_min:epoch_max]),max(kl_where[epoch_min:epoch_max]))])
ax2.grid()
ax2.legend()

ax3.plot(kl_logit,'-',label='kl logit raw')
#ax3.set_ylim([0,1.1])
ax3.set_xlim([epoch_min, epoch_max])
ax3.set_ylim([None, 1.01*max(kl_logit[epoch_min:epoch_max])])
ax3.grid()
ax3.legend()


ax4.plot(f_geco_sparsity ,'x-',label='geco_sparsity')
#ax4.set_ylim([0,100])
ax4.set_xlim([epoch_min, epoch_max])
ax4.set_ylim([None, 1.01*max(f_geco_sparsity[epoch_min:epoch_max])])
ax4.grid()
ax4.legend()

ax5.plot(f_geco_balance ,'x-',label='geco_balance')
ax5.set_xlim([epoch_min, epoch_max])
ax5.set_ylim([None, 1.01*max(f_geco_balance[epoch_min:epoch_max])])
ax5.grid()
ax5.legend()

ax6.plot(loss,'-',label='loss')
ax6.plot(f_geco_sparsity * sparsity_raw,'x-',label='scaled_sparsity')
ax6.plot(f_geco_balance * reg_raw,'x-',label='scaled_reg')
ax6.plot(f_geco_balance * nll_raw,'x-',label='scaled_nll')
ax6.plot((1-f_geco_balance) * kl_raw,'x-',label='scaled_kl')
ax6.set_ylim([0, 1.01*max(loss[epoch_min:epoch_max])])
ax6.set_xlim([epoch_min, epoch_max])
ax6.grid()
ax6.legend()

fig.tight_layout()  # otherwise the right y-label is slightly clipped
fig_file = os.path.join(dir_output, "metrics.png")
plt.savefig(fig_file)

In [None]:
params["GECO"]

In [None]:
fontsize=20
labelsize=20
f = plt.figure(figsize=(20,20))
ax1 = f.add_subplot(311)
ax2 = f.add_subplot(312)
ax3 = f.add_subplot(313)
epoch_min, epoch_max = 0, None


#-----------------------------------

color = 'tab:red'
ax1.set_xlabel('epochs', fontsize=fontsize)
ax1.set_ylabel('fg_fraction', fontsize=fontsize, color=color)
ax1.tick_params(axis='both', which='major', labelsize=labelsize)
ax1.plot(history_dict["train_fg_fraction"], '.--', color=color, label="n_object")
ax1.set_xlim([epoch_min, epoch_max])
ymin=min(params["GECO"]['target_fg_fraction'])
ymax=max(params["GECO"]['target_fg_fraction'])
ax1.plot(ymin*np.ones(len(history_dict["train_fg_fraction"])), '-', color='black', label="y_min")
ax1.plot(ymax*np.ones(len(history_dict["train_fg_fraction"])), '-', color='black', label="y_max")
ax1.tick_params(axis='y', labelcolor=color)
ax1.grid()

ax1b = ax1.twinx()  # instantiate a second axes that shares the same x-axis
color = 'tab:green'
ax1b.set_xlabel('epochs', fontsize=fontsize)
ax1b.set_ylabel('geco_sparsity', color=color, fontsize=fontsize)
ax1b.tick_params(axis='both', which='major', labelsize=labelsize)
plt.plot(history_dict["train_geco_sparsity"],'-',label="geco_sparsity",color=color)
ax1b.tick_params(axis='y', labelcolor=color)
ax1b.grid()

##------------------------------------

color = 'tab:red'
ax2.set_xlabel('epochs', fontsize=fontsize)
ax2.set_ylabel('nll av', fontsize=fontsize, color=color)
ax2.tick_params(axis='both', which='major', labelsize=labelsize)
ax2.plot(history_dict["train_nll"], '.--', color=color, label="nll av")
ax2.set_xlim([epoch_min, epoch_max])

ymin=min(params["GECO"]["target_nll"])
ymax=max(params["GECO"]["target_nll"])
ax2.plot(ymin*np.ones(len(history_dict["train_nll"])), '-', color='black', label="y_min")
ax2.plot(ymax*np.ones(len(history_dict["train_nll"])), '-', color='black', label="y_max")
ax2.tick_params(axis='y', labelcolor=color)

ax2.grid()

ax2b = ax2.twinx()  # instantiate a second axes that shares the same x-axis
color = 'tab:green'
ax2b.set_xlabel('epochs', fontsize=fontsize)
ax2b.set_ylabel('geco_balance', fontsize=fontsize, color=color)
plt.plot(history_dict["train_geco_balance"],'-',label="geco_balance",color=color)
ax2b.tick_params(axis='both', which='major', labelsize=labelsize)
ax2b.tick_params(axis='y', labelcolor=color)
ax2b.grid()

##------------------------------------

color = 'tab:red'
ax3.set_xlabel('epochs', fontsize=fontsize)
ax3.set_ylabel('delta_1', fontsize=fontsize, color=color)
ax3.tick_params(axis='both', which='major', labelsize=labelsize)
ax3.plot(history_dict["train_delta_1"], '.--', color=color, label="delta_1")
ax3.set_xlim([epoch_min, epoch_max])


ax3b = ax3.twinx()  # instantiate a second axes that shares the same x-axis
color = 'tab:green'
ax3b.set_xlabel('epochs', fontsize=fontsize)
ax3b.set_ylabel('delta_2', fontsize=fontsize, color=color)
plt.plot(history_dict["train_delta_2"],'-',label="delta_2",color=color)
ax3b.tick_params(axis='y', labelcolor=color)
ax3b.tick_params(axis='both', which='major', labelsize=labelsize)
ax3b.grid()

#-----------------------------------

fig.tight_layout()  # otherwise the right y-label is slightly clipped
fig_file = os.path.join(dir_output, "geco.png")
plt.savefig(fig_file)

In [None]:
# Plot of KL vs evidence
fontsize=20
labelsize=20

epoch_min, epoch_max = 0, 2500
scale= 1
N = len(history_dict["train_nll"][epoch_min:epoch_max])
colors = np.arange(0.0,N,1.0)/N

f = plt.figure(figsize=(20,10))
ax1 = f.add_subplot(221)
ax2 = f.add_subplot(222)
ax3 = f.add_subplot(223)
ax4 = f.add_subplot(224, projection='3d')

ax1.set_xlabel('NLL',fontsize=fontsize)
ax1.set_ylabel('KL',fontsize=fontsize)
ax1.tick_params(axis='both', which='major', labelsize=labelsize)
ax1.scatter(history_dict["train_nll"][epoch_min:epoch_max], history_dict["train_kl_tot"][epoch_min:epoch_max],c=colors)
ax1.plot(history_dict["train_nll"][epoch_min:epoch_max], history_dict["train_kl_tot"][epoch_min:epoch_max], '--')
ax1.grid()
#ax1.set_xlim(xmax=2.5)

ax2.set_xlabel('SPARSITY',fontsize=fontsize)
ax2.set_ylabel('NLL',fontsize=fontsize)
ax2.tick_params(axis='both', which='major', labelsize=labelsize)
ax2.scatter(history_dict["train_sparsity"][epoch_min:epoch_max], history_dict["train_nll"][epoch_min:epoch_max], c=colors)
ax2.plot(history_dict["train_sparsity"][epoch_min:epoch_max], history_dict["train_nll"][epoch_min:epoch_max], '--')
ax2.grid()
#ax2.set_xlim(xmax=2.5)

ax3.set_xlabel('SPARSITY',fontsize=fontsize)
ax3.set_ylabel('KL',fontsize=fontsize)
ax3.tick_params(axis='both', which='major', labelsize=labelsize)
ax3.scatter(history_dict["train_sparsity"][epoch_min:epoch_max], history_dict["train_kl_tot"][epoch_min:epoch_max], c=colors)
ax3.plot(history_dict["train_sparsity"][epoch_min:epoch_max], history_dict["train_kl_tot"][epoch_min:epoch_max], '--')
ax3.grid()
#ax3.set_xlim(xmax=2.5)


ax4.scatter(history_dict["train_kl_tot"][epoch_min:epoch_max],
         history_dict["train_sparsity"][epoch_min:epoch_max],
         history_dict["train_nll"][epoch_min:epoch_max], c=colors )

ax4.plot(history_dict["train_kl_tot"][epoch_min:epoch_max],
         history_dict["train_sparsity"][epoch_min:epoch_max],
         history_dict["train_nll"][epoch_min:epoch_max], '--', label='training')
ax4.set_xlabel('kl_tot', fontsize=fontsize)
ax4.set_ylabel('sparsity', fontsize=fontsize)
ax4.set_zlabel('nll', fontsize=fontsize)
ax4.legend(prop={'size':fontsize})

fig.tight_layout()  # otherwise the right y-label is slightly clipped
fig_file = os.path.join(dir_output, "nll_vs_kll_vs_sparsity.png")
plt.savefig(fig_file)

# Run one epoch in eval mode

In [None]:
#epoch=100
#load_model_optimizer(path=os.path.join(dir_output, "ckp_"+str(epoch)+".pkl"), model=vae)

vae.eval()
with torch.no_grad():
    test_metrics = process_one_epoch(model=vae, 
                                     dataloader=test_loader)
    print(test_metrics)

# Check the error

In [None]:
ref_img_pkl = os.path.join(dir_output, "reference.pkl")
tmp_list = [0, 1, 2,3,4,5,6,7,8,9]
#tmp_list = [255, 148, 291, 310, 2,3,4,5,6,7,8,9,10]
#tmp_list = [425, 411, 61, 194, 91, 384, 339, 54, 336]

reference_imgs, labels, index =test_loader.load(index=torch.tensor(tmp_list[:9]))
save_obj(reference_imgs, ref_img_pkl)

reference_imgs = load_obj(ref_img_pkl)
b = show_batch(reference_imgs[:],n_col=3,n_padding=4,title="REFERENCE")

ref_img_png = os.path.join(dir_output, "reference.png")
b.savefig(ref_img_png)
display(b)

In [None]:
vae.geco_dict

In [None]:
chosen=0
with torch.no_grad():
    print("")
    print("")
    print("--- eval mode ---")
    vae.eval()
    output_test = vae.forward(reference_imgs[:],
                              draw_image=True,
                              draw_boxes=True,
                              verbose=True)
    
    print("")
    print("")
    print("--- train mode ---")
    vae.train()
    output_train = vae.forward(reference_imgs[:],
                               draw_image=True,
                               draw_boxes=True,
                               verbose=True)

In [None]:
pmap_train = show_batch(output_train.inference.p_map, n_col=3,n_padding=4,title="Train Prob MAP")
pmap_test = show_batch(output_test.inference.p_map, n_col=3,n_padding=4,title="Test Prob MAP")

counts_train = torch.sum(output_train.inference.prob>0.5,dim=0).view(-1).cpu().numpy().tolist()
rec_train = show_batch(output_train.imgs[:],n_col=3,n_padding=4,title="# rec train "+str(counts_train))

counts_test = torch.sum(output_test.inference.prob>0.5,dim=0).view(-1).cpu().numpy().tolist()
rec_test = show_batch(output_test.imgs[:],n_col=3,n_padding=4,title="# rec test "+str(counts_test))

background = show_batch(output_train.inference.big_bg,n_col=3,n_padding=4,title="BACKGROUND")
reference = show_batch(reference_imgs[:],n_col=3,n_padding=4,title="REFERENCE")

background.savefig(os.path.join(dir_output, "background.png"))
reference.savefig(os.path.join(dir_output, "reference.png"))
rec_test.savefig(os.path.join(dir_output, "rec_test.png"))
rec_train.savefig(os.path.join(dir_output, "rec_train.png"))
pmap_test.savefig(os.path.join(dir_output, "pmap_test.png"))
pmap_train.savefig(os.path.join(dir_output, "pmap_train.png"))

In [None]:
display(background, reference)

In [None]:
print(output_train.inference.p_map.sum(dim=(-1,-2,-3)).cpu())
print(output_test.inference.p_map.sum(dim=(-1,-2,-3)).cpu())

In [None]:
display(rec_train,reference)
display(rec_test,reference)

In [None]:
display(pmap_train,reference)
display(pmap_test,reference)

In [None]:
plt.imshow(output_train.inference.p_map[chosen,0].cpu().numpy())
_ = plt.colorbar()
print(torch.topk(output_train.inference.p_map[chosen,0].view(-1), k=10, largest=True, sorted=True)[0])

In [None]:
plt.imshow(output_test.inference.p_map[chosen,0].cpu().numpy())
_ = plt.colorbar()
print(torch.topk(output_test.inference.p_map[chosen,0].view(-1), k=10, largest=True, sorted=True)[0])

In [None]:
_ = plt.hist(output_train.inference.p_map[0,0].view(-1).cpu().numpy(), density=True, bins=50, label="pmap_train")
_ = plt.hist(output_test.inference.p_map[0,0].view(-1).cpu().numpy(), density=True, bins=50, label="pmap_test")
plt.legend()
plt.savefig(os.path.join(dir_output, "hist_pmap.png"))

# Visualize one chosen image in details

In [None]:
output = output_train
how_many_to_show=20
counts = torch.sum(output.inference.prob>0.5,dim=0).view(-1).cpu().numpy().tolist()
prob_tmp = np.round(output.inference.prob[:how_many_to_show,chosen].view(-1).cpu().numpy(),decimals=4)*10000
prob_title = (prob_tmp.astype(int)/10000).tolist()
print("counts ->",counts[chosen]," prob ->",prob_title)

In [None]:
tmp1 = reference_imgs[chosen]
tmp2 = torch.sum(output.inference.big_img[:how_many_to_show,chosen],dim=0)
tmp3 = torch.sum(output.inference.big_mask[:how_many_to_show,chosen],dim=0)
mask_times_imgs = output.inference.big_mask * output.inference.big_img
tmp4 = torch.sum(mask_times_imgs[:how_many_to_show,chosen],dim=0)
print("sum big_masks", torch.max(tmp3))
print("sum big_masks * big_imgs", torch.max(tmp4))
combined = torch.stack((tmp1,tmp2,tmp3,tmp4),dim=0)
print(combined.shape)
b = show_batch(combined, n_col=2, title="# ref, IMGS, MASKS, IMGS*MASKS", figsize=(24,24))
b.savefig(os.path.join(dir_output, "ref_img_mask.png"))
display(b)

In [None]:
print(torch.min(output.inference.big_mask[:how_many_to_show,chosen]), torch.max(output.inference.big_mask[:how_many_to_show,chosen]))
show_batch(output.inference.big_mask[:how_many_to_show,chosen], n_col=4, title="# MASKS", figsize=(24,24))

In [None]:
b = show_batch(reference_imgs[chosen]+output.inference.big_mask[:how_many_to_show,chosen], 
               n_col=3, n_padding=4,title="# MASKS over REF, p="+str(prob_title), figsize=(24,24))
b.savefig(os.path.join(dir_output, "mask_over_ref.png"))
display(b)

In [None]:
b = show_batch(reference_imgs[chosen]+10*output.inference.big_img[:how_many_to_show,chosen], 
               n_col=4, n_padding=4,title="# IMGS over REF, p="+str(prob_title), figsize=(24,24), normalize_range=(0,1))
b.savefig(os.path.join(dir_output, "imgs_over_ref.png"))
display(b)

In [None]:
output.inference.prob.shape

In [None]:
prob =  output.inference.prob[:,chosen, None, None, None]
b_mask = output.inference.big_mask[:,chosen]
b_img = output.inference.big_img[:,chosen]
b_combined = b_img * b_mask * prob
tmp = torch.cat((b_mask, b_img, b_combined), dim=0)
b = show_batch(tmp, n_col=tmp.shape[0]//3, n_padding=4, title="# mask, imgs, product. p="+str(prob_title), figsize=(24,24))
b.savefig(os.path.join(dir_output, "mask_imgs_product.png"))
display(b)

### Show the probability map

In [None]:
_ = plt.imshow(output.inference.p_map[chosen,0].cpu().numpy())
_ = plt.colorbar()
plt.savefig(os.path.join(dir_output, "pmap_chosen.png"))

# MAKE MOVIE

### Test

In [None]:
epoch="xxx"
a = show_batch(reference_imgs[:9],n_col=3,n_padding=4,title="REFERENCE")
b = show_batch(output.inference.p_map[:9],n_col=3,n_padding=4,title="EPOCH = "+str(epoch))
c = show_batch(output.inference.big_bg[:9],n_col=3,n_padding=4,title="EPOCH = "+str(epoch))
d = show_batch(output.imgs[:9],n_col=3,n_padding=4,title="EPOCH = "+str(epoch))

display(a,b,c,d)

# actual loop

In [None]:
for epoch in range(0,50000,5):
    if(epoch<10):
        label ="0000"+str(epoch)
    elif(epoch<100):
        label = "000"+str(epoch)
    elif(epoch<1000):
        label = "00"+str(epoch)
    elif(epoch<10000):
        label = "0"+str(epoch)
    elif(epoch<100000):
        label = str(epoch)
    else:
        raise Exception
    

    try:
        ckpt_file = os.path.join(dir_output, "ckp_"+str(epoch)+".pkl")
                     
        ckpt = load_ckpt(path=ckpt_file, device=None)
                                 
        load_model_optimizer(ckpt=ckpt, 
                             model=vae,
                             optimizer=None)
        
        print("epoch, label, prob_cor_factor ->",epoch,label,vae.prob_corr_factor)
        vae.train()
        with torch.no_grad():
            output = vae.forward(reference_imgs,
                                 draw_image=True,
                                 draw_boxes=True,
                                 verbose=False)
        
        b=show_batch(output.imgs[:8],n_col=4,n_padding=4,title="EPOCH = "+str(epoch))
        b.savefig(os.path.join(dir_output, 'movie_rec_'+label+'.png'), bbox_inches='tight')
        
        b=show_batch(output.inference.p_map[:8],n_col=4,n_padding=4,title="EPOCH = "+str(epoch), normalize_range=None)
        b.savefig(os.path.join(dir_output, 'movie_map_'+label+'.png'), bbox_inches='tight') 
        
        b=show_batch(output.inference.big_bg[:8],n_col=4,n_padding=4,title="EPOCH = "+str(epoch))
        b.savefig(os.path.join(dir_output, 'movie_bg_'+label+'.png'), bbox_inches='tight') 
        

    except:
        pass

## Make sorted list of image files so that I can create the movie

In [None]:
dir_output

In [None]:
rec_filenames = glob.glob(dir_output+"/movie_rec*.png")
map_filenames = glob.glob(dir_output+"/movie_map*.png")
bg_filenames = glob.glob(dir_output+"/movie_bg*.png")

rec_filenames.sort()
map_filenames.sort()
bg_filenames.sort()
print(rec_filenames)
print(map_filenames)
print(bg_filenames)

In [None]:
def show_frame_rec(n):
    return Image(filename=rec_filenames[n])

def show_frame_map(n):
    return Image(filename=map_filenames[n])

def show_frame_bg(n):
    return display.Image(filename=bg_filenames[n])

def show_frame_all(n):
    try:
        a = Image(filename=bg_filenames[n])
        b = Image(filename=map_filenames[n])
        c = Image(filename=rec_filenames[n])
        return display(a,b,c)
    except IndexError:
        print("list index out of range")
        pass

In [None]:
# make a gif file
movie_rec = os.path.join(dir_output, "movie_rec.gif")
movie_map = os.path.join(dir_output, "movie_map.gif")
movie_bg = os.path.join(dir_output, "movie_bg.gif")

frame_per_second = 2
im = mpy.ImageSequenceClip(rec_filenames, fps=frame_per_second)
im.write_gif(movie_rec, fps=frame_per_second)

im = mpy.ImageSequenceClip(map_filenames, fps=frame_per_second)
im.write_gif(movie_map, fps=frame_per_second)

im = mpy.ImageSequenceClip(bg_filenames, fps=frame_per_second)
im.write_gif(movie_bg, fps=frame_per_second)

### Show the movies

In [None]:
show_batch(reference_imgs[:8],n_col=4,n_padding=4,title="REFERENCE")

In [None]:
HTML("<img src="+movie_rec+"></img>")

In [None]:
HTML("<img src="+movie_map+"></img>")

In [None]:
show_batch(reference_imgs[:8],n_col=4,n_padding=4,title="REFERENCE")

In [None]:
HTML("<img src="+movie_bg+"></img>")

### Look at few frames

In [None]:
show_frame_all(0)

In [None]:
show_frame_all(1)

In [None]:
show_batch(reference_imgs[:8],n_col=4,n_padding=4,title="REFERENCE")

In [None]:
show_frame_all(10)

In [None]:
show_frame_all(11)

In [None]:
show_frame_all(15)

In [None]:
show_frame_all(20)

In [None]:
show_frame_all(21)

In [None]:
show_frame_all(22)

In [None]:
show_frame_all(23)

In [None]:
show_frame_all(24)

In [None]:
show_batch(reference_imgs[:8],n_col=4,n_padding=4,title="REFERENCE")

# FINAL CHECK 1

In [None]:
imgs_in_tmp, labels, index = train_loader.load(batch_size=8)
auch = vae.generate(imgs_in=imgs_in_tmp[:1], draw_bounding_box=False)

pmap_gen = show_batch(auch.inference.p_map[:8], title="generated p_map")
imgs_gen = show_batch(auch.imgs[:8], title="generated imgs")
display(pmap_gen, imgs_gen)

In [None]:
big_mask = auch.inference.big_mask[:,0]
big_img = auch.inference.big_img[:,0]
tmp = torch.cat((big_mask, big_img),dim=0)
print(auch.inference.prob)
show_batch(tmp, n_col=tmp.shape[0]//2, title="masks and imgs", figsize=(24,24))

In [None]:
print(auch.inference.big_mask.shape)
fg_mask = torch.sum(auch.inference.big_mask, dim=0)
img = torch.sum(auch.inference.big_img, dim=0)
print(fg_mask.shape)

figure, axes = plt.subplots(ncols=3, figsize=(24, 24))
axes[0].imshow(fg_mask[0,0].cpu(), cmap='gray')
axes[1].imshow(img[0,0].cpu(), cmap='gray')
axes[2].imshow((fg_mask[0,0]*img[0,0]).cpu(), cmap='gray')

# FINAL CHECK

In [None]:
#seg_mask = vae.segment_with_tiling(train_loader.x[...,2000:2400,2000:2400], 
#                                   crop_w=80, crop_h=80, 
#                                   stride_w=60, stride_h=60, n_objects_max_per_patch=10)

x,y,index = test_loader.load(batch_size=8)
seg_mask = vae.segment(x)

vae.eval()
output_test = vae.forward(x,
                          draw_image=True,
                          draw_bounding_box=True,
                          verbose=False)

print(x.shape, seg_mask.shape, output_test.imgs.shape)

In [None]:
show_batch(x, figsize=(24,24))

In [None]:
show_batch(output_test.imgs, figsize=(24,24))

In [None]:
figure, axes = plt.subplots(ncols=3, figsize=(24, 24))
axes[0].imshow(x[0,0].cpu(), cmap='gray')
axes[1].imshow(skimage.color.label2rgb(skimage.img_as_ubyte(seg_mask[0,0].cpu()), x[0,0].cpu(), alpha=0.25, bg_label=0))
axes[2].imshow(output_test.imgs[0,0].cpu(), cmap='hot')

In [None]:
vae.segment_with_tile()