In [10]:
from model import DQN, LSTM_DQN
import torch
import time
from torchinfo import summary

In [11]:
state = torch.randint(1,10,(1,4,8,3), dtype=torch.float32)
state = state.view(state.size(0), 4, -1)
print(state.shape)

torch.Size([1, 4, 24])


In [12]:
# standard dqn
test_network = DQN(4, 8, [4,4,4,4])
print(summary(test_network, input_data=state))
start_time = time.time()

q_values = test_network(state)
print(q_values.shape)

total_time = time.time() - start_time
print(f"{total_time} to process 600 states, average of {total_time/600:.6f} seconds per state")

Layer (type:depth-idx)                   Output Shape              Param #
DQN                                      [1, 4, 4]                 129
├─Sequential: 1-1                        [1, 64, 24]               --
│    └─Conv1d: 2-1                       [1, 64, 24]               832
│    └─ReLU: 2-2                         [1, 64, 24]               --
│    └─Conv1d: 2-3                       [1, 64, 24]               12,352
│    └─ReLU: 2-4                         [1, 64, 24]               --
├─Sequential: 1-2                        [1, 128]                  --
│    └─Linear: 2-5                       [1, 128]                  196,736
│    └─ReLU: 2-6                         [1, 128]                  --
├─ModuleList: 1-3                        --                        --
│    └─Linear: 2-7                       [1, 4]                    516
│    └─Linear: 2-8                       [1, 4]                    516
│    └─Linear: 2-9                       [1, 4]                    516
│

In [13]:
# noisy dqn
test_network = DQN(4, 8, [4,4,4,4], noisy=True)
print(summary(test_network, input_data=state))
start_time = time.time()

q_values = test_network(state)
print(q_values.shape)

total_time = time.time() - start_time
print(f"{total_time} to process 600 states, average of {total_time/600:.6f} seconds per state")

Layer (type:depth-idx)                   Output Shape              Param #
DQN                                      [1, 4, 4]                 258
├─Sequential: 1-1                        [1, 64, 24]               --
│    └─Conv1d: 2-1                       [1, 64, 24]               832
│    └─ReLU: 2-2                         [1, 64, 24]               --
│    └─Conv1d: 2-3                       [1, 64, 24]               12,352
│    └─ReLU: 2-4                         [1, 64, 24]               --
├─Sequential: 1-2                        [1, 128]                  --
│    └─Noisy_Layer: 2-5                  [1, 128]                  393,472
│    └─ReLU: 2-6                         [1, 128]                  --
├─ModuleList: 1-3                        --                        --
│    └─Noisy_Layer: 2-7                  [1, 4]                    1,032
│    └─Noisy_Layer: 2-8                  [1, 4]                    1,032
│    └─Noisy_Layer: 2-9                  [1, 4]                    1

In [14]:
# dueling noisy dqn
test_network = DQN(4, 8, [4,4,4,4], noisy=True, dueling=True)
print(summary(test_network, input_data=state))
start_time = time.time()

q_values = test_network(state)
print(q_values.shape)

total_time = time.time() - start_time
print(f"{total_time} to process 600 states, average of {total_time/600:.6f} seconds per state")

Layer (type:depth-idx)                   Output Shape              Param #
DQN                                      [1, 4, 4]                 --
├─Sequential: 1-1                        [1, 64, 24]               --
│    └─Conv1d: 2-1                       [1, 64, 24]               832
│    └─ReLU: 2-2                         [1, 64, 24]               --
│    └─Conv1d: 2-3                       [1, 64, 24]               12,352
│    └─ReLU: 2-4                         [1, 64, 24]               --
├─Sequential: 1-2                        [1, 128]                  --
│    └─Noisy_Layer: 2-5                  [1, 128]                  393,472
│    └─ReLU: 2-6                         [1, 128]                  --
├─Noisy_Layer: 1-3                       [1, 1]                    258
├─ModuleList: 1-4                        --                        --
│    └─Noisy_Layer: 2-7                  [1, 4]                    1,032
│    └─Noisy_Layer: 2-8                  [1, 4]                    1,03

In [15]:
state = torch.randint(1,10,(1,15,8,3), dtype=torch.float32)
state = state.view(state.size(0), 15, -1)
print(state.shape)

torch.Size([1, 15, 24])


In [16]:
# standard lstm dqn
test_network = LSTM_DQN(24, [4,4,4,4])
print(summary(test_network, input_data=state))
start_time = time.time()

hidden = None
q_values, hidden = test_network(state, hidden)
print(q_values.shape)
print(hidden[0].shape)
print(hidden[1].shape)

total_time = time.time() - start_time
print(f"{total_time} to process 600 states, average of {total_time/600:.6f} seconds per state")

Layer (type:depth-idx)                   Output Shape              Param #
LSTM_DQN                                 [1, 4, 4]                 129
├─LSTM: 1-1                              [1, 15, 128]              78,848
├─Sequential: 1-2                        [1, 128]                  --
│    └─Linear: 2-1                       [1, 128]                  16,512
│    └─ReLU: 2-2                         [1, 128]                  --
├─ModuleList: 1-3                        --                        --
│    └─Linear: 2-3                       [1, 4]                    516
│    └─Linear: 2-4                       [1, 4]                    516
│    └─Linear: 2-5                       [1, 4]                    516
│    └─Linear: 2-6                       [1, 4]                    516
Total params: 97,553
Trainable params: 97,553
Non-trainable params: 0
Total mult-adds (Units.MEGABYTES): 1.20
Input size (MB): 0.00
Forward/backward pass size (MB): 0.02
Params size (MB): 0.39
Estimated Total Siz

In [17]:
# noisy lstm dqn
test_network = LSTM_DQN(24, [4,4,4,4], noisy=True)
print(summary(test_network, input_data=state))
start_time = time.time()

hidden = None
q_values, hidden = test_network(state, hidden)
print(q_values.shape)
print(hidden[0].shape)
print(hidden[1].shape)

total_time = time.time() - start_time
print(f"{total_time} to process 600 states, average of {total_time/600:.6f} seconds per state")

Layer (type:depth-idx)                   Output Shape              Param #
LSTM_DQN                                 [1, 4, 4]                 258
├─LSTM: 1-1                              [1, 15, 128]              78,848
├─Sequential: 1-2                        [1, 128]                  --
│    └─Noisy_Layer: 2-1                  [1, 128]                  33,024
│    └─ReLU: 2-2                         [1, 128]                  --
├─ModuleList: 1-3                        --                        --
│    └─Noisy_Layer: 2-3                  [1, 4]                    1,032
│    └─Noisy_Layer: 2-4                  [1, 4]                    1,032
│    └─Noisy_Layer: 2-5                  [1, 4]                    1,032
│    └─Noisy_Layer: 2-6                  [1, 4]                    1,032
Total params: 116,258
Trainable params: 116,258
Non-trainable params: 0
Total mult-adds (Units.MEGABYTES): 5.43
Input size (MB): 0.00
Forward/backward pass size (MB): 0.02
Params size (MB): 0.46
Estimated

In [18]:
# dueling noisy lstm dqn
test_network = LSTM_DQN(24, [4,4,4,4], noisy=True, dueling=True)
print(summary(test_network, input_data=state))
start_time = time.time()

hidden = None
q_values, hidden = test_network(state, hidden)
print(q_values.shape)
print(hidden[0].shape)
print(hidden[1].shape)

total_time = time.time() - start_time
print(f"{total_time} to process 600 states, average of {total_time/600:.6f} seconds per state")

Layer (type:depth-idx)                   Output Shape              Param #
LSTM_DQN                                 [1, 4, 4]                 --
├─LSTM: 1-1                              [1, 15, 128]              78,848
├─Sequential: 1-2                        [1, 128]                  --
│    └─Noisy_Layer: 2-1                  [1, 128]                  33,024
│    └─ReLU: 2-2                         [1, 128]                  --
├─Noisy_Layer: 1-3                       [1, 1]                    258
├─ModuleList: 1-4                        --                        --
│    └─Noisy_Layer: 2-3                  [1, 4]                    1,032
│    └─Noisy_Layer: 2-4                  [1, 4]                    1,032
│    └─Noisy_Layer: 2-5                  [1, 4]                    1,032
│    └─Noisy_Layer: 2-6                  [1, 4]                    1,032
Total params: 116,258
Trainable params: 116,258
Non-trainable params: 0
Total mult-adds (Units.MEGABYTES): 5.43
Input size (MB): 0.00
