# Training the Models

This notebook can be used to train and reproduce the nnU-Net models from our ground-truth annotations for the IXI dataset.

With the dataset complete, you can train your model using the nnU-Net Framework (v2), which can be checked out from its [Github](https://github.com/MIC-DKFZ/nnUNet) repository or installed from [PyPI](https://pypi.org/project/nnunetv2/).
**IMPORTANT:** Make sure to set up pytorch and CUDA correctly for your system before installing nnU-Net. 
Otherwise, the default pytorch package will be installed from PyPI and likely not be compatible with your local GPU. 
Please refer to the Installation [Instructions](https://github.com/MIC-DKFZ/nnUNet/blob/master/documentation/installation_instructions.md) of nnU-Net.

The following cell installs Pytorch 2.1.0 with CUDA 11.8 in your Conda-based Jupyter environment, you may need to adapt this to your system.

In [None]:
!conda install pytorch torchvision torchaudio pytorch-cuda=11.8 -c pytorch -c nvidia
!pip install nnunetv2==2.2.2

The following cell sets up the paths for nnu-net:

In [None]:
import os
ixi_raw = '<path-to/IXI>'                           # set to the folder containing the extracted IXI data
ds_raw = '<path-to/Dataset600_IXI>'                 # set to the Dataset600_IXI folder checked out with this repository
nnu_preprocessed = '<path-to/preprocessed>'          # folder tow rite preprocessed data to, tends to use huge amounts of space. Should be a fast access drive (we recommend a fast SSD)
nnu_results = '<path-to/results>'                    # folder for training results. Access speed and storage capacity is not as critical (can be HDD)
nnu_raw = os.path.dirname(ds_raw)

Finally, the next cell runs the training using subprocess.
If you are unable to configure nnU-Net in your Jupyter environment, you can run the cells below to print  command lines, which you can then run in any python prompt after setting the environment variables.

In [None]:
env = {
    "nnUNet_raw": nnu_raw,
    "nnUNet_preprocessed": nnu_preprocessed,
    "nnUNet_results": nnu_results
}
print("-- environment: \n{}\n".format('\n'.join(': '.join(e) for e in env.items())))
import os
os.environ.update(env)

### nnU-Net: Preprocessing the dataset

Before the actual training can be started, nnU-Net needs to preprocess the dataset (once).
Run the following cell once before your training.

In [None]:
cmd = ["nnUNetv2_plan_and_preprocess", "--verify_dataset_integrity", 
       '-d', '600',          # dataset task to train
       '-c', '3d_fullres',   # model configuration to train
       '--verbose'          # show training output
       ]
cmd = ' '.join(cmd)

print(f"-- running command: \n> {cmd}")
import os
os.system(cmd)

### nnU-Net: Running the Training

After the dataset has been preprocessed, you can execute the following cell to train the models.

In [None]:
for fold in range(5):
    cmd = ["nnUNetv2_train", 
           '600',                               # dataset task to train
           '3d_fullres',                        # model configuration to train
           str(fold),                           # fold to train (0-5)
           '-tr', 'nnUNetTrainerNoMirroring',   # do not use mirroring augmentations
           ]
    cmd = ' '.join(cmd)
    
    print(f"-- running command: \n> {cmd}")
    import os
    os.system(cmd)