Skip to content

rpnode-fss/RPNODE_FSS

Repository files navigation

RPNODE_FSS

Robust Prototypical Few-Shot Organ Segmentation with Regularized Neural-ODEs


Installation and setup

To install this repository and its dependent packages, run the following.

git clone https://github.com/rpnode-fss/RPNODE_FSS.git
cd RPNODE_FSS
conda create --name RPNODE_FSS # (optional, for making a conda environment)
pip install -r requirements.txt

The processed datasets can be downloaded from here.

Some relevant trained model weights can be downloaded from here.

Change the paths to BCV, CT-ORG and Decathlon datasets in config.py and test_config.py according to paths on your local. Also change the path to ImageNet pretrained VGG model weights in these files.

Training

To train R-PNODE, run

python3 train.py with dataset=BCV model_name=<save-name> target=<train-target> n_shot=<shot> ode_layers=3 ode_time=4

Further parameters like the standard deviation of gaussian perturbation can be changed in the training config.

So, for example, to train the experiment BCV 1-shot with Spleen organ as the novel class, the command would be

python3 train.py with dataset=BCV model_name=bcv_1shot_spleen target=1 n_shot=1 ode_layers=3 ode_time=4

This will store model weights with the name bcv_1shot_spleen_tar1.pth in the model root directory. Please refer to the class mapping below to find which target index to use for which target class. Note that a single model is needed to test the method in both in-domain and cross-domain settings for a particular shot and target. Similarly, all different attacks are tested on a single trained model.

Testing

To test a trained model, run

python3 test_attacked.py with snapshot=<weights-path> target=<test-target> dataset=<BCV/CTORG/Decathlon> attack=<Clean/FGSM/PGD/SMIA> attack_eps=<eps> to_attack=<q/s>

Arguments for some particular settings are:

Setting Arguments
BCV in-domain 1-shot Liver dataset=BCV n_shot=1 target=6
BCV in-domain 3-shot Spleen dataset=BCV n_shot=3 target=1
BCV -> CT-ORG cross-domain 1-shot Liver dataset=CTORG n_shot=1 target=1
BCV -> Decathlon cross-domain 3-shot Liver dataset=Decathlon n_shot=3 target=2
BCV -> Decathlon cross-domain 1-shot Spleen dataset=Decathlon n_shot=1 target=6

Note particularly for the cross-domain settings that the target class index used during training may be different from that used during testing. Special care must be taken that models trained for particular target organs are tested for the same organs to avoid misleading results. Please refer to the class mapping at the end of the readme for exact target indices.

The possible options for the attack argument are:

  • clean (standard FSS, without any attack)
  • fgsm
  • pgd
  • smia
  • bim
  • cw
  • dag
  • auto

These are case insensitive, and using variants like FGSM, FGsm, fGsM will also lead to same effect.

This command can be used for testing on all settings, namely 1-shot and 3-shot, liver and spleen and Clean, FGSM, PGD, SMIA, BIM, CW, DAG and Auto-Attack with different epsilons.

Visualization

Visualization can be enabled by setting save_vis as True. The path where the visualisations will be saved can be modified here.


Acknowledgement

This repository adapts from BiGRU-FSS. 4 of the baselines are from here. The other baseline is the Neural-ODE based SONet. Attacks are adapted from well known libraries like CleverHans and TorchAttacks. Other attacks are adapted from their respective sources: SMIA, DAG and Auto-attack. We thank the authors for their awesome works.

Class Mapping

BCV:
    Liver: 6
    Spleen: 1
CT-ORG: 
    Liver: 1
Decathlon: 
    Liver: 2
    Spleen: 6

About

Robust Prototypical Few-Shot Organ Segmentation with Regularized Neural-ODEs

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages