# 레이어 가중치 임의로 초기화 하기

In [1]:
import torch
import torch.nn as nn

## weight 초기화

In [3]:
n_class = 28
n_dim = 512
embedding = nn.Embedding(n_class, n_dim)
embedding

Embedding(28, 512)

In [10]:
init_range = 0.1
# 범위 내의 uniform distribution으로 가중치 초기화
embedding.weight.data.uniform_(-init_range, init_range)

tensor([[ 0.0274,  0.0130, -0.0493,  ...,  0.0409, -0.0268,  0.0271],
        [-0.0086, -0.0320,  0.0854,  ...,  0.0927, -0.0167, -0.0975],
        [-0.0279,  0.0317,  0.0057,  ...,  0.0772, -0.0510, -0.0687],
        ...,
        [-0.0599,  0.0246, -0.0251,  ..., -0.0921,  0.0598,  0.0282],
        [ 0.0171, -0.0007,  0.0151,  ...,  0.0915, -0.0467,  0.0437],
        [ 0.0012, -0.0283, -0.0981,  ...,  0.0369, -0.0250, -0.0390]])

In [11]:
embedding.weight

Parameter containing:
tensor([[ 0.0274,  0.0130, -0.0493,  ...,  0.0409, -0.0268,  0.0271],
        [-0.0086, -0.0320,  0.0854,  ...,  0.0927, -0.0167, -0.0975],
        [-0.0279,  0.0317,  0.0057,  ...,  0.0772, -0.0510, -0.0687],
        ...,
        [-0.0599,  0.0246, -0.0251,  ..., -0.0921,  0.0598,  0.0282],
        [ 0.0171, -0.0007,  0.0151,  ...,  0.0915, -0.0467,  0.0437],
        [ 0.0012, -0.0283, -0.0981,  ...,  0.0369, -0.0250, -0.0390]],
       requires_grad=True)

## Bias 초기화

In [15]:
input_dim = 7
output_dim = 12

fc = nn.Linear(input_dim, output_dim)
fc

Linear(in_features=7, out_features=12, bias=True)

In [16]:
fc.weight

Parameter containing:
tensor([[ 0.1163,  0.3046, -0.1006,  0.2548,  0.1786,  0.1081,  0.3656],
        [ 0.3279,  0.2114, -0.0131, -0.3168, -0.1481, -0.1592,  0.0870],
        [ 0.0794,  0.1264,  0.1920,  0.0260, -0.3559, -0.2213, -0.1615],
        [-0.1733, -0.1112, -0.0250,  0.3229,  0.1303,  0.0174, -0.0056],
        [-0.2000, -0.2292, -0.3580,  0.1053, -0.2243,  0.0454, -0.1451],
        [-0.2339,  0.0474, -0.0082, -0.1234,  0.0384, -0.0079,  0.0814],
        [ 0.2550,  0.3175,  0.3719,  0.1353,  0.0593,  0.1334,  0.0893],
        [ 0.2902,  0.0825,  0.3769,  0.2871,  0.3054,  0.0654,  0.1672],
        [-0.0549,  0.2630, -0.0087, -0.0521, -0.2486, -0.0694,  0.0192],
        [-0.1748, -0.1113, -0.0085,  0.2610,  0.1529, -0.0268, -0.2390],
        [ 0.2100, -0.3631,  0.2365, -0.1804,  0.2873,  0.3607,  0.3773],
        [ 0.1979, -0.1182, -0.1451,  0.2423,  0.0529, -0.1993,  0.2752]],
       requires_grad=True)

In [17]:
fc.bias # n_bias == output_dim

Parameter containing:
tensor([ 0.0366, -0.1610,  0.0843, -0.0288,  0.3416,  0.1482, -0.0821, -0.0371,
        -0.3617,  0.3198, -0.0633, -0.0912], requires_grad=True)

In [19]:
# bias 0으로 초기화
fc.bias.data.zero_()

tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])

In [20]:
fc.bias.data

tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])

# kaiming_normal_
(https://pytorch.org/docs/stable/nn.init.html#torch.nn.init.kaiming_normal_)
Delving deep into rectifiers: Surpassing human-level performance on ImageNet classification - He, K. et al. (2015) 에서 제안한 방식으로 값을 채운다.

In [27]:
conv_input_dim = 2
conv_output_dim = 4
kernel_size = 1
encoder_projection = nn.Conv2d(conv_input_dim, conv_output_dim, kernel_size=kernel_size)

print(encoder_projection.weight.data.shape)
print(encoder_projection.bias.data.shape) # conv_output_dim

torch.Size([4, 2, 1, 1])
torch.Size([4])


In [30]:
nn.init.kaiming_normal_(encoder_projection.weight.data, a=0, mode="fan_out", nonlinearity="relu")
# a: leaky_relu에서 0이하 부분의 음의 기울기, mode = "fan_out" or "fan_in", nonlinearity = "relu" or "leaky_relu"

tensor([[[[ 0.5511]],

         [[ 0.5276]]],


        [[[-0.1093]],

         [[ 0.2632]]],


        [[[-0.1005]],

         [[-0.5003]]],


        [[[ 0.0409]],

         [[-0.5255]]]])

# weight의 fan_in, fan_out 반환하여 편향(Bias) 초기화 하기

In [38]:
import math

In [36]:
_fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(encoder_projection.weight.data)

In [40]:
bound = 1 / math.sqrt(fan_out)
nn.init.normal_(encoder_projection.bias, -bound, bound)

Parameter containing:
tensor([-0.1903, -0.5756, -0.4199,  0.1565], requires_grad=True)

# argparse

In [8]:
import argparse
import importlib
import pytorch_lightning as pl

In [None]:
parser = argparse.