In [1]:
import json
import os

import numpy as np
import torch
from EmbedSeg.train import begin_training
from EmbedSeg.utils.create_dicts import (
    create_configs,
    create_dataset_dict,
    create_loss_dict,
    create_model_dict,
)
from matplotlib.colors import ListedColormap

### Specify the path to `train`, `val` crops and the type of `center` embedding which we would like to train the network for:

The train-val images, masks and center-images will be accessed from the path specified by `data_dir` and `project-name`.
<a id='center'></a>

In [2]:
data_dir = "/cluster/project/treutlein/DATA/imaging/EmbedSeg_test/data/"
project_name = "3D_Brain_organoids_with_meta"
run_name = "all_03_03_2023"

data_dir = data_dir + "/" + project_name + f"/crops_{run_name}"

# project_name = '3D_Brain_organoids'
center = "medoid"  # 'centroid', 'medoid'

print(
    "Project Name chosen as : {}. \nTrain-Val images-masks-center-images will be accessed from : {}".format(
        project_name, data_dir
    )
)

Project Name chosen as : 3D_Brain_organoids_with_meta. 
Train-Val images-masks-center-images will be accessed from : /cluster/project/treutlein/DATA/imaging/EmbedSeg_test/data//3D_Brain_organoids_with_meta/crops_all_03_03_2023


In [3]:
try:
    assert center in {"medoid", "centroid"}
    print("Spatial Embedding Location chosen as : {}".format(center))
except AssertionError as e:
    e.args += ('Please specify center as one of : {"medoid", "centroid"}', 42)
    raise

Spatial Embedding Location chosen as : medoid


### Obtain properties of the dataset 

Here, we read the `dataset.json` file prepared in the `01-data` notebook previously.

In [4]:
if os.path.isfile(f"data_properties_{run_name}.json"):
    with open(f"data_properties_{run_name}.json") as json_file:
        data = json.load(json_file)
        (
            data_type,
            foreground_weight,
            n_z,
            n_y,
            n_x,
            pixel_size_z_microns,
            pixel_size_x_microns,
        ) = (
            data["data_type"],
            float(data["foreground_weight"]),
            int(data["n_z"]),
            int(data["n_y"]),
            int(data["n_x"]),
            float(data["pixel_size_z_microns"]),
            float(data["pixel_size_x_microns"]),
        )

In [5]:
n_x = 600

In [6]:
n_y = 600

In [7]:
n_z = 80

### Specify training dataset-related parameters

Some hints: 
* The `train_size` attribute indicates the number of image-mask paired examples which the network would see in one complete epoch. Ideally this should be the number of `train` image crops. 

In the cell after this one, a `train_dataset_dict` dictionary is generated from the parameters specified here!

In [8]:
train_size = len(os.listdir(os.path.join(data_dir, project_name, "train", "images")))
train_batch_size = 8
# virtual_batch_multiplier = 16

### Create the `train_dataset_dict` dictionary  

In [9]:
train_dataset_dict = create_dataset_dict(
    data_dir=data_dir,
    project_name=project_name,
    center=center,
    size=train_size,
    batch_size=train_batch_size,
    # virtual_batch_multiplier = virtual_batch_multiplier,
    type="train",
    name="3d",
)

`train_dataset_dict` dictionary successfully created with: 
 -- train images accessed from /cluster/project/treutlein/DATA/imaging/EmbedSeg_test/data//3D_Brain_organoids_with_meta/crops_all_03_03_2023/3D_Brain_organoids_with_meta/train/images, 
 -- number of images per epoch equal to 5512, 
 -- batch size set at 8, 


### Specify validation dataset-related parameters

Some hints:
* The size attribute indicates the number of image-mask paired examples which the network would see in one complete epoch. Here, it is recommended to set `val_size` equal to the total number of validation image crops.

In the cell after this one, a `val_dataset_dict` dictionary is generated from the parameters specified here!

In [10]:
val_size = len(os.listdir(os.path.join(data_dir, project_name, "val", "images")))
val_batch_size = 16

### Create the `val_dataset_dict` dictionary

In [11]:
val_dataset_dict = create_dataset_dict(
    data_dir=data_dir,
    project_name=project_name,
    center=center,
    size=val_size,
    batch_size=val_batch_size,
    type="val",
    name="3d",
)

`val_dataset_dict` dictionary successfully created with: 
 -- val images accessed from /cluster/project/treutlein/DATA/imaging/EmbedSeg_test/data//3D_Brain_organoids_with_meta/crops_all_03_03_2023/3D_Brain_organoids_with_meta/val/images, 
 -- number of images per epoch equal to 338, 
 -- batch size set at 16, 


### Specify model-related parameters

Some hints:
* Set the `input_channels` attribute equal to the number of channels in the input images. 
* Set the `num_classes = [6, 1]` for `3d` training and `num_classes = [4, 1]` for `2d` training
<br>(here, 6 implies the offsets and bandwidths in x, y and z dimensions and 1 implies the `seediness` value per pixel)

In the cell after this one, a `model_dataset_dict` dictionary is generated from the parameters specified here!

In [12]:
input_channels = 1
num_classes = [6, 1]

### Create the `model_dict` dictionary

In [13]:
model_dict = create_model_dict(
    input_channels=input_channels, num_classes=num_classes, name="3d"
)

`model_dict` dictionary successfully created with: 
 -- num of classes equal to 1, 
 -- input channels equal to [6, 1], 
 -- name equal to branched_erfnet_3d


### Create the `loss_dict` dictionary

In [14]:
loss_dict = create_loss_dict(n_sigma=3, foreground_weight=foreground_weight)

`loss_dict` dictionary successfully created with: 
 -- foreground weight equal to 48.009, 
 -- w_inst equal to 1, 
 -- w_var equal to 10, 
 -- w_seed equal to 1


### Specify additional parameters 

Some hints:
* The `n_epochs` attribute determines how long the training should proceed. In general for reasonable results, you should atleast train for longer than 50 epochs.
* The `save_dir` attribute identifies the location where the checkpoints and loss curve details are saved. 
* If one wishes to **resume training** from a previous checkpoint, they could point `resume_path` attribute appropriately. For example, one could set `resume_path = './experiment/Mouse-Organoid-Cells-CBG-demo/checkpoint.pth'` to resume training from the last checkpoint.


In [15]:
n_epochs = 200
save_dir = os.path.join("experiment", project_name + "-" + run_name)
resume_path = None

In the cell after this one, a `configs` dictionary is generated from the parameters specified here!
<a id='resume'></a>

### Create the  `configs` dictionary 

In [16]:
configs = create_configs(
    n_epochs=n_epochs,
    resume_path=resume_path,
    save_dir=save_dir,
    n_z=n_z,
    n_y=n_y,
    n_x=n_x,
    # train_lr=5e-4,
    anisotropy_factor=pixel_size_z_microns / pixel_size_x_microns,
)

`configs` dictionary successfully created with: 
 -- n_epochs equal to 210, 
 -- save_dir equal to experiment/3D_Brain_organoids_with_meta-all_03_03_2023_resumed_v3, 
 -- n_z equal to 80, 
 -- n_y equal to 600, 
 -- n_x equal to 600, 


### Begin training!

Executing the next cell would begin the training. 

In [None]:
begin_training(train_dataset_dict, val_dataset_dict, model_dict, loss_dict, configs)

3-D `train` dataloader created! Accessing data from /cluster/project/treutlein/DATA/imaging/EmbedSeg_test/data//3D_Brain_organoids_with_meta/crops_all_03_03_2023/3D_Brain_organoids_with_meta/train/
Number of images in `train` directory is 5512
Number of instances in `train` directory is 5512
Number of center images in `train` directory is 5512
*************************
3-D `val` dataloader created! Accessing data from /cluster/project/treutlein/DATA/imaging/EmbedSeg_test/data//3D_Brain_organoids_with_meta/crops_all_03_03_2023/3D_Brain_organoids_with_meta/val/
Number of images in `val` directory is 338
Number of instances in `val` directory is 338
Number of center images in `val` directory is 338
*************************
Creating Branched Erfnet 3D with [6, 1] outputs
initialize last layer with size:  torch.Size([16, 6, 2, 2, 2])
Created spatial emb loss function with: n_sigma: 3, foreground_weight: 48.00866531190462
*************************
Created logger with keys:  ('train', 'val',

100%|██████████| 689/689 [11:43<00:00,  1.02s/it]
100%|██████████| 22/22 [00:18<00:00,  1.19it/s]


===> train loss: 0.96
===> val loss: 0.85, val iou: 0.66
=> saving checkpoint
Starting epoch 123
learning rate: 0.00022622467159583027


100%|██████████| 689/689 [11:57<00:00,  1.04s/it]
100%|██████████| 22/22 [00:17<00:00,  1.29it/s]


===> train loss: 0.87
===> val loss: 0.84, val iou: 0.67
=> saving checkpoint
Starting epoch 124
learning rate: 0.0002238830656952329


100%|██████████| 689/689 [11:53<00:00,  1.04s/it]
100%|██████████| 22/22 [00:16<00:00,  1.30it/s]


===> train loss: 0.83
===> val loss: 0.84, val iou: 0.67
=> saving checkpoint
Starting epoch 125
learning rate: 0.00022153873534893717


 21%|██        | 142/689 [02:23<08:52,  1.03it/s]

<div class="alert alert-block alert-warning"> 
  Common causes for errors during training, may include : <br>
    1. Not having <b>center images</b> for  <b>both</b> train and val directories  <br>
    2. <b>Mismatch</b> between type of center-images saved in <b>01-data.ipynb</b> and the type of center chosen in this notebook (see the <b><a href="#center"> center</a></b> parameter in the third code cell in this notebook)   <br>
    3. In case of resuming training from a previous checkpoint, please ensure that the model weights are read from the correct directory, using the <b><a href="#resume"> resume_path</a></b> parameter. Additionally, please ensure that the <b>save_dir</b> parameter for saving the model weights points to a relevant directory. 
</div>

In [None]:
checkpoint_path = os.path.join(
    "experiment", project_name + "-" + run_name, "best_iou_model.pth"
)
if os.path.isfile("data_properties.json"):
    with open(os.path.join("data_properties_{run_name}.json")) as json_file:
        data = json.load(json_file)
        one_hot = data["one_hot"]
        data_type = data["data_type"]
        min_object_size = int(data["min_object_size"])
        # foreground_weight = float(data['foreground_weight'])
        # n_z, n_y, n_x = int(data['n_z']),int(data['n_y']), int(data['n_x'])
        pixel_size_z_microns, pixel_size_y_microns, pixel_size_x_microns = (
            float(data["pixel_size_z_microns"]),
            float(data["pixel_size_y_microns"]),
            float(data["pixel_size_x_microns"]),
        )
        # mask_start_x, mask_start_y, mask_start_z = 700,700,160
        # mask_end_x, mask_end_y, mask_end_z =  800,800,200
if os.path.isfile(f"normalization_{run_name}.json"):
    with open(os.path.join(f"normalization_{run_name}.json")) as json_file:
        data = json.load(json_file)
        norm = data["norm"]

In [None]:
if os.path.exists(checkpoint_path):
    print("Trained model weights found at : {}".format(checkpoint_path))
else:
    print("Trained model weights were not found at the specified location!")

In [None]:
tta = True
ap_val = 0.5
data_dir = "/cluster/project/treutlein/DATA/imaging/EmbedSeg_test/data/"
project_name = "3D_Brain_organoids"
save_dir = data_dir + "/" + project_name + f"/3D_one_image_per_day_{run_name}"

In [None]:
save_dir

In [None]:
from EmbedSeg.test import begin_evaluating, test_3d
from EmbedSeg.utils.create_dicts import create_test_configs_dict

test_configs = create_test_configs_dict(
    data_dir=os.path.join(data_dir, project_name),
    checkpoint_path=checkpoint_path,
    tta=tta,
    ap_val=ap_val,
    min_object_size=min_object_size,
    save_dir=save_dir,
    norm=norm,
    data_type=data_type,
    n_z=n_z,
    n_y=n_y,
    type="3D_one_image_per_day",
    n_x=n_x,
    anisotropy_factor=pixel_size_z_microns / pixel_size_x_microns,
    name="3d",
    seed_thresh=0.7,
    fg_thresh=0.4,
    expand_grid=False,
)
begin_evaluating(test_configs)

In [None]:
tta = True
ap_val = 0.5
data_dir = "/cluster/project/treutlein/DATA/imaging/EmbedSeg_test/data/"

save_dir = data_dir + "/" + project_name + f"/3D_one_image_per_day_AGAR_{run_name}"

In [None]:
from EmbedSeg.test import begin_evaluating, test_3d
from EmbedSeg.utils.create_dicts import create_test_configs_dict

test_configs = create_test_configs_dict(
    data_dir=os.path.join(data_dir, project_name),
    checkpoint_path=checkpoint_path,
    tta=tta,
    ap_val=ap_val,
    min_object_size=min_object_size,
    save_dir=save_dir,
    norm=norm,
    data_type=data_type,
    n_z=n_z,
    n_y=n_y,
    type="3D_one_image_per_day_AGAR",
    n_x=n_x,
    anisotropy_factor=pixel_size_z_microns / pixel_size_x_microns,
    name="3d",
    seed_thresh=0.7,
    fg_thresh=0.4,
    expand_grid=False,
)
begin_evaluating(test_configs)