In [1]:
from typing import Any, Union, List

import torch
import torch.nn as nn

from coca_transformer import VisionTransformer

In [2]:
_visual_input = torch.randn(2, 3, 32, 32)

## CoCa Image Encoder

In [3]:
coca_visual = VisionTransformer(
    image_size = 32,
    in_channels = 3,
    patch_size = 4,
    width = 128,
    layers = 4,
    heads = 8,
    mlp_ratio = 4,
    attentional_pool = True
)
output = coca_visual(_visual_input)

In [4]:
for _out in output:
    print(_out.shape)

torch.Size([2, 512])
torch.Size([2, 255, 512])


In [5]:
from coca_transformer import Attention

attn = Attention(dim = 128)
attn(torch.randn(2, 4, 128)).shape

torch.Size([2, 4, 128])

## CoCa Text Encoder

In [6]:
from coca_transformer import TextTransformer

text_trn = TextTransformer()
txt_output = text_trn(torch.randint(0, 128, (2, 77)))

In [7]:
txt_output[0].shape, txt_output[1].shape

(torch.Size([2, 512]), torch.Size([2, 77, 512]))

In [8]:
from coca_tokenzier import SimpleTokenizer, decode, tokenize

tokenzier = SimpleTokenizer()
result = tokenize(
    tokenzier,
    ["This image is nuisance, with low signal.",
     "This image is defect, with low signal.",
     "abda haos aoej"]
)

In [9]:
result[2]

tensor([49406,   596,  1140,   560,  1299,  7313, 25009, 49407,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0], dtype=torch.int32)

## CoCa Final

In [10]:
from coca_cfg import CLIPTextCfg, CLIPVisionCfg, MultimodalCfg
from coca_model import _build_vision_tower, _build_text_tower, _build_text_decoder_tower
from coca_model import CoCa

In [11]:
clip_visual_cfg = CLIPVisionCfg(in_channels = 4)
coca_visual = _build_vision_tower(512, clip_visual_cfg)
visual_output = coca_visual(torch.randn(3, 4, 32, 32))
for _out in visual_output:
    print(_out.shape)

torch.Size([3, 512])
torch.Size([3, 255, 512])


In [12]:
clip_text_cfg = CLIPTextCfg()
coca_text = _build_text_tower(512, clip_text_cfg)
text_output = coca_text(torch.randint(0, 128, (3, 77)))
for _out in text_output:
    print(_out.shape)

torch.Size([3, 512])
torch.Size([3, 77, 512])


In [13]:
clip_multi_cfg = MultimodalCfg()
coca_decoder = _build_text_decoder_tower(512, clip_multi_cfg)
coca_output = coca_decoder(visual_output[1], text_output[1])
for _out in coca_output:
    print(_out.shape)

torch.Size([77, 512])
torch.Size([77, 512])
torch.Size([77, 512])


In [14]:
coca_main_model = CoCa(
    embed_dim = 512,
    multimodal_cfg = MultimodalCfg(),
    text_cfg = CLIPTextCfg(),
    vision_cfg = CLIPVisionCfg(in_channels = 4)
)

In [15]:
coca_result = coca_main_model(
    image = torch.randn(2, 4, 32, 32),
    text = torch.randint(0, 40000, (2, 77))
)

In [16]:
for k, v in coca_result.items():
    print(f"{k}: {v.shape}")

image_features: torch.Size([2, 512])
text_features: torch.Size([2, 512])
logits: torch.Size([2, 76, 49408])
labels: torch.Size([2, 76])
logit_scale: torch.Size([])


In [17]:
coca_txt_result = coca_main_model.generate(torch.randn(1, 4, 32, 32) * 0.2)
for txt_result in coca_txt_result:
    print(decode(txt_result).split("<|endoftext|>")[0].replace("<|startoftext|>", ""))

############################!


## CoCa Loss

In [18]:
from coca_loss import CoCaLoss

In [19]:
coca_criterion = CoCaLoss(caption_loss_weight = 0.5, clip_loss_weight = 0.5)
losses = coca_criterion(
    image_features = coca_result['image_features'], 
    text_features = coca_result['text_features'], 
    logits = coca_result['logits'], 
    labels = coca_result['labels'], 
    logit_scale = coca_result['logit_scale'],
    output_dict = True
)
total_loss = sum(losses.values())
total_loss

tensor(5.6957, grad_fn=<AddBackward0>)

## CoCa Training

In [20]:
from torch.utils.data import Dataset, DataLoader

class CoCaDataset(Dataset):
    def __init__(self):
        self.image = torch.randn(50, 4, 32, 32)
        self.text = torch.randint(0, 40000, (200, 77))

    def __len__(self):
        return len(self.image)

    def __getitem__(self, idx):
        return self.image[idx], self.text[idx]

In [21]:
from coca_train import train_one_epoch


coca_main_model = CoCa(
    embed_dim = 512,
    multimodal_cfg = MultimodalCfg(),
    text_cfg = CLIPTextCfg(),
    vision_cfg = CLIPVisionCfg(in_channels = 4)
)
caca_dataloader = DataLoader(
    CoCaDataset(),
    batch_size = 10
)
optimizer = torch.optim.SGD(coca_main_model.parameters(), lr = 1e-3, momentum = 0.95)

train_one_epoch(
    model = coca_main_model, 
    dataloader = caca_dataloader, 
    loss = coca_criterion, 
    epoch = 1, 
    optimizer = optimizer, 
    scaler = None
)


Train Epoch: 1, Data (t): 0.001
contrastive_loss: 1.181,caption_loss: 5.404,loss: 6.585,
Train Epoch: 1, Data (t): 0.020
contrastive_loss: 1.234,caption_loss: 5.404,loss: 6.638,
Train Epoch: 1, Data (t): 0.029
contrastive_loss: 1.250,caption_loss: 5.404,loss: 6.654,
Train Epoch: 1, Data (t): 0.032
contrastive_loss: 1.238,caption_loss: 5.404,loss: 6.642,
Train Epoch: 1, Data (t): 0.033
contrastive_loss: 1.241,caption_loss: 5.404,loss: 6.645,
