In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from sources.fallback_policy.model import QNetwork

In [2]:
# belief_base -- a1 --> belief_base a2
# belief_base -- a2 --> belief_base a3


embedding_dim = 768
num_belief = 2 + 1 # including cls

num_actions = 2
num_future_actions = 1

cls_belief = torch.randn(1, embedding_dim).unsqueeze(0)

belief_base = torch.randn(num_belief, embedding_dim).unsqueeze(0)
belief_base = torch.cat([cls_belief, belief_base], dim=1)

next_belief_base_a = torch.randn(num_belief, embedding_dim).unsqueeze(0)
next_belief_base_a = torch.cat([cls_belief, next_belief_base_a], dim=1)

next_belief_base_b = torch.randn(num_belief, embedding_dim).unsqueeze(0)
next_belief_base_b = torch.cat([cls_belief, next_belief_base_b], dim=1)

action_a = torch.randn(1, embedding_dim)
action_b = torch.randn(1, embedding_dim)
action_c = torch.randn(1, embedding_dim)
action_d = torch.randn(1, embedding_dim)

transition_a = (belief_base, action_a, next_belief_base_a, action_c, 10)
transition_b = (belief_base, action_b, next_belief_base_b, action_d, 0)

batch_belief_base = torch.cat([transition_a[0], transition_b[0]]).to('cuda')
batch_action = torch.cat([transition_a[1], transition_b[1]]).to('cuda')
batch_next_belief_base = torch.cat([transition_a[2], transition_b[2]]).to('cuda')
batch_next_actions = torch.cat([transition_a[3], transition_b[3]]).to('cuda')
batch_reward = torch.tensor([transition_a[4], transition_b[4]], dtype=torch.float).to('cuda')
batch_reward

tensor([10.,  0.], device='cuda:0')

In [3]:
network = QNetwork(embedding_dim, embedding_dim, n_blocks=1)
network = network.to('cuda')
network

QNetwork(
  (belief_base_encoder): BeliefBaseEncoder(
    (blocks): ModuleList(
      (0): BeliefTransformerBlock(
        (attention_dropout): Dropout(p=0.0, inplace=False)
        (output_dropout): Dropout(p=0.0, inplace=False)
        (layer_norm_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (qkv_proj_layer): Linear(in_features=768, out_features=2304, bias=False)
        (mlp): PositionWiseFF(
          (c_fc): Linear(in_features=768, out_features=3072, bias=False)
          (gelu): GELU(approximate='none')
          (c_proj): Linear(in_features=3072, out_features=768, bias=False)
          (dropout): Dropout(p=0.0, inplace=False)
        )
        (layer_norm_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      )
    )
  )
  (hidden): Linear(in_features=1536, out_features=768, bias=False)
  (q_value_layer): Linear(in_features=768, out_features=1, bias=False)
)

In [4]:
# q-learning 
GAMMA = 0.99

num_parameters = sum(p.numel() for p in network.parameters())
print(num_parameters)

optimizer = torch.optim.Adam(network.parameters(), lr=1e-4)


for epoch in range(1000):
    optimizer.zero_grad()
    # Q(s', a')
    next_q_values = network(belief_base=batch_next_belief_base, belief_base_sizes=[num_belief], action_tensors=batch_next_actions)
    #print(next_q_values.size())
    # best action of the next state
    #next_q_values = torch.tensor([v.max() for v in next_q_values])
    best_next_q_values, _ = next_q_values.max(dim=0)
    targets = batch_reward  + (GAMMA * best_next_q_values)
    
    # Q(s, a)
    q_values = network(belief_base=batch_belief_base, belief_base_sizes=[num_belief], action_tensors=batch_action)
    if epoch % 100 == 0:
        print(f"[{epoch}] q-values {q_values.squeeze(-1)}, targets {targets.squeeze(-1)}")
    loss = F.smooth_l1_loss(q_values.squeeze(-1), targets.detach())
    loss.backward()
    nn.utils.clip_grad_norm_(network.parameters(), 1.)
    optimizer.step()
    #break

7670016
[0] q-values tensor([-0.0533, -0.1886], device='cuda:0', grad_fn=<SqueezeBackward1>), targets tensor([10.1080,  0.1080], device='cuda:0', grad_fn=<SqueezeBackward1>)
[100] q-values tensor([10.1696,  0.0918], device='cuda:0', grad_fn=<SqueezeBackward1>), targets tensor([10.0299,  0.0299], device='cuda:0', grad_fn=<SqueezeBackward1>)
[200] q-values tensor([10.1762,  0.0989], device='cuda:0', grad_fn=<SqueezeBackward1>), targets tensor([10.0341,  0.0341], device='cuda:0', grad_fn=<SqueezeBackward1>)
[300] q-values tensor([10.0364,  0.0375], device='cuda:0', grad_fn=<SqueezeBackward1>), targets tensor([10.0379,  0.0379], device='cuda:0', grad_fn=<SqueezeBackward1>)
[400] q-values tensor([10.0378,  0.0378], device='cuda:0', grad_fn=<SqueezeBackward1>), targets tensor([10.0378,  0.0378], device='cuda:0', grad_fn=<SqueezeBackward1>)
[500] q-values tensor([10.1754,  0.1147], device='cuda:0', grad_fn=<SqueezeBackward1>), targets tensor([10.0413,  0.0413], device='cuda:0', grad_fn=<Squee