# Evolution of the training process

In [1]:
%run review/__common.py
%load_ext autoreload
%autoreload 2



```python
@hydra.main(version_base="1.3", config_path="conf", config_name="config")
def main(cfg: DictConfig) -> None:
    ...
    rank_zero_logger.info("Training started...")
    
    iter = trainer.iter_init
    tagged_iter = 1
    num_rollout_steps = 1

    terminate_training = False
    finetune = False
    update_dataloader = False

    # training loop
    while True:
        for _, data in enumerate(trainer.datapipe):

            if iter >= cfg.num_iters_step1 + cfg.num_iters_step2 and not finetune:
                finetune = True
                        
            if (finetune and ((iter - (cfg.num_iters_step1 + cfg.num_iters_step2)) % cfg.step_change_freq == 0) and (iter != tagged_iter)):
                update_dataloader = True
                tagged_iter = iter

            # update the dataloader for finetuning
            if update_dataloader:
                num_rollout_steps = (iter - (cfg.num_iters_step1 + cfg.num_iters_step2)) // cfg.step_change_freq + 2
                trainer.datapipe = DataPipe(num_steps=num_rollout_steps)
                update_dataloader = False
                rank_zero_logger.info(f"Switching to {num_rollout_steps}-step rollout!")
                break

            # Prepare the input & output
            ...
            # training step
            loss = trainer.train(invar_cat, outvar)

            # validation
            if iter % cfg.val_freq == 0:
                ...
                error = trainer.validation.step(channels=list(np.arange(cfg.num_channels_val)), iter=iter)
                ...

            # checkpoiting
            if iter % cfg.save_freq == 0:
                save_checkpoint(..., epoch=iter)
            
            iter += 1

            # terminate
            if (iter >= cfg.num_iters_step1 + cfg.num_iters_step2 + cfg.num_iters_step3):
                terminate_training = True
                break

        if terminate_training:
            break
```

In [5]:
print("cfg.num_iters_step1:", cfg.num_iters_step1)
print("cfg.num_iters_step2:", cfg.num_iters_step2)
print("cfg.num_iters_step3:", cfg.num_iters_step3)
print()
print("cfg.val_freq:", cfg.val_freq)
print("cfg.save_freq:", cfg.save_freq)
print("cfg.step_change_freq:", cfg.step_change_freq)

cfg.num_iters_step1: 1000
cfg.num_iters_step2: 299000
cfg.num_iters_step3: 11000

cfg.val_freq: 5
cfg.save_freq: 1
cfg.step_change_freq: 1000


In [7]:
iter = 0
tagged_iter = 1
num_rollout_steps = 1

terminate_training = False
finetune = False
update_dataloader = False

# training loop
while True:
    print("New epoch...")
    for _, data in enumerate(range(10 * 1460)):
        if iter >= cfg.num_iters_step1 + cfg.num_iters_step2 and not finetune:
            print(iter, "Setting finetune to True")
            finetune = True
                    
        if (finetune and ((iter - (cfg.num_iters_step1 + cfg.num_iters_step2)) % cfg.step_change_freq == 0) and (iter != tagged_iter)):
            print(iter, "Setting update_dataloader to True")
            update_dataloader = True
            tagged_iter = iter

        # update the dataloader for finetuning
        if update_dataloader:
            num_rollout_steps = (iter - (cfg.num_iters_step1 + cfg.num_iters_step2)) // cfg.step_change_freq + 2
            print(iter, "Updating dataloader with num_rollout_steps =", num_rollout_steps)
            #trainer.datapipe = DataPipe(num_steps=num_rollout_steps)
            update_dataloader = False
            break

        # Prepare the input & output
        # ...
        # training step
        # loss = trainer.train(invar_cat, outvar)

        # validation
        #if iter % cfg.val_freq == 0:
            #...
            #error = trainer.validation.step(channels=list(np.arange(cfg.num_channels_val)), iter=iter)
            #...

        # checkpoiting
        # if iter % cfg.save_freq == 0:
        #    save_checkpoint(..., epoch=iter)
        
        iter += 1

        # terminate
        if (iter >= cfg.num_iters_step1 + cfg.num_iters_step2 + cfg.num_iters_step3):
            print(iter, "Terminating training")
            terminate_training = True
            break

    if terminate_training:
        break

New epoch...
New epoch...
New epoch...
New epoch...
New epoch...
New epoch...
New epoch...
New epoch...
New epoch...
New epoch...
New epoch...
New epoch...
New epoch...
New epoch...
New epoch...
New epoch...
New epoch...
New epoch...
New epoch...
New epoch...
New epoch...
300000 Setting finetune to True
300000 Setting update_dataloader to True
300000 Updating dataloader with num_rollout_steps = 2
New epoch...
301000 Setting update_dataloader to True
301000 Updating dataloader with num_rollout_steps = 3
New epoch...
302000 Setting update_dataloader to True
302000 Updating dataloader with num_rollout_steps = 4
New epoch...
303000 Setting update_dataloader to True
303000 Updating dataloader with num_rollout_steps = 5
New epoch...
304000 Setting update_dataloader to True
304000 Updating dataloader with num_rollout_steps = 6
New epoch...
305000 Setting update_dataloader to True
305000 Updating dataloader with num_rollout_steps = 7
New epoch...
306000 Setting update_dataloader to True
306000