In [1]:
from torch.optim.sgd import SGD


class MetaSGD(SGD):
    def __init__(self, net, *args, **kwargs):
        super(MetaSGD, self).__init__(*args, **kwargs)
        self.net = net

    def set_parameter(self, current_module, name, parameters):
        if '.' in name:
            name_split = name.split('.')
            module_name = name_split[0]
            rest_name = '.'.join(name_split[1:])
            for children_name, children in current_module.named_children():
                if module_name == children_name:
                    self.set_parameter(children, rest_name, parameters)
                    break
        else:
            current_module._parameters[name] = parameters

    def meta_step(self, grads):
        group = self.param_groups[0]
        weight_decay = group['weight_decay']
        momentum = group['momentum']
        dampening = group['dampening']
        nesterov = group['nesterov']
        lr = group['lr']

        # 根据weight_decay、momentum、nesterov对梯度进行更新
        for (name, parameter), grad in zip(self.net.named_parameters(), grads):
            parameter.detach_()
            if weight_decay != 0:
                grad_wd = grad.add(parameter, alpha=weight_decay)
            else:
                grad_wd = grad
            if momentum != 0 and 'momentum_buffer' in self.state[parameter]:
                buffer = self.state[parameter]['momentum_buffer']
                grad_b = buffer.mul(momentum).add(grad_wd, alpha=1-dampening)
            else:
                grad_b = grad_wd
            if nesterov:
                grad_n = grad_wd.add(grad_b, alpha=momentum)
            else:
                grad_n = grad_b
            self.set_parameter(self.net, name, parameter.add(grad_n, alpha=-lr))


In [14]:
import torch


logits = torch.randn((2,4))
target = torch.tensor([0,2])
torch.zeros_like(logits).scatter_(1, target.unsqueeze(1), 1), torch.zeros_like(logits), target.unsqueeze(1)

(tensor([[1., 0., 0., 0.],
         [0., 0., 1., 0.]]),
 tensor([[0., 0., 0., 0.],
         [0., 0., 0., 0.]]),
 tensor([[0],
         [2]]))

In [27]:
logits = torch.randn((2,4))
target = torch.tensor([0,2])
delta = torch.tensor([[2],[2]])
y_t = logits
y_t_target = y_t * torch.zeros_like(y_t).scatter_(1, target.unsqueeze(1), 1)
y_t_delta = delta * torch.zeros_like(y_t).scatter_(1, target.unsqueeze(1), 1)
y_t = y_t_target - y_t_delta
y_t_delta, y_t

tensor([[2],
        [2]]) tensor([[1., 0., 0., 0.],
        [0., 0., 1., 0.]])


(tensor([[2., 0., 0., 0.],
         [0., 0., 2., 0.]]),
 tensor([[-1.9150, -0.0000, -0.0000,  0.0000],
         [ 0.0000,  0.0000, -2.5433, -0.0000]]))

In [77]:
import torch
import torch.optim as optim

m1 = torch.nn.Linear(1, 1)
m2 = torch.nn.Linear(1, 2)
m3 = torch.nn.Linear(1, 3)
trainable_list1 = torch.nn.ModuleList([m1, m2, m3])
trainable_list2 = torch.nn.ModuleList([m1, m2, m3])

optimizer1 = optim.SGD(trainable_list1.parameters(),
                       lr=0.05,
                       momentum=0.9,
                       weight_decay=5e-4)
optimizer2 = optim.SGD(trainable_list2[1:-1].parameters(),
                       lr=0.05,
                       momentum=0.9,
                       weight_decay=5e-4)

# 加载 optimizer1 的状态字典
state_dict_optimizer1 = optimizer1.state_dict()

# 获取 optimizer2 的状态字典
state_dict_optimizer2 = optimizer2.state_dict()

# 通过模型的名称匹配参数组，只保留 optimizer1 中与 optimizer2 相同模型的参数的部分
for param_group1 in state_dict_optimizer1['param_groups']:
    param_names1 = set(param_group1['params'])
    print(param_names1)
    matched_params1 = {k: v for k, v in state_dict_optimizer1['state'].items() if k in param_names1}
    
    # 找到 optimizer2 中相同模型的参数组
    param_group2 = next((param_group2 for param_group2 in state_dict_optimizer2['param_groups'] if set(param_group2['params']) == param_names1), None)
    
    # 如果找到匹配的参数组，则更新它
    if param_group2 is not None:
        param_group2['params'] = list(matched_params1.values())

# 将更新后的状态字典加载到 optimizer2
optimizer2.load_state_dict(state_dict_optimizer2)
