In [3]:
import requests
import torch
from PIL import Image
from io import BytesIO
from transformers import AutoProcessor, AutoModelForImageTextToText


In [5]:
!gdown 1FMVcFM78XZE1KE1rIkGBpCdcdI58S1LB

Downloading...
From (original): https://drive.google.com/uc?id=1FMVcFM78XZE1KE1rIkGBpCdcdI58S1LB
From (redirected): https://drive.google.com/uc?id=1FMVcFM78XZE1KE1rIkGBpCdcdI58S1LB&confirm=t&uuid=57d5de12-677e-451c-a8d9-c5e221c83f55
To: /home/radahn/Sachish/try/DL-Ass-2/custom_captions_dataset.zip
100%|████████████████████████████████████████| 286M/286M [00:11<00:00, 24.8MB/s]


In [6]:
# unzip to Datasets
!mkdir -p Datasets
!unzip custom_captions_dataset.zip -d Datasets

Archive:  custom_captions_dataset.zip
   creating: Datasets/custom_captions_dataset/
   creating: Datasets/custom_captions_dataset/train/
  inflating: Datasets/custom_captions_dataset/train/train_3.jpg  
  inflating: Datasets/custom_captions_dataset/train/train_10.jpg  
  inflating: Datasets/custom_captions_dataset/train/train_11.jpg  
  inflating: Datasets/custom_captions_dataset/train/train_12.jpg  
  inflating: Datasets/custom_captions_dataset/train/train_15.jpg  
  inflating: Datasets/custom_captions_dataset/train/train_35.jpg  
  inflating: Datasets/custom_captions_dataset/train/train_38.jpg  
  inflating: Datasets/custom_captions_dataset/train/train_46.jpg  
  inflating: Datasets/custom_captions_dataset/train/train_49.jpg  
  inflating: Datasets/custom_captions_dataset/train/train_51.jpg  
  inflating: Datasets/custom_captions_dataset/train/train_67.jpg  
  inflating: Datasets/custom_captions_dataset/train/train_68.jpg  
  inflating: Datasets/custom_captions_dataset/train/train_7

In [None]:
# create a dataloader to load the image path from the dataset along with its caption
from torch.utils.data import Dataset, DataLoader
import os
import random
from torchvision import transforms
from tqdm import tqdm
import pandas as pd

class CustomDataset(Dataset):

    def __init__(self, root_dir, transform=None, split='train'):

        self.split = split
        self.root_dir = root_dir
        self.transform = transform
        self.caption_file = os.path.join(root_dir, f'{split}.csv')
        self.image_dir = os.path.join(root_dir, f'{split}')
        self.image_paths = []
        self.captions = []

        df = pd.read_csv(self.caption_file)

        for index, row in tqdm(df.iterrows(), total=len(df)):
            image_path = os.path.join(self.image_dir, row['filename'])
            caption = row['caption']

            self.image_paths.append(image_path)
            self.captions.append(caption)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        caption = self.captions[idx]

        image = Image.open(image_path).convert("RGB")
        if self.transform:
            image = self.transform(image)

        return image, caption

In [30]:
test_dataset = CustomDataset(
    root_dir='Datasets/custom_captions_dataset',
    split='test'
)

100%|██████████| 928/928 [00:00<00:00, 68554.42it/s]


In [37]:
from transformers import AutoProcessor, AutoModel, AutoTokenizer, AutoModelForCausalLM

In [63]:
class CustomModel(torch.nn.Module):

    def __init__(self, image_encoder, text_decoder):

        super(CustomModel, self).__init__()

        self.processor = AutoProcessor.from_pretrained(image_encoder)
        self.image_model = AutoModel.from_pretrained(image_encoder)

        self.tokenizer = AutoTokenizer.from_pretrained(text_decoder)

        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token

        self.text_model = AutoModelForCausalLM.from_pretrained(text_decoder)

        self.projection = torch.nn.Linear(
            self.image_model.config.hidden_size,
            self.text_model.config.n_embd
        )
        
        # Additional adaptation layers for image features
        self.image_adapter = torch.nn.Sequential(
            torch.nn.Linear(self.image_model.config.hidden_size, 512),
            torch.nn.LayerNorm(512),
            torch.nn.GELU(),
            torch.nn.Linear(512, self.image_model.config.hidden_size)
        )
        
        # Prefix token embedding to signal the text model to generate a caption
        self.prefix_embedding = torch.nn.Parameter(
            torch.randn(1, 1, self.text_model.config.n_embd)
        )
        
        # Global image context processor
        self.global_context = torch.nn.Sequential(
            torch.nn.Linear(self.image_model.config.hidden_size, 512),
            torch.nn.LayerNorm(512),
            torch.nn.GELU(),
            torch.nn.Linear(512, self.text_model.config.n_embd)
        )

    def forward(self, images, labels=None):

        # Process images
        inputs = self.processor(images, return_tensors="pt").to(images[0].device)
        image_outputs = self.image_model(**inputs)
        patch_embeddings = image_outputs.last_hidden_state  # shape: (B, N+1, D)
        
        # Get CLS token for global context
        cls_token = patch_embeddings[:, 0:1, :]
        global_image_context = self.global_context(cls_token)
        
        # Apply image adapter to patch embeddings
        adapted_embeddings = self.image_adapter(patch_embeddings) + patch_embeddings
        # adapted_embeddings = patch_embeddings
        
        proj_patch_embeddings = self.projection(adapted_embeddings)  # shape: (B, N+1, D)
        
        # Setup as prefix to text model (using CLS token + other visual tokens)
        batch_size = proj_patch_embeddings.size(0)
        prefix_expanded = self.prefix_embedding.expand(batch_size, -1, -1)
        
        # Combine prompt token and image embeddings
        combined_input_embeds = torch.cat([
            prefix_expanded,                    # Caption start token
            global_image_context,               # Global image context
            proj_patch_embeddings[:, 1:, :]     # Visual tokens (excluding CLS)
        ], dim=1)
        
        # Create attention mask allowing attending to all prefix tokens
        extended_attention_mask = torch.ones(
            (combined_input_embeds.size(0), combined_input_embeds.size(1)),
            dtype=torch.long,
            device=images.device
        )
        
        # Training mode
        if labels is not None:

            label_tokens = self.tokenizer(
                labels,
                return_tensors="pt",
                padding=True,
                truncation=True,
                max_length=128
            ).to(images.device)
            
            outputs = self.text_model(
                inputs_embeds=combined_input_embeds,
                attention_mask=extended_attention_mask,
                labels=label_tokens.input_ids,
                return_dict=True
            )
            
            return outputs.loss, outputs.logits
        
        # Inference mode
        else:

            outputs = self.text_model.generate(
                inputs_embeds=combined_input_embeds,
                attention_mask=extended_attention_mask,
                max_new_tokens=128,  # Use max_new_tokens instead of max_length
                num_beams=5,
                early_stopping=True,
                no_repeat_ngram_size=2,
                temperature=0.7
            )
            
            # Decode the generated IDs
            generated_captions = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
            
            return generated_captions


In [64]:
custom_model = CustomModel(
    image_encoder='WinKawaks/vit-small-patch16-224',
    text_decoder='openai-community/gpt2'
).to('cuda')

Some weights of ViTModel were not initialized from the model checkpoint at WinKawaks/vit-small-patch16-224 and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [66]:

image = Image.open('/home/radahn/Sachish/try/DL-Ass-2/images/vgg16.png').convert("RGB")
transform = transforms.Compose([
    transforms.ToTensor(),
])
image = transform(image).unsqueeze(0).to('cuda')

dummy_image = Image.new('RGB', (224, 224), color='blue')
transform = transforms.Compose([
    transforms.ToTensor(),
])
dummy_image = transform(dummy_image).unsqueeze(0).to('cuda')

dummy_caption = "A blue square."

images = [dummy_image for _ in range(2)]
captions = [dummy_caption for _ in range(2)]

imag_list = []
imag_list.append(image)
custom_model.forward(
    images=images,
    labels = captions
)

ValueError: axes don't match array