#Data

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
!mkdir -p /content/drive/MyDrive/IC/ViT

In [3]:
!pip install transformers > /dev/null

In [4]:
# Tiny dataset
# !gdown https://drive.google.com/uc?id=1qYPCnXXxjEcHEg3tLGt3fDkd2ialAgS4

# Full dataset with jpeg
!gdown https://drive.google.com/uc?id=1-xJoBvzwQKgJjPzHb3fq1sFicwyIisx7

# Full dataset without jpeg
# !gdown https://drive.google.com/uc?id=1gFSdm8K9SXNPXG9tQWS4bmO_nappN2AL
!unzip data_v1.zip -d /content/data > /dev/null

Downloading...
From: https://drive.google.com/uc?id=1-xJoBvzwQKgJjPzHb3fq1sFicwyIisx7
To: /content/data_v1.zip
100% 640M/640M [00:03<00:00, 206MB/s]


In [5]:
import json
data = json.load(open("/content/data/train_data.json", "r"))

In [6]:
data['annotations'][0]

{'id': 0,
 'image_id': 2,
 'caption': 'ba chiếc thuyền đang di chuyển ở trên con sông',
 'segment_caption': 'ba chiếc thuyền đang di_chuyển ở trên con sông'}

#Vocab

In [7]:
import random
import numpy as np
import torch
import os
def set_random_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
set_random_seed(10)

In [8]:
from collections import Counter
import itertools
from itertools import count

class IMCP_Vocab():
  def __init__(self, texts) -> None:
    words = list(itertools.chain(*[text.split(" ") for text in texts]))
    counter = Counter(words)
    self.vocab = {key: i for i, key in zip(count(start = 4), counter.keys())}
    self.special_ids = [0, 1, 2, 3]
    self.max_seq_len = 256
    self.counter = counter
    self.special_tokens = ["<unk>", "<pad>", "<bos>", "<eos>"]
    for id, token in zip(self.special_ids, self.special_tokens):
      self.vocab[token] = id
    self.vocab = {k: v for k, v in sorted(self.vocab.items(), key=lambda x:x[1])}
    self.id2word = {v: k for k, v in self.vocab.items()}
    
    self.bos_token = "<bos>"
    self.eos_token = "<eos>"
    self.pad_token = "<pad>"
    self.unk_token = "<unk>"
    
  def get_vocab(self):
    return self.vocab

  def get_vocab_dump(self):
    vocab = dict()
    vocab['itos'] = list(vocab.keys())
    vocab['stoi'] = self.vocab
    vocab['freqs'] = dict(self.counter)
    return vocab
  
  def batch_decode(self, predictions_ids):
    preds = []
    for seq in predictions_ids:
        preds.append(" ".join([self.id2word[id] for id in seq if id not in [0,1,2,3,4,5]]))
    return preds

# Dataset

In [9]:
import torch
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import json
import os

class IMCP_Dataset(Dataset):
  def __init__(self, image_path = "/content/data/train-images", summary_path = "/content/data/train_data.json"):
    super().__init__()
    self.data = json.load(open(summary_path, "r"))
    self.image_path = image_path
    self.vocab = IMCP_Vocab(texts = [ann['segment_caption'] for ann in self.data['annotations']])
    self.imgid2imgname = {entry['id']: entry['filename'] for entry in self.data['images']}

  def __len__(self):
    return len(self.data['annotations'])

  def __getitem__(self, index):
    annotation = self.data['annotations'][index]
    image_id = annotation['image_id']
    # images = self.data['images'][index]
    # image_id = images['id']
    image_name = self.imgid2imgname[image_id]
    image = Image.open(os.path.join(self.image_path, image_name)).convert('RGB')
    caption = annotation['segment_caption']
    return image, caption

In [10]:
# class IMCP_Test_Dataset(Dataset):
#   def __init__(self, image_path = "/content/data/public-test-images", summary_path = "/content/data/test_data.json"):
#     super().__init__()
#     self.data = json.load(open(summary_path, "r"))
#     self.image_path = image_path
#     self.imgid2imgname = {entry['id']: entry['filename'] for entry in self.data['images']}

#   def __len__(self):
#     return len(self.data['images'])

#   def __getitem__(self, index):
#     entry = self.data['images'][index]
#     image_id = entry['id']
#     image_name = entry['filename']
#     image = Image.open(os.path.join(self.image_path, image_name)).convert('RGB')
#     caption = [self.data['annotations'][i]['segment_caption'] for i in range(len(self.data['annotations'])) if self.data['annotations'][i]['image_id'] == image_id]
#     return image, caption, image_id

In [11]:
train_dataset = IMCP_Dataset()
vocab = train_dataset.vocab
train_dataset, valid_dataset = torch.utils.data.random_split(train_dataset, [0.9, 0.1])
# test_dataset = IMCP_Test_Dataset()

In [12]:
# Save vocab to file
with open("/content/drive/MyDrive/IC/ViT/vocab.json", 'w+') as file:
  json.dump(vocab.get_vocab_dump(), file, ensure_ascii = False)

In [13]:
device = "cuda" if torch.cuda.is_available() else "cpu"

# Collator

In [14]:
from torch.nn.utils.rnn import pad_sequence
import torchvision.transforms as transforms

class IMCP_Collator:
  def __init__(self, vocab, train = True):
    self.vocab = vocab
    self.bos_id = self.vocab.get_vocab()['<bos>']
    self.eos_id = self.vocab.get_vocab()['<eos>']
    self.pad_id = self.vocab.get_vocab()['<pad>']
    self.train = train

  def tokenize_texts(self, captions):
    raw_captions = [caption.split(" ") for caption in captions]
    truncated_captions = [s[:self.vocab.max_seq_len] for s in raw_captions]
    max_len = max([len(c) for c in truncated_captions])

    padded_captions = []
    for c in truncated_captions:
        c = [self.vocab.get_vocab()[word] for word in c]
        seq = [self.bos_id] + c + [self.eos_id] + [self.pad_id] * (max_len - len(c))
        padded_captions.append(seq)

    padded_captions = [torch.Tensor(seq).long() for seq in padded_captions]
    padded_captions = pad_sequence(padded_captions, batch_first=True)
    return padded_captions
  
  def resize_and_stack(self, images):
    image_tensors = []
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
    ])
    
    for image in images:
      img_tensor = transform(image)
      image.close()
      image_tensors.append(img_tensor)
      
    stacked = torch.stack(image_tensors)
    return stacked

  def __call__(self, batch):
    if self.train:
      images = [example[0] for example in batch]
      captions = [example[1] for example in batch]
      return self.resize_and_stack(images), self.tokenize_texts(captions)
    else:
      images = [example[0] for example in batch]
      captions = [example[1] for example in batch]
      image_ids = [example[2] for example in batch]
      return self.resize_and_stack(images), captions, image_ids
    

In [15]:
collator = IMCP_Collator(vocab, train = True)
# collatorTest = IMCP_Collator(vocab, train = False, model = "resnet101")

# DataLoader

In [16]:
train_dataloader = DataLoader(train_dataset, batch_size = 16, collate_fn = collator)
valid_dataloader = DataLoader(valid_dataset, batch_size = 16, collate_fn = collator, shuffle = False)
# test_dataloader = DataLoader(test_dataset, batch_size = 16, collate_fn = collatorTest, shuffle = False)

In [17]:
for images, captions in train_dataloader:
  print(images.shape)
  print(captions)
  break

torch.Size([16, 3, 224, 224])
tensor([[   2,   13,   33,   21,    7,  441, 1713,  830,  799,   37,  482,    3,
            1,    1,    1],
        [   2,   13,    4,  675,   11,  853,   40,    9,   37,  240,  362,    3,
            1,    1,    1],
        [   2,   13,   16,  171,  341,  610,   98,  211,    9,   15,  155,   28,
          311,  239,    3],
        [   2,   13,   21,   22,  162,  163,  214,    7,  409,   89,    3,    1,
            1,    1,    1],
        [   2,    9,   10,  105,   13,   46,   40,   28,   16,  952,  239,  643,
            3,    1,    1],
        [   2,   16,   21,   58,    7,  407,  408,    9,   37,  103,  223,    3,
            1,    1,    1],
        [   2,   15,  130,    5,  586,   98,  181,   77,    5,  586,   98,  218,
          181,    3,    1],
        [   2,   13,   33,  377,  320,  101,  255,   15,   37,  166,   51,    3,
            1,    1,    1],
        [   2,  381,  352,  125,  126,   99,  303,  139,   80,    4,   21,    3,
            1,   

#Model

In [18]:
config = {
  "activation_dropout": 0.0,
  "activation_function": "gelu",
  "architectures": [
    "MBartModel"
  ],
  "attention_dropout": 0.0,
  "bos_token_id": 0,
  "classifier_dropout": 0.0,
  "d_model": 512,
  "decoder_attention_heads": 8,
  "decoder_ffn_dim": 2048,
  "decoder_layerdrop": 0.0,
  "decoder_layers": 2,
  "decoder_start_token_id": 1,
  "dropout": 0.1,
  "encoder_attention_heads": 8,
  "encoder_ffn_dim": 2048,
  "encoder_layerdrop": 0.0,
  "encoder_layers": 2,
  "eos_token_id": 1,
  "forced_eos_token_id": 1,
  "gradient_checkpointing": False,
  "init_std": 0.02,
  "is_encoder_decoder": True,
  "max_position_embeddings": 256,
  "model_type": "mbart",
  "num_hidden_layers": 2,
  "pad_token_id": 2,
  "scale_embedding": False,
  "torch_dtype": "float32",
  "transformers_version": "4.10.2",
  "use_cache": True,
  "vocab_size": len(vocab.get_vocab()) + 5
}

In [19]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.optim import Adam
from torch.nn.utils.rnn import pack_padded_sequence

from transformers import ViTModel, get_cosine_schedule_with_warmup
from transformers import BartForConditionalGeneration, BartConfig

# Check if GPU is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Define Vision Transformer encoder
class Encoder(nn.Module):
    def __init__(self, embed_size):
        super(Encoder, self).__init__()
        self.vit = ViTModel.from_pretrained('google/vit-base-patch16-224')
        self.linear = nn.Linear(self.vit.config.hidden_size, embed_size)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.5)
    
    def forward(self, images):
        with torch.no_grad():
            features = self.vit(images)
        features = self.linear(features.last_hidden_state)
        features = self.relu(features)
        features = self.dropout(features)
        return features

# Define Transformer decoder
class Decoder(nn.Module):
    def __init__(self, max_seq_length):
        super(Decoder, self).__init__()
        self.config = BartConfig(**config)
        self.decoder = BartForConditionalGeneration(self.config)
        self.max_seq_length = max_seq_length
    
    def forward(self, features, captions):
        outputs = self.decoder(
            encoder_outputs = [features], ## Use [] instead of () Important!
            labels = captions)
        return outputs

    def generate(self, features, max_length = 50):
        model_kwargs = {
            "encoder_outputs": [features], ## Use [] instead of () Important!
        }
        input_ids = torch.ones((len(features), 1), device=device, dtype=torch.long)
        input_ids = input_ids * 1
        output = self.decoder.generate(input_ids, num_beams=1, max_new_tokens = max_length, **model_kwargs)
        return output


# Define hyperparameters
vocab_size = len(collator.vocab.get_vocab()) + 5
embed_size = 512
max_seq_length = 30
learning_rate = 0.0004
num_epochs = 8
total_step = len(train_dataloader) * num_epochs

# Initialize encoder and decoder
encoder = Encoder(embed_size).to(device)
decoder = Decoder(max_seq_length).to(device)

# Define loss function and optimizer
params = list(encoder.parameters()) + list(decoder.parameters())
optimizer = Adam(params, lr=learning_rate)
lr_scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps = 0.1 * total_step, num_training_steps = total_step)

# Train the model
for epoch in range(num_epochs):
    print(f"Start Epoch {epoch}")
    for i, (images, captions) in enumerate(train_dataloader):
        # Move data to GPU
        images = images.to(device)
        captions = captions.to(device)
        
        # Forward pass
        features = encoder(images)
        outputs = decoder(features, captions)
        loss = outputs.loss
        
        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        lr_scheduler.step()
        
        # Print loss
        if i % 100 == 0:
            print('Epoch [{}/{}], Step [{}/{}], LR [{:.8f}], Loss: {:.4f}'.format(epoch+1, num_epochs, i+1, len(train_dataloader), lr_scheduler.get_last_lr()[0] ,loss.item()))
    # Validate
    with torch.no_grad():
        valid_loss = []
        for i, (images, captions) in enumerate(valid_dataloader):
        # Move data to GPU
            images = images.to(device)
            captions = captions.to(device)
            # Forward pass
            features = encoder(images)
            outputs = decoder(features, captions)
            loss = outputs.loss
            valid_loss.append(loss.item())
        
        print('Epoch [{}/{}], Valid Loss: {:.4f}'.format(epoch+1, num_epochs, np.mean(valid_loss)))

Downloading (…)lve/main/config.json:   0%|          | 0.00/69.7k [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/346M [00:00<?, ?B/s]

Some weights of the model checkpoint at google/vit-base-patch16-224 were not used when initializing ViTModel: ['classifier.weight', 'classifier.bias']
- This IS expected if you are initializing ViTModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ViTModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['vit.pooler.dense.weight', 'vit.pooler.dense.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Start Epoch 0
Epoch [1/8], Step [1/1061], LR [0.00000047], Loss: 7.6281
Epoch [1/8], Step [101/1061], LR [0.00004760], Loss: 4.4379
Epoch [1/8], Step [201/1061], LR [0.00009472], Loss: 3.0580
Epoch [1/8], Step [301/1061], LR [0.00014185], Loss: 2.1143
Epoch [1/8], Step [401/1061], LR [0.00018897], Loss: 1.9737
Epoch [1/8], Step [501/1061], LR [0.00023610], Loss: 2.0824
Epoch [1/8], Step [601/1061], LR [0.00028322], Loss: 1.6413
Epoch [1/8], Step [701/1061], LR [0.00033035], Loss: 1.6595
Epoch [1/8], Step [801/1061], LR [0.00037747], Loss: 1.7640
Epoch [1/8], Step [901/1061], LR [0.00039995], Loss: 1.4516
Epoch [1/8], Step [1001/1061], LR [0.00039961], Loss: 1.1891
Epoch [1/8], Valid Loss: 1.5061
Start Epoch 1
Epoch [2/8], Step [1/1061], LR [0.00039923], Loss: 2.0127
Epoch [2/8], Step [101/1061], LR [0.00039834], Loss: 1.5428
Epoch [2/8], Step [201/1061], LR [0.00039712], Loss: 1.4295
Epoch [2/8], Step [301/1061], LR [0.00039556], Loss: 1.2808
Epoch [2/8], Step [401/1061], LR [0.0003936

In [20]:
torch.save(encoder.state_dict(), '/content/drive/MyDrive/IC/ViT/encoder.pth')
torch.save(decoder.state_dict(), '/content/drive/MyDrive/IC/ViT/decoder.pth')