In [1]:
pip install transformers datasets

Collecting datasets
  Downloading datasets-3.1.0-py3-none-any.whl.metadata (20 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py310-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.9.0,>=2023.1.0 (from fsspec[http]<=2024.9.0,>=2023.1.0->datasets)
  Downloading fsspec-2024.9.0-py3-none-any.whl.metadata (11 kB)
Downloading datasets-3.1.0-py3-none-any.whl (480 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m480.6/480.6 kB[0m [31m9.7 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m10.7 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading fsspec-2024.9.0-py3-none-any.whl (

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
from transformers import ViTModel, ViTConfig, BertTokenizer, ViTImageProcessor
from datasets import load_dataset
from PIL import Image
from datasets import DatasetDict

import numpy as np
import random

In [3]:
# Load the dataset
dataset = load_dataset("tomytjandra/h-and-m-fashion-caption-12k")  # Replace with your dataset path or identifier

# Initialize the tokenizer (you can choose a different tokenizer if preferred)
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# Define image transformations
# image_transform = transforms.Compose([
#     transforms.Resize((224, 224)),  # ViT typically expects 224x224 images
#     transforms.ToTensor(),
#     transforms.Normalize(mean=[0.485, 0.456, 0.406],  # ImageNet statistics
#                          std=[0.229, 0.224, 0.225]),
# ])

processor = ViTImageProcessor().from_pretrained('google/vit-base-patch16-224-in21k')

# Preprocessing function
def preprocess_function(examples):
    # Process images
    images = [image.convert("RGB") for image in examples['image']]
    encoding = processor(images=images)
    examples['pixel_values'] = encoding['pixel_values']

    # Tokenize captions
    captions = examples['text']
    encoding = tokenizer(captions, padding='max_length', truncation=True, max_length=224)
    examples['input_ids'] = encoding['input_ids']
    examples['attention_mask'] = encoding['attention_mask']

    return examples

# Step 1: Split into train_val and test
train_val_split = dataset['train'].train_test_split(test_size=1250, seed=42)  # 10% for test
train_val = train_val_split['train']
test = train_val_split['test']

# Step 2: Split train_val into train and validation
train_validation_split = train_val.train_test_split(test_size=1250, seed=42)
train = train_validation_split['train']
validation = train_validation_split['test']

# Step 3: Create a new DatasetDict with the splits
new_dataset = DatasetDict({
    'train': train,
    'validation': validation,
    'test': test
})

# Optional: Verify the splits
processed_ds_train = new_dataset['train'].map(preprocess_function, batched=True, batch_size=100, remove_columns=['text', 'image'])
processed_ds_val = new_dataset['validation'].map(preprocess_function, batched=True, batch_size=100, remove_columns=['text', 'image'])
processed_ds_test = new_dataset['test'].map(preprocess_function, batched=True, batch_size=100, remove_columns=['text', 'image'])

processed_ds_train.set_format(
    type='torch',
    columns=['pixel_values', 'input_ids', 'attention_mask']
)

processed_ds_val.set_format(
    type='torch',
    columns=['pixel_values', 'input_ids', 'attention_mask']
)

processed_ds_test.set_format(
    type='torch',
    columns=['pixel_values', 'input_ids', 'attention_mask']
)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md:   0%|          | 0.00/323 [00:00<?, ?B/s]

train-00000-of-00011.parquet:   0%|          | 0.00/478M [00:00<?, ?B/s]

train-00001-of-00011.parquet:   0%|          | 0.00/465M [00:00<?, ?B/s]

train-00002-of-00011.parquet:   0%|          | 0.00/418M [00:00<?, ?B/s]

train-00003-of-00011.parquet:   0%|          | 0.00/335M [00:00<?, ?B/s]

train-00004-of-00011.parquet:   0%|          | 0.00/331M [00:00<?, ?B/s]

train-00005-of-00011.parquet:   0%|          | 0.00/321M [00:00<?, ?B/s]

train-00006-of-00011.parquet:   0%|          | 0.00/307M [00:00<?, ?B/s]

train-00007-of-00011.parquet:   0%|          | 0.00/307M [00:00<?, ?B/s]

train-00008-of-00011.parquet:   0%|          | 0.00/319M [00:00<?, ?B/s]

train-00009-of-00011.parquet:   0%|          | 0.00/304M [00:00<?, ?B/s]

train-00010-of-00011.parquet:   0%|          | 0.00/297M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/12437 [00:00<?, ? examples/s]

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

preprocessor_config.json:   0%|          | 0.00/160 [00:00<?, ?B/s]

Map:   0%|          | 0/9937 [00:00<?, ? examples/s]

Map:   0%|          | 0/1250 [00:00<?, ? examples/s]

Map:   0%|          | 0/1250 [00:00<?, ? examples/s]

In [None]:
new_dataset['train'][4]['text']

'all over pattern black straight-cut kimono in woven fabric with a pattern at the hem cut-out sections under the arms and no fasteners'

In [None]:
processed_ds_train[4]['input_ids']

tensor([  101,  2035,  2058,  5418,  2304,  3442,  1011,  3013,  5035, 17175,
         1999, 17374,  8313,  2007,  1037,  5418,  2012,  1996, 19610,  3013,
         1011,  2041,  5433,  2104,  1996,  2608,  1998,  2053,  3435, 24454,
         2015,   102,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0, 

In [None]:
processed_ds_train[4]['attention_mask']

tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0])

In [4]:
# Define training parameters
batch_size = 32

# Create DataLoader for training
train_loader = DataLoader(processed_ds_train, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(processed_ds_val, batch_size=batch_size, shuffle=True)

In [13]:
class ImageCaptioningModel(nn.Module):
    def __init__(self, vocab_size, embed_dim, decoder_layers, decoder_heads, decoder_ffn_dim, max_seq_length=224):
        super(ImageCaptioningModel, self).__init__() #initialize from parent .init()

        # Encoder: Vision Transformer (ViT)
        vit_config = ViTConfig.from_pretrained('google/vit-base-patch16-224-in21k')
        vit_config.num_hidden_layers = 6  # Reduce the number of layers to 4-6
        self.encoder = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k', config=vit_config)

        # Decoder: Transformer
        decoder_layer = nn.TransformerDecoderLayer(d_model=embed_dim, nhead=decoder_heads, dim_feedforward=decoder_ffn_dim)
        self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=decoder_layers)

        # Embedding for input tokens
        self.token_embedding = nn.Embedding(vocab_size, embed_dim)
        self.positional_encoding = nn.Parameter(torch.zeros(1, max_seq_length, embed_dim)) # ??????????????

        # Final linear layer to generate vocabulary scores
        self.output_linear = nn.Linear(embed_dim, vocab_size)

        # Projection to match dimensions
        self.encoder_proj = nn.Linear(vit_config.hidden_size, embed_dim)

    def forward(self, pixel_values, input_ids, attention_mask):
        # Encoder
        encoder_outputs = self.encoder(pixel_values=pixel_values)
        encoder_hidden_states = encoder_outputs.last_hidden_state  # (batch_size, num_patches + 1, hidden_size)
        # Project encoder outputs to embed_dim
        encoder_proj = self.encoder_proj(encoder_hidden_states)  # (batch_size, seq_len, embed_dim)
        encoder_proj = encoder_proj.permute(1, 0, 2)  # (seq_len, batch_size, embed_dim)

        # Decoder
        embeddings = self.token_embedding(input_ids) + self.positional_encoding[:, :input_ids.size(1), :]
        embeddings = embeddings.permute(1, 0, 2)  # (seq_len, batch_size, embed_dim)

        # Create a causal mask for the decoder to prevent attending to future tokens
        tgt_mask = nn.Transformer.generate_square_subsequent_mask(input_ids.size(1)).to(pixel_values.device)

        decoder_outputs = self.decoder(embeddings, encoder_proj, tgt_mask=tgt_mask, tgt_key_padding_mask=~attention_mask.bool())
        decoder_outputs = decoder_outputs.permute(1, 0, 2)  # (batch_size, seq_len, embed_dim)

        outputs = self.output_linear(decoder_outputs)  # (batch_size, seq_len, vocab_size)

        return outputs

In [16]:
# Define vocabulary size and other hyperparameters
vocab_size = tokenizer.vocab_size
embed_dim = 512
decoder_layers = 6  # 4-6 layers as per requirement
decoder_heads = 8
decoder_ffn_dim = 2048
max_seq_length = 224

# Initialize the model
model = ImageCaptioningModel(vocab_size, embed_dim, decoder_layers, decoder_heads, decoder_ffn_dim, max_seq_length)

# Move model to GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

# Define loss and optimizer
criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)

# Define separate learning rates
learning_rate_encoder = 1e-5  # Lower learning rate for pre-trained encoder
learning_rate_decoder = 1e-4  # Higher learning rate for decoder
# Create parameter groups
optimizer = optim.AdamW([
    {'params': model.encoder.parameters(), 'lr': learning_rate_encoder},
    {'params': model.decoder.parameters(), 'lr': learning_rate_decoder},
    {'params': model.token_embedding.parameters(), 'lr': learning_rate_decoder},
    {'params': model.encoder_proj.parameters(), 'lr': learning_rate_decoder},
    {'params': model.output_linear.parameters(), 'lr': learning_rate_decoder}
], betas=(0.9, 0.98), eps=1e-9)

# Initialize the scheduler
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min',
                                                 factor=0.5, patience=2,
                                                 verbose=True, min_lr=1e-6)

Some weights of the model checkpoint at google/vit-base-patch16-224-in21k were not used when initializing ViTModel: ['encoder.layer.10.attention.attention.key.bias', 'encoder.layer.10.attention.attention.key.weight', 'encoder.layer.10.attention.attention.query.bias', 'encoder.layer.10.attention.attention.query.weight', 'encoder.layer.10.attention.attention.value.bias', 'encoder.layer.10.attention.attention.value.weight', 'encoder.layer.10.attention.output.dense.bias', 'encoder.layer.10.attention.output.dense.weight', 'encoder.layer.10.intermediate.dense.bias', 'encoder.layer.10.intermediate.dense.weight', 'encoder.layer.10.layernorm_after.bias', 'encoder.layer.10.layernorm_after.weight', 'encoder.layer.10.layernorm_before.bias', 'encoder.layer.10.layernorm_before.weight', 'encoder.layer.10.output.dense.bias', 'encoder.layer.10.output.dense.weight', 'encoder.layer.11.attention.attention.key.bias', 'encoder.layer.11.attention.attention.key.weight', 'encoder.layer.11.attention.attention.q

In [17]:
def generate_caption(model, image, processor ,tokenizer, device, max_length=224):
    model.eval()
    with torch.no_grad():
        # Encode image
        encoder_outputs = model.encoder(pixel_values=image.unsqueeze(0))
        encoder_hidden_states = model.encoder_proj(encoder_outputs.last_hidden_state).permute(1, 0, 2)

        # Initialize decoder input with [CLS] or <start> token
        # Assuming you have a <start> token, else use tokenizer.bos_token_id
        input_ids = torch.tensor([[tokenizer.cls_token_id]], device=device)  # (1, 1)
        generated = []

        for _ in range(max_length):
            embeddings = model.token_embedding(input_ids) + model.positional_encoding[:, :input_ids.size(1), :]
            embeddings = embeddings.permute(1, 0, 2)

            decoder_outputs = model.decoder(embeddings, encoder_hidden_states)
            decoder_outputs = decoder_outputs.permute(1, 0, 2)
            outputs = model.output_linear(decoder_outputs)

            # Get the last token
            next_token_logits = outputs[:, -1, :]  # (1, vocab_size)

            next_token = next_token_logits.argmax(dim=-1).unsqueeze(0)  # (1, 1)

            if next_token.item() == tokenizer.sep_token_id:
                break
            generated.append(next_token.item())
            input_ids = torch.cat([input_ids, next_token], dim=1)

        caption = tokenizer.decode(generated, skip_special_tokens=True)
        return caption

In [18]:
num_epochs = 2

for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0
    epoch_val_loss = 0
    for batch in train_loader:
        pixel_values = batch['pixel_values'].to(device)  # (batch_size, 3, 224, 224)
        input_ids = batch['input_ids'].to(device)        # (batch_size, seq_length)
        attention_mask = batch['attention_mask'].to(device)  # (batch_size, seq_length)

        # Shift input_ids and create labels
        # Typically, input_ids are shifted right for the decoder input
        # Labels are the actual tokens to predict
        labels = input_ids[:, 1:].contiguous()
        decoder_input_ids = input_ids[:, :-1].contiguous()
        decoder_attention_mask = attention_mask[:, :-1].contiguous()

        optimizer.zero_grad()

        outputs = model(pixel_values, decoder_input_ids, decoder_attention_mask)
        # outputs: (batch_size, seq_length -1, vocab_size)

        loss = criterion(outputs.view(-1, vocab_size), labels.view(-1))
        loss.backward()

        optimizer.step()

        epoch_loss += loss.item()

    avg_loss = epoch_loss / len(train_loader)

    model.eval()
    with torch.no_grad():
        for batch in val_loader:
            pixel_values = batch['pixel_values'].to(device)  # (batch_size, 3, 224, 224)
            input_ids = batch['input_ids'].to(device)        # (batch_size, seq_length)
            attention_mask = batch['attention_mask'].to(device)  # (batch_size, seq_length)

            labels = input_ids[:, 1:].contiguous()
            decoder_input_ids = input_ids[:, :-1].contiguous()
            decoder_attention_mask = attention_mask[:, :-1].contiguous()

            outputs = model(pixel_values, decoder_input_ids, decoder_attention_mask)

            loss = criterion(outputs.view(-1, vocab_size), labels.view(-1))

            epoch_val_loss += loss.item()

        for _ in range(3):
            index = random.randint(0, len(processed_ds_test)-1)
            image = processed_ds_test[index]['pixel_values'].to(device)
            caption = generate_caption(model, image, processor ,tokenizer, device)
            print(caption)

    image = processed_ds_test[0]['pixel_values'].to(device)
    caption = generate_caption(model, image, processor ,tokenizer, device)
    print(caption)


    avg_val_loss = epoch_val_loss / len(val_loader)

    scheduler.step(avg_val_loss)

    print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {avg_loss:.4f}, val_loss: {avg_val_loss:.4f}, lr: {scheduler.get_last_lr()}")

solid white blouse in a cotton weave with a small stand - up collar and long sleeves with buttoned cuffs
solid white blouse in a cotton weave with a v - neck and short sleeves with a rounded hem
solid dark blue short dress in a crepe weave with a v - neck and long sleeves with a rounded hem
solid black calf - length skirt in a viscose weave with a concealed zip in the side and a rounded hem
Epoch 1/2, Loss: 3.2077, val_loss: 1.6956, lr: [1e-05, 0.0001, 0.0001, 0.0001, 0.0001]
all over pattern light blue blouse in a viscose weave with a small stand - up collar and buttons down the front long sleeves with buttoned cuffs and a rounded hem
all over pattern light beige short dress in a viscose weave with a v - neck and short puff sleeves with a seam at the waist with a concealed zip in one side unlined
solid black trousers in a relaxed fit with a high waist zip fly and straight wide legs with creases and button and straight wide legs
solid black calf - length skirt in a viscose weave with a

In [None]:
torch.save(model.state_dict(), '/content/drive/MyDrive/VIT+decoder/VIT_decoder.pth')

In [None]:
model.load_state_dict(torch.load('/content/drive/MyDrive/VIT+decoder/VIT_decoder.pth'))
model.eval()  # Set the model to evaluation mode

  model.load_state_dict(torch.load('/content/drive/MyDrive/VIT+decoder/VIT_decoder.pth'))


ImageCaptioningModel(
  (encoder): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-5): 6 x ViTLayer(
          (attention): ViTSdpaAttention(
            (attention): ViTSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(in_featu

In [None]:
processed_ds_test[:5]['pixel_values'].shape

torch.Size([5, 3, 224, 224])

In [None]:
def generate_caption(model, image, processor ,tokenizer, device, max_length=224):
    model.eval()
    with torch.no_grad():
        # Encode image
        encoder_outputs = model.encoder(pixel_values=image.unsqueeze(0))
        encoder_hidden_states = model.encoder_proj(encoder_outputs.last_hidden_state).permute(1, 0, 2)

        # Initialize decoder input with [CLS] or <start> token
        # Assuming you have a <start> token, else use tokenizer.bos_token_id
        input_ids = torch.tensor([[tokenizer.cls_token_id]], device=device)  # (1, 1)
        generated = []

        for _ in range(max_length):
            embeddings = model.token_embedding(input_ids) + model.positional_encoding[:, :input_ids.size(1), :]
            embeddings = embeddings.permute(1, 0, 2)

            decoder_outputs = model.decoder(embeddings, encoder_hidden_states)
            decoder_outputs = decoder_outputs.permute(1, 0, 2)
            outputs = model.output_linear(decoder_outputs)

            # Get the last token
            next_token_logits = outputs[:, -1, :]  # (1, vocab_size)

            next_token = next_token_logits.argmax(dim=-1).unsqueeze(0)  # (1, 1)

            if next_token.item() == tokenizer.sep_token_id:
                break
            generated.append(next_token.item())
            input_ids = torch.cat([input_ids, next_token], dim=1)

        caption = tokenizer.decode(generated, skip_special_tokens=True)
        return caption

# Example usage:
image = processed_ds_test[0]['pixel_values'].to(device)
caption = generate_caption(model, image, processor ,tokenizer, device)
print(caption)

RuntimeError: shape '[197, 8, 64]' is invalid for input of size 201728

In [None]:
#inference
#tgt_mask nad what it mean to attent only to previous tokens (look from MLCV lectures)
#save on each epoch or each N epochs