In [6]:
from model import DQN
import torch
import time
from torchinfo import summary

In [7]:

state = torch.randint(1,10,(600,4,8,3), dtype=torch.float32)
state = state.view(state.size(0), 4, -1)
print(state.shape)

torch.Size([600, 4, 24])


In [8]:
# standard dqn
test_network = DQN(4, 8, [4,4,4,4])
print(summary(test_network, input_data=state))
start_time = time.time()

q_values = test_network(state)
print(q_values.shape)

total_time = time.time() - start_time
print(f"{total_time} to process 600 states, average of {total_time/600:.6f} seconds per state")

Layer (type:depth-idx)                   Output Shape              Param #
DQN                                      [600, 4, 4]               129
├─Sequential: 1-1                        [600, 64, 24]             --
│    └─Conv1d: 2-1                       [600, 64, 24]             832
│    └─ReLU: 2-2                         [600, 64, 24]             --
│    └─Conv1d: 2-3                       [600, 64, 24]             12,352
│    └─ReLU: 2-4                         [600, 64, 24]             --
├─Sequential: 1-2                        [600, 128]                --
│    └─Linear: 2-5                       [600, 128]                196,736
│    └─ReLU: 2-6                         [600, 128]                --
├─ModuleList: 1-3                        --                        --
│    └─Linear: 2-7                       [600, 4]                  516
│    └─Linear: 2-8                       [600, 4]                  516
│    └─Linear: 2-9                       [600, 4]                  516
│

In [9]:
# standard noisy dqn
test_network = DQN(4, 8, [4,4,4,4], noisy=True)
print(summary(test_network, input_data=state))
start_time = time.time()

q_values = test_network(state)
print(q_values.shape)

total_time = time.time() - start_time
print(f"{total_time} to process 600 states, average of {total_time/600:.6f} seconds per state")

Layer (type:depth-idx)                   Output Shape              Param #
DQN                                      [600, 4, 4]               258
├─Sequential: 1-1                        [600, 64, 24]             --
│    └─Conv1d: 2-1                       [600, 64, 24]             832
│    └─ReLU: 2-2                         [600, 64, 24]             --
│    └─Conv1d: 2-3                       [600, 64, 24]             12,352
│    └─ReLU: 2-4                         [600, 64, 24]             --
├─Sequential: 1-2                        [600, 128]                --
│    └─Noisy_Layer: 2-5                  [600, 128]                393,472
│    └─ReLU: 2-6                         [600, 128]                --
├─ModuleList: 1-3                        --                        --
│    └─Noisy_Layer: 2-7                  [600, 4]                  1,032
│    └─Noisy_Layer: 2-8                  [600, 4]                  1,032
│    └─Noisy_Layer: 2-9                  [600, 4]                  1

In [10]:
# dueling noisy dqn
test_network = DQN(4, 8, [4,4,4,4], noisy=True, dueling=True)
print(summary(test_network, input_data=state))
start_time = time.time()

q_values = test_network(state)
print(q_values.shape)

total_time = time.time() - start_time
print(f"{total_time} to process 600 states, average of {total_time/600:.6f} seconds per state")

Layer (type:depth-idx)                   Output Shape              Param #
DQN                                      [600, 4, 4]               --
├─Sequential: 1-1                        [600, 64, 24]             --
│    └─Conv1d: 2-1                       [600, 64, 24]             832
│    └─ReLU: 2-2                         [600, 64, 24]             --
│    └─Conv1d: 2-3                       [600, 64, 24]             12,352
│    └─ReLU: 2-4                         [600, 64, 24]             --
├─Sequential: 1-2                        [600, 128]                --
│    └─Noisy_Layer: 2-5                  [600, 128]                393,472
│    └─ReLU: 2-6                         [600, 128]                --
├─Noisy_Layer: 1-3                       [600, 1]                  258
├─ModuleList: 1-4                        --                        --
│    └─Noisy_Layer: 2-7                  [600, 4]                  1,032
│    └─Noisy_Layer: 2-8                  [600, 4]                  1,03