In [31]:
from model import DQN, LSTM_DQN
import torch
import time
from torchinfo import summary
from torch.distributions.categorical import Categorical

In [32]:
state = torch.randint(1,10,(10,4,8,3), dtype=torch.float32)
state = state.view(state.size(0), 4, -1)
print(state.shape)

torch.Size([10, 4, 24])


In [33]:
# standard dqn
test_network = DQN(4, 8, [4,4,4,4])
print(summary(test_network, input_data=state))
start_time = time.time()

q_values = test_network(state)
print(q_values.shape)

total_time = time.time() - start_time
print(f"{total_time} to process 600 states, average of {total_time/600:.6f} seconds per state")

Layer (type:depth-idx)                   Output Shape              Param #
DQN                                      [10, 4, 4]                129
├─Sequential: 1-1                        [10, 64, 24]              --
│    └─Conv1d: 2-1                       [10, 64, 24]              832
│    └─ReLU: 2-2                         [10, 64, 24]              --
│    └─Conv1d: 2-3                       [10, 64, 24]              12,352
│    └─ReLU: 2-4                         [10, 64, 24]              --
├─Sequential: 1-2                        [10, 128]                 --
│    └─Linear: 2-5                       [10, 128]                 196,736
│    └─ReLU: 2-6                         [10, 128]                 --
├─ModuleList: 1-3                        --                        --
│    └─Linear: 2-7                       [10, 4]                   516
│    └─Linear: 2-8                       [10, 4]                   516
│    └─Linear: 2-9                       [10, 4]                   516
│

In [34]:
state = torch.randint(1,10,(1,15,8,3), dtype=torch.float32)
state = state.view(state.size(0), 15, -1)
print(state.shape)

torch.Size([1, 15, 24])


In [35]:
# standard lstm dqn
test_network = LSTM_DQN(24, [4,4,4,4])
print(summary(test_network, input_data=state))
start_time = time.time()

hidden = None
q_values, hidden = test_network(state, hidden)
print(q_values.shape)
print(hidden[0].shape)
print(hidden[1].shape)

total_time = time.time() - start_time
print(f"{total_time} to process 600 states, average of {total_time/600:.6f} seconds per state")

Layer (type:depth-idx)                   Output Shape              Param #
LSTM_DQN                                 [1, 4, 4]                 129
├─LSTM: 1-1                              [1, 15, 128]              78,848
├─Sequential: 1-2                        [1, 128]                  --
│    └─Linear: 2-1                       [1, 128]                  16,512
│    └─ReLU: 2-2                         [1, 128]                  --
├─ModuleList: 1-3                        --                        --
│    └─Linear: 2-3                       [1, 4]                    516
│    └─Linear: 2-4                       [1, 4]                    516
│    └─Linear: 2-5                       [1, 4]                    516
│    └─Linear: 2-6                       [1, 4]                    516
Total params: 97,553
Trainable params: 97,553
Non-trainable params: 0
Total mult-adds (Units.MEGABYTES): 1.20
Input size (MB): 0.00
Forward/backward pass size (MB): 0.02
Params size (MB): 0.39
Estimated Total Siz

