### 간단한 nn 모듈 사용 예제

In [1]:
import torch
from torch.autograd import Variable

torch.manual_seed(7)

N, D_in, H, D_out = 64,1000,100,10

x = Variable(torch.randn(N, D_in))
y = Variable(torch.randn(N, D_out), requires_grad=False)

"""
모델을 순차적 계층(Sequence of layers)으로 정의
"""
model = torch.nn.Sequential(torch.nn.Linear(D_in, H),
                            torch.nn.ReLU(),
                            torch.nn.Linear(H, D_out))

"""
널리 사용하는 손실 함수들에 대한 정의도 포함
"""

loss_fn = torch.nn.MSELoss(size_average = False)
learning_rate = 1e-4
for t in range(500):
    # 순전파 단계
    y_pred = model(x)
    
    # 손실을 계산하고 출력
    loss = loss_fn(y_pred, y)
    print(t, loss.data)
    
    # 역전파 단계 실행 전, 변화도를 0으로 세팅
    model.zero_grad()
    
    # 역전파 단계
    loss.backward()
    
    # 경사하강법 사용하여 가중치 갱신
    for param in model.parameters():
        param.data -= learning_rate * param.grad.data



0 tensor(654.8659)
1 tensor(607.4591)
2 tensor(566.1823)
3 tensor(529.6280)
4 tensor(496.9473)
5 tensor(467.1930)
6 tensor(439.8441)
7 tensor(414.8025)
8 tensor(391.7271)
9 tensor(370.2265)
10 tensor(350.0983)
11 tensor(331.1125)
12 tensor(313.1535)
13 tensor(296.1174)
14 tensor(279.9932)
15 tensor(264.7670)
16 tensor(250.3797)
17 tensor(236.7228)
18 tensor(223.7385)
19 tensor(211.4389)
20 tensor(199.7336)
21 tensor(188.5886)
22 tensor(178.0149)
23 tensor(167.9607)
24 tensor(158.4340)
25 tensor(149.3776)
26 tensor(140.7773)
27 tensor(132.6402)
28 tensor(124.9441)
29 tensor(117.6373)
30 tensor(110.7234)
31 tensor(104.1849)
32 tensor(98.0163)
33 tensor(92.1954)
34 tensor(86.6986)
35 tensor(81.5274)
36 tensor(76.6598)
37 tensor(72.1005)
38 tensor(67.8230)
39 tensor(63.7976)
40 tensor(60.0105)
41 tensor(56.4614)
42 tensor(53.1309)
43 tensor(50.0055)
44 tensor(47.0783)
45 tensor(44.3301)
46 tensor(41.7508)
47 tensor(39.3320)
48 tensor(37.0640)
49 tensor(34.9383)
50 tensor(32.9418)
51 tensor

432 tensor(9.4821e-05)
433 tensor(9.2256e-05)
434 tensor(8.9760e-05)
435 tensor(8.7331e-05)
436 tensor(8.4973e-05)
437 tensor(8.2672e-05)
438 tensor(8.0439e-05)
439 tensor(7.8261e-05)
440 tensor(7.6150e-05)
441 tensor(7.4100e-05)
442 tensor(7.2102e-05)
443 tensor(7.0157e-05)
444 tensor(6.8266e-05)
445 tensor(6.6426e-05)
446 tensor(6.4634e-05)
447 tensor(6.2896e-05)
448 tensor(6.1202e-05)
449 tensor(5.9550e-05)
450 tensor(5.7952e-05)
451 tensor(5.6393e-05)
452 tensor(5.4878e-05)
453 tensor(5.3401e-05)
454 tensor(5.1967e-05)
455 tensor(5.0574e-05)
456 tensor(4.9216e-05)
457 tensor(4.7896e-05)
458 tensor(4.6611e-05)
459 tensor(4.5359e-05)
460 tensor(4.4144e-05)
461 tensor(4.2960e-05)
462 tensor(4.1809e-05)
463 tensor(4.0690e-05)
464 tensor(3.9601e-05)
465 tensor(3.8540e-05)
466 tensor(3.7512e-05)
467 tensor(3.6508e-05)
468 tensor(3.5531e-05)
469 tensor(3.4582e-05)
470 tensor(3.3662e-05)
471 tensor(3.2764e-05)
472 tensor(3.1887e-05)
473 tensor(3.1036e-05)
474 tensor(3.0208e-05)
475 tensor(