# Training a new model
This Jupyter Notebook will teach you how to train a new model. The only requirements is that you already sampled your Ground Truth masks using the "format_dataset" notebook tutorial.

## Downloading the package

Make sure that the notebook is running with Python>=3.10 and with a version of PyTorch >=1.13 installed with CUDA available (necessary for training).

To verify if PyTorch and Cuda are installed, run the following cell.

In [None]:
import torch
print(f"Current version of Pytorch: {torch.__version__}")
print(f"Cuda working properly: {torch.cuda.is_available()}")
device = "cuda" if torch.cuda.is_available() else "cpu"

Current version of Pytorch: 2.5.1+cu124
Cuda working properly: True


If you have the good version of Torch and Cuda is working, you can run the following cell to install our package. Otherwise, fix your Python environment before proceeding.

In [None]:
!pip install nagini3D



In [None]:
from nagini3D.data.training_set import TrainingSet, custom_collate
from nagini3D.data.th_optim_set import OptimSet
from nagini3D.models.model import Nagini3D

## Configuration setting

Your train and validation datasets should be organized as follow:


```
-set name
 |-images (dir storing the images)
 |-masks  (dir storing the corresponding masks)
 |-samplings (dir storing the output of the format_dataset code: samplings and spots maps)
 ```

In [None]:
P = 101                   # number of points on the predicted surfaces (should be an odd number to match Fibonacci lattice requirements, should be <= at the number of points sampled on the GT masks)
M1 = 4                    # complexity parameter (see article for details)
M2 = 2                    # complexity parameter (see article for details)
batch_size = 1           # The algorithm is really memory consuming, try to keep the batch small to avoid CUDA out of memory, especially with high values of P (>>100)
nb_epochs = 1000

patch_size = 64          # Size of the patches cropped in your images to provide to the network
train_set_dir = ""       # Directory storing your training dataset (organized as mentionned below)
val_set_dir =   ""       # Directory storing your validation dataset (organized as mentionned below)
data_aug = True          # If True, random flip along x,y,z axis
cell_ratio_th = 0.02     # If >0.0, the patch with a proportion of object voxels smaller than the choosen value will be discarded
anisotropy = [1,1,1]

In [None]:
settings_cfg = {
    "experiment_name": "jupyter-nagini",
    "M1": M1,
    "M2": M2,
    "nb_points_on_surface": P,
    "nb_epochs": nb_epochs
}

data_cfg = {
    "patch_size": patch_size,
    "data_aug": data_aug,
    "train": train_set_dir,
    "val": val_set_dir
}

In [None]:
from torch.utils.data import DataLoader

In [None]:
train_set = TrainingSet(nb_points=P, patch_size=patch_size, dataset_dir=train_set_dir, data_aug=data_aug, cell_ratio_th=cell_ratio_th, anisotropy_ratio=anisotropy)
val_set = TrainingSet(nb_points=P, r_mean=train_set.r_mean,  patch_size=patch_size, data_aug=data_aug, dataset_dir=val_set_dir, cell_ratio_th=cell_ratio_th, anisotropy_ratio=anisotropy)

train_loader = DataLoader(train_set, collate_fn=custom_collate, batch_size=batch_size)
val_loader = DataLoader(val_set, collate_fn=custom_collate, batch_size=batch_size)

In [None]:
model_cfg = {
    "input_channels": 1,                      # number of input channels of your image
    "features_start": 32,                     # number of feature at the first layer of the U-Net
    "num_layers": 3,                          # number of layers
    "inner_normalisation": "BatchNorm",
    "train_bn": True,
    "padding_mode": "zeros",
    "bilinear": False
}

save_path = ""        # directory where you want to store the model weigths, the configuration file and the optimal thresholds file at the end of the training

model = Nagini3D(unet_cfg=model_cfg, P=P, M1=M1, M2=M2, device=device, save_path=save_path)

In [None]:
optim_cfg = {
    "name": "adam",
    "parameters": {
        "lr": 0.0001,
        "weight_decay": 0.0001
    }
}

# you can also use SGD optimizer by switching "adam" for "sgd" and precising the corresponding parameters

model.init_optimizer(optimizer_cfg=optim_cfg)

In [None]:
loss_cfg = {
  "epoch_reg_max": 200,  # int: number of epochs during which the regularization preventing surface twists is applied
  "reg_part": 0.0,       # float in [0,1]: weight the regularisation preventing surface twists, keep it zero for small values of M1 and M2 (typically 4,2), increase if you augment them.
  "lambdas":  {
      "snakes": 1.0,             # float: weight the surface loss
      "spots": 1.0,              # float: weight the probability loss used for center detection
      "regularization" : 0.001
  }
}

model.init_loss(loss_lambdas=loss_cfg["lambdas"], reg_part=loss_cfg["reg_part"], epoch_reg_max=loss_cfg["epoch_reg_max"])

In [None]:
save_cfg = {"path": save_path}
settings_cfg["r_mean"] = train_set.r_mean
cfg = {"settings": settings_cfg, "optimizer": optim_cfg, "model": model_cfg, "data": data_cfg, "settings": settings_cfg, "loss": loss_cfg}

model.save_config(cfg_dict=cfg)

## Training

In [None]:
best_val_score = float("inf")
epoch_best_val = -1

for epoch_idx in range(nb_epochs):
  print(f"Epoch {epoch_idx+1} / {nb_epochs}\nTraining ...")

  snake_ratio = model.update_snake_loss(epoch_idx) # if you added a regularization loss, i.e. "reg_part">0, this will slowly decrease the regularization part to be zero at "epoch_reg_max"

  loss = model.epoch(data_loader = train_loader)

  print(f"\nLoss : {loss['loss']}\nTesting ...")

  validation, _ = model.val(data_loader = val_loader, nb_cells_to_plot=4)

  print(f"\nAccuracy on validation set : {validation['loss']}")

  if validation["loss"] < best_val_score:
      model.save_model(f"best")
      best_val_score = validation["loss"]
      epoch_best_val = epoch_idx

cfg["save"] = {**cfg["save"], "best_epoch" : epoch_best_val}
model.save_config(cfg_dict=cfg)
model.save_model("final")

## Processing of optimal thresholds for inference computed on validation set

In [None]:
from os.path import join

model.load_weights(join(model.save_dir, "best.pkl"))
optim_set = OptimSet(cfg["data"]["val"])
opti_th = model.optimize_thresholds(optim_set, r_mean=train_set.r_mean, nb_tiles=(1,1,1))
model.save_th(th_dict=opti_th)