<a href="https://colab.research.google.com/github/zzqnot996/learn_deeplearning/blob/main/register_buffer_vs_register_parameter.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [18]:
import torch
import torch.nn as nn

class MyModule(nn.Module):
  def __init__(self):
    super(MyModule,self).__init__()
    self.conv1 = nn.Conv2d(in_channels = 3, out_channels=6, kernel_size=3, stride=1, padding=1, bias=False)
    self.conv2 = nn.Conv2d(in_channels = 6, out_channels=9, kernel_size=3, stride=1, padding=1, bias=False)
    self.weight = torch.ones(10,10)
    self.bias = torch.zeros(10)

  def forward(self,x):
    x=self.conv1(x)
    x=self.conv2(x)
    x = x*self.weight + self.bias
    return x


net = MyModule()

for name, param in net.named_parameters():

  print(name,param.shape)

print("\n","#"*40,"\n")

for key, val in net.state_dict().items():
  print(key,val.shape)

conv1.weight torch.Size([6, 3, 3, 3])
conv2.weight torch.Size([9, 6, 3, 3])

 ######################################## 

conv1.weight torch.Size([6, 3, 3, 3])
conv2.weight torch.Size([9, 6, 3, 3])


# 1、`register_parameter()`

`register_parameter()`是 torch.nn.Module 类中的一个方法

# 1.1、 作用

- 可将 `self.weight` 和 `self.bias` 定义为可学习的参数，保存到网络对象的参数中，被优化器作用进行学习
    
- `self.weight` 和 `self.bias` 可被保存到 state\_dict 中，进而可以 保存到网络文件 / 网络参数文件中
    

# 1.2、用法

  `register_parameter(name，param)`

- `name`：参数名称
    
- `param`：参数张量， 须是 `torch.nn.Parameter()` 对象 或 None ，否则报错如下
    

> TypeError: cannot assign 'torch.FloatTensor' object to parameter 'xx' (torch.nn.Parameter or None required)

In [19]:
import torch
import torch.nn as nn


class MyModule(nn.Module):
    def __init__(self):
        super(MyModule, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=6, kernel_size=3, stride=1, padding=1, bias=False)
        self.conv2 = nn.Conv2d(in_channels=6, out_channels=9, kernel_size=3, stride=1, padding=1, bias=False)

        # register_parameter  -----must torch.nn.Parameter
        self.register_parameter('weight', torch.nn.Parameter(torch.ones(10, 10)))
        self.register_parameter('bias', torch.nn.Parameter(torch.zeros(10)))


    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = x * self.weight + self.bias
        return x


net = MyModule()

for name, param in net.named_parameters():
    print(name, param.shape)

print('\n', '*'*40, '\n')

for key, val in net.state_dict().items():
    print(key, val.shape)


weight torch.Size([10, 10])
bias torch.Size([10])
conv1.weight torch.Size([6, 3, 3, 3])
conv2.weight torch.Size([9, 6, 3, 3])

 **************************************** 

weight torch.Size([10, 10])
bias torch.Size([10])
conv1.weight torch.Size([6, 3, 3, 3])
conv2.weight torch.Size([9, 6, 3, 3])


# 2、`register_buffer()`

`register_buffer()`是 `torch.nn.Module()` 类中的一个方法

# 2.1 、作用

- 将 `self.weight` 和 `self.bias` 定义为不可学习的参数，不会被保存到网络对象的参数中，不会被优化器作用进行学习
    
- `self.weight` 和 `self.bias` 可被保存到 state\_dict 中，进而可以 保存到网络文件 / 网络参数文件中
    

它用于在网络实例中 注册缓冲区，存储在缓冲区中的数据，类似于参数（但不是参数）

- 参数：可以被优化器更新  （requires\_grad=False / True）
    
- buffer 中的数据 ： 不会被优化器更新
    

# 2.2、用法

  `register_buffer(name，tensor)`

- `name`：参数名称
    
- `tensor`：张量

In [20]:
import torch
import torch.nn as nn


class MyModule(nn.Module):
    def __init__(self):
        super(MyModule, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=6, kernel_size=3, stride=1, padding=1, bias=False)
        self.conv2 = nn.Conv2d(in_channels=6, out_channels=9, kernel_size=3, stride=1, padding=1, bias=False)

        self.register_buffer('weight', torch.ones(10, 10))
        self.register_buffer('bias', torch.zeros(10))


    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = x * self.weight + self.bias
        return x


net = MyModule()

for name, param in net.named_parameters():
    print(name, param.shape)

print('\n', '*'*40, '\n')

for key, val in net.state_dict().items():
    print(key, val.shape)

conv1.weight torch.Size([6, 3, 3, 3])
conv2.weight torch.Size([9, 6, 3, 3])

 **************************************** 

weight torch.Size([10, 10])
bias torch.Size([10])
conv1.weight torch.Size([6, 3, 3, 3])
conv2.weight torch.Size([9, 6, 3, 3])
