In [1]:
import torch
import torch.nn as nn
import torch.optim as optim

# Simple toy model
class TinyNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.lin1 = nn.Linear(3, 4)
        self.lin2 = nn.Linear(4, 1)

    def forward(self, x):
        return self.lin2(torch.relu(self.lin1(x)))

model = TinyNet()


In [2]:


# Optimizer
optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9)

print("=== Param groups (runtime) ===")
for gi, group in enumerate(optimizer.param_groups):
    print(f"\nGroup {gi}:")
    print("keys:", list(group.keys()))
    print("lr:", group["lr"])
    print("num params:", len(group["params"]))
    for pi, p in enumerate(group["params"]):
        print(f"  param {pi} id={id(p)} shape={tuple(p.shape)} requires_grad={p.requires_grad}")

print("\n=== Raw state (before any step) ===")
print(optimizer.state)  # should be empty dict-like

# Do one fake training step to populate state
x = torch.randn(5, 3)
y = torch.randn(5, 1)
loss = torch.nn.functional.mse_loss(model(x), y)
loss.backward()
optimizer.step()

print("\n=== Raw state (after one step) ===")
for p, s in optimizer.state.items():
    print(f"param id={id(p)} shape={tuple(p.shape)}")
    for k, v in s.items():
        print(f"  state[{k!r}] type={type(v)} shape={getattr(v, 'shape', None)}")

print("\n=== Serialized optimizer.state_dict() ===")
opt_sd = optimizer.state_dict()
for k in opt_sd.keys():
    print(k, "-> type:", type(opt_sd[k]))

print("\nparam_groups[0] in state_dict():")
print(opt_sd["param_groups"][0])

print("\nstate keys in state_dict():")
print("param ids:", list(opt_sd["state"].keys()))
for pid, st in opt_sd["state"].items():
    print(f"  pid={pid} -> keys={list(st.keys())}")


=== Param groups (runtime) ===

Group 0:
keys: ['params', 'lr', 'momentum', 'dampening', 'weight_decay', 'nesterov', 'maximize', 'foreach', 'differentiable', 'fused']
lr: 0.1
num params: 4
  param 0 id=22936997753184 shape=(4, 3) requires_grad=True
  param 1 id=22936997747184 shape=(4,) requires_grad=True
  param 2 id=22936997756624 shape=(1, 4) requires_grad=True
  param 3 id=22936997756784 shape=(1,) requires_grad=True

=== Raw state (before any step) ===
defaultdict(<class 'dict'>, {})

=== Raw state (after one step) ===
param id=22936997753184 shape=(4, 3)
  state['momentum_buffer'] type=<class 'torch.Tensor'> shape=torch.Size([4, 3])
param id=22936997747184 shape=(4,)
  state['momentum_buffer'] type=<class 'torch.Tensor'> shape=torch.Size([4])
param id=22936997756624 shape=(1, 4)
  state['momentum_buffer'] type=<class 'torch.Tensor'> shape=torch.Size([1, 4])
param id=22936997756784 shape=(1,)
  state['momentum_buffer'] type=<class 'torch.Tensor'> shape=torch.Size([1])

=== Seriali

In [3]:
optimizer.param_groups[0]["params"][0]

Parameter containing:
tensor([[-0.1969, -0.2550, -0.1314],
        [-0.0949,  0.1184,  0.2931],
        [-0.0740,  0.5392,  0.1214],
        [-0.4077,  0.0215, -0.3823]], requires_grad=True)

In [4]:
optimizer.state[optimizer.param_groups[0]["params"][0]]

{'momentum_buffer': tensor([[ 0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000],
         [ 0.0246, -0.0580,  0.0237],
         [ 0.0245, -0.0579,  0.0236]])}