In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

In [3]:
from transformers import VivitConfig, VivitModel, ViTModel, ViTConfig, BertModel, BertTokenizer, BertConfig, TransfoXLModel

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
class MultiModelEncoder(nn.Module):
    def __init__(self, vit_model_name='google/vit-base-patch16-224-in21k',
                 vivit_model_name='google/vivit-b-16x2-kinetics400',
                 bert_model_name='bert-base-uncased'):
        super(MultiModelEncoder, self).__init__()

        self.vit = ViTModel.from_pretrained(vit_model_name)
        self.vivit = VivitModel.from_pretrained(vivit_model_name)
        self.bert = BertModel.from_pretrained(bert_model_name)
        self.bert_tokenizer = BertTokenizer.from_pretrained(bert_model_name)

    def forward(self, image, video, text):
        vit_outputs = self.vit(image)
        vit_pooled_output = vit_outputs.pooler_output
        vit_last_hidden_state = vit_outputs.last_hidden_state

        vivit_outputs = self.vivit(video)
        vivit_pooled_output = vivit_outputs.pooler_output
        vivit_last_hidden_state = vivit_outputs.last_hidden_state

        text_encoding = self.bert_tokenizer(
            text,
            return_tensors='pt',
            padding=True,
            truncation=True,
            max_length=128
        )
        input_ids = text_encoding['input_ids'].to(device)
        attention_mask = text_encoding['attention_mask'].to(device)
        bert_outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        bert_pooled_output = bert_outputs.pooler_output
        bert_last_hidden_state = bert_outputs.last_hidden_state

        concatenated_output = torch.cat((vit_pooled_output, vivit_pooled_output, bert_pooled_output), dim=1)
        last_hidden_states = torch.cat((vit_last_hidden_state, vivit_last_hidden_state, bert_last_hidden_state), dim=1)

        return { 
            "concatenated_output": concatenated_output,
            "last_hidden_states": last_hidden_states
        }

In [5]:
class MultiModelDecoder(nn.Module):
    def __init__(self, d_model=2304,
                 nhead=8,
                 num_layers=12
                ):
        super(MultiModelDecoder, self).__init__()
        
        decoder_layer = nn.TransformerDecoderLayer(d_model=d_model, nhead=nhead)
        self.transformer_model = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)

    def forward(self, encoder_last_state, attension_state):
        transformer_outputs = self.transformer_model(encoder_last_state, attension_state)
        return transformer_outputs

In [6]:
class Policy(nn.Module):
    def __init__(self, d_model, hidden_size, num_actions):
        super(Policy, self).__init__()
        self.policy = nn.Sequential(
            nn.Linear(d_model, hidden_size),
            nn.Dropout(0.1, False),
            nn.LayerNorm(hidden_size),
            nn.Linear(hidden_size, hidden_size),
            nn.Dropout(0.1, False),
            nn.LayerNorm(hidden_size),
            nn.Linear(hidden_size, num_actions)
        )
        
    def forward(self, input):
        return self.policy(input)

In [7]:
class Agent(nn.Module):
    def __init__(self):
        super(Agent, self).__init__()
        self.encoder = MultiModelEncoder()
        self.decoder = MultiModelDecoder()
        self.policy = Policy(2304, 2048, 10)

    def forward(self, image, video, text):
        encoder_embed = self.encoder(image, video, text)
        decoder_embed = self.decoder(encoder_embed['concatenated_output'], encoder_embed['concatenated_output'])
        logits = self.policy(decoder_embed)
        return logits

    def get_model_size(self):
        param_size = 0
        buffer_size = 0
        for param in self.parameters():
            param_size += param.numel() * param.element_size()
        for buffer in self.buffers():
            buffer_size += buffer.numel() * buffer.element_size()

        size_all_mb = (param_size + buffer_size) / 1024**2
        return { 
            "size_all_mb" : size_all_mb,
            "parameter_size" : param_size
        }

In [8]:
agent = Agent()
agent.to(device)

Some weights of VivitModel were not initialized from the model checkpoint at google/vivit-b-16x2-kinetics400 and are newly initialized: ['vivit.pooler.dense.bias', 'vivit.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Agent(
  (encoder): MultiModelEncoder(
    (vit): ViTModel(
      (embeddings): ViTEmbeddings(
        (patch_embeddings): ViTPatchEmbeddings(
          (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
        )
        (dropout): Dropout(p=0.0, inplace=False)
      )
      (encoder): ViTEncoder(
        (layer): ModuleList(
          (0-11): 12 x ViTLayer(
            (attention): ViTAttention(
              (attention): ViTSelfAttention(
                (query): Linear(in_features=768, out_features=768, bias=True)
                (key): Linear(in_features=768, out_features=768, bias=True)
                (value): Linear(in_features=768, out_features=768, bias=True)
                (dropout): Dropout(p=0.0, inplace=False)
              )
              (output): ViTSelfOutput(
                (dense): Linear(in_features=768, out_features=768, bias=True)
                (dropout): Dropout(p=0.0, inplace=False)
              )
            )
            (intermediate): 

In [9]:
with torch.inference_mode():
    image = torch.randn(1, 3, 224, 224).to(device)
    video = torch.randn(1, 32, 3, 224, 224).to(device)
    text = ["This is a sample sentence for the text encoder."]
    logits = agent(image, video, text)

In [10]:
logits.shape

torch.Size([1, 10])

In [11]:
model_size = agent.get_model_size()
model_size

{'size_all_mb': 3499.4111709594727, 'parameter_size': 3669390376}

In [12]:
print(f'Model size: {model_size["size_all_mb"]:.2f} MB')
print(f'Model parameter: {model_size["parameter_size"]}')

Model size: 3499.41 MB
Model parameter: 3669390376


In [13]:
optimizer = torch.optim.Adam(agent.parameters())
criterion = torch.nn.CrossEntropyLoss()