In [2]:
from dataset import GameplayActionPairVideoDataset
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer
import torch
from utils import custom_collate_fn
from model.agent import Agent, device
from torch import nn, optim
from model.action_loss import ActionLoss

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
root_dir = "output_logs"

In [4]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
dataset = GameplayActionPairVideoDataset(root_dir=root_dir, tokenizer=tokenizer)



In [4]:
dataloader = DataLoader(dataset, batch_size=1, shuffle=True, collate_fn=custom_collate_fn)
dataloader

<torch.utils.data.dataloader.DataLoader at 0x7f44c5f972b0>

In [5]:
agent = Agent(debug=False).to(device)

In [6]:
criterion = ActionLoss()
optimizer = optim.Adam(agent.parameters(), lr=0.001)

In [28]:
num_epochs = 100

In [29]:
for epoch in range(num_epochs):
    for batch, (instruction, frame, action) in enumerate(dataloader):
        frame = frame.to(device)
        action = action.to(device)
        _, _, channel, height, width = frame.shape
        images = frame.reshape(-1, channel, height, width).to(device)
        instruction = tokenizer.decode(*instruction)
        optimizer.zero_grad()
        logits = agent(images, frame, instruction)
        # print(f"logits: {logits}")
        # print(f"action: {action}")
        loss = criterion(logits, action)
        loss.backward()
        optimizer.step()
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}")

Epoch [1/100], Loss: 78.3218
Epoch [2/100], Loss: 79.8710
Epoch [3/100], Loss: 78.2609
Epoch [4/100], Loss: 87.6718
Epoch [5/100], Loss: 78.4673
Epoch [6/100], Loss: 79.8512
Epoch [7/100], Loss: 1375.9509
Epoch [8/100], Loss: 215.3219
Epoch [9/100], Loss: 87.6747
Epoch [10/100], Loss: 87.3105
Epoch [11/100], Loss: 77.1046
Epoch [12/100], Loss: 215.3221
Epoch [13/100], Loss: 77.1018
Epoch [14/100], Loss: 77.1186
Epoch [15/100], Loss: 162.0769
Epoch [16/100], Loss: 79.8704
Epoch [17/100], Loss: 77.1018
Epoch [18/100], Loss: 215.3273
Epoch [19/100], Loss: 78.4712
Epoch [20/100], Loss: 78.4694
Epoch [21/100], Loss: 204.5791
Epoch [22/100], Loss: 1375.9490
Epoch [23/100], Loss: 77.1107
Epoch [24/100], Loss: 78.4658
Epoch [25/100], Loss: 1375.9666
Epoch [26/100], Loss: 87.1038
Epoch [27/100], Loss: 78.2594
Epoch [28/100], Loss: 204.5778
Epoch [29/100], Loss: 204.5791
Epoch [30/100], Loss: 77.1107
Epoch [31/100], Loss: 77.1064
Epoch [32/100], Loss: 87.3092
Epoch [33/100], Loss: 79.8815
Epoch 

In [5]:
(instruction, frame, ground_truth) = dataset[0]
frame = frame.unsqueeze(0).to(device)
instruction = instruction.to(device)
instruction = tokenizer.decode(instruction)
ground_truth = ground_truth.to(device)
action = agent.get_actions(frame, instruction)

NameError: name 'agent' is not defined

In [6]:
model = Agent().to(device)
model.load_state_dict(torch.load('model_weights.pth'))

<All keys matched successfully>

In [7]:
model.eval()
(instruction, frame, ground_truth) = dataset[10]
frame = frame.unsqueeze(0).to(device)
instruction = instruction.to(device)
instruction = tokenizer.decode(instruction)
ground_truth = ground_truth.to(device)
action = model.get_actions(frame, instruction)

In [12]:
INDEX = 2
print(f"action: {action[:, :, INDEX]}")
print(f"ground_truth: {ground_truth[:, INDEX]}")

action: tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1.]], device='cuda:0')
ground_truth: tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0