In [1]:
import torch
from torch.autograd import Variable
import torchvision.transforms as transforms
import torchvision.datasets as dsets
import numpy as np
from torch import optim
from torch.nn import CrossEntropyLoss

In [2]:
from dqn import DQN
from torch.nn.functional import smooth_l1_loss
from mlnn import MLNN
from replay_memory import ReplayMemory, Transition
from training import optimize_dqn
import random

In [3]:
train_dataset = dsets.MNIST(root='./data', train=True, transform=transforms.ToTensor(), download=True)
test_dataset = dsets.MNIST(root='./data', train=False, transform=transforms.ToTensor(), download=True)

In [4]:
torch.manual_seed(0)
torch.cuda.manual_seed(0)
np.random.seed(0)
random.seed(0)

In [5]:
batch_size = 100
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=True)

In [6]:
batch_size = 100
epochs = 30
input_dims = 784
hidden_dims = 333
output_dims = 10
gamma = 0.999
target_update = 10
len_train_dataset = len(train_loader.dataset)

In [7]:
n_layers = 3
n_width = 2

In [8]:
policy_net = DQN(hidden_dims, hidden_dims, n_width).cuda()
target_net = DQN(hidden_dims, hidden_dims, n_width).cuda()
model = MLNN(n_layers, n_width, input_dims, hidden_dims, output_dims)

In [9]:
memory_capacity = 100
replay_memory = ReplayMemory(memory_capacity)

In [10]:
optimizer = optim.Adam(model.get_params())
dqn_optimizer = optim.Adam(policy_net.parameters())
criterion = CrossEntropyLoss().cuda()

In [11]:
def optimize_dqn(policy_net, target_net, replay_memory, optimizer, batch_size, gamma):
    if len(replay_memory) < batch_size:
        return
    
    transitions = replay_memory.sample(batch_size)
    
    batch = Transition(*zip(*transitions))
    
    state = torch.stack(batch.state)
    action = torch.stack(batch.action).reshape([-1, 1])
    next_state = torch.stack(batch.next_state)
    reward = torch.stack(batch.reward)
    
    q_values = policy_net(state).gather(1, action.reshape([-1, 1])).squeeze()
    #print(batch.reward)
    expected_q_values = (target_net(next_state).max(1)[0].detach() * gamma) + reward

    loss = smooth_l1_loss(q_values, expected_q_values)

    optimizer.zero_grad()
    loss.backward()
    
    for param in policy_net.parameters():
        param.grad.data.clamp_(-1, 1)
    
    optimizer.step()
    
    return loss

In [12]:
avg_mlnn_losses = []
avg_dqn_losses = []

for epoch in range(epochs):
    avg_mlnn_loss = 0.
    avg_dqn_loss = 0.
    
    for i, (images, labels) in enumerate(train_loader):
        images = Variable(images.view(-1, 28 * 28)).cuda()
        labels = Variable(labels).cuda()

        replays, loss, outputs = model.train(images, labels, policy_net, criterion)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        avg_mlnn_loss += loss.detach()

        for replay in replays:
            replay_memory.push(replay.state, replay.action, replay.next_state, replay.reward)

        avg_dqn_loss += optimize_dqn(policy_net, target_net, replay_memory, dqn_optimizer, batch_size, gamma)

        if i % target_update == 0:
            target_net.load_state_dict(policy_net.state_dict())

    avg_mlnn_loss /= float(i)
    avg_dqn_loss /= float(i)
    
    avg_mlnn_losses.append(avg_mlnn_loss)
    avg_dqn_losses.append(avg_dqn_loss)
    
    print(epoch, avg_mlnn_loss.data, avg_dqn_loss.data)

0 tensor(1.0399, device='cuda:0') tensor(0.0745, device='cuda:0')
1 tensor(0.9155, device='cuda:0') tensor(0.0194, device='cuda:0')
2 tensor(0.8378, device='cuda:0') tensor(0.0292, device='cuda:0')
3 tensor(0.7209, device='cuda:0') tensor(0.0689, device='cuda:0')
4 tensor(0.5636, device='cuda:0') tensor(0.1381, device='cuda:0')
5 tensor(0.5467, device='cuda:0') tensor(0.1632, device='cuda:0')
6 tensor(0.5244, device='cuda:0') tensor(0.1831, device='cuda:0')
7 tensor(0.5066, device='cuda:0') tensor(0.1981, device='cuda:0')
8 tensor(0.4925, device='cuda:0') tensor(0.2147, device='cuda:0')
9 tensor(0.4917, device='cuda:0') tensor(0.2249, device='cuda:0')
10 tensor(0.5316, device='cuda:0') tensor(0.1999, device='cuda:0')
11 tensor(0.5124, device='cuda:0') tensor(0.2072, device='cuda:0')
12 tensor(0.5054, device='cuda:0') tensor(0.2153, device='cuda:0')
13 tensor(0.4916, device='cuda:0') tensor(0.2253, device='cuda:0')
14 tensor(0.4562, device='cuda:0') tensor(0.2550, device='cuda:0')
15 te