# Optimizer checkpointing

In [1]:
import torch

import torch.nn as nn
import torch.optim as optim
import apex.optimizers

from modulus.distributed import DistributedManager
from modulus.models_new.graphcast.graph_cast_net_newest import GraphCastNetNew
from modulus.utils_new.caching import Cache
from modulus.utils.graphcast.loss import GraphCastLossFunction
from torch.optim.lr_scheduler import SequentialLR, LinearLR, CosineAnnealingLR, LambdaLR


DistributedManager.initialize()
dist = DistributedManager()
Cache.initialize(dir="/iopsstor/scratch/cscs/stefschu/DSM500/cache")

cp_path = "/iopsstor/scratch/cscs/stefschu/DSM500/github/analysis/_model/optimizer_checkpoint_test.checkpoint"

def instantiate(dtype):
    # Instantiate the model
    model = GraphCastNetNew(                    # Typical values
        sample_height=721,        # 721
        sample_width=1440,          # 1440
        sample_channels=21,    # 21
        
        include_static_data=True,         # True
        include_spatial_info=True,         # True
        include_temporal_info=True,         # True
        include_solar_radiation=True,         # True

        batch_size=1,      # 1
        mesh_level=6,               # 6
        activation_fn="silu",         # "silu",
        hidden_dim=128,               # 512
        hidden_layers=1,         # 1
        aggregation_op="sum",       # "sum"
        processor_layers=3,   # 16
    )
    model = model.to(dtype).to(dist.device)

    # Define a loss function
    criterion = GraphCastLossFunction()

    # Define an optimizer
    optimizer = apex.optimizers.FusedAdam(
        model.parameters(),
        lr=0.01,
        betas=(0.9, 0.95),
        adam_w_mode=True,
        weight_decay=0.1,
    )

    scheduler = SequentialLR(
        optimizer,
        schedulers=[
            LinearLR(
                optimizer,
                start_factor=1e-3,
                end_factor=1.0,
                total_iters=1000,
            ),
            CosineAnnealingLR(
                optimizer,
                T_max=10000,
                eta_min=0.0,
            ),
            LambdaLR(
                optimizer,
                lr_lambda=lambda epoch: 1e-3,
            ),
        ],
        milestones=[
            1000,
            10000
        ],
    )

    input_data = torch.ones((1, 1, 31, 721, 1440)).to(dtype).to(dist.device)
    target = torch.ones((1, 1, 21, 721, 1440)).to(dtype).cuda(dist.device)

    return model, criterion, optimizer, scheduler, input_data, target, dtype

def train(model, criterion, optimizer, scheduler, input_data, target, dtype):
    # Forward pass
    output = model(input_data)
    loss = criterion(output, target)

    # Backward pass and optimization
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    scheduler.step()

  warn(


In [3]:
model, criterion, optimizer, scheduler, input_data, target, dtype = instantiate(torch.bfloat16)

In [4]:
train(model, criterion, optimizer, scheduler, input_data, target, dtype)
print(scheduler.get_last_lr()[0])

1.9990000000000003e-05


In [5]:
train(model, criterion, optimizer, scheduler, input_data, target, dtype)
print(scheduler.get_last_lr()[0])

2.9980000000000004e-05


In [8]:
scheduler.state_dict()

{'_milestones': [1000, 10000],
 'last_epoch': 2,
 '_last_lr': [2.9980000000000004e-05],
 '_schedulers': [{'start_factor': 0.001,
   'end_factor': 1.0,
   'total_iters': 1000,
   'base_lrs': [0.01],
   'last_epoch': 2,
   'verbose': False,
   '_step_count': 3,
   '_get_lr_called_within_step': False,
   '_last_lr': [2.9980000000000004e-05]},
  {'T_max': 10000,
   'eta_min': 0.0,
   'base_lrs': [0.01],
   'last_epoch': -1,
   'verbose': False,
   '_step_count': 1,
   '_get_lr_called_within_step': False,
   '_last_lr': [1e-05]},
  {'base_lrs': [0.01],
   'last_epoch': -1,
   'verbose': False,
   '_step_count': 1,
   '_get_lr_called_within_step': False,
   '_last_lr': [1e-05],
   'lr_lambdas': [None]}]}

In [11]:
cp = torch.load("/iopsstor/scratch/cscs/stefschu/DSM500/github/analysis/checkpoints/model.iter001000.pth", map_location=dist.device, weights_only=True)

In [12]:
cp["scheduler"]

{'_milestones': [1000, 310000],
 'last_epoch': 1000,
 '_last_lr': [0.001],
 '_schedulers': [{'start_factor': 0.001,
   'end_factor': 1.0,
   'total_iters': 1000,
   'base_lrs': [0.001],
   'last_epoch': 999,
   'verbose': False,
   '_step_count': 1000,
   '_get_lr_called_within_step': False,
   '_last_lr': [0.0009990009999999955]},
  {'T_max': 299000,
   'eta_min': 0.0,
   'base_lrs': [0.001],
   'last_epoch': 0,
   'verbose': False,
   '_step_count': 2,
   '_get_lr_called_within_step': False,
   '_last_lr': [0.001]},
  {'base_lrs': [0.001],
   'last_epoch': -1,
   'verbose': False,
   '_step_count': 1,
   '_get_lr_called_within_step': False,
   '_last_lr': [0.0001],
   'lr_lambdas': [None]}]}

In [13]:
scheduler.load_state_dict(cp["scheduler"])
print(scheduler.get_last_lr()[0])

0.001


In [15]:
scheduler.state_dict()

{'_milestones': [1000, 310000],
 'last_epoch': 1000,
 '_last_lr': [0.001],
 '_schedulers': [{'start_factor': 0.001,
   'end_factor': 1.0,
   'total_iters': 1000,
   'base_lrs': [0.001],
   'last_epoch': 999,
   'verbose': False,
   '_step_count': 1000,
   '_get_lr_called_within_step': False,
   '_last_lr': [0.0009990009999999955]},
  {'T_max': 299000,
   'eta_min': 0.0,
   'base_lrs': [0.001],
   'last_epoch': 0,
   'verbose': False,
   '_step_count': 2,
   '_get_lr_called_within_step': False,
   '_last_lr': [0.001]},
  {'base_lrs': [0.001],
   'last_epoch': -1,
   'verbose': False,
   '_step_count': 1,
   '_get_lr_called_within_step': False,
   '_last_lr': [0.0001],
   'lr_lambdas': [None]}]}

In [None]:
##########################################
print("-" * 80)
print("> First training")
model, criterion, optimizer, dtype = instantiate(torch.bfloat16)
train(model, criterion, optimizer, dtype)

p = next(model.parameters())
print(f"Model parameter:              [{p.min():28.25f}, {p.max():28.25f}] | {p.dtype} | {p.device} | {p.shape}")
p = optimizer.param_groups[0]["params"][0]
print(f"Optimizer parameter:          [{p.min():28.25f}, {p.max():28.25f}] | {p.dtype} | {p.device} | {p.shape}")
state = next(iter(optimizer.state.values()))
_iter = iter(state.items())
k, v = next(_iter)
print(f"Optimizer state ({k:10}): [{v.min():28.25f}, {v.max():28.25f}] | {v.dtype}  | {v.device} | {v.shape}")
k, v = next(_iter)
print(f"Optimizer state ({k:10}): [{v.min():28.25f}, {v.max():28.25f}] | {v.dtype}  | {v.device} | {v.shape}")

##########################################
print("-" * 80)
print("> Review state dict")

state = next(iter(optimizer.state_dict()["state"].values()))
_iter = iter(state.items())
k, v = next(_iter)
print(f"Optimizer state ({k:10}): [{v.min():28.25f}, {v.max():28.25f}] | {v.dtype}  | {v.device} | {v.shape}")
k, v = next(_iter)
print(f"Optimizer state ({k:10}): [{v.min():28.25f}, {v.max():28.25f}] | {v.dtype}  | {v.device} | {v.shape}")

##########################################
print("-" * 80)
print("> save + load, then review state dict again")

torch.save({"optimizer": optimizer.state_dict()}, cp_path)
cp = torch.load(cp_path, map_location="cuda:0", weights_only=True)["optimizer"]
state = next(iter(cp["state"].values()))
_iter = iter(state.items())
k, v = next(_iter)
print(f"Optimizer state ({k:10}): [{v.min():28.25f}, {v.max():28.25f}] | {v.dtype}  | {v.device} | {v.shape}")
k, v = next(_iter)
print(f"Optimizer state ({k:10}): [{v.min():28.25f}, {v.max():28.25f}] | {v.dtype}  | {v.device} | {v.shape}")

##########################################
print("-" * 80)
print("> Initialize model")
model, criterion, optimizer = instantiate()

p = next(model.parameters())
print(f"Model parameter:              [{p.min():28.25f}, {p.max():28.25f}] | {p.dtype} | {p.device} | {p.shape}")
p = optimizer.param_groups[0]["params"][0]
print(f"Optimizer parameter:          [{p.min():28.25f}, {p.max():28.25f}] | {p.dtype} | {p.device} | {p.shape}")

##########################################
print("-" * 80)
print("> Reload optimizer (WITHOUT FIX)")

optimizer.load_state_dict(cp)

state = next(iter(optimizer.state.values()))
_iter = iter(state.items())
k, v = next(_iter)
print(f"Optimizer state ({k:10}): [{v.min():28.25f}, {v.max():28.25f}] | {v.dtype}  | {v.device} | {v.shape}")
k, v = next(_iter)
print(f"Optimizer state ({k:10}): [{v.min():28.25f}, {v.max():28.25f}] | {v.dtype}  | {v.device} | {v.shape}")


##########################################
print("-" * 80)
print("> Reload optimizer (WITH FIX)")



optimizer.load_state_dict(cp)

state = next(iter(optimizer.state.values()))
_iter = iter(state.items())
k, v = next(_iter)
print(f"Optimizer state ({k:10}): [{v.min():28.25f}, {v.max():28.25f}] | {v.dtype}  | {v.device} | {v.shape}")
k, v = next(_iter)
print(f"Optimizer state ({k:10}): [{v.min():28.25f}, {v.max():28.25f}] | {v.dtype}  | {v.device} | {v.shape}")


--------------------------------------------------------------------------------
> First training


ValueError: too many values to unpack (expected 3)

In [None]:
_cp = torch.load(cp_path, map_location="cuda:0", weights_only=True)["optimizer"]

print("In the checkpoint:")
for state in _cp["state"].values():
    for k, v in state.items():
        print(f"Optimizer state ({k:10}): [{v.min():28.25f}, {v.max():28.25f}] | {v.dtype}  | {v.device} | {v.shape}")
    break

model, criterion, optimizer, dtype = instantiate(dtype=torch.bfloat16)

print()
print("After loading from checkpoint:")
optimizer.load_state_dict(_cp)

# Convert state tensors to desired dtype
for op_state, cp_state in zip(optimizer.state.values(), _cp["state"].values()):
    for (op_k, op_v), (cp_k, cp_v) in zip(op_state.items(), cp_state.items()):
        print(f"Optimizer state:  ({op_k:10}): [{op_v.min():28.25f}, {op_v.max():28.25f}] | {op_v.dtype}  | {op_v.device} | {op_v.shape}")
        print(f"Checkpoint state: ({cp_k:10}): [{cp_v.min():28.25f}, {cp_v.max():28.25f}] | {cp_v.dtype}  | {cp_v.device} | {cp_v.shape}")
        op_state[op_k] = op_state[op_k].to(torch.float32)
        op_state[op_k] = cp_v
    break

print()
print("After fix:")
state = next(iter(optimizer.state.values()))
_iter = iter(state.items())
k, v = next(_iter)
print(f"Optimizer state ({k:10}): [{v.min():28.25f}, {v.max():28.25f}] | {v.dtype}  | {v.device} | {v.shape}")
k, v = next(_iter)
print(f"Optimizer state ({k:10}): [{v.min():28.25f}, {v.max():28.25f}] | {v.dtype}  | {v.device} | {v.shape}")

In the checkpoint:
Optimizer state (exp_avg   ): [-0.0072753923013806343078613,  0.0070312516763806343078613] | torch.float32  | cuda:0 | torch.Size([128, 31])
Optimizer state (exp_avg_sq): [ 0.0000000017291771348126872,  0.0002646566135808825492859] | torch.float32  | cuda:0 | torch.Size([128, 31])

After loading from checkpoint:
Optimizer state:  (exp_avg   ): [-0.0072631835937500000000000,  0.0070190429687500000000000] | torch.bfloat16  | cuda:0 | torch.Size([128, 31])
Checkpoint state: (exp_avg   ): [-0.0072753923013806343078613,  0.0070312516763806343078613] | torch.float32  | cuda:0 | torch.Size([128, 31])
Optimizer state:  (exp_avg_sq): [ 0.0000000017316779121756554,  0.0002651214599609375000000] | torch.bfloat16  | cuda:0 | torch.Size([128, 31])
Checkpoint state: (exp_avg_sq): [ 0.0000000017291771348126872,  0.0002646566135808825492859] | torch.float32  | cuda:0 | torch.Size([128, 31])

After fix:
Optimizer state (exp_avg   ): [-0.0072753923013806343078613,  0.00703125167638063