In [1]:
from dataset import GameplayActionPairVideoDataset
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer
import torch
from utils import custom_collate_fn
from model.agent import Agent, device
from torch import nn, optim

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
root_dir = "output_logs"

In [3]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
dataset = GameplayActionPairVideoDataset(root_dir=root_dir, tokenizer=tokenizer)



In [4]:
dataloader = DataLoader(dataset, batch_size=1, shuffle=True, collate_fn=custom_collate_fn)
dataloader

<torch.utils.data.dataloader.DataLoader at 0x7d9f9c28ae30>

In [5]:
agent = Agent(debug=False).to(device)

In [6]:
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(agent.parameters(), lr=0.001)

In [7]:
num_epochs = 1

In [None]:
for epoch in range(num_epochs):
    for batch, (instruction, frame, action) in enumerate(dataloader):
        frame = frame.to(device)
        action = action.to(device)
        _, _, channel, height, width = frame.shape
        images = frame.reshape(-1, channel, height, width).to(device)
        instruction = tokenizer.decode(*instruction)
        optimizer.zero_grad()
        logits = agent(images, frame, instruction)
        # print(f"logits: {logits}")
        # print(f"action: {action}")
        loss = criterion(logits, action)
        loss.backward()
        optimizer.step()
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}")

In [9]:
agent.eval()
with torch.no_grad():
    instruction, frame, action = dataset[0]
    frame = frame.to(device)
    action = action.to(device)
    _, channel, height, width = frame.shape
    images = frame.reshape(-1, channel, height, width).to(device)
    instruction = tokenizer.decode(instruction)
    print(f"frame shape: {frame.unsqueeze(dim=0).shape}")
    print(f"images shape: {images.shape}")
    print(f"instruction: {instruction}")
    logits = agent(images, frame.unsqueeze(dim=0), instruction)
    probs = torch.sigmoid(logits)
    actions = (probs > 0.5).float()
    print(f"probabillity: {probs[0][0]}")
    print(f"action prediction: {actions[0][0]}")
    print(f"ground truth: {action[0]}")

frame shape: torch.Size([1, 144, 3, 64, 64])
images shape: torch.Size([144, 3, 64, 64])
instruction: [CLS] charged attack [SEP]
probabillity: tensor([0.0190, 0.0090, 0.0090, 0.0021, 0.0055, 0.0060, 0.0066, 0.0067, 0.0101,
        1.0000, 1.0000, 0.0096], device='cuda:0')
action prediction: tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 0.], device='cuda:0')
ground truth: tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], device='cuda:0')


In [17]:
agent.get_actions(frame.unsqueeze(dim=0), instruction)

tensor([[[0., 0., 0.,  ..., 1., 1., 0.],
         [0., 0., 0.,  ..., 1., 1., 0.],
         [0., 0., 0.,  ..., 1., 1., 0.],
         ...,
         [0., 0., 0.,  ..., 1., 1., 0.],
         [0., 0., 0.,  ..., 1., 1., 0.],
         [0., 0., 0.,  ..., 1., 1., 0.]]], device='cuda:0')