# Profiling `train_multistep.py` and `trainerERA5_multistep_grad_accum.py`

## `train_multistep.py`

```python
Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
   241                                           def main(rank, world_size, conf, backend, trial=False):
   242                                               """
   243                                               Main function to set up training and validation processes.
   244
   245                                               Args:
   246                                                   rank (int): Rank of the current process.
   247                                                   world_size (int): Number of processes participating in the job.
   248                                                   conf (dict): Configuration dictionary containing model, data, and training parameters.
   249                                                   backend (str): Backend to be used for distributed training.
   250                                                   trial (bool, optional): Flag for whether this is an Optuna trial. Defaults to False.
   251
   252                                               Returns:
   253                                                   Any: The result of the training process.
   254                                               """
   255
   256                                               # convert $USER to the actual user name
   257         1      16301.0  16301.0      0.0      conf['save_loc'] = os.path.expandvars(conf['save_loc'])
   258
   259         1       1322.0   1322.0      0.0      if conf["trainer"]["mode"] in ["fsdp", "ddp"]:
   260         1  313010857.0    3e+08      0.0          setup(rank, world_size, conf["trainer"]["mode"], backend)
   261
   262                                               # infer device id from rank
   263
   264         1      26461.0  26461.0      0.0      device = torch.device(f"cuda:{rank % torch.cuda.device_count()}") if torch.cuda.is_available() else torch.device("cpu")
   265         1 1184547466.0    1e+09      0.0      torch.cuda.set_device(rank % torch.cuda.device_count())
   266
   267                                               # Config settings
   268         1       3076.0   3076.0      0.0      seed = 1000 if "seed" not in conf else conf["seed"]
   269         1   32462283.0    3e+07      0.0      seed_everything(seed)
   270
   271         1       1232.0   1232.0      0.0      train_batch_size = conf['trainer']['train_batch_size']
   272         1        501.0    501.0      0.0      valid_batch_size = conf['trainer']['valid_batch_size']
   273
   274                                               # get file names
   275         1   45995972.0    5e+07      0.0      all_ERA_files = sorted(glob(conf["data"]["save_loc"]))
   276
   277                                               # <------------------------------------------ std_new or 'std_cached'
   278         1       1844.0   1844.0      0.0      if conf['data']['scaler_type'] == 'std_new' or 'std_cached':
   279
   280                                                   # check and glob surface files
   281         1       2815.0   2815.0      0.0          if ('surface_variables' in conf['data']) and (len(conf['data']['surface_variables']) > 0):
   282         1     857703.0 857703.0      0.0              surface_files = sorted(glob(conf["data"]["save_loc_surface"]))
   283
   284                                                   else:
   285                                                       surface_files = None
   286
   287                                                   # check and glob dyn forcing files
   288         1       1773.0   1773.0      0.0          if ('dynamic_forcing_variables' in conf['data']) and (len(conf['data']['dynamic_forcing_variables']) > 0):
   289         1     434206.0 434206.0      0.0              dyn_forcing_files = sorted(glob(conf["data"]["save_loc_dynamic_forcing"]))
   290
   291                                                   else:
   292                                                       dyn_forcing_files = None
   293
   294                                                   # check and glob diagnostic files
   295         1       1563.0   1563.0      0.0          if ('diagnostic_variables' in conf['data']) and (len(conf['data']['diagnostic_variables']) > 0):
   296                                                       diagnostic_files = sorted(glob(conf["data"]["save_loc_diagnostic"]))
   297
   298                                                   else:
   299         1        221.0    221.0      0.0              diagnostic_files = None
   300
   301                                               # -------------------------------------------------- #
   302                                               # import training / validation years from conf
   303
   304         1        672.0    672.0      0.0      if 'train_years' in conf['data']:
   305         1        511.0    511.0      0.0          train_years_range = conf['data']['train_years']
   306                                               else:
   307                                                   train_years_range = [1979, 2014]
   308
   309         1        521.0    521.0      0.0      if 'valid_years' in conf['data']:
   310         1        250.0    250.0      0.0          valid_years_range = conf['data']['valid_years']
   311                                               else:
   312                                                   valid_years_range = [2014, 2018]
   313
   314                                               # convert year info to str for file name search
   315         1      13856.0  13856.0      0.0      train_years = [str(year) for year in range(train_years_range[0], train_years_range[1])]
   316         1       2545.0   2545.0      0.0      valid_years = [str(year) for year in range(valid_years_range[0], valid_years_range[1])]
   317
   318                                               # Filter the files for training / validation
   319         1     218851.0 218851.0      0.0      train_files = [file for file in all_ERA_files if any(year in file for year in train_years)]
   320         1      33184.0  33184.0      0.0      valid_files = [file for file in all_ERA_files if any(year in file for year in valid_years)]
   321
   322                                               # <----------------------------------- std_new or 'std_cached'
   323         1        561.0    561.0      0.0      if conf['data']['scaler_type'] == 'std_new' or 'std_cached':
   324
   325         1        161.0    161.0      0.0          if surface_files is not None:
   326
   327         1     214624.0 214624.0      0.0              train_surface_files = [file for file in surface_files if any(year in file for year in train_years)]
   328         1      32082.0  32082.0      0.0              valid_surface_files = [file for file in surface_files if any(year in file for year in valid_years)]
   329
   330                                                   else:
   331                                                       train_surface_files = None
   332                                                       valid_surface_files = None
   333
   334         1        391.0    391.0      0.0          if dyn_forcing_files is not None:
   335
   336         1     239712.0 239712.0      0.0              train_dyn_forcing_files = [file for file in dyn_forcing_files if any(year in file for year in train_years)]
   337         1      33745.0  33745.0      0.0              valid_dyn_forcing_files = [file for file in dyn_forcing_files if any(year in file for year in valid_years)]
   338
   339                                                   else:
   340                                                       train_dyn_forcing_files = None
   341                                                       valid_dyn_forcing_files = None
   342
   343         1        241.0    241.0      0.0          if diagnostic_files is not None:
   344
   345                                                       train_diagnostic_files = [file for file in diagnostic_files if any(year in file for year in train_years)]
   346                                                       valid_diagnostic_files = [file for file in diagnostic_files if any(year in file for year in valid_years)]
   347
   348                                                   else:
   349         1        161.0    161.0      0.0              train_diagnostic_files = None
   350         1        131.0    131.0      0.0              valid_diagnostic_files = None
   351
   352                                               # load Timer unit: 1e-09 s
   353         2        3e+10    1e+10      0.8      train_dataset, train_sampler = load_dataset_and_sampler(conf,
   354         1        131.0    131.0      0.0                                                              train_files,
   355         1        100.0    100.0      0.0                                                              train_surface_files,
   356         1        100.0    100.0      0.0                                                              train_dyn_forcing_files,
   357         1        100.0    100.0      0.0                                                              train_diagnostic_files,
   358         1        231.0    231.0      0.0                                                              world_size, rank, is_train=True)
   359                                               # validation set and sampler
   360         2  223009470.0    1e+08      0.0      valid_dataset, valid_sampler = load_dataset_and_sampler(conf,
   361         1        180.0    180.0      0.0                                                              valid_files,
   362         1        241.0    241.0      0.0                                                              valid_surface_files,
   363         1        251.0    251.0      0.0                                                              valid_dyn_forcing_files,
   364         1        120.0    120.0      0.0                                                              valid_diagnostic_files,
   365         1        381.0    381.0      0.0                                                              world_size, rank, is_train=False)
   366
   367                                               # setup the dataloder for this process
   368
   369         2     587722.0 293861.0      0.0      train_loader = torch.utils.data.DataLoader(
   370         1        401.0    401.0      0.0          train_dataset,
   371         1        171.0  "mode"], backend)
   372         1        391.0    391.0      0.0          shuffle=False,
   373         1        250.0    250.0      0.0          sampler=train_sampler,
   374         1        151.0    151.0      0.0          pin_memory=True,
   375         1        151.0    151.0      0.0          persistent_workers=False,
   376         1        221.0    221.0      0.0          num_workers=1,  # multiprocessing is handled in the dataset
   377         1        160.0    160.0      0.0          drop_last=True,
   378         1        140.0    140.0      0.0          prefetch_factor=4
   379                                               )
   380
   381         2      33534.0  16767.0      0.0      valid_loader = torch.utils.data.DataLoader(
   382         1        130.0    130.0      0.0          valid_dataset,
   383         1        180.0    180.0      0.0          batch_size=valid_batch_size,
   384         1        230.0    230.0      0.0          shuffle=False,
   385         1        140.0    140.0      0.0          sampler=valid_sampler,
   386         1        140.0    140.0      0.0          pin_memory=False,
   387         1        140.0    140.0      0.0          num_workers=1,  # multiprocessing is handled in the dataset
   388         1         90.0     90.0      0.0          drop_last=True,
   389         1         90.0     90.0      0.0          prefetch_factor=4
   390                                               )
   391
   392                                               # model
   393
   394         1 2828381656.0    3e+09      0.1      m = load_model(conf)
   395
   396                                               # have to send the module to the correct device first
   397
   398         1  496232366.0    5e+08      0.0      m.to(device)
399
   400                                               # move out of eager-mode
   401         1       3106.0   3106.0      0.0      if conf["trainer"].get("compile", False):
   402                                                   m = torch.compile(m)
   403
   404                                               # Wrap in DDP or FSDP module, or none
   405
   406         1 7956198312.0    8e+09      0.2      model = distributed_model_wrapper(conf, m, device)
   407
   408                                               # Load model weights (if any), an optimizer, scheduler, and gradient scaler
   409
   410         1        1e+10    1e+10      0.4      conf, model, optimizer, scheduler, scaler = load_model_states_and_optimizer(conf, model, device)
   411
   412                                               # Train and validation losses
   413
   414         1  831897166.0    8e+08      0.0      train_criterion = VariableTotalLoss2D(conf)
   415         1   12467354.0    1e+07      0.0      valid_criterion = VariableTotalLoss2D(conf, validation=True)
   416
   417                                               # Optional load stopping probability annealer
   418
   419                                               # Set up some metrics
   420
   421         1   11567129.0    1e+07      0.0      metrics = LatWeightedMetrics(conf)
   422
   423                                               # Initialize a trainer object
   424         1     655323.0 655323.0      0.0      trainer_cls = load_trainer(conf)
   425         1      93931.0  93931.0      0.0      trainer = trainer_cls(model, rank, module=(conf["trainer"]["mode"] == "ddp"))
   426
   427                                               # Fit the model
   428
   429         2        4e+12    2e+12     98.4      result = trainer.fit(
   430         1        121.0    121.0      0.0          conf,
   431         1        341.0    341.0      0.0          train_loader=train_loader,
   432         1        251.0    251.0      0.0          valid_loader=valid_loader,
   433         1        220.0    220.0      0.0          optimizer=optimizer,
   434         1        250.0    250.0      0.0          train_criterion=train_criterion,
   435         1        140.0    140.0      0.0          valid_criterion=valid_criterion,
   436         1        141.0    141.0      0.0          scaler=scaler,
   437         1        100.0    100.0      0.0          scheduler=scheduler,
   438         1        141.0    141.0      0.0          metrics=metrics,
   439         1        120.0    120.0      0.0          trial=trial  # Optional
   440                                               )
   441
   442         1       1563.0   1563.0      0.0      return result
```

## `trainer.fit`

```python
Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
   169                                               def fit(
   170                                                   self,
   171                                                   conf: Dict[str, Any],
   172                                                   train_loader: DataLoader,
   173                                                   valid_loader: DataLoader,
   174                                                   optimizer: Optimizer,
   175                                                   train_criterion: torch.nn.Module,
   176                                                   valid_criterion: torch.nn.Module,
   177                                                   scaler: GradScaler,
   178                                                   scheduler: _LRScheduler,
   179                                                   metrics: Dict[str, Any],
   180                                                   rollout_scheduler: Optional[callable] = None,
   181                                                   trial: bool = False
   182                                               ) -> Dict[str, Any]:
   183
   184                                                   """
   185                                                   Fit the model to the data.
   186
   187                                                   Args:
   188                                                       conf (Dict[str, Any]): Configuration dictionary.
   189                                                       train_loader (DataLoader): DataLoader for training data.
   190                                                       valid_loader (DataLoader): DataLoader for validation data.
   191                                                       optimizer (Optimizer): The optimizer to use for training.
   192                                                       train_criterion (torch.nn.Module): Loss function for training.
   193                                                       valid_criterion (torch.nn.Module): Loss function for validation.
   194                                                       scaler (GradScaler): Gradient scaler for mixed precision training.
   195                                                       scheduler (_LRScheduler): Learning rate scheduler.
   196                                                       metrics (Dict[str, Any]): Dictionary of metrics to track during training.
   197                                                       rollout_scheduler (Optional[callable]): Function to schedule rollout probability, if applicable.
   198                                                       trial (bool): Whether this is a trial run (e.g., for hyperparameter tuning).
   199
   200                                                   Returns:
   201                                                       Dict[str, Any]: Dictionary containing the best results from training.
   202                                                   """
   203
   204                                                   # convert $USER to the actual user name
   205         1       3607.0   3607.0      0.0          conf['save_loc'] = save_loc = os.path.expandvars(conf['save_loc'])
   206
   207                                                   # training hyperparameters
   208         1        451.0    451.0      0.0          start_epoch = conf['trainer']['start_epoch']
   209         1        311.0    311.0      0.0          epochs = conf['trainer']['epochs']
   210         1        592.0    592.0      0.0          skip_validation = conf['trainer']['skip_validation'] if 'skip_validation' in conf['trainer'] else False
   211         1        371.0    371.0      0.0          flag_load_weights = conf['trainer']['load_weights']
   212
   213                                                   # Check if 'training_metric' and 'training_metric_direction' exist in the config
   214         1        621.0    621.0      0.0          training_metric = conf['trainer'].get('training_metric', "train_loss" if skip_validation else "valid_loss")
   215         1        531.0    531.0      0.0          direction = conf['trainer'].get('training_metric_direction', "min")
   216         1      58523.0  58523.0      0.0          logger.info(f"The training metric being used is {training_metric} which has direction {direction}")
   217         1        611.0    611.0      0.0          direction = min if direction == "min" else max
   218
   219                                                   # Check if we are saving user-defined variable metrics
   220         1        701.0    701.0      0.0          save_metric_vars = conf['trainer'].get('save_metric_vars', [])
   221
   222                                                   # =========================================== #
   223                                                   # user can specify to run a fixed number of epochs
   224         1        421.0    421.0      0.0          if 'num_epoch' in conf['trainer']:
   225         1      40538.0  40538.0      0.0              logger.info('The current job will run {} epochs max'.format(conf['trainer']['num_epoch']))
   226                                                   else:
   227                                                       conf['trainer']['num_epoch'] = 1e8
   228                                                   # =========================================== #
   229
   230                                                   # Reload the results saved in the training csv if continuing to train
   231         1        340.0    340.0      0.0          if (start_epoch == 0) or (flag_load_weights is False):
   232                                                       results_dict = defaultdict(list)
   233                                                       # Set start_epoch to the length of the training log and train for one epoch
   234                                                       # This is a manual override, you must use train_one_epoch = True
   235                                                       if "train_one_epoch" in conf["trainer"] and conf["trainer"]["train_one_epoch"]:
   236                                                           epochs = 1
   237                                                   else:
   238         1        901.0    901.0      0.0              results_dict = defaultdict(list)
   239         1  116185865.0    1e+08      0.0              saved_results = pd.read_csv(os.path.join(save_loc, "training_log.csv"))
   240
   241                                                       # Set start_epoch to the length of the training log and train for one epoch
   242                                                       # This is a manual override, you must use train_one_epoch = True
   243         1       1603.0   1603.0      0.0              if "train_one_epoch" in conf["trainer"] and conf["trainer"]["train_one_epoch"]:
   244                                                           start_epoch = len(saved_results)
   245                                                           epochs = start_epoch + 1
   246
   247        12      16802.0   1400.2      0.0              for key in saved_results.columns:
   248        11       4580.0    416.4      0.0                  if key == "index":
   249         1        191.0    191.0      0.0                      continue
   250        10  141702870.0    1e+07      0.0                  results_dict[key] = list(saved_results[key])
   251
   252         1        340.0    340.0      0.0          count = 0
   253         2       2054.0   1027.0      0.0          for epoch in range(start_epoch, epochs):
   254
   255         2       2173.0   1086.5      0.0              if count >= conf['trainer']['num_epoch']:
   256         1     159267.0 159267.0      0.0                  logger.info('Completed {} epochs, exiting'.format(conf['trainer']['num_epoch']))
   257         1       1122.0   1122.0      0.0                  break
   258
   259                                                       # ========================= #
   260                                                       # backup the previous epoch
   261                                                       # ========================= #
   262         1        631.0    631.0      0.0              if count > 0 and conf['trainer']['save_backup_weights']:
   263                                                           if self.rank == 0:
   264                                                               # checkpoint.pt
   265                                                               shutil.copyfile(os.path.join(save_loc, "checkpoint.pt"),
   266                                                                               os.path.join(save_loc, "backup_checkpoint.pt"))
   267
   268                                                               # model_checkpoint.pt and optimizer_checkpoint.pt
   269                                                               if conf["trainer"]["mode"] == "fsdp":
   270                                                                   shutil.copyfile(os.path.join(save_loc, "model_checkpoint.pt"),
   271                                                                                   os.path.join(save_loc, "backup_model_checkpoint.pt"))
   272
   273                                                                   shutil.copyfile(os.path.join(save_loc, "optimizer_checkpoint.pt"),
   274                                                                                   os.path.join(save_loc, "backup_optimizer_checkpoint.pt"))
   275
   276         1     125903.0 125903.0      0.0              logger.info(f"Beginning epoch {epoch}")
   277
   278                                                       # set the epoch in the dataset and sampler to ensure distribured randomness is handled correctly
   279         1       4519.0   4519.0      0.0              if hasattr(train_loader, 'sampler') and hasattr(train_loader.sampler, 'set_epoch'):
   280         1       3336.0   3336.0      0.0                  train_loader.sampler.set_epoch(epoch)  # Start a new forecast
   281
   282         1       1894.0   1894.0      0.0              if hasattr(train_loader.dataset, 'set_epoch'):
   283         1       5100.0   5100.0      0.0                  train_loader.dataset.set_epoch(epoch)  # Ensure we don't start in the middle of a forecast epoch-over-epoch
   284
   285                                                       ############
   286                                                       #
   287                                                       # Train
   288                                                       #
   289                                                       ############
   290
   291         2        2e+12    9e+11     97.1              train_results = self.train_one_epoch(
   292         1        251.0    251.0      0.0                  epoch,
   293         1        240.0    240.0      0.0                  conf,
   294         1        210.0    210.0      0.0                  train_loader,
   295         1        361.0    361.0      0.0                  optimizer,
   296         1        301.0    301.0      0.0                  train_criterion,
   297         1        721.0    721.0      0.0                  scaler,
   298         1        241.0    241.0      0.0                  scheduler,
   299         1        211.0    211.0      0.0                  metrics
   300                                                       )
   301
   302                                                       ############
   303                                                       #
   304                                                       # Validation
   305                                                       #
   306                                                       ############
   307
   308         1        681.0    681.0      0.0              if skip_validation:
   309
   310                                                           valid_results = train_results
   311
   312                                                       else:
   313
   314         2        5e+10    2e+10      2.7                  valid_results = self.validate(
   315         1        581.0    581.0      0.0                      epoch,
   316         1        251.0    251.0      0.0                      conf,
   317         1        421.0    421.0      0.0                      valid_loader,
   318         1        851.0    851.0      0.0                      valid_criterion,
   319         1        210.0    210.0      0.0                      metrics
   320                                                           )
   321
   322                                                       #################
   323                                                       #
   324                                                       # Save results
   325                                                       #
   326                                                       #################
   327
   328         1       9508.0   9508.0      0.0              results_dict["epoch"].append(epoch)
   329
   330                                                       # Save metrics for select variables
   331         1       1102.0   1102.0      0.0              required_metrics = ["loss", "acc", "mae", "forecast_len"]  # Base required metrics
   332         1       4679.0   4679.0      0.0              if isinstance(save_metric_vars, list) and len(save_metric_vars) > 0:
   333                                                           names = [key.replace("train_", "") for key in train_results.keys() if any(var in key for var in save_metric_vars)]
   334         1       1974.0   1974.0      0.0              elif isinstance(save_metric_vars, bool) and save_metric_vars:
   335                                                           names = [key.replace("train_", "") for key in train_results.keys()]
   336                                                       else:
   337         1        351.0    351.0      0.0                  names = []
   338         1       4349.0   4349.0      0.0              names = list(set(names + required_metrics))
   339
   340         5       2294.0    458.8      0.0              for name in names:
   341         4     168176.0  42044.0      0.0                  results_dict[f"train_{name}"].append(np.mean(train_results[f"train_{name}"]))
   342         4       1182.0    295.5      0.0                  if skip_validation:
   343                                                               continue
   344         4     524352.0 131088.0      0.0                  results_dict[f"valid_{name}"].append(np.mean(valid_results[f"valid_{name}"]))
   345         1       7454.0   7454.0      0.0              results_dict["lr"].append(optimizer.param_groups[0]["lr"])
   346
   347                                                       # update the learning rate if epoch-by-epoch updates
   348
   349         1       1653.0   1653.0      0.0              if conf['trainer']['use_scheduler'] and conf['trainer']['scheduler']['scheduler_type'] in update_on_epoch:
   350                                                           if conf['trainer']['scheduler']['scheduler_type'] == 'plateau':
   351                                                               scheduler.step(results_dict[training_metric][-1])
   352                                                           else:
   353                                                               scheduler.step()
   354
   355                                                       # Create pandas df
   356
   357                                                       # Find the maximum length among all lists
   358         1      11663.0  11663.0      0.0              max_len = max(len(lst) for lst in results_dict.values())
   359
   360                                                       # Prepend NaNs to lists that are shorter than max_len
   361         1       1132.0   1132.0      0.0              padded_dict = OrderedDict()
   362        11       5462.0    496.5      0.0              for key, lst in results_dict.items():
   363        10       2903.0    290.3      0.0                  if len(lst) < max_len:
   364                                                               padded_dict[key] = [np.nan] * (max_len - len(lst)) + lst
   365                                                           else:
   366        10       3947.0    394.7      0.0                      padded_dict[key] = lst
   367
   368         1    2262224.0    2e+06      0.0              df = pd.DataFrame.from_dict(padded_dict).reset_index()
   369
   370                                                       # Save the dataframe to disk
   371
   372         1        411.0    411.0      0.0              if trial:  # If using ECHO-opt, save to the trial_results directory
   373                                                           df.to_csv(
   374                                                               os.path.join(f"{save_loc}", "trial_results", f"training_log_{trial.number}.csv"),
   375                                                               index=False,
   376                                                           )
   377                                                       else:
   378         1   14437259.0    1e+07      0.0                  df.to_csv(os.path.join(f"{save_loc}", "training_log.csv"), index=False)
   379
   380                                                       ############
   381                                                       #
   382                                                       # Checkpoint
   383                                                       #
   384                                                       ############
   385
   386         1       1413.0   1413.0      0.0              if not trial:
   387
   388         1       3617.0   3617.0      0.0                  if conf["trainer"]["mode"] != "fsdp":
   389
   390                                                               if self.rank == 0:
   391
   392                                                                   # Save the current model
   393
   394                                                                   logger.info(f"Saving model, optimizer, grad scaler, and learning rate scheduler states to {save_loc}")
   395
   396                                                                   state_dict = {
   397                                                                       "epoch": epoch,
   398                                                                       "model_state_dict": self.model.state_dict(),
   399                                                                       "optimizer_state_dict": optimizer.state_dict(),
   400                                                                       'scheduler_state_dict': scheduler.state_dict() if conf["trainer"]["use_scheduler"] else None,
   401                                                                       'scaler_state_dict': scaler.state_dict()
   402                                                                   }
   403                                                                   torch.save(state_dict, f"{save_loc}/checkpoint.pt")
   404
   405                                                           else:
   406
   407         1     236507.0 236507.0      0.0                      logger.info(f"Saving FSDP model, optimizer, grad scaler, and learning rate scheduler states to {save_loc}")
   408
   409                                                               # Initialize the checkpoint I/O handler
   410         1      12294.0  12294.0      0.0                      checkpoint_io = TorchFSDPCheckpointIO()
   411
   412                                                               # Save model and optimizer checkpoints
   413         2  521726705.0    3e+08      0.0                      checkpoint_io.save_unsharded_model(
   414         1       1172.0   1172.0      0.0                          self.model,
   415         1      10400.0  10400.0      0.0                          os.path.join(save_loc, "model_checkpoint.pt"),
   416         1        401.0    401.0      0.0                          gather_dtensor=True,
   417         1        300.0    300.0      0.0                          use_safetensors=False,
   418         1       1243.0   1243.0      0.0                          rank=self.rank
   419                                                               )
   420         2 2424341721.0    1e+09      0.1                      checkpoint_io.save_unsharded_optimizer(
   421         1        291.0    291.0      0.0                          optimizer,
   422         1       7645.0   7645.0      0.0                          os.path.join(save_loc, "optimizer_checkpoint.pt"),
   423         1        572.0    572.0      0.0                          gather_dtensor=True,
   424         1        400.0    400.0      0.0                          rank=self.rank
   425                                                               )
   426
   427                                                               # Still need to save the scheduler and scaler states, just in another file for FSDP
   428         1       2685.0   2685.0      0.0                      state_dict = {
   429         1        211.0    211.0      0.0                          "epoch": epoch,
   430         1       1973.0   1973.0      0.0                          'scheduler_state_dict': scheduler.state_dict() if conf["trainer"]["use_scheduler"] else None,
   431         1      12625.0  12625.0      0.0                          'scaler_state_dict': scaler.state_dict()
   432                                                               }
   433
   434         1   22825879.0    2e+07      0.0                      torch.save(state_dict, os.path.join(save_loc, "checkpoint.pt"))
   435
   436                                                       # clear the cached memory from the gpu
   437         1    9895808.0    1e+07      0.0              torch.cuda.empty_cache()
   438         1  198765830.0    2e+08      0.0              gc.collect()
   439         1       2504.0   2504.0      0.0              count += 1
   440
   441         1        510.0    510.0      0.0              if skip_validation:
   442                                                           pass
   443                                                       else:
   444                                                           # Stop training if we have not improved after X epochs (stopping patience)
   445         2      42734.0  21367.0      0.0                  best_epoch = [i for i, j in enumerate(results_dict[training_metric])
   446         1        130.0    130.0      0.0                                if j == direction(results_dict[training_metric])][0]
   447         1       1423.0   1423.0      0.0                  offset = epoch - best_epoch
   448
   449                                                           # ==================== #
   450                                                           # backup the best epoch
   451                                                           # ==================== #
   452         1        781.0    781.0      0.0                  if offset == 0 and conf['trainer']['save_best_weights']:
   453                                                               if self.rank == 0:
   454                                                                   # checkpoint.pt
   455                                                                   shutil.copyfile(
   456                                                                       os.path.join(save_loc, "checkpoint.pt"),
   457                                                                       os.path.join(save_loc, "best_checkpoint.pt")
   458                                                                   )
   459
   460                                                                   # model_checkpoint.pt and optimizer_checkpoint.pt
   461                                                                   if conf["trainer"]["mode"] == "fsdp":
   462                                                                       shutil.copyfile(
   463                                                                           os.path.join(save_loc, "model_checkpoint.pt"),
   464                                                                           os.path.join(save_loc, "best_model_checkpoint.pt")
   465                                                                       )
   466
   467                                                                       shutil.copyfile(
   468                                                                           os.path.join(save_loc, "optimizer_checkpoint.pt"),
   469                                                                           os.path.join(save_loc, "best_optimizer_checkpoint.pt")
   470                                                                       )
   471
   472                                                           # ==================== #
   473                                                           # early stopping block
   474                                                           # ==================== #
   475         1       1844.0   1844.0      0.0                  if offset >= conf['trainer']['stopping_patience']:
   476                                                               logger.info("Best {} were in epoch {}; current epoch is {}; early stopping.".format(
   477                                                                   training_metric, best_epoch, epoch))
   478                                                               break
   479
   480                                                       # ==================== #
   481                                                       # stop after one epoch
   482                                                       # ==================== #
   483         1       1263.0   1263.0      0.0              if 'stop_after_epoch' in conf['trainer']:
   484                                                           if conf['trainer']['stop_after_epoch']:
   485                                                               break
   486
   487         3       7885.0   2628.3      0.0          best_epoch = [
   488         1        832.0    832.0      0.0              i for i, j in enumerate(results_dict[training_metric]) if j == direction(results_dict[training_metric])
   489         1        140.0    140.0      0.0          ][0]
   490
   491         1       9288.0   9288.0      0.0          result = {k: v[best_epoch] for k, v in results_dict.items()}
   492
   493         1       1262.0   1262.0      0.0          if conf["trainer"]["mode"] in ["fsdp", "ddp"]:
   494         1     259852.0 259852.0      0.0              cleanup()
   495
   496         1        191.0    191.0      0.0          return result
```

## `train_one_epoch`

```python
Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
    61                                               def train_one_epoch(
    62                                                   self,
    63                                                   epoch,
    64                                                   conf,
    65                                                   trainloader,
    66                                                   optimizer,
    67                                                   criterion,
    68                                                   scaler,
    69                                                   scheduler,
    70                                                   metrics
    71                                               ):
    72
    73                                                   """
    74                                                   Trains the model for one epoch.
    75
    76                                                   Args:
    77                                                       epoch (int): Current epoch number.
    78                                                       conf (dict): Configuration dictionary containing training settings.
    79                                                       trainloader (DataLoader): DataLoader for the training dataset.
    80                                                       optimizer (torch.optim.Optimizer): Optimizer used for training.
    81                                                       criterion (callable): Loss function used for training.
    82                                                       scaler (torch.cuda.amp.GradScaler): Gradient scaler for mixed precision training.
    83                                                       scheduler (torch.optim.lr_scheduler._LRScheduler): Learning rate scheduler.
    84                                                       metrics (callable): Function to compute metrics for evaluation.
    85
    86                                                   Returns:
    87                                                       dict: Dictionary containing training metrics and loss for the epoch.
    88                                                   """
    89
    90         1        892.0    892.0      0.0          batches_per_epoch = conf['trainer']['batches_per_epoch']
    91         1        370.0    370.0      0.0          amp = conf['trainer']['amp']
    92         1        652.0    652.0      0.0          distributed = True if conf["trainer"]["mode"] in ["fsdp", "ddp"] else False
    93         1        691.0    691.0      0.0          forecast_length = conf["data"]["forecast_len"]
    94
    95                                                   # number of diagnostic variables
    96         1       1142.0   1142.0      0.0          varnum_diag = len(conf["data"]['diagnostic_variables'])
    97
    98                                                   # number of dynamic forcing + forcing + static
    99         3       1131.0    377.0      0.0          static_dim_size = len(conf['data']['dynamic_forcing_variables']) + \
   100         1        290.0    290.0      0.0                            len(conf['data']['forcing_variables']) + \
   101         1        290.0    290.0      0.0                            len(conf['data']['static_variables'])
   102
   103                                                   # [Optional] Use the config option to set when to backprop
   104         1        340.0    340.0      0.0          if 'backprop_on_timestep' in conf['data']:
   105                                                       backprop_on_timestep = conf['data']['backprop_on_timestep']
   106                                                   else:
   107                                                       # If not specified in config, use the range 1 to forecast_len
   108         1       1974.0   1974.0      0.0              backprop_on_timestep = list(range(0, conf['data']['forecast_len']+1+1))
   109
   110         1        551.0    551.0      0.0          assert forecast_length <= backprop_on_timestep[-1], (
   111                                                       f"forecast_length ({forecast_length + 1}) must not exceed the max value in backprop_on_timestep {backprop_on_timestep}"
   112                                                   )
   113
   114                                                   # update the learning rate if epoch-by-epoch updates that dont depend on a metric
   115         1        341.0    341.0      0.0          if conf['trainer']['use_scheduler'] and conf['trainer']['scheduler']['scheduler_type'] == "lambda":
   116                                                       scheduler.step()
   117
   118                                                   # set up a custom tqdm
   119         1       3547.0   3547.0      0.0          if not isinstance(trainloader.dataset, IterableDataset):
   120         1        180.0    180.0      0.0              batches_per_epoch = (
   121         1      15410.0  15410.0      0.0                  batches_per_epoch if 0 < batches_per_epoch < len(trainloader) else len(trainloader)
   122                                                       )
   123
   124         2     859685.0 429842.5      0.0          batch_group_generator = tqdm.tqdm(
   125         1        401.0    401.0      0.0              range(batches_per_epoch), total=batches_per_epoch, leave=True
   126                                                   )
   127
   128         1    6359749.0    6e+06      0.0          self.model.train()
   129
   130         1        842.0    842.0      0.0          dl = cycle(trainloader)
   131
   132         1       1142.0   1142.0      0.0          results_dict = defaultdict(list)
   133
   134        11       4652.0    422.9      0.0          for steps in range(batches_per_epoch):
   135
   136        10       4718.0    471.8      0.0              logs = {}
   137        10    2160803.0 216080.3      0.0              loss = 0
   138        10      11973.0   1197.3      0.0              stop_forecast = False
   139        10   20803236.0    2e+06      0.0              y_pred = None  # Place holder that gets updated after first roll-out
   140
   141        20     576044.0  28802.2      0.0              with autocast(enabled=amp):
   142
   143       240     276366.0   1151.5      0.0                  while not stop_forecast:
   144
   145       240        2e+11    1e+09     14.6                      batch = next(dl)
   146
   147       480   12347402.0  25723.8      0.0                      for i, forecast_step in enumerate(batch["forecast_step"]):
   148                                                                   # if self.rank == 0:
   149                                                                   #     logger.info(f"i: {i}, forecast_step: {forecast_step}")
   150       240   12683168.0  52846.5      0.0                          if forecast_step == 1:
   151                                                                       # Initialize x and x_surf with the first time step
   152        10      17019.0   1701.9      0.0                              if "x_surf" in batch:
   153                                                                           # combine x and x_surf
   154                                                                           # input: (batch_num, time, var, level, lat, lon), (batch_num, time, var, lat, lon)
   155                                                                           # output: (batch_num, var, time, lat, lon), 'x' first and then 'x_surf'
   156        10  854855476.0    9e+07      0.1                                  x = concat_and_reshape(batch["x"], batch["x_surf"]).to(self.device)#.float()
   157                                                                       else:
   158                                                                           # no x_surf
   159                                                                           x = reshape_only(batch["x"]).to(self.device)#.float()
   160
   161                                                                   # add forcing and static variables (regardless of fcst hours)
   162       240     294712.0   1228.0      0.0                          if 'x_forcing_static' in batch:
   163
   164                                                                       # (batch_num, time, var, lat, lon) --> (batch_num, var, time, lat, lon)
   165       240  143398150.0 597492.3      0.0                              x_forcing_batch = batch['x_forcing_static'].to(self.device).permute(0, 2, 1, 3, 4)#.float()
   166
   167                                                                       # concat on var dimension
   168       240   25076797.0 104486.7      0.0                              x = torch.cat((x, x_forcing_batch), dim=1)
   169
   170                                                                   # predict with the model
   171       240        1e+12    5e+09     69.2                          y_pred = self.model(x)
   172
   173                                                                   # only load y-truth data if we intend to backprop (default is every step gets grads computed
   174       240   16387437.0  68281.0      0.0                          if forecast_step in backprop_on_timestep:
   175
   176                                                                       # calculate rolling loss
   177       240     235393.0    980.8      0.0                              if "y_surf" in batch:
   178       240        2e+10    8e+07      1.2                                  y = concat_and_reshape(batch["y"], batch["y_surf"]).to(self.device)
   179                                                                       else:
   180                                                                           y = reshape_only(batch["y"]).to(self.device)
   181
   182       240    1053344.0   4388.9      0.0                              if 'y_diag' in batch:
   183
   184                                                                           # (batch_num, time, var, lat, lon) --> (batch_num, var, time, lat, lon)
   185                                                                           y_diag_batch = batch['y_diag'].to(self.device).permute(0, 2, 1, 3, 4)#.float()
   186
   187                                                                           # concat on var dimension
   188                                                                           y = torch.cat((y, y_diag_batch), dim=1)
   189
   190       240 2488403762.0    1e+07      0.1                              loss = criterion(y.to(y_pred.dtype), y_pred).mean()
   191
   192                                                                       # track the loss
   193       240   20136823.0  83903.4      0.0                              accum_log(logs, {'loss': loss.item()})
   194
   195                                                                       # compute gradients
   196       240        2e+11    1e+09     13.8                              scaler.scale(loss).backward()
   197
   198       240     260156.0   1084.0      0.0                          if distributed:
   199       240        1e+10    6e+07      0.8                              torch.distributed.barrier()
   200
   201                                                                   # stop after X steps
   202       240    9798455.0  40826.9      0.0                          stop_forecast = batch['stop_forecast'][i]
   203
   204                                                                   # step-in-step-out
   205       240    1534845.0   6395.2      0.0                          if x.shape[2] == 1:
   206
   207                                                                       # cut diagnostic vars from y_pred, they are not inputs
   208       240     330691.0   1377.9      0.0                              if 'y_diag' in batch:
   209                                                                           x = y_pred[:, :-varnum_diag, ...].detach()
   210                                                                       else:
   211       240    4160670.0  17336.1      0.0                                  x = y_pred.detach()
   212
   213                                                                   # multi-step in
   214                                                                   else:
   215                                                                       # static channels will get updated on next pass
   216
   217                                                                       if static_dim_size == 0:
   218                                                                           x_detach = x[:, :, 1:, ...].detach()
   219                                                                       else:
   220                                                                           x_detach = x[:, :-static_dim_size, 1:, ...].detach()
   221
   222                                                                       # cut diagnostic vars from y_pred, they are not inputs
   223                                                                       if 'y_diag' in batch:
   224                                                                           x = torch.cat([x_detach,
   225                                                                                          y_pred[:, :-varnum_diag, ...].detach()], dim=2)
   226                                                                       else:
   227                                                                           x = torch.cat([x_detach,
   228                                                                                          y_pred.detach()], dim=2)
   229
   230       240    2552801.0  10636.7      0.0                      if stop_forecast:
   231        10       2195.0    219.5      0.0                          break
   232
   233                                                           # scale, accumulate, backward
   234
   235        10       2286.0    228.6      0.0                  if distributed:
   236        10    4020154.0 402015.4      0.0                      torch.distributed.barrier()
   237
   238        10 1225184451.0    1e+08      0.1                  scaler.step(optimizer)
   239        10      67271.0   6727.1      0.0                  scaler.update()
   240        10    7495549.0 749554.9      0.0                  optimizer.zero_grad()
   241
   242                                                       # Metrics
   243                                                       # metrics_dict = metrics(y_pred.float(), y.float())
   244        10 1077346971.0    1e+08      0.1              metrics_dict = metrics(y_pred, y)
   245      2890    9220389.0   3190.4      0.0              for name, value in metrics_dict.items():
   246      2880  205919969.0  71500.0      0.0                  value = torch.Tensor([value]).cuda(self.device, non_blocking=True)
   247      2880    1160097.0    402.8      0.0                  if distributed:
   248      2880  303978599.0 105548.1      0.0                      dist.all_reduce(value, dist.ReduceOp.AVG, async_op=False)
   249      2880  192668871.0  66898.9      0.0                  results_dict[f"train_{name}"].append(value[0].item())
   250
   251        10     427673.0  42767.3      0.0              batch_loss = torch.Tensor([logs["loss"]]).cuda(self.device)
   252        10       3418.0    341.8      0.0              if distributed:
   253        10     639531.0  63953.1      0.0                  dist.all_reduce(batch_loss, dist.ReduceOp.AVG, async_op=False)
   254        10     305730.0  30573.0      0.0              results_dict["train_loss"].append(batch_loss[0].item())
   255        10      12675.0   1267.5      0.0              results_dict["train_forecast_len"].append(forecast_length+1)
   256
   257        10     597321.0  59732.1      0.0              if not np.isfinite(np.mean(results_dict["train_loss"])):
   258                                                           print(results_dict["train_loss"], batch["x"].shape, batch["y"].shape, batch["index"])
   259                                                           try:
   260                                                               raise optuna.TrialPruned()
   261                                                           except Exception as E:
   262                                                               raise E
   263
   264                                                       # agg the results
   265        20     125398.0   6269.9      0.0              to_print = "Epoch: {} train_loss: {:.6f} train_acc: {:.6f} train_mae: {:.6f} forecast_len: {:.6f}".format(
   266        10       4217.0    421.7      0.0                  epoch,
   267        10     182013.0  18201.3      0.0                  np.mean(results_dict["train_loss"]),
   268        10     131071.0  13107.1      0.0                  np.mean(results_dict["train_acc"]),
   269        10     115521.0  11552.1      0.0                  np.mean(results_dict["train_mae"]),
   270        10       3480.0    348.0      0.0                  forecast_length+1
   271                                                       )
   272        10      56299.0   5629.9      0.0              to_print += " lr: {:.12f}".format(optimizer.param_groups[0]["lr"])
   273        10      15078.0   1507.8      0.0              if self.rank == 0:
   274                                                           batch_group_generator.update(1)
   275                                                           batch_group_generator.set_description(to_print)
   276
   277        10      14726.0   1472.6      0.0              if conf['trainer']['use_scheduler'] and conf['trainer']['scheduler']['scheduler_type'] in update_on_batch:
   278                                                           scheduler.step()
   279
   280                                                   #  Shutdown the progbar
   281         1     363450.0 363450.0      0.0          batch_group_generator.close()
   282
   283                                                   # clear the cached memory from the gpu
   284         1  110189503.0    1e+08      0.0          torch.cuda.empty_cache()
   285         1  226151814.0    2e+08      0.0          gc.collect()
   286
   287         1        902.0    902.0      0.0          return results_dict
```



