In [None]:
%load_ext autoreload
%autoreload 2

# Import

In [None]:
import wandb
wandb.login()

import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks.progress import RichProgressBar
from pytorch_lightning.loggers import WandbLogger
from sav.datamodule import DatamoduleSAV
from sav.module.fs_segmenter import FewShotSegmenter

# Set config

In [None]:
args = dict(# checkpoint
            seed=0, 
            num_epoch=1,
            checkpoint_path='results/test',
            model_name='test',
            version='0',
            precision_for_training=16,
    
            # model
            backbone='vgg16', 
            optimizer='adam', 
            learning_rate=1e-4, 
            weight_decay=1e-5,
    
            # datamodule
            datapath='demo_data/train',
            nshot=3,
            nsamples=500,
            contrast=(0.5,1.5),
            rotation_degrees=90.0,
            scale=(0.25,1.0),
            crop_size=256,
            val_data_ratio=0.15,
            batch_size=5,
            n_cpu=4,
            )

# Initialise logger and callback

In [None]:
# Set Logger
# logger = WandbLogger(save_dir="lightning_logs",
#                      project="slice-and-view_avgpool"
#                      )

# Set checkpoints paths
checkpoint_callback = ModelCheckpoint(
                            save_top_k=5,
                            monitor="val/val_loss",
                            mode="min",
                            dirpath=args['checkpoint_path'],
                            filename= args['model_name'] + "-{epoch:02d}",
                         )

# Initialise datamodule and module

In [None]:
pl.seed_everything(args['seed'])

datamodule = DatamoduleSAV(datapath=args['datapath'],
                           nshot=args['nshot'],
                           nsamples=args['nsamples'],
                           contrast=args['contrast'],
                           rotation_degrees=args['rotation_degrees'],
                           scale=args['scale'],
                           crop_size=args['crop_size'],
                           val_data_ratio=args['val_data_ratio'],
                           batch_size=args['batch_size'],
                           n_cpu=args['n_cpu'])

model = FewShotSegmenter(backbone=args['backbone'],
                         optimizer=args['optimizer'],
                         learning_rate=args['learning_rate'],
                         weight_decay=args['weight_decay'])

# Initialise trainer and start training

In [None]:
trainer = pl.Trainer(max_epochs=args['num_epoch'], 
                     callbacks=[RichProgressBar(),checkpoint_callback],
                    #  logger=logger, 
                     precision=args['precision_for_training'],
                     accelerator='gpu', 
                     devices=1)

trainer.fit(model, datamodule)

# Load model checkpoint from the checkpoint (optional)

In [None]:
model = FewShotSegmenter(backbone=args['backbone'],
                         optimizer=args['optimizer'],
                         learning_rate=args['learning_rate'],
                         weight_decay=args['weight_decay']
                        ).load_from_checkpoint('results/test/checkpoints/test/last.ckpt') 

# Segment a single image

In [None]:
from sav.utils.annotator import Annotator
from torchvision import transforms

model.to(device="cuda")
annotator = Annotator(model=model.to(device="cuda"),
                      down_sampling=4,
                      patch_width= 224,
                      patch_height= 224,
                      margin=32,
                      batch_size= 1,
                      keep_dim=True)

out = annotator(query_img_path = "demo_data/evaluation/query_set/A3D_cauliflower_0001-NLM0001.tiff",
                support_imgs_dir = "demo_data/evaluation/support_set/cauliflower/image",
                support_annots_dir = "demo_data/evaluation/support_set/cauliflower/annotation",
               )

In [None]:
import matplotlib.pyplot as plt
import numpy as np

fig, axs = plt.subplots(1,1,figsize=(4,3),dpi=300)
axs.imshow(out['raw'],cmap='gray')
axs.imshow(np.where(out['annot']>0.5,1,0), alpha=0.3, cmap='cividis')
axs.axis("off")