## vit-gpt2 (pre-trained)

In [1]:
import os
import torch
from PIL import Image
from transformers import VisionEncoderDecoderModel, AutoTokenizer, ViTFeatureExtractor

class ChartCaptioner:
    def __init__(
        self, 
        model_name="nlpconnect/vit-gpt2-image-captioning", 
        max_length=128
    ):
        self.model = VisionEncoderDecoderModel.from_pretrained(model_name)
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")
        
        self.tokenizer.pad_token = self.tokenizer.eos_token
        self.tokenizer.padding_side = 'left' 
        
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model.to(self.device)
        self.model.eval()
        
        self.max_length = max_length
    
    def generate_caption(self, image_path):
        """
        Generate caption with explicit attention mask handling
        """
        image = Image.open(image_path).convert('RGB')
        
        # Prepare image inputs
        pixel_values = self.feature_extractor(
            images=image, 
            return_tensors="pt"
        ).pixel_values.to(self.device)
        

        decoder_input_ids = torch.zeros((1, 1), dtype=torch.long, device=self.device)
        decoder_input_ids[0, 0] = self.tokenizer.bos_token_id
        
        decoder_attention_mask = torch.ones_like(decoder_input_ids)
        
        with torch.no_grad():
            outputs = self.model.generate(
                pixel_values,
                decoder_input_ids=decoder_input_ids,
                decoder_attention_mask=decoder_attention_mask,
                max_length=self.max_length,
                num_beams=8,
                early_stopping=True,
                no_repeat_ngram_size=2,
                do_sample=True,
                top_k=50,
                top_p=0.95,
                temperature=0.8
            )
        
        try:
            caption = self.tokenizer.decode(
                outputs[0], 
                skip_special_tokens=True
            )
            
            caption = caption.strip()
            caption = ' '.join(caption.split())  
            
            return caption if caption else "No meaningful caption generated."
        
        except Exception as e:
            print(f"Error decoding caption: {e}")
            return "Caption generation failed."
    

captioner = ChartCaptioner()
single_image_caption = captioner.generate_caption("ImageList/7.png")
print("Single Image Caption:", single_image_caption)


Config of the encoder: <class 'transformers.models.vit.modeling_vit.ViTModel'> is overwritten by shared encoder config: ViTConfig {
  "architectures": [
    "ViTModel"
  ],
  "attention_probs_dropout_prob": 0.0,
  "encoder_stride": 16,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.0,
  "hidden_size": 768,
  "image_size": 224,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "model_type": "vit",
  "num_attention_heads": 12,
  "num_channels": 3,
  "num_hidden_layers": 12,
  "patch_size": 16,
  "qkv_bias": true,
  "transformers_version": "4.46.3"
}

Config of the decoder: <class 'transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel'> is overwritten by shared decoder config: GPT2Config {
  "activation_function": "gelu_new",
  "add_cross_attention": true,
  "architectures": [
    "GPT2LMHeadModel"
  ],
  "attn_pdrop": 0.1,
  "bos_token_id": 50256,
  "decoder_start_token_id": 50256,
  "embd_pdrop": 0.1,
  "eos_token_id": 50256,
  "initializer_rang

model.safetensors:   0%|          | 0.00/982M [00:00<?, ?B/s]

Single Image Caption: a collage of photos of various types of electronic equipment


## vit-gpt2 (fine-tuned)

In [2]:
import os
import pandas as pd
from PIL import Image
import torch
import re
from torch.utils.data import Dataset, DataLoader
from transformers import VisionEncoderDecoderModel, BertTokenizer, ViTFeatureExtractor
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments

data = pd.read_csv("captions.csv")

def clean_caption(caption):
    caption = caption.lower()
    caption = re.sub(r'\b(figure|fig)\s*\d+(\.\d+)?', '', caption, flags=re.IGNORECASE)
    caption = re.sub(r"[^a-z0-9A-Z\s]", "", caption)
    caption = re.sub(r"\s+", " ", caption).strip()
    return caption

class ChartCaptionDataset(Dataset):
    def __init__(self, csv_file, img_folder, tokenizer, feature_extractor, max_length=128):
        self.data = pd.read_csv(csv_file)
        self.data['full_caption'] = self.data['full_caption'].apply(clean_caption)
        self.img_folder = img_folder
        self.tokenizer = tokenizer
        self.feature_extractor = feature_extractor
        self.max_length = max_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_id = self.data.iloc[idx]['imageid']
        caption = self.data.iloc[idx]['full_caption']
        img_path = os.path.join(self.img_folder, f"{img_id}.png")

        image = Image.open(img_path).convert('RGB')
        pixel_values = self.feature_extractor(images=image, return_tensors="pt").pixel_values

        inputs = self.tokenizer(
            caption,
            max_length=self.max_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        )

        labels = inputs.input_ids.squeeze()
        labels[labels == self.tokenizer.pad_token_id] = -100

        return {
            'pixel_values': pixel_values.squeeze(),
            'labels': labels
        }

tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")
model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning")

dataset = ChartCaptionDataset(
    csv_file="captions.csv",
    img_folder="ImageList",
    tokenizer=tokenizer,
    feature_extractor=feature_extractor
)

dataloader = DataLoader(dataset, batch_size=4, shuffle=True)

training_args = Seq2SeqTrainingArguments(
    predict_with_generate=True,
    learning_rate=10e-5,
    evaluation_strategy="no",
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    logging_dir="./logs",
    output_dir="./vit_finetune",
    num_train_epochs=3,
    save_steps=500,
    save_total_limit=2,
    fp16=True
)
    
def collate_fn(batch):
    pixel_values = torch.stack([item['pixel_values'] for item in batch])
    labels = torch.stack([item['labels'] for item in batch])
    return {
        'pixel_values': pixel_values,
        'labels': labels
    }

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
    data_collator=collate_fn,
)

trainer.train()
model.save_pretrained("./vit_finetune")

Config of the encoder: <class 'transformers.models.vit.modeling_vit.ViTModel'> is overwritten by shared encoder config: ViTConfig {
  "architectures": [
    "ViTModel"
  ],
  "attention_probs_dropout_prob": 0.0,
  "encoder_stride": 16,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.0,
  "hidden_size": 768,
  "image_size": 224,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "model_type": "vit",
  "num_attention_heads": 12,
  "num_channels": 3,
  "num_hidden_layers": 12,
  "patch_size": 16,
  "qkv_bias": true,
  "transformers_version": "4.46.3"
}

Config of the decoder: <class 'transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel'> is overwritten by shared decoder config: GPT2Config {
  "activation_function": "gelu_new",
  "add_cross_attention": true,
  "architectures": [
    "GPT2LMHeadModel"
  ],
  "attn_pdrop": 0.1,
  "bos_token_id": 50256,
  "decoder_start_token_id": 50256,
  "embd_pdrop": 0.1,
  "eos_token_id": 50256,
  "initializer_rang

Step,Training Loss
500,6.6324


In [3]:
def generate_caption(image_path):
    image = Image.open(image_path).convert('RGB')
    pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values
    pixel_values = pixel_values.to(model.device)

    output_ids = model.generate(pixel_values, max_length=200,do_sample=True,top_k=50,top_p=0.95)
    caption = tokenizer.decode(output_ids[0], skip_special_tokens=True)
    return caption

test_image_path = "ImageList/7.png"
generated_caption = generate_caption(test_image_path)
print(f"Generated Caption: {generated_caption}")

Generated Caption: ##ingen estimated technologies blue as a and non different of global sea surface temperature anomalies percentage a hemisphere global mean temperature by the blacks from land for a of land for global mean temperature red and c change and b2 emissions for two to 2100 from different b2 emissions due to 2100 the range of the vertical red temperature degc from the scenario estimates a anomalies percentage lines change to the scenario estimates of the fi in the mean of theal the mean of the mean of the mean are shown from the mean of a et al of the mean in the bars on the mean of the c is a the mean of the climate of the mean of the mean of range of the mean of
