In [1]:
import torch
import torch.nn as nn

In [2]:
w = torch.empty(3,5)
print("before param init: \n{}".format(w))

before param init: 
tensor([[-2.7450e+03,  4.5856e-41, -2.4681e+23, -4.3239e-13, -2.7434e+03],
        [ 4.5856e-41, -2.7450e+03,  4.5856e-41, -5.1497e+35,  3.2302e-13],
        [-2.7434e+03,  4.5856e-41, -2.7450e+03,  4.5856e-41, -3.6253e-36]])


# Tensor init

# 1.1 constant init

In [3]:
nn.init.constant_(w, 0.3)
print("after param init: \n{}".format(w))

after param init: 
tensor([[0.3000, 0.3000, 0.3000, 0.3000, 0.3000],
        [0.3000, 0.3000, 0.3000, 0.3000, 0.3000],
        [0.3000, 0.3000, 0.3000, 0.3000, 0.3000]])


# 1.2 ones init

In [4]:
nn.init.ones_(w)
print("after param init: \n{}".format(w))

after param init: 
tensor([[1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.]])


# 1.3 zeros init

In [5]:
nn.init.zeros_(w)
print("after param init: \n{}".format(w))

after param init: 
tensor([[0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.]])


# 1.4 uniform init

In [6]:
nn.init.uniform_(w, a=0.0, b=0.2)
print("after param init: \n{}".format(w))

after param init: 
tensor([[0.0897, 0.1913, 0.0653, 0.0816, 0.0497],
        [0.0549, 0.1356, 0.1205, 0.0653, 0.0551],
        [0.1471, 0.0642, 0.1996, 0.1910, 0.1545]])


# 1.5 normal init

In [7]:
nn.init.normal_(w, mean=0.0, std=1)
print("after param init: \n{}".format(w))

after param init: 
tensor([[-0.0056, -0.7587,  0.1510, -0.7427,  0.1151],
        [ 0.1302,  0.5895, -0.6391,  0.0750, -1.7525],
        [ 1.6899,  0.4278,  0.0482, -1.4283, -1.5449]])


# Module init

In [8]:
@torch.no_grad()
def init_weights(m):
    print(m)
    if type(m) == nn.Linear:
        nn.init.ones_(m.weight)
        nn.init.zeros_(m.bias)

In [9]:
net = nn.Linear(2, 3)

In [10]:
net.apply(init_weights)
print(net)

Linear(in_features=2, out_features=3, bias=True)
Linear(in_features=2, out_features=3, bias=True)


In [14]:
list(net.named_parameters())

[('weight',
  Parameter containing:
  tensor([[1., 1.],
          [1., 1.],
          [1., 1.]], requires_grad=True)),
 ('bias',
  Parameter containing:
  tensor([0., 0., 0.], requires_grad=True))]