In [1]:
from cellseg_models_pytorch.datasets.folder_dataset_train import SegmentationFolderDataset
import pytorch_lightning as pl
import cellseg_models_pytorch as csmp
from cellseg_models_pytorch.training.lit import SegmentationExperiment
import warnings
from torch.utils.data import DataLoader

In [2]:
warnings.filterwarnings('ignore')

In [3]:
train_dir = './data/benchmarks/NuCLS-split/train/'
val_dir = './data/benchmarks/NuCLS-split/val/'
test_dir = './data/benchmarks/NuCLS-split/test/'

trainset = SegmentationFolderDataset(
    path=train_dir + 'rgb/',
    mask_path=train_dir + 'mask_mat/',
    img_transforms=["blur", "hue_sat"],
    inst_transforms=["cellpose"]
)
valset = SegmentationFolderDataset(
    path=val_dir + 'rgb/',
    mask_path=val_dir + 'mask_mat/',
    img_transforms=["blur", "hue_sat"],
    inst_transforms=["cellpose"]
)
testset = SegmentationFolderDataset(
    path=test_dir + 'rgb/',
    mask_path=test_dir + 'mask_mat/',
    img_transforms=["blur", "hue_sat"],
    inst_transforms=["cellpose"]
)

def od_collate_fn(batch):
    '''Stack images and targets in batches of consistant size and shape for object detection.

    Args:
        batch: List of (image, target) tuples.

    Returns:
        Tuple of stacked images and targets.
    '''
    return tuple(zip(*batch))

trainloader = DataLoader(trainset, batch_size=8, shuffle=True, num_workers=8, collate_fn=od_collate_fn)
valloader = DataLoader(valset, batch_size=8, shuffle=False, num_workers=8, collate_fn=od_collate_fn)

In [4]:
def get_cellpose_model(num_classes, enc_name="tf_efficientnetv2_s", max_epochs=10):
    """Get cellpose model
    enc_name (str): name of encoder. e.g. -> "tf_efficientnetv2_s"
    num_classes (int): number of classes. e.g. -> len(lizard_module.type_classes)
    """
    model = csmp.models.cellpose_base(
        enc_name=enc_name,
        type_classes=num_classes,
    )
    experiment = SegmentationExperiment(
        model=model,
        branch_losses={"cellpose": "ssim_mse", "type": "tversky_focal"},
        branch_metrics={"cellpose": [None], "type": ["miou"]},
        optimizer="adamw",
    )
    trainer = pl.Trainer(
        accelerator="gpu",
        max_epochs=max_epochs,
        move_metrics_to_cpu=True,
    )
    return experiment, trainer

def train_model(experiment, trainer, trainloader, valloader, ckpt_path=None):
    if ckpt_path is None:
        trainer.fit(experiment, trainloader, valloader)
        return None
    trainer.fit(experiment, 
                trainloader,
                valloader,
                ckpt_path=ckpt_path)
    return None

model, trainer = get_cellpose_model(num_classes=15, enc_name="tf_efficientnetv2_s", max_epochs=10)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [5]:
train_model(model, trainer, trainloader, valloader)

Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/2
Initializing distributed: GLOBAL_RANK: 1, MEMBER: 2/2
----------------------------------------------------------------------------------------------------
distributed_backend=nccl
All distributed processes registered. Starting with 2 processes
----------------------------------------------------------------------------------------------------

You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
LOCAL_RANK: 1 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name          | Type          | Params
------------------------------------------------
0 | model         | CellPoseUnet  | 23.5 M
1 | criterio

Sanity Checking: 0it [00:00, ?it/s]

ProcessRaisedException: 

-- Process 0 terminated with the following error:
Traceback (most recent call last):
  File "/home/zy45/anaconda3/envs/epathologist/lib/python3.9/site-packages/torch/multiprocessing/spawn.py", line 69, in _wrap
    fn(i, *args)
  File "/home/zy45/anaconda3/envs/epathologist/lib/python3.9/site-packages/pytorch_lightning/strategies/launchers/multiprocessing.py", line 139, in _wrapping_function
    results = function(*args, **kwargs)
  File "/home/zy45/anaconda3/envs/epathologist/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 650, in _fit_impl
    self._run(model, ckpt_path=self.ckpt_path)
  File "/home/zy45/anaconda3/envs/epathologist/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1112, in _run
    results = self._run_stage()
  File "/home/zy45/anaconda3/envs/epathologist/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1191, in _run_stage
    self._run_train()
  File "/home/zy45/anaconda3/envs/epathologist/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1204, in _run_train
    self._run_sanity_check()
  File "/home/zy45/anaconda3/envs/epathologist/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1276, in _run_sanity_check
    val_loop.run()
  File "/home/zy45/anaconda3/envs/epathologist/lib/python3.9/site-packages/pytorch_lightning/loops/loop.py", line 199, in run
    self.advance(*args, **kwargs)
  File "/home/zy45/anaconda3/envs/epathologist/lib/python3.9/site-packages/pytorch_lightning/loops/dataloader/evaluation_loop.py", line 152, in advance
    dl_outputs = self.epoch_loop.run(self._data_fetcher, dl_max_batches, kwargs)
  File "/home/zy45/anaconda3/envs/epathologist/lib/python3.9/site-packages/pytorch_lightning/loops/loop.py", line 199, in run
    self.advance(*args, **kwargs)
  File "/home/zy45/anaconda3/envs/epathologist/lib/python3.9/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py", line 137, in advance
    output = self._evaluation_step(**kwargs)
  File "/home/zy45/anaconda3/envs/epathologist/lib/python3.9/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py", line 234, in _evaluation_step
    output = self.trainer._call_strategy_hook(hook_name, *kwargs.values())
  File "/home/zy45/anaconda3/envs/epathologist/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1494, in _call_strategy_hook
    output = fn(*args, **kwargs)
  File "/home/zy45/anaconda3/envs/epathologist/lib/python3.9/site-packages/pytorch_lightning/strategies/ddp_spawn.py", line 288, in validation_step
    return self.model(*args, **kwargs)
  File "/home/zy45/anaconda3/envs/epathologist/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/zy45/anaconda3/envs/epathologist/lib/python3.9/site-packages/torch/nn/parallel/distributed.py", line 1040, in forward
    output = self._run_ddp_forward(*inputs, **kwargs)
  File "/home/zy45/anaconda3/envs/epathologist/lib/python3.9/site-packages/torch/nn/parallel/distributed.py", line 1000, in _run_ddp_forward
    return module_to_run(*inputs[0], **kwargs[0])
  File "/home/zy45/anaconda3/envs/epathologist/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/zy45/anaconda3/envs/epathologist/lib/python3.9/site-packages/pytorch_lightning/overrides/base.py", line 110, in forward
    return self._forward_module.validation_step(*inputs, **kwargs)
  File "/home/zy45/anaconda3/envs/epathologist/lib/python3.9/site-packages/cellseg_models_pytorch/training/lit/lightning_experiment.py", line 215, in validation_step
    return self.log_step(batch, batch_idx, "val")
  File "/home/zy45/anaconda3/envs/epathologist/lib/python3.9/site-packages/cellseg_models_pytorch/training/lit/lightning_experiment.py", line 182, in log_step
    res = self.step(batch, batch_idx, phase)
  File "/home/zy45/anaconda3/envs/epathologist/lib/python3.9/site-packages/cellseg_models_pytorch/training/lit/lightning_experiment.py", line 154, in step
    soft_masks = self.model(batch["image"])
TypeError: tuple indices must be integers or slices, not str
