All figures generated by the code.
All processed data and trained models data.
Model saved after training in this folder.
All source code for data processing, model, loss function, etc.
Fig_2_4.ipynb file is to plot Fig. 2 to Fig.4 in the report and similarly others.
This script is used to train a U-Net model for MRI image reconstruction based image segmentation.
python train_mri.py --mri_inchannels 0 3where the --mri_inchannels takes the index of the modality to consider for the training. Here, 0 means T1 and 3 means FLAIR.
The steps:
-
Environment: The Environmemnt file is environment.yaml.
-
conda env create -f environment.yaml
-
conda activate CS525_Project2
-
pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126
-
Please contact us for futher questions.
-
Load configuration: The script loads hyperparameters from a
config.jsonfile. These parameters include file paths, batch size, training and validation data details, and model parameters. -
Prepare the dataset:
- It collects training and validation data from the specified directory (
data_dir), which contains MRI volume slices stored as.h5files. - The data is loaded using a custom
UNetDatasetorUNetDataset_Memclass based on the config file and wrapped in PyTorch'sDataLoaderfor batching and shuffling.
- It collects training and validation data from the specified directory (
-
Model setup:
- A U-Net model is initialized with specific input and output channels.
-
Loss function & optimizer:
- The
DiceLossfunction is used to measure the accuracy of image segmentation. - The Adam optimizer is initialized with learning rate and other hyperparameters.
- The
-
Training:
- The model is trained using the
train_modelfunction, which takes the training and validation data loaders, the model, loss function, optimizer, and hyperparameters such as the number of epochs and learning rate scheduler details.
- The model is trained using the
-
Model saving:
- After training, the model is saved according to the the key "saved_model_path" of
config.json. Also, the training loss/accuracy is saved according to the according to the the key "log_file of" ofconfig.json.
- After training, the model is saved according to the the key "saved_model_path" of
The script is essential for training the MRI reconstruction model and monitoring its performance during training.
The pipeline of image reconstruction and training is shown in
.