* venv : torch2_0

In [2]:
import torch
import torch.nn.functional as F

In [3]:
class SampleModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = torch.nn.Conv2d(in_channels=3,out_channels=16,kernel_size=3)
        self.relu = torch.nn.ReLU()
        self.pool = torch.nn.MaxPool2d(kernel_size=2)
        self.fc1 = torch.nn.Linear(in_features=16 * 8 * 8, out_features=256)
        self.fc2 = torch.nn.Linear(in_features=256, out_features=10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x = self.pool(x)
        x = x.view(x.size()[0], -1) # flatten
        x = self.fc1(x)
        x = self.fc2(x)
        return x

model = SampleModel()

### パラメーターの読み込み

#### model.parameters()利用

In [4]:
# パラメーターの読み込み(model.parameters()を利用)
for param in model.parameters():
    print(param)

Parameter containing:
tensor([[[[ 0.1151,  0.1480, -0.1310],
          [-0.1385,  0.1490,  0.0087],
          [ 0.1515,  0.1367,  0.0814]],

         [[-0.0109,  0.1216, -0.0365],
          [ 0.1378,  0.1595,  0.0226],
          [-0.0060, -0.1791,  0.1260]],

         [[-0.1365,  0.1229,  0.0680],
          [ 0.0974, -0.1142,  0.0961],
          [-0.1338, -0.1542,  0.1746]]],


        [[[ 0.1159,  0.0325, -0.0216],
          [-0.1664, -0.1423, -0.0899],
          [ 0.0043,  0.0394,  0.0672]],

         [[-0.1005,  0.0364, -0.0123],
          [-0.0456,  0.1285,  0.0041],
          [-0.1652,  0.0627,  0.0060]],

         [[-0.1881, -0.1097, -0.0504],
          [-0.0313,  0.0688,  0.1101],
          [-0.0601, -0.0881, -0.0526]]],


        [[[ 0.1182,  0.1060, -0.1399],
          [ 0.1677,  0.1281,  0.0727],
          [-0.1391,  0.0885,  0.1247]],

         [[ 0.1062, -0.1738,  0.0103],
          [-0.0919,  0.0554,  0.1916],
          [-0.1625,  0.0877, -0.0307]],

         [[-0.0750,  0

#### state_dict()利用

In [5]:
# パラメーターの読み込み(state_dict()利用)
model.state_dict()

OrderedDict([('conv1.weight',
              tensor([[[[ 0.1151,  0.1480, -0.1310],
                        [-0.1385,  0.1490,  0.0087],
                        [ 0.1515,  0.1367,  0.0814]],
              
                       [[-0.0109,  0.1216, -0.0365],
                        [ 0.1378,  0.1595,  0.0226],
                        [-0.0060, -0.1791,  0.1260]],
              
                       [[-0.1365,  0.1229,  0.0680],
                        [ 0.0974, -0.1142,  0.0961],
                        [-0.1338, -0.1542,  0.1746]]],
              
              
                      [[[ 0.1159,  0.0325, -0.0216],
                        [-0.1664, -0.1423, -0.0899],
                        [ 0.0043,  0.0394,  0.0672]],
              
                       [[-0.1005,  0.0364, -0.0123],
                        [-0.0456,  0.1285,  0.0041],
                        [-0.1652,  0.0627,  0.0060]],
              
                       [[-0.1881, -0.1097, -0.0504],
                        [-

In [6]:
model.state_dict()['fc2.bias']

tensor([-0.0265, -0.0351, -0.0555,  0.0268,  0.0219, -0.0179,  0.0335,  0.0577,
        -0.0559, -0.0622])

#### 変数を直接参照

In [7]:
# state_dict()に比べ、requires_gradの情報も見える
model.fc2.bias

Parameter containing:
tensor([-0.0265, -0.0351, -0.0555,  0.0268,  0.0219, -0.0179,  0.0335,  0.0577,
        -0.0559, -0.0622], requires_grad=True)

### パラメーターの書き換え

#### 変数に直接代入

In [20]:
# だめみたい(エラーとなる)
#  - スライスで無理やりできるようだが、データ一括処理にあまり対応していない
#  - エラーメッセージでも、以下のnn.Parameter()が紹介されていた
#model.fc2.bias = torch.randn(model.fc2.bias.shape)
#model.fc2.bias

#### nn.Parameter()を利用

In [21]:
model.fc2.bias = torch.nn.Parameter(torch.randn(model.fc2.bias.shape))
model.fc2.bias

Parameter containing:
tensor([ 0.7221,  1.4647,  0.7196, -0.3147, -0.2619,  0.7150,  0.6261, -0.6898,
        -0.1861, -1.2967], requires_grad=True)

In [23]:
# 補足：optimizerにパラメーターを渡した後は、.dataを用いて更新する
model.fc2.bias.data = torch.randn(model.fc2.bias.shape)
model.fc2.bias.data

tensor([-0.3397,  0.0599,  0.3497,  0.8991, -0.4590, -0.0663,  0.6867,  1.1968,
        -0.4934,  0.0698])

### パラメーターを固定

In [24]:
for param in model.parameters():
    param.requires_grad = False
    print(param)

Parameter containing:
tensor([[[[ 0.1151,  0.1480, -0.1310],
          [-0.1385,  0.1490,  0.0087],
          [ 0.1515,  0.1367,  0.0814]],

         [[-0.0109,  0.1216, -0.0365],
          [ 0.1378,  0.1595,  0.0226],
          [-0.0060, -0.1791,  0.1260]],

         [[-0.1365,  0.1229,  0.0680],
          [ 0.0974, -0.1142,  0.0961],
          [-0.1338, -0.1542,  0.1746]]],


        [[[ 0.1159,  0.0325, -0.0216],
          [-0.1664, -0.1423, -0.0899],
          [ 0.0043,  0.0394,  0.0672]],

         [[-0.1005,  0.0364, -0.0123],
          [-0.0456,  0.1285,  0.0041],
          [-0.1652,  0.0627,  0.0060]],

         [[-0.1881, -0.1097, -0.0504],
          [-0.0313,  0.0688,  0.1101],
          [-0.0601, -0.0881, -0.0526]]],


        [[[ 0.1182,  0.1060, -0.1399],
          [ 0.1677,  0.1281,  0.0727],
          [-0.1391,  0.0885,  0.1247]],

         [[ 0.1062, -0.1738,  0.0103],
          [-0.0919,  0.0554,  0.1916],
          [-0.1625,  0.0877, -0.0307]],

         [[-0.0750,  0