In [1]:
from data.dataset import GameplayActionPairVideoDataset
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer
import torch
from tools.utils import custom_collate_fn
from model.agent import Agent, device, AgentConfig
from torch import nn, optim
from model.action_loss import ActionLoss
from model.cvivit import CvivitConfig
from model.encoder import MultiModelEncoderConfig
from model.decoder import MultiModelDecoderConfig

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
root_dir = "output_logs"

In [3]:
dataset = GameplayActionPairVideoDataset(root_dir=root_dir, image_size=(224, 224))

In [4]:
dataloader = DataLoader(dataset, batch_size=1, shuffle=False, collate_fn=custom_collate_fn)
dataloader

<torch.utils.data.dataloader.DataLoader at 0x79274ff267a0>

In [5]:
config = AgentConfig(
    encoder_config=MultiModelEncoderConfig(
        vit_model_name='google/vit-base-patch16-224-in21k',
        language_model_name='bert-base-uncased',
        cvivit_config=CvivitConfig(
            image_size=224,
            color_channel=3,
            emb_size=768,
            d_model=768,
            patch_size=(2, 8, 8),
            num_layers_spatial=2,
            num_heads_spatial=4,
            dim_feedforward_spatial=512,
            dropout_spatial=0.1,
            num_layers_temporal=2,
            num_heads_temporal=4,
            dim_feedforward_temporal=512,
            dropout_temporal=0.1
        )
    ),
    decoder_config=MultiModelDecoderConfig(
        d_model=768,
        dim_feedforward=512,
        nhead=4,
        num_layers=2
    )
)

In [6]:
agent = Agent(config=config, debug=False).to(device)



In [7]:
criterion = ActionLoss()
optimizer = optim.Adam(agent.parameters(), lr=0.001)

In [8]:
agent.get_model_size() 

{'size_all_mb': 869.2451820373535, 'parameter_size': 911461444}

In [9]:
num_epochs = 100

In [10]:
for epoch in range(num_epochs):
    agent.train()
    for batch, (instruction, frames, action) in enumerate(dataloader):
        optimizer.zero_grad()
        frames = frames.to(device)
        action = action.to(device)
        # print(f"batch: {batch}")
        # print(f"instruction: {instruction}")
        # print(f"frames shape: {frames.shape}")
        # print(f"action shape: {action.shape}")
        logits = agent(instruction, frames, action)
        loss = criterion(logits, action)
        loss.backward()
        optimizer.step()
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}")

Epoch [1/100], Loss: 112.6025
Epoch [2/100], Loss: 109.6787
Epoch [3/100], Loss: 109.4236
Epoch [4/100], Loss: 107.5775
Epoch [5/100], Loss: 108.2580
Epoch [6/100], Loss: 106.5706
Epoch [7/100], Loss: 104.3582
Epoch [8/100], Loss: 104.3898
Epoch [9/100], Loss: 106.3629
Epoch [10/100], Loss: 106.5839
Epoch [11/100], Loss: 108.6181
Epoch [12/100], Loss: 109.2675
Epoch [13/100], Loss: 107.6633
Epoch [14/100], Loss: 103.5963
Epoch [15/100], Loss: 99.3512
Epoch [16/100], Loss: 106.6611
Epoch [17/100], Loss: 108.3289
Epoch [18/100], Loss: 114.1079
Epoch [19/100], Loss: 106.7798
Epoch [20/100], Loss: 104.5087
Epoch [21/100], Loss: 112.2035
Epoch [22/100], Loss: 103.0787
Epoch [23/100], Loss: 104.2693
Epoch [24/100], Loss: 106.4524
Epoch [25/100], Loss: 106.1944
Epoch [26/100], Loss: 102.4339
Epoch [27/100], Loss: 105.6391
Epoch [28/100], Loss: 106.7806
Epoch [29/100], Loss: 109.8615
Epoch [30/100], Loss: 107.0872
Epoch [31/100], Loss: 105.4555
Epoch [32/100], Loss: 107.0690
Epoch [33/100], Lo

In [35]:
(instruction, frames, ground_truth) = dataset[17]
frames = frames.unsqueeze(0).to(device)
ground_truth = ground_truth.to(device)
action = agent.get_actions(frames, instruction)

In [51]:
INDEX = 0
print(f"action shape: {action.shape}")
print(f"action: {action[:, :, INDEX]}")
print(f"ground_truth: {ground_truth[:, INDEX]}")

action shape: torch.Size([1, 45, 12])
action: tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1.]], device='cuda:0')
ground_truth: tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       device='cuda:0')
