In [1]:
from data.dataset import GameplayActionPairVideoDataset
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer
import torch
from tools.utils import custom_collate_fn
from model.agent import Agent, device, AgentConfig
from torch import nn, optim
from model.action_loss import ActionLoss
from model.cvivit import CvivitConfig
from model.encoder import MultiModelEncoderConfig
from model.decoder import MultiModelDecoderConfig

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
root_dir = "output_logs"

In [3]:
dataset = GameplayActionPairVideoDataset(root_dir=root_dir, image_size=(224, 224))

In [4]:
dataloader = DataLoader(dataset, batch_size=1, shuffle=True, collate_fn=custom_collate_fn)
dataloader

<torch.utils.data.dataloader.DataLoader at 0x79cebc5ea470>

In [5]:
config = AgentConfig(
    encoder_config=MultiModelEncoderConfig(
        vit_model_name='google/vit-base-patch16-224-in21k',
        language_model_name='bert-base-uncased',
        cvivit_config=CvivitConfig(
            image_size=224,
            color_channel=3,
            emb_size=768,
            d_model=768,
            patch_size=(2, 8, 8),
            num_layers_spatial=2,
            num_heads_spatial=4,
            dim_feedforward_spatial=512,
            dropout_spatial=0.1,
            num_layers_temporal=2,
            num_heads_temporal=4,
            dim_feedforward_temporal=512,
            dropout_temporal=0.1
        )
    ),
    decoder_config=MultiModelDecoderConfig(
        d_model=768,
        dim_feedforward=512,
        nhead=4,
        num_layers=2
    )
)

In [6]:
agent = Agent(config=config, debug=False).to(device)



In [7]:
criterion = ActionLoss()
optimizer = optim.Adam(agent.parameters(), lr=0.001)

In [8]:
agent.get_model_size() 

{'size_all_mb': 869.2070960998535, 'parameter_size': 911421508}

In [9]:
num_epochs = 1

In [10]:
for epoch in range(num_epochs):
    agent.train()
    for batch, (instruction, frames, action) in enumerate(dataloader):
        frames = frames.to(device)
        action = action.to(device)
        _, _, channel, height, width = frames.shape
        images = frames.reshape(-1, channel, height, width).to(device)
        logits = agent(images, frames, instruction)
        loss = criterion(logits, action)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}")

Epoch [1/1], Loss: 155.0792


In [11]:
(instruction, frame, ground_truth) = dataset[1]
frame = frame.unsqueeze(0).to(device)
ground_truth = ground_truth.to(device)
action = agent.get_actions(frame, instruction)

In [12]:
INDEX = 0
print(f"action shape: {action.shape}")
print(f"action: {action[:, :, INDEX]}")
print(f"ground_truth: {ground_truth[:, INDEX]}")

action shape: torch.Size([1, 26, 12])
action: tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1.]], device='cuda:0')
ground_truth: tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0.], device='cuda:0')
