# Good Initialization in Weight is much more important

- Not all 0's

# 2006 RBM
- Restricted : no connections within a layer
- Pre-training
- Fine-tuning

# Xavier2010 / He initialization 2015
- Xaiver Normal Initialization
- Xaiver Uniform Initialization
- He Normal Initialization
- He Uniform Initialization

In [1]:
import torch
import torchvision.datasets as dsets
import torchvision.transforms as transforms
import random

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [3]:
learning_rate = 1e-3
training_epochs = 20
batch_size = 100

In [4]:
mnist_train = dsets.MNIST(root='MNIST_data', train=True, transform=transforms.ToTensor())
mnist_test = dsets.MNIST(root='MNIST_data', train=False, transform=transforms.ToTensor())

data_loader = torch.utils.data.DataLoader(dataset=mnist_train, batch_size=batch_size, shuffle=True, drop_last=True)

In [5]:
linear1 = torch.nn.Linear(784, 256, bias=True)
linear2 = torch.nn.Linear(256, 256, bias=True)
linear3 = torch.nn.Linear(256, 10, bias=True)
relu = torch.nn.ReLU()

In [6]:
torch.nn.init.normal_(linear1.weight)
torch.nn.init.normal_(linear2.weight)
torch.nn.init.normal_(linear3.weight)

Parameter containing:
tensor([[ 0.3371,  1.5778, -0.9435,  ..., -1.3917, -0.9079,  1.6330],
        [ 0.0306,  0.3587, -0.6631,  ...,  0.9687, -0.2138,  1.2170],
        [-1.7550, -0.1530,  0.2656,  ...,  1.6182,  0.3594,  1.4337],
        ...,
        [ 0.3382, -0.9864, -0.7690,  ...,  1.0798,  0.8634,  0.4950],
        [ 0.2922, -0.1261, -1.1612,  ..., -0.5381, -0.3303, -0.9882],
        [ 1.8342,  1.8050,  1.0897,  ...,  0.5521,  0.2920, -0.6831]],
       requires_grad=True)

In [7]:
model = torch.nn.Sequential(linear1, relu, linear2, relu, linear3).to(device)

criterion = torch.nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [8]:
len(data_loader)

600

In [9]:
total_batch = len(data_loader)
for epoch in range(training_epochs):
    avg_cost = 0
    for X,Y in data_loader:
        X = X.view(-1, 28*28).to(device)
        Y = Y.to(device)
        
        optimizer.zero_grad()
        hypothesis = model(X)
        cost = criterion(hypothesis, Y)
        cost.backward()
        optimizer.step()
        
        avg_cost += cost/total_batch
    print('Epoch:', '%04d' % (epoch + 1), 'cost =', '{:.9f}'.format(avg_cost))

Epoch: 0001 cost = 153.991210938
Epoch: 0002 cost = 35.972740173
Epoch: 0003 cost = 21.992288589
Epoch: 0004 cost = 15.176042557
Epoch: 0005 cost = 10.973302841
Epoch: 0006 cost = 8.108571053
Epoch: 0007 cost = 6.156758785
Epoch: 0008 cost = 4.559571743
Epoch: 0009 cost = 3.521098137
Epoch: 0010 cost = 2.690568447
Epoch: 0011 cost = 1.960561633
Epoch: 0012 cost = 1.510765195
Epoch: 0013 cost = 1.135576963
Epoch: 0014 cost = 0.928186357
Epoch: 0015 cost = 0.802532434
Epoch: 0016 cost = 0.552375376
Epoch: 0017 cost = 0.552641392
Epoch: 0018 cost = 0.463200450
Epoch: 0019 cost = 0.393307000
Epoch: 0020 cost = 0.349274218


# 다층 Layer & Xavier_uniform weight initialization 

In [10]:
data_loader = torch.utils.data.DataLoader(dataset=mnist_train, batch_size=batch_size, shuffle=True, drop_last=True)

linear1 = torch.nn.Linear(784, 512, bias=True)
linear2 = torch.nn.Linear(512, 512, bias=True)
linear3 = torch.nn.Linear(512, 512, bias=True)
linear4 = torch.nn.Linear(512, 512, bias=True)
linear5 = torch.nn.Linear(512, 10, bias=True)
relu = torch.nn.ReLU()

torch.nn.init.xavier_uniform_(linear1.weight)
torch.nn.init.xavier_uniform_(linear2.weight)
torch.nn.init.xavier_uniform_(linear3.weight)
torch.nn.init.xavier_uniform_(linear4.weight)
torch.nn.init.xavier_uniform_(linear5.weight)

Parameter containing:
tensor([[-0.0812, -0.0807, -0.0679,  ..., -0.0906, -0.0587,  0.0927],
        [-0.0613,  0.0023,  0.0298,  ..., -0.0426,  0.0027, -0.0952],
        [-0.0398,  0.0915, -0.0893,  ...,  0.0949,  0.1035,  0.0818],
        ...,
        [-0.0556, -0.0426, -0.0623,  ...,  0.0671, -0.0772, -0.0652],
        [-0.0812, -0.1014,  0.0155,  ..., -0.0153, -0.0307, -0.0996],
        [-0.0295, -0.1072,  0.1052,  ..., -0.0547,  0.0244, -0.0234]],
       requires_grad=True)

In [11]:
model = torch.nn.Sequential(linear1, relu, linear2, relu, linear3, relu, linear4, relu, linear5).to(device)
criterion = torch.nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [12]:
total_batch = len(data_loader)
for epoch in range(training_epochs):
    avg_cost = 0
    for X,Y in data_loader:
        X = X.view(-1, 28*28).to(device)
        Y = Y.to(device)
        
        optimizer.zero_grad()
        hypothesis = model(X)
        cost = criterion(hypothesis, Y)
        cost.backward()
        optimizer.step()
        
        avg_cost += cost/total_batch
    print('Epoch:', '%04d' % (epoch + 1), 'cost =', '{:.9f}'.format(avg_cost))

Epoch: 0001 cost = 0.209069237
Epoch: 0002 cost = 0.087504402
Epoch: 0003 cost = 0.062247038
Epoch: 0004 cost = 0.050295848
Epoch: 0005 cost = 0.041956827
Epoch: 0006 cost = 0.033338562
Epoch: 0007 cost = 0.029913915
Epoch: 0008 cost = 0.025480697
Epoch: 0009 cost = 0.023643117
Epoch: 0010 cost = 0.021809930
Epoch: 0011 cost = 0.019750046
Epoch: 0012 cost = 0.020538975
Epoch: 0013 cost = 0.016350599
Epoch: 0014 cost = 0.015930574
Epoch: 0015 cost = 0.012464210
Epoch: 0016 cost = 0.016491905
Epoch: 0017 cost = 0.015366506
Epoch: 0018 cost = 0.010395721
Epoch: 0019 cost = 0.013211161
Epoch: 0020 cost = 0.012775788


In [13]:
with torch.no_grad():
    X_test = mnist_test.test_data.view(-1, 28 * 28).float().to(device)
    Y_test = mnist_test.test_labels.to(device)

    prediction = model(X_test)
    correct_prediction = torch.argmax(prediction, 1) == Y_test
    accuracy = correct_prediction.float().mean()
    print('Accuracy:', accuracy.item())

    r = random.randint(0, len(mnist_test) - 1)
    X_single_data = mnist_test.test_data[r:r + 1].view(-1, 28 * 28).float().to(device)
    Y_single_data = mnist_test.test_labels[r:r + 1].to(device)

    print('Label: ', Y_single_data.item())
    single_prediction = model(X_single_data)
    print('Prediction: ', torch.argmax(single_prediction, 1).item())

Accuracy: 0.9772999882698059
Label:  1
Prediction:  1


