Skip to content

wenhui0206/Robust-TransBTS

Repository files navigation

Semi-supervised Learning using Robust Loss

This is the implementation of the paper: Semi-supervised Learning using Robust Loss

This repo is forked from the official implementation for TransBTS: Multimodal Brain Tumor Segmentation Using Transformer. The multimodal brain tumor datasets (BraTS 2019 & BraTS 2020) could be acquired from here.

Python Scripts

  • python -m torch.distributed.launch --nproc_per_node=2 train_cv.py --corrupt_r=0.5 --train_partial=False --beta=0.001 --experiment='test_run' --fold=0
  • python /scratch1/wenhuicu/robust_seg/TransBTS/validation.py --test_file='model_epoch_last.pth' --valid_file='test_list.txt' --submission='' --experiment='test_run_f0' --csv_name='test_run_f0.csv'

TransBTS

TransBTS Architecture of 3D TransBTS.

Requirements

  • python 3.7
  • pytorch 1.6.0
  • torchvision 0.7.0
  • pickle
  • nibabel

Data preprocess

After downloading the dataset from here, data preprocessing is needed which is to convert the .nii files as .pkl files and realize date normalization.

python3 preprocess.py

Training

Run the training script on BraTS dataset. Distributed training is available for training the proposed TransBTS, where --nproc_per_node decides the numer of gpus and --master_port implys the port number.

  • train_cv.py is the training file mainly used for baseline training, CE loss, and Robust loss training, and performs 3-fold cross validation. To train on different folds, specify the fold number.

  • train_main.py is same as train_cv.py except that it does not have cross-validation part.

  • train_cps.py implements cross pseudo supervision. For comparison with robust loss, it is based on train_cv.py

python -m torch.distributed.launch --nproc_per_node=2 train_cv.py

Testing

Run python validation.py

  • validation.py is the one used for performance evaluation. It calculates Dice Scores and Hausdorff Distance. Results are saved in a csv file, you can use calc_mean_var() function in plot.py file to calculate mean dices across 3 folds.

  • predict.py has the code for actual model evaluation and metric calculation. We use validate_performance() function to calculate dices and hd, and save the mean in a csv file separately for each fold. compare_performance() is used to generate predicted segmentation maps and save them if specify --submission argument. Also, compare_performance saves all dice scores of each subject in a txt file for later analysis.

Model Architecture

  • TransBTS_downsample8x_skipconnection_lw.py is the one uses one layer in the transformer module, and half-sized hidden layer (last flatten layer)

Reference

1.setr-pytorch

2.BraTS2017

About

Performing semi-supervised learning for brain lesion segmentation by introducing a robust loss

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages