In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from sources.fallback_policy.model import QNetwork

In [2]:
# belief_base -- a1 --> belief_base a2
# belief_base -- a2 --> belief_base a3


embedding_dim = 768
num_belief = 2 + 1 # including cls

num_actions = 2
num_future_actions = 1

cls_belief = torch.randn(1, embedding_dim).unsqueeze(0)

belief_base = torch.randn(num_belief, embedding_dim).unsqueeze(0)
belief_base = torch.cat([cls_belief, belief_base], dim=1)

next_belief_base_a = torch.randn(num_belief, embedding_dim).unsqueeze(0)
next_belief_base_a = torch.cat([cls_belief, next_belief_base_a], dim=1)

next_belief_base_b = torch.randn(num_belief, embedding_dim).unsqueeze(0)
next_belief_base_b = torch.cat([cls_belief, next_belief_base_b], dim=1)


goal = torch.randn(1, embedding_dim)

action_a = torch.randn(1, embedding_dim)
action_b = torch.randn(1, embedding_dim)
action_c = torch.randn(1, embedding_dim)
action_d = torch.randn(1, embedding_dim)

transition_a = (belief_base, action_a, next_belief_base_a, action_c, 0.7, goal)
transition_b = (belief_base, action_b, next_belief_base_b, action_d, 0, goal)

batch_belief_base = torch.cat([transition_a[0], transition_b[0]]).to('cuda')
batch_action = torch.cat([transition_a[1], transition_b[1]]).to('cuda')
batch_next_belief_base = torch.cat([transition_a[2], transition_b[2]]).to('cuda')
batch_next_actions = torch.cat([transition_a[3], transition_b[3]]).to('cuda')
batch_reward = torch.tensor([transition_a[4], transition_b[4]], dtype=torch.float).to('cuda')
batch_goal = torch.cat([transition_a[5], transition_b[5]]).to('cuda')
batch_reward

tensor([0.7000, 0.0000], device='cuda:0')

In [3]:
network = QNetwork(embedding_dim, embedding_dim, embedding_dim, n_blocks=1)
network = network.to('cuda')
network

QNetwork(
  (belief_base_encoder): BeliefBaseEncoder(
    (blocks): ModuleList(
      (0): BeliefTransformerBlock(
        (attention_dropout): Dropout(p=0.0, inplace=False)
        (output_dropout): Dropout(p=0.0, inplace=False)
        (layer_norm_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (qkv_proj_layer): Linear(in_features=768, out_features=2304, bias=False)
        (mlp): PositionWiseFF(
          (c_fc): Linear(in_features=768, out_features=768, bias=False)
          (gelu): GELU(approximate='none')
          (c_proj): Linear(in_features=768, out_features=768, bias=False)
          (dropout): Dropout(p=0.0, inplace=False)
        )
        (layer_norm_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      )
    )
  )
  (hidden): Linear(in_features=1536, out_features=768, bias=False)
  (q_value_layer): Linear(in_features=768, out_features=1, bias=False)
)

In [None]:
q(belief_base, a)

In [4]:
# q-learning 
GAMMA = 0.99

num_parameters = sum(p.numel() for p in network.parameters() if p.requires_grad)
print(f"Number of parameters: {num_parameters}")

optimizer = torch.optim.Adam(network.parameters(), lr=1e-4)


for epoch in range(100):
    optimizer.zero_grad()
    # Q(s', a')
    next_q_values = network(belief_base=batch_next_belief_base, 
                            belief_base_sizes=[num_belief], 
                            action_tensors=batch_next_actions)
    best_next_q_values, _ = next_q_values.max(dim=0)
    targets = batch_reward  + (GAMMA * best_next_q_values)
    
    # Q(s, a)
    q_values = network(belief_base=batch_belief_base, belief_base_sizes=[num_belief], action_tensors=batch_action)
    if epoch % 10 == 0:
        print(f"[{epoch}] q-values {q_values.squeeze(-1)}, targets {targets.squeeze(-1)}")
    loss = F.smooth_l1_loss(q_values.squeeze(-1), targets.detach())
    loss.backward()
    nn.utils.clip_grad_norm_(network.parameters(), 5.)
    optimizer.step()
    #break

Number of parameters: 4131072
[0] q-values tensor([-0.1408, -0.0129], device='cuda:0', grad_fn=<SqueezeBackward1>), targets tensor([0.9366, 0.2366], device='cuda:0', grad_fn=<SqueezeBackward1>)
[10] q-values tensor([0.9142, 0.2249], device='cuda:0', grad_fn=<SqueezeBackward1>), targets tensor([0.9116, 0.2116], device='cuda:0', grad_fn=<SqueezeBackward1>)
[20] q-values tensor([0.9517, 0.1589], device='cuda:0', grad_fn=<SqueezeBackward1>), targets tensor([0.8993, 0.1993], device='cuda:0', grad_fn=<SqueezeBackward1>)
[30] q-values tensor([0.9421, 0.1588], device='cuda:0', grad_fn=<SqueezeBackward1>), targets tensor([0.8957, 0.1957], device='cuda:0', grad_fn=<SqueezeBackward1>)
[40] q-values tensor([0.9024, 0.1880], device='cuda:0', grad_fn=<SqueezeBackward1>), targets tensor([0.8950, 0.1950], device='cuda:0', grad_fn=<SqueezeBackward1>)
[50] q-values tensor([0.8773, 0.2068], device='cuda:0', grad_fn=<SqueezeBackward1>), targets tensor([0.8949, 0.1949], device='cuda:0', grad_fn=<SqueezeBac

In [5]:
torch.cuda.memory_allocated('cuda') / 1024.0 / 1024.0

79.56103515625