In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

In [9]:
from transformers import ViTModel, ViTConfig, BertModel, BertTokenizer, BertConfig, TransfoXLModel

In [11]:
class CvivitConfig:
    def __init__(
        self,
        image_size=128,
        color_channel=3,
        num_frames=32,
        emb_size=768,
        d_model=768,
        patch_size=(2, 8, 8),
        num_layers_spatial=4,
        num_heads_spatial=8,
        dim_feedforward_spatial=2048,
        dropout_spatial=0.1,
        num_layers_temporal=4,
        num_heads_temporal=8,
        dim_feedforward_temporal=2048,
        dropout_temporal=0.1,
    ):
        self.image_size = image_size
        self.color_channel = color_channel
        self.num_frames = num_frames
        self.emb_size = emb_size
        self.d_model = d_model
        self.patch_size = patch_size
        self.num_layers_spatial = num_layers_spatial
        self.num_heads_spatial = num_heads_spatial
        self.dim_feedforward_spatial = dim_feedforward_spatial
        self.dropout_spatial = dropout_spatial
        self.num_layers_temporal = num_layers_temporal
        self.num_heads_temporal = num_heads_temporal
        self.dim_feedforward_temporal = dim_feedforward_temporal
        self.dropout_temporal = dropout_temporal

In [12]:
class PolicyConfig:
    def __init__(
        self,
        d_model=768,
        hidden_size=2048,
        num_actions=10
    ):
        self.d_model = d_model
        self.hidden_size = hidden_size
        self.num_actions = num_actions

In [13]:
class MultiModelEncoder(nn.Module):
    def __init__(self, vit_model_name='google/vit-base-patch16-224-in21k',
                 bert_model_name='bert-base-uncased',
                 cvivit_config: CvivitConfig=CvivitConfig()):
        super(MultiModelEncoder, self).__init__()

        self.vit = ViTModel.from_pretrained(vit_model_name)
        self.cvivit = VideoTransformerModel(
            video_dimension=(cvivit_config.color_channel, cvivit_config.image_size, cvivit_config.image_size),
            emb_size=cvivit_config.emb_size,
            d_model=cvivit_config.d_model,
            patch_size=cvivit_config.patch_size,
            num_layers_spatial=cvivit_config.num_layers_spatial,
            num_heads_spatial=cvivit_config.num_heads_spatial,
            dim_feedforward_spatial=cvivit_config.dim_feedforward_spatial,
            dropout_spatial=cvivit_config.dropout_spatial,
            num_layers_temporal=cvivit_config.num_layers_temporal, 
            num_heads_temporal=cvivit_config.num_heads_temporal,
            dim_feedforward_temporal=cvivit_config.dim_feedforward_temporal,
            dropout_temporal=cvivit_config.dropout_temporal
        )
        self.bert = BertModel.from_pretrained(bert_model_name)
        self.bert_tokenizer = BertTokenizer.from_pretrained(bert_model_name)

    def forward(self, image, video, text):
        vit_outputs = self.vit(image)
        vit_last_hidden_state = vit_outputs.last_hidden_state

        cvivit = self.cvivit(video)
        
        text_encoding = self.bert_tokenizer(
            text,
            return_tensors='pt',
            padding=True,
            truncation=True,
            max_length=128
        )
        input_ids = text_encoding['input_ids'].to(device)
        attention_mask = text_encoding['attention_mask'].to(device)
        bert_outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        instruction_last_hidden_state = bert_outputs.last_hidden_state

        
        print(f"Vit last_hidden_states shape: {vit_last_hidden_state.shape}")
        print(f"Video last_hidden_states shape: {cvivit.shape}")
        print(f"Text last_hidden_states shape: {instruction_last_hidden_state.shape}")
        vision_last_hidden_states = torch.cat((vit_last_hidden_state, cvivit), dim=1)
        print(f"Token last_hidden_states shape: {vision_last_hidden_states.shape}")
        return {
            "instruction_last_hidden_state": instruction_last_hidden_state,
            "vision_last_hidden_states": vision_last_hidden_states
        }

In [14]:
class MultiModelDecoder(nn.Module):
    def __init__(
        self,
        d_model=768,
        nhead=8,
        num_layers=4
    ):
        super(MultiModelDecoder, self).__init__()
        self.decoder_layer = nn.TransformerDecoderLayer(d_model=d_model, nhead=nhead, batch_first=True)
        self.decoder = nn.TransformerDecoder(decoder_layer=self.decoder_layer, num_layers=num_layers)

    def forward(self, instruction, memory):
        x = self.decoder(instruction, memory)
        print(f"Decoder outputs shape: {x.shape}")
        return x

In [15]:
class Policy(nn.Module):
    def __init__(self, policy_config=PolicyConfig()):
        super(Policy, self).__init__()
        self.policy = nn.Sequential(
            nn.Linear(policy_config.d_model, policy_config.hidden_size),
            nn.Dropout(0.1, False),
            nn.LayerNorm(policy_config.hidden_size),
            nn.Linear(policy_config.hidden_size, policy_config.hidden_size),
            nn.Dropout(0.1, False),
            nn.LayerNorm(policy_config.hidden_size),
            nn.Linear(policy_config.hidden_size, policy_config.num_actions)
        )
        
    def forward(self, input):
        return self.policy(input)

In [19]:
class Agent(nn.Module):
    def __init__(self):
        super(Agent, self).__init__()
        self.encoder = MultiModelEncoder()
        self.decoder = MultiModelDecoder()
        self.policy = Policy()

    def forward(self, image, video, text):
        encoder_outputs = self.encoder(image, video, text)
        decoder_outputs = self.decoder(encoder_outputs['instruction_last_hidden_state'], encoder_outputs['vision_last_hidden_states'])
        logits = self.policy(decoder_outputs)
        return logits

    def get_model_size(self):
        param_size = 0
        buffer_size = 0
        for param in self.parameters():
            param_size += param.numel() * param.element_size()
        for buffer in self.buffers():
            buffer_size += buffer.numel() * buffer.element_size()

        size_all_mb = (param_size + buffer_size) / 1024**2
        return { 
            "size_all_mb" : size_all_mb,
            "parameter_size" : param_size
        }

In [20]:
with torch.inference_mode():
    agent = Agent()
    agent.to(device)
    image = torch.randn(2, 3, 224, 224).to(device)
    video = torch.randn(2, 32, 3, 128, 128).to(device)
    text = ["This is a sample sentence for the text encoder.", "This is a sample sentence for the text encoder."]
    logits = agent(image, video, text)

NameError: name 'TransformerEncoder3D' is not defined

In [11]:
logits, logits.shape

(tensor([[[ 0.2085, -1.3850,  0.0412,  0.0765,  0.3234, -0.7414,  0.7565,
            1.0119,  0.5515,  0.8427],
          [ 0.2577, -1.6940,  0.4986,  0.5504,  0.5008, -0.8456,  0.0920,
            0.9436,  0.9753,  1.0063],
          [-0.7221, -0.8581,  0.4347,  0.0464,  0.5096, -0.5257,  0.0034,
            0.2765,  1.2155,  0.5176],
          [ 0.3543, -1.2693,  0.1108,  0.8030,  1.0233, -0.6481,  0.8123,
            0.5214,  0.1852,  1.5182],
          [ 0.1190, -1.3624,  0.0703,  0.5090,  0.5268, -1.0735,  0.7735,
           -0.0508,  0.5317,  1.0983],
          [ 0.3523, -0.7707,  0.0971,  0.3184,  0.8897, -0.7063,  0.0821,
            0.4055,  0.4633,  0.8298],
          [ 0.2188, -1.1613,  1.3514,  0.0692,  0.0376, -1.0174,  0.1084,
            0.6880,  1.5006,  1.2779],
          [-0.6426, -1.0456,  0.6534,  0.6634,  0.8785, -0.4172,  0.5121,
            0.1043,  0.7147,  0.7445],
          [ 0.0431, -1.9833,  0.2333,  0.0348,  0.8297, -0.8826,  0.4583,
            0.3087,  0

In [12]:
model_size = agent.get_model_size()
model_size

{'size_all_mb': 1098.5605850219727, 'parameter_size': 1151916072}

In [13]:
print(f'Model size: {model_size["size_all_mb"]:.2f} MB')
print(f'Model parameter: {model_size["parameter_size"]}')

Model size: 1098.56 MB
Model parameter: 1151916072
