# Train the model

This notebook shows the process of training a model.
Configuration of the model is defined in the `config` dictionary.

In [1]:
%load_ext autoreload
%autoreload 2

## Import libraries

In [2]:
print("Importing can be slow due to large packages (`torch`). Please wait...")

# Added `noqa: E402` to suppress warning about the import order when linting
import torch
import os
import yaml
import pandas as pd
from time import time
from argparse import ArgumentParser     # noqa: E402
from pathlib import Path                # noqa: E402
from tqdm.notebook import tqdm   # displays a progress bar
from typing import Dict, Any, Optional
import plotly.graph_objects as go
from plotly.subplots import make_subplots
# import ipywidgets as widgets
# from IPython.display import display
from pprint import pprint

from config.config_loader import load_config            # noqa: E402
from utils.makepath import makepath as mkp               # noqa: E402
# from data_lib.turtle_dataset_constructor \
#     import TurtleDatasetConstructor                     # noqa: E402
from scripts.mri.pdhg_net_trainer import PdhgNetTrainer     # noqa: E402
# from scripts.model_loader import ModelLoader            # noqa: E402

print("Importing done.")

Importing can be slow due to large packages (`torch`). Please wait...
Importing done.


## Config

First, make sure we know the relative path to the root of the project.

In [3]:
# Root directory of the project is two levels up from this notebook.
# Change this if the notebook is moved.
num_levels_up = 2

root_dir = mkp(".")
for _ in range(num_levels_up):
    root_dir = mkp(root_dir, "..")

os.listdir(root_dir)

['scripts',
 'requirements.txt',
 'mri.egg-info',
 'venv',
 'README.md',
 'figures',
 'config',
 'LICENSE',
 'utils',
 'networks',
 'tmp',
 'dyn_mri_test.py',
 '.gitignore',
 'gradops',
 'pyproject.toml',
 'gifs',
 'data',
 'pdhg',
 'data_lib',
 'wandb',
 'encoding_objects',
 '.git']

Next, let's choose the configuration to use for training.
We can pass the path to a configuration file, which is either a YAML or a JSON file.
An example configuration file named `example_model_config` is provided in the `config` directory.
It contains the configuration for a tiny version of the U-TGV type-2 model.

Another way is to pass a dictionary directly to the `config` parameter.

Finally, there are options to simply state the model type.
Here we have three options: `u_tv`, `u_tgv_type_1`, and `u_tgv_type_2`.
They correspond to the U-TV, U-TGV type-1, and U-TGV type-2 models, respectively.
The details of these models are as described in the report.


In [4]:
class MyArgs:
    def __init__(
            self,
            config,
            output_dir = None,
            device = "cpu",
            logs_local: bool = True,
            uses_wandb: bool = False
    ):
        self.config = config
        self.output_dir = output_dir
        self.device = device
        self.logs_local = logs_local
        self.uses_wandb = uses_wandb

# TODO: Set the model configuration here.
args = MyArgs(
    # TODO: Make sure the path is correct.
    # Modify the config file or pass a different one if needed
    config=mkp(root_dir, "config", "example_model_config.yaml"),
    # config=mkp(root_dir, "config", "example_model_config.yml"),

    # # Actual U-TV model in the report
    # config="u_tv",

    # # Actual U-TGV type-1 model in the report
    # config="u_tgv_type_1",

    # # Actual U-TGV type-2 model in the report
    # config="u_tgv_type_2",

    # # Direct the output to a different directory if needed
    # output_dir=mkp(root_dir, "your_output_directory"),
    # output_dir=mkp(root_dir, "tmp", "example_model")

    # Change the device if needed
    # device="cuda" if torch.cuda.is_available() else "cpu",
    # device="cpu",
    # device="mps",  # Apple GPU
)

print(f"Config choice: {args.config}")

Config choice: ../../config/example_model_config.yaml


Adjust the config if needed. Here we will reduce the U-Net size and use a small number of samples and epochs to demonstrate the training process.

In [5]:
# config = load_config(args.config, root_dir=root_dir, is_training=True)
config_file_path = mkp(root_dir, "config", "example_mri_tgv_config.yaml")
# with open(config_file_path, "r") as f:
#     config = yaml.load(f, Loader=yaml.FullLoader)
# config = yaml.load(open(config_file_path, "r"), Loader=yaml.FullLoader)
config = load_config(config_file_path, root_dir=root_dir, is_training=True)

config["data"]["train_num_samples"] = 10
config["data"]["val_num_samples"] = 10

# config["unet"]["init_filters"] = 32

config["train"]["num_epochs"] = 10

# config["device"] = "cpu"
# config["device"] = "cuda"   # Nvidia
# config["device"] = "mps"    # Apple

Config loaded from file ../../config/example_mri_tgv_config.yaml


## Train

Let's instantiate the trainer with the chosen configuration and model type.

In [6]:
trainer = PdhgNetTrainer(
    config_choice=config, tqdm=tqdm, device=args.device)
if args.output_dir is not None:
    trainer.config["log"]["save_dir"] = args.output_dir
print(f"Output directory: {trainer.config['log']['save_dir']}")
print(f"Device: {trainer.config['device']}")

Config loaded from dict
Trainer initialized.
Output directory: ../../tmp/example_mri_tgv
Device: cpu


Make sure the path to the data is correct.

In [7]:
print(f"Data path: {trainer.config['data']['data_path']}")

Data path: ../../tmp/mri


### Log the progress

We need to create the log files to log the training loss and validation loss. 

In [8]:
trainer.init_logger(force_overwrite=True)

File '../../tmp/example_mri_tgv/train_epoch_metrics.csv' already exists.
Overwriting the file...
Creating file '../../tmp/example_mri_tgv/train_epoch_metrics.csv'...
File '../../tmp/example_mri_tgv/train_intermediate_metrics.csv' already exists.
Overwriting the file...
Creating file '../../tmp/example_mri_tgv/train_intermediate_metrics.csv'...
File '../../tmp/example_mri_tgv/val_epoch_metrics.csv' already exists.
Overwriting the file...
Creating file '../../tmp/example_mri_tgv/val_epoch_metrics.csv'...
File '../../tmp/example_mri_tgv/val_intermediate_metrics.csv' already exists.
Overwriting the file...
Creating file '../../tmp/example_mri_tgv/val_intermediate_metrics.csv'...
Logging options initialized.
Logger initialized.


### Initialize the model

In [9]:
trainer.init_pdhg_net()

Creating a new model...
Norm of operator A: 1
Norm of gradient operator nabla: 2.8284270763397217
L: 3.0
Model initialized. Model device: cpu.
Number of trainable parameters: 517477.
Model size: 1.97 MB.


### Set up the loss function and training time

In [10]:
trainer.init_training_options()

# # Re-adjust if needed
# trainer.num_epochs = 2
# print(f"Number of epochs adjusted to: {trainer.num_epochs}")

Using loss function: MSELoss
Training for 10 epochs, starting from epoch 0.
Training options initialized.


### Load the data

In [11]:
trainer.load_data()



imgs_true_complex.shape: torch.Size([3452, 320, 320])

min_abs_val: 4.898726047031232e-07
max_abs_val: 2.58732533454895
Training dataset size: 10
Validation dataset size: 10
Test dataset size: 302
type of training_dataset: <class 'torch.utils.data.dataset.Subset'>
type of validation_dataset: <class 'torch.utils.data.dataset.Subset'>
type of test_dataset: <class 'torch.utils.data.dataset.Subset'>
Data loaded.


### [Optional] Log the configuration

For easy future reference.

In [12]:
# Store config and other logs if specified.
if args.logs_local:
    trainer.logger.log_config_local(trainer.pdhg_net)

Saving config in ../../tmp/example_mri_tgv...
Config saved


### [Optional] Log to WandB

In [13]:
# Initialize WandB for logging if specified.
if args.uses_wandb:
    trainer.logger.init_wandb()

### Start training

In [14]:
# Start training.
start_time = time()
trainer.start_training()
end_time = time()
print(f"Training took {end_time - start_time} seconds.")

Model will be saved in ../../tmp/example_mri_tgv.
Training started for 10 epochs.


  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

acceleration_factor_R: 5
standard_deviation_sigma: 0.1454569458961487


  return func(*args, **kwargs)


TypeError: only integer tensors of a single element can be converted to an index