# Model IO

本代码实现：

1. 创建 config 数据类，并定义 json IO 方法
2. 创建模型 和 自定义IO的类方法，特别的会读写“优化器状态”（Optimizer） 用于端点训练状态恢复
3. 基于断点状态恢复优化器

In [1]:
import torch

## Config IO

观察 `optimizer_class` 可以写入类名,  `Transformers` 库中的类似 `AutoConfig` 可以通过配置文件字段初始化具体的 config 类

In [2]:
from dataclasses import dataclass, asdict, field
import json

@dataclass
class ModelConfig:
    input_size: int = -1
    hidden_size: int = -1
    output_size: int = -1
    optimizer_class: type = torch.optim.Adam
    learning_rate: float = 0.001
    dropout: float = 0.2
    custom_objects: dict = field(default_factory=dict)
    
    def save_json(self, path):
        """将配置保存为JSON（可读性更好）"""
        # 转换不可JSON序列化的对象
        config_dict = asdict(self)
        config_dict['optimizer_class'] = str(self.optimizer_class)
        with open(path, 'w') as f:
            json.dump(config_dict, f, indent=2)
    
    @classmethod
    def from_json(cls, path):
        """从JSON加载配置（注意：无法恢复类对象）"""
        with open(path, 'r') as f:
            data = json.load(f)
        return cls(**data)

In [3]:
config = ModelConfig(
    input_size=784,
    hidden_size=128,
    output_size=10,
    learning_rate=0.01,
)
print(config)

ModelConfig(input_size=784, hidden_size=128, output_size=10, optimizer_class=<class 'torch.optim.adam.Adam'>, learning_rate=0.01, dropout=0.2, custom_objects={})


In [4]:
config.save_json('./output/test_config.json')
!cat ./output/test_config.json

{
  "input_size": 784,
  "hidden_size": 128,
  "output_size": 10,
  "optimizer_class": "<class 'torch.optim.adam.Adam'>",
  "learning_rate": 0.01,
  "dropout": 0.2,
  "custom_objects": {}
}

In [5]:
new_config =  ModelConfig()
print(new_config)

ModelConfig(input_size=-1, hidden_size=-1, output_size=-1, optimizer_class=<class 'torch.optim.adam.Adam'>, learning_rate=0.001, dropout=0.2, custom_objects={})


In [6]:
new_config = new_config.from_json('./output/test_config.json')
print(new_config)

ModelConfig(input_size=784, hidden_size=128, output_size=10, optimizer_class="<class 'torch.optim.adam.Adam'>", learning_rate=0.01, dropout=0.2, custom_objects={})


## Model

In [7]:
class ToyModel(torch.nn.Module):
    def __init__(self, config: ModelConfig):
        super().__init__()
        self.config = config
        
        self.fc1 = torch.nn.Linear(config.input_size, config.hidden_size)
        self.dropout = torch.nn.Dropout(config.dropout)
        self.fc2 = torch.nn.Linear(config.hidden_size, config.output_size)
    
    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.dropout(x)
        return self.fc2(x)
    
    def save(self, path, optimizer=None):
        """保存模型和配置"""
        save_data = {
            'model_state_dict': self.state_dict(),
            'config': self.config,
            'config_dict': asdict(self.config)  # 用于JSON序列化
        }
        
        if optimizer:
            save_data['optimizer_state_dict'] = optimizer.state_dict()
        
        torch.save(save_data, path)
        self.config.save_json(path.replace('.pth', '_config.json'))
    
    @classmethod
    def load(cls, path, device='cpu'):
        """加载模型和配置, 可先加载至 CPU 再搬运到 GPU 设备"""
        data = torch.load(path, map_location=device, weights_only = False)
        
        # 从保存的数据中重建配置
        config = data['config']
        
        model = cls(config)
        model.load_state_dict(data['model_state_dict'])
        model.to(device)
        
        optimizer_state = data.get('optimizer_state_dict')
        return model, config, optimizer_state

## 通过 config 初始化模型

In [8]:
model = ToyModel(config)
print(model)
optimizer = config.optimizer_class(model.parameters(), lr=config.learning_rate)

ToyModel(
  (fc1): Linear(in_features=784, out_features=128, bias=True)
  (dropout): Dropout(p=0.2, inplace=False)
  (fc2): Linear(in_features=128, out_features=10, bias=True)
)


## 保存

In [9]:
model.save("./output/model_io.pth", optimizer=optimizer)

In [10]:
!cat ./output/model_io_config.json

{
  "input_size": 784,
  "hidden_size": 128,
  "output_size": 10,
  "optimizer_class": "<class 'torch.optim.adam.Adam'>",
  "learning_rate": 0.01,
  "dropout": 0.2,
  "custom_objects": {}
}

## 加载

In [11]:
loaded_model, loaded_config, optimizer_state = ToyModel.load("./output/model_io.pth", device='cpu')

## 优化器恢复

In [12]:
# 重建优化器
print(loaded_config.optimizer_class)
loaded_optimizer = loaded_config.optimizer_class(
    loaded_model.parameters(), 
    lr=loaded_config.learning_rate
)
if optimizer_state:
    loaded_optimizer.load_state_dict(optimizer_state)

<class 'torch.optim.adam.Adam'>


训练断点恢复, 还需考虑训练的 epochs、step（最后一次梯度更新）

# Reference

代码参考 DeepSeek-R1 所生成内容