# Please read [this](https://github.com/stardist/stardist/tree/conic-2022/examples/conic-2022/README.md) first

This notebook demonstrates how we trained models for the 2022 [*Colon Nuclei Identification and Counting (CoNIC)* challenge](https://conic-challenge.grand-challenge.org).

Please see [our paper](https://arxiv.org/abs/2203.02284) for more details.

In [1]:
import os
from csbdeep.utils.tf import limit_gpu_memory

# you may need to adjust this to your GPU needs and memory capacity

# os.environ['CUDA_VISIBLE_DEVICES'] = ...
# limit_gpu_memory(0.8, total_memory=24000)

limit_gpu_memory(None, allow_growth=True)

In [2]:
import numpy as np
from types import SimpleNamespace
from sklearn.model_selection import train_test_split

from stardist import gputools_available
from stardist.models import Config2D, StarDist2D

from conic import get_data, oversample_classes, CLASS_NAMES

from conic import HEStaining, HueBrightnessSaturation
from augmend import (
    Augmend,
    AdditiveNoise,
    Augmend,
    Elastic,
    FlipRot90,
    GaussianBlur,
    Identity,
)

In [3]:
def get_class_count(Y0):
    class_count = np.bincount(Y0[:,::4,::4,1].ravel())
    try:
        import pandas as pd
        df = pd.DataFrame(class_count, index=CLASS_NAMES.values(), columns=["counts"])
        df = df.drop("BACKGROUND")
        df["%"] = (100 * (df["counts"] / df["counts"].sum())).round(2)
        display(df)
    except ModuleNotFoundError:
        print("install 'pandas' to show class counts")
    return class_count

## Configuration

In [4]:
args = SimpleNamespace()

# data in
args.datadir     = "./data" # path to 'Patch-level Lizard Dataset' as provided by CoNIC organizers
args.oversample  = True     # oversample training patches with rare classes
args.frac_val    = 0.1      # fraction of data used for validation during training
args.seed        = None     # for reproducible train/val data sets

# model out (parameters as used for our challenge submissions)
args.modeldir    = "./models"
args.epochs      = 1000
args.batchsize   = 4
args.n_depth     = 4
args.lr          = 3e-4
args.patch       = 256
args.n_rays      = 64
args.grid        = (1,1)
args.head_blocks = 2
args.augment     = True
args.cls_weights = False

args.workers     = 1
args.gpu_datagen = False and args.workers==1 and gputools_available() # note: ignore potential scikit-tensor error

vars(args)

{'datadir': './data',
 'oversample': True,
 'frac_val': 0.1,
 'seed': None,
 'modeldir': './models',
 'epochs': 1000,
 'batchsize': 4,
 'n_depth': 4,
 'lr': 0.0003,
 'patch': 256,
 'n_rays': 64,
 'grid': (1, 1),
 'head_blocks': 2,
 'augment': True,
 'cls_weights': False,
 'workers': 1,
 'gpu_datagen': False}

In [6]:
# FOR DEMO PURPOSES ONLY: reduce model size and training time
args.epochs      = 20
args.n_depth     = 3
args.n_rays      = 32
args.grid        = (2,2)

## Data

We only use the [Patch-level Lizard Dataset](https://drive.google.com/drive/folders/1il9jG7uA4-ebQ_lNmXbbF2eOK9uNwheb) provided by the [CoNIC challenge](https://conic-challenge.grand-challenge.org) organizers.

In [7]:
%ls -sh1 $args.datadir

total 2.2G
 20K by-nc-sa.md
 72K counts.csv
4.0K dl.txt
934M images.npy
1.3G labels.npy
 68K patch_info.csv
4.0K README.txt


In [8]:
X, Y, D, Y0, idx = get_data(args.datadir, seed=args.seed)
X, Xv, Y, Yv, D, Dv, Y0, Y0v, idx, idxv = train_test_split(X, Y, D, Y0, idx, test_size=args.frac_val, random_state=args.seed)
class_count = get_class_count(Y0)

  0%|          | 0/4981 [00:00<?, ?it/s]

Unnamed: 0,counts,%
Neutrophil,22572,0.77
Epithelial,1895959,64.44
Lymphocyte,340118,11.56
Plasma,101651,3.45
Eosinophil,18598,0.63
Connective,563428,19.15


In [9]:
if args.oversample:
    X, Y, D, Y0, idx = oversample_classes(X, Y, D, Y0, idx, seed=args.seed)
    class_count = get_class_count(Y0)

oversample classes [5 1 4 3]
adding 4482 images of class 5 (Eosinophil)
adding 4068 images of class 1 (Neutrophil)
adding 1917 images of class 4 (Plasma)
adding 1048 images of class 3 (Lymphocyte)


Unnamed: 0,counts,%
Neutrophil,547807,4.4
Epithelial,5480641,44.02
Lymphocyte,2459795,19.76
Plasma,730444,5.87
Eosinophil,311316,2.5
Connective,2920271,23.46


In [10]:
if args.cls_weights:
    inv_freq = np.median(class_count) / class_count
    inv_freq = inv_freq ** 0.5
    class_weights = inv_freq.round(4)
else:
    class_weights = np.ones(len(CLASS_NAMES))
print(f"class weights = {class_weights.tolist()}")

class weights = [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]


In [11]:
print(f'training images: {len(X)}, validation images: {len(Xv)}')

training images: 15997, validation images: 499


## Augmentation

In [12]:
if args.augment:
    aug = Augmend()
    aug.add([HEStaining(amount_matrix=0.15, amount_stains=0.4), Identity()], probability=0.9)

    aug.add([FlipRot90(axis=(0,1)), FlipRot90(axis=(0,1))])
    aug.add([Elastic(grid=5, amount=10, order=1, axis=(0,1), use_gpu=False),
             Elastic(grid=5, amount=10, order=0, axis=(0,1), use_gpu=False)], probability=0.8)
    
    aug.add([GaussianBlur(amount=(0,2), axis=(0,1), use_gpu=False), Identity()], probability=0.1)    
    aug.add([AdditiveNoise(0.01), Identity()], probability=0.8)
    
    aug.add([HueBrightnessSaturation(hue=0, brightness=0.1, saturation=(1,1)), Identity()], probability=0.9)

    def augmenter(x,y):
        return aug([x,y])    
else:
    augmenter = None

## StarDist setup

In [13]:
conf = Config2D(
    n_rays                = args.n_rays,
    grid                  = args.grid,
    n_channel_in          = X.shape[-1],
    n_classes             = len(CLASS_NAMES)-1,
    use_gpu               = args.gpu_datagen,

    backbone              = 'unet',
    unet_n_filter_base    = 64,
    unet_n_depth          = args.n_depth,
    head_blocks           = args.head_blocks, 
    net_conv_after_unet   = 256,

    train_batch_size      = args.batchsize,
    train_patch_size      = (args.patch, args.patch),
    train_epochs          = args.epochs,
    train_steps_per_epoch = 1024 // args.batchsize,
    train_learning_rate   = args.lr, 
    train_loss_weights    = (1.0, 0.2, 1.0),
    train_class_weights   = class_weights.tolist(),
    train_background_reg  = 0.01,
    train_reduce_lr       = {'factor': 0.5, 'patience': 80, 'min_delta': 0},
)

vars(conf)

{'n_dim': 2,
 'axes': 'YXC',
 'n_channel_in': 3,
 'n_channel_out': 33,
 'train_checkpoint': 'weights_best.h5',
 'train_checkpoint_last': 'weights_last.h5',
 'train_checkpoint_epoch': 'weights_now.h5',
 'n_rays': 32,
 'grid': (2, 2),
 'backbone': 'unet',
 'n_classes': 6,
 'unet_n_depth': 3,
 'unet_kernel_size': (3, 3),
 'unet_n_filter_base': 64,
 'unet_n_conv_per_depth': 2,
 'unet_pool': (2, 2),
 'unet_activation': 'relu',
 'unet_last_activation': 'relu',
 'unet_batch_norm': False,
 'unet_dropout': 0.0,
 'unet_prefix': '',
 'net_conv_after_unet': 256,
 'head_blocks': 2,
 'net_input_shape': (None, None, 3),
 'net_mask_shape': (None, None, 1),
 'train_shape_completion': False,
 'train_completion_crop': 32,
 'train_patch_size': (256, 256),
 'train_background_reg': 0.01,
 'train_foreground_only': 0.9,
 'train_sample_cache': True,
 'train_dist_loss': 'mae',
 'train_loss_weights': (1.0, 0.2, 1.0),
 'train_class_weights': [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
 'train_epochs': 20,
 'train_steps_

### Create model

In [14]:
model = StarDist2D(conf, name='conic', basedir=args.modeldir)

Using default values: prob_thresh=0.5, nms_thresh=0.4.


## Training

In [15]:
model.train(X, Y, classes=D, validation_data=(Xv, Yv, Dv), augmenter=augmenter, workers=args.workers)

Epoch 1/20
Cause: Unable to locate the source code of <function _gcd_import at 0x7fb1dc72f430>. Note that functions defined in certain environments, like the interactive Python shell, do not expose their source code. If that is the case, you should define them in a .py source file. If you are certain the code is graph-compatible, wrap the call using @tf.autograph.experimental.do_not_convert. Original error: could not get source code


Cause: Unable to locate the source code of <function _gcd_import at 0x7fb1dc72f430>. Note that functions defined in certain environments, like the interactive Python shell, do not expose their source code. If that is the case, you should define them in a .py source file. If you are certain the code is graph-compatible, wrap the call using @tf.autograph.experimental.do_not_convert. Original error: could not get source code


Cause: Unable to locate the source code of <function _gcd_import at 0x7fb1dc72f430>. Note that functions defined in certain environments, like the interactive Python shell, do not expose their source code. If that is the case, you should define them in a .py source file. If you are certain the code is graph-compatible, wrap the call using @tf.autograph.experimental.do_not_convert. Original error: could not get source code
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20


Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20

Loading network weights from 'weights_best.h5'.


<keras.callbacks.History at 0x7fb11c7bef10>

In [16]:
model.optimize_thresholds(Xv, Yv, nms_threshs=[0.1, 0.2, 0.3])

NMS threshold = 0.1:  80%|████████  | 16/20 [00:55<00:13,  3.45s/it, 0.499 -> 0.572]
NMS threshold = 0.2:  80%|████████  | 16/20 [00:54<00:13,  3.41s/it, 0.499 -> 0.573]
NMS threshold = 0.3:  80%|████████  | 16/20 [00:55<00:13,  3.46s/it, 0.499 -> 0.573]


Using optimized values: prob_thresh=0.497882, nms_thresh=0.3.
Saving to 'thresholds.json'.


{'prob': 0.49788182973861694, 'nms': 0.3}