<a href="https://colab.research.google.com/github/ranjith-d7/style-transfer-app/blob/main/finetune_BLIP_.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

We will fine-tune a vision-language model (BLIP) to generate captions describing charts (bar/pie charts) with trends and summaries. BLIP (“Bootstrapping Language-Image Pre-training”) is a state-of-the-art captioning model from Salesforce.


In [1]:
!pip install transformers datasets


Collecting datasets
  Downloading datasets-3.6.0-py3-none-any.whl.metadata (19 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py311-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2025.3.0,>=2023.1.0 (from fsspec[http]<=2025.3.0,>=2023.1.0->datasets)
  Downloading fsspec-2025.3.0-py3-none-any.whl.metadata (11 kB)
Downloading datasets-3.6.0-py3-none-any.whl (491 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m491.5/491.5 kB[0m [31m32.9 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m12.6 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading fsspec-2025.3.0-py3-none-any.whl 

In [2]:
import torch
from datasets import Dataset, Features, Value, Image
from transformers import BlipProcessor, BlipForConditionalGeneration, DataCollatorForSeq2Seq, Seq2SeqTrainer, Seq2SeqTrainingArguments
import matplotlib.pyplot as plt
from PIL import Image as PILImage


In [3]:
charts = []
captions = []

# Example 1: Bar chart (ascending)
plt.figure(figsize=(4,3))
x = ['Q1','Q2','Q3','Q4']
y = [1, 3, 6, 9]
plt.bar(x, y, color='skyblue')
plt.title("Sales by Quarter")
plt.ylabel("Sales")
plt.savefig('bar_up.png')
plt.close()
charts.append('bar_up.png')
captions.append("Bar chart showing steadily increasing sales from Q1 to Q4.")

# Example 2: Bar chart (descending)
plt.figure(figsize=(4,3))
x = ['2016','2017','2018','2019']
y = [9, 6, 3, 1]
plt.bar(x, y, color='lightgreen')
plt.title("Downloads by Year")
plt.ylabel("Number of downloads (millions)")
plt.savefig('bar_down.png')
plt.close()
charts.append('bar_down.png')
captions.append("Bar chart showing decreasing downloads from 2016 to 2019.")

# Example 3: Pie chart
plt.figure(figsize=(4,3))
sizes = [50, 25, 15, 10]
labels = ['Category A','Category B','Category C','Category D']
plt.pie(sizes, labels=labels, autopct='%1.1f%%', startangle=90)
plt.title("Market Share Distribution")
plt.savefig('pie1.png')
plt.close()
charts.append('pie1.png')
captions.append("Pie chart showing market share, with Category A as the largest slice.")

# Example 4: Pie chart (different)
plt.figure(figsize=(4,3))
sizes = [40, 30, 20, 10]
labels = ['Red','Blue','Green','Yellow']
plt.pie(sizes, labels=labels, autopct='%1.1f%%', startangle=140)
plt.title("Color Distribution")
plt.savefig('pie2.png')
plt.close()
charts.append('pie2.png')
captions.append("Pie chart showing distribution of colors: Red and Blue are the largest segments.")

# Assemble into a dataset (features: image and text)
features = Features({"image": Image(decode=True), "text": Value("string")})
data = {"image": charts, "text": captions}
dataset = Dataset.from_dict(data, features=features)
# Split into train and test
dataset = dataset.train_test_split(test_size=0.2)
train_dataset = dataset["train"]
eval_dataset = dataset["test"]


In [5]:
from transformers import BlipProcessor

processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
prefix = "A chart showing"

def preprocess_function(examples):
    # examples["image"] is already a PIL Image, examples["text"] is the caption
    encoding = processor(
        images=examples["image"],
        text=[prefix] * len(examples["image"]),
        padding="max_length",
        truncation=True,
        max_length=128,
        return_tensors="np"
    )
    # copy input_ids → labels so the model computes loss
    encoding["labels"] = encoding["input_ids"].copy()
    return encoding

# apply to both train and eval
train_dataset = train_dataset.map(preprocess_function, batched=True, remove_columns=["image","text"])
eval_dataset  = eval_dataset.map (preprocess_function, batched=True, remove_columns=["image","text"])



Map:   0%|          | 0/3 [00:00<?, ? examples/s]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

In [6]:
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")


config.json:   0%|          | 0.00/4.56k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/990M [00:00<?, ?B/s]

In [9]:
data_collator = DataCollatorForSeq2Seq(tokenizer=processor.tokenizer, model=model, label_pad_token_id=-100)
training_args = Seq2SeqTrainingArguments(
    output_dir="./results",
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    num_train_epochs=3,
    logging_steps=10,
    remove_unused_columns=False # Instead, pass it to TrainingArguments
)
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    data_collator=data_collator,
    tokenizer=processor.tokenizer
)


  trainer = Seq2SeqTrainer(


In [10]:
trainer.train()




<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize?ref=models
wandb: Paste an API key from your profile and hit enter:

 ··········


[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mranjithdhanasekaran[0m ([33mranjithdhanasekaran-b-s-abdir-rahman-crescent-institute-[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Step,Training Loss


TrainOutput(global_step=6, training_loss=10.868080139160156, metrics={'train_runtime': 47.1688, 'train_samples_per_second': 0.191, 'train_steps_per_second': 0.127, 'total_flos': 5340812870025216.0, 'train_loss': 10.868080139160156, 'epoch': 3.0})

In [27]:
test_image = PILImage.open('bar_up.png')
inputs = processor(images=test_image, text="A chart showing", return_tensors="pt")

# Move the input tensors to the same device as the model's weights
# Assuming your model is on a CUDA device, if not, replace 'cuda' with 'cpu'
inputs = inputs.to(model.device)

outputs = model.generate(pixel_values=inputs.pixel_values)
caption = processor.tokenizer.decode(outputs[0], skip_special_tokens=True)
print("Generated caption:", caption)



Generated caption: a bar graph with the same number of saless
