# Development notebook

In [1]:
# https://ipython.org/ipython-doc/3/config/extensions/autoreload.html
%load_ext autoreload
%autoreload 2

In [2]:
import sys
sys.path.insert(0, "../examples/train_on_cifar10/")

In [12]:
import torch
from torch.optim import SGD
from torch.nn import CrossEntropyLoss

In [7]:
from small_vgg16_bn import get_small_vgg16_bn
model = get_small_vgg16_bn(10)

In [5]:
from train import get_data_loaders, train_data_transform, val_data_transform

In [10]:
train_loader, val_loader = get_data_loaders("../examples/train_on_cifar10/", train_data_transform, val_data_transform,
                                            16, 16, num_workers=8, cuda=False)


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ../examples/train_on_cifar10/cifar-10-python.tar.gz


In [16]:
loss_fn = CrossEntropyLoss()
optimizer = SGD(model.parameters(), lr=0.1)

In [19]:
from ignite._utils import to_variable


def _prepare_batch(batch):
    x, y = batch
    x = to_variable(x, cuda=False)
    y = to_variable(y, cuda=False)
    return x, y

def _update(engine, batch):
    model.train()
    optimizer.zero_grad()
    x, y = _prepare_batch(batch)
    y_pred = model(x)
    loss = loss_fn(y_pred, y)
    loss.backward()
    optimizer.step()
    return loss

In [18]:
train_loader_iter = iter(train_loader)
batch = next(train_loader_iter)
len(batch), batch[0].shape

(2, torch.Size([16, 3, 32, 32]))

In [27]:
len(optimizer.param_groups[0]['params']), optimizer.param_groups[0]['params'][0].shape

(58, torch.Size([64, 3, 3, 3]))

In [29]:
for k in optimizer.param_groups[0]:
    print(k)

momentum
dampening
params
nesterov
weight_decay
lr


In [39]:
from torch.nn import Module, LSTM, Linear, Sequential, ReLU


class ActorNetwork(Module):
    """Actor network or Policy
    """
    
    def __init__(self, input_dim, output_dim):
        super(ActorNetwork, self).__init__()
        self.lstm = LSTM(input_dim, 20, 2)

        
    def forward(self, state, hidden=None):
        out, hidden = self.lstm(state, hidden)
        return out, hidden
    

class CriticNetwork(Module):
    """Critic network or Q-function
    """
    
    def __init__(self, input_dim):
        super(CriticNetwork, self).__init__()
        self.q_fn = Sequential(
            Linear(input_dim, 10),
            ReLU(inplace=True)
            Linear(10, 1)
        )
        
    def forward(self, action, state):
        x = torch.cat([action, state], dim=1)
        reward = self.q_fn(x)
        return reward

In [None]:
actor_network = None

In [37]:
from torch.autograd import Variable

In [38]:
rnn = LSTM(10, 20, 2)
x = Variable(torch.from(5, 3, 10))
output, hn = rnn(x, None) 