# IDEFICS: A Flamingo-based model, trained at scale for the community
# Finetuning Demo Notebook:

Credit: [Flamingo blog](https://www.deepmind.com/blog/tackling-multiple-tasks-with-a-single-visual-language-model)

This google colab notebook shows how to run predictions with the 4-bit quantized 🤗 [Idefics-9B model](https://huggingface.co/HuggingFaceM4/idefics-9b) and finetune it on a specific dataset.

[IDEFICS](https://huggingface.co/HuggingFaceM4/idefics-80b) is a multi-modal model based on the [Flamingo](https://arxiv.org/abs/2204.14198) architecture. It can take images and texts as input and return text outputs but it does not support image generation. \\
IDEFICS is built on top of two unimodal open-access pre-trained models to connect the two modalities. Newly initialized parameters in the form of Transformer blocks bridge the gap between the vision encoder and the language model. The model is trained on a mixture of image/text pairs and unstrucutred multimodal web documents. \\
The [finetuned versions](https://huggingface.co/HuggingFaceM4/idefics-80b-instruct) of IDEFICS behave like LLM chatbots while also understanding visual input. \\
You can play with the [demo here](https://huggingface.co/spaces/HuggingFaceM4/idefics_playground)

The code for this notebook was contributed to by *Léo Tronchon, Younes Belkada, and Stas Bekman*, the IDEFICS model has been contributed to by: *Lucile Saulnier, Léo Tronchon, Hugo Laurençon, Stas Bekman, Amanpreet Singh, Siddharth Karamcheti, and Victor Sanh*

# Install and import necessary libraries

In [None]:
!pip install -q datasets
!pip install -q git+https://github.com/huggingface/transformers.git
!pip install -q bitsandbytes sentencepiece accelerate loralib
!pip install -q -U git+https://github.com/huggingface/peft.git

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/547.8 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m[90m━━━[0m [32m501.8/547.8 kB[0m [31m16.5 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m547.8/547.8 kB[0m [31m8.0 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/116.3 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m6.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m316.1/316.1 kB[0m [31m16.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m39.9/39.9 MB[0m [31m16.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.8/134.8 kB[0m [31m5.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━

In [None]:
import torch
from datasets import load_dataset
from peft import LoraConfig, get_peft_model
from PIL import Image
from transformers import IdeficsForVisionText2Text, AutoProcessor, Trainer, TrainingArguments, BitsAndBytesConfig
import torchvision.transforms as transforms

# Load quantized model
First get the quantized version of the model. This will allow us to use the 9B version of Idefics with a single 16GB gpu



In [None]:
from google.colab import output
output.enable_custom_widget_manager()
!pip install huggingface_hub

import os
os.environ['HF_TOKEN'] = '' # change token




In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

# checkpoint = "HuggingFaceM4/tiny-random-idefics"
checkpoint = "HuggingFaceM4/idefics-9b"

# Here we skip some special modules that can't be quantized properly
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
    llm_int8_skip_modules=["lm_head", "embed_tokens"],
)

processor = AutoProcessor.from_pretrained(checkpoint, use_auth_token=True)
# Simply take-off the quantization_config arg if you want to load the original model
model = IdeficsForVisionText2Text.from_pretrained(checkpoint, quantization_config=bnb_config, device_map="auto")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


preprocessor_config.json:   0%|          | 0.00/281 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/1.36k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.84M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/61.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/181 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.41k [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/99.2k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/19 [00:00<?, ?it/s]

model-00001-of-00019.safetensors:   0%|          | 0.00/2.00G [00:00<?, ?B/s]

model-00002-of-00019.safetensors:   0%|          | 0.00/1.82G [00:00<?, ?B/s]

model-00003-of-00019.safetensors:   0%|          | 0.00/1.98G [00:00<?, ?B/s]

model-00004-of-00019.safetensors:   0%|          | 0.00/1.93G [00:00<?, ?B/s]

model-00005-of-00019.safetensors:   0%|          | 0.00/1.93G [00:00<?, ?B/s]

model-00006-of-00019.safetensors:   0%|          | 0.00/1.98G [00:00<?, ?B/s]

model-00007-of-00019.safetensors:   0%|          | 0.00/1.89G [00:00<?, ?B/s]

model-00008-of-00019.safetensors:   0%|          | 0.00/1.98G [00:00<?, ?B/s]

model-00009-of-00019.safetensors:   0%|          | 0.00/1.93G [00:00<?, ?B/s]

model-00010-of-00019.safetensors:   0%|          | 0.00/1.93G [00:00<?, ?B/s]

model-00011-of-00019.safetensors:   0%|          | 0.00/1.98G [00:00<?, ?B/s]

model-00012-of-00019.safetensors:   0%|          | 0.00/1.89G [00:00<?, ?B/s]

model-00013-of-00019.safetensors:   0%|          | 0.00/1.98G [00:00<?, ?B/s]

model-00014-of-00019.safetensors:   0%|          | 0.00/1.93G [00:00<?, ?B/s]

model-00015-of-00019.safetensors:   0%|          | 0.00/1.93G [00:00<?, ?B/s]

model-00016-of-00019.safetensors:   0%|          | 0.00/1.97G [00:00<?, ?B/s]

model-00017-of-00019.safetensors:   0%|          | 0.00/1.98G [00:00<?, ?B/s]

model-00018-of-00019.safetensors:   0%|          | 0.00/1.97G [00:00<?, ?B/s]

model-00019-of-00019.safetensors:   0%|          | 0.00/705M [00:00<?, ?B/s]

Instantiating IdeficsAttention without passing a `layer_idx` is not recommended and will lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` when creating this class.


Loading checkpoint shards:   0%|          | 0/19 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/137 [00:00<?, ?B/s]

If you print the model, you will see that all `nn.Linear` layers are in fact replaced by `bnb.nn.Linear4bit` layers.

In [None]:
print(model)

IdeficsForVisionText2Text(
  (model): IdeficsModel(
    (embed_tokens): IdeficsDecoupledEmbedding(
      num_embeddings=32000, num_additional_embeddings=2, embedding_dim=4096, partially_freeze=False
      (additional_embedding): Embedding(2, 4096)
    )
    (vision_model): IdeficsVisionTransformer(
      (embeddings): IdeficsVisionEmbeddings(
        (patch_embedding): Conv2d(3, 1280, kernel_size=(14, 14), stride=(14, 14), bias=False)
        (position_embedding): Embedding(257, 1280)
      )
      (pre_layrnorm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
      (encoder): IdeficsVisionEncoder(
        (layers): ModuleList(
          (0-31): 32 x IdeficsVisionEncoderLayer(
            (self_attn): IdeficsVisionAttention(
              (k_proj): Linear4bit(in_features=1280, out_features=1280, bias=True)
              (v_proj): Linear4bit(in_features=1280, out_features=1280, bias=True)
              (q_proj): Linear4bit(in_features=1280, out_features=1280, bias=True)
        

# Check inference

In [None]:
def convert_to_rgb(image):
    if image.mode == "RGB":
        return image
    image_rgba = image.convert("RGBA")
    background = Image.new("RGBA", image_rgba.size, (255, 255, 255))
    alpha_composite = Image.alpha_composite(background, image_rgba)
    return alpha_composite.convert("RGB")


def check_inference(model, processor, prompts, max_new_tokens=5):
    image_size = processor.image_processor.image_size
    image_mean = processor.image_processor.image_mean
    image_std = processor.image_processor.image_std

    image_transform = transforms.Compose([
        convert_to_rgb,
        transforms.RandomResizedCrop((image_size, image_size), scale=(0.9, 1.0), interpolation=transforms.InterpolationMode.BICUBIC),
        transforms.ToTensor(),
        transforms.Normalize(mean=image_mean, std=image_std),
    ])


    tokenizer = processor.tokenizer
    bad_words = ["<image>", "<fake_token_around_image>"]
    if len(bad_words) > 0:
        bad_words_ids = tokenizer(bad_words, add_special_tokens=False).input_ids

    eos_token = "</s>"
    eos_token_id = tokenizer.convert_tokens_to_ids(eos_token)

    inputs = processor(prompts,transform=image_transform, return_tensors="pt").to(device)
    generated_ids = model.generate(**inputs, eos_token_id=[eos_token_id], bad_words_ids=bad_words_ids, max_new_tokens=max_new_tokens, early_stopping=True)
    generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
    return generated_text

# Data preprocess

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import os
import pandas as pd
import random
from sklearn.metrics import accuracy_score, f1_score

base_path = '/content/drive/My Drive/Dataset'

dev_path = os.path.join(base_path, 'dev')
test_path = os.path.join(base_path, 'test')
training_path = os.path.join(base_path, 'train')

dev_folder = [f for f in os.listdir(dev_path) if os.path.isdir(os.path.join(dev_path, f))]
test_folder = [f for f in os.listdir(test_path) if os.path.isdir(os.path.join(test_path, f))]
training_folder = [f for f in os.listdir(training_path) if os.path.isdir(os.path.join(training_path, f))]


In [None]:
test_folder

['CCAH',
 'ACCFP',
 'CCUIM',
 'CCSAD',
 'SCCC',
 'EIB',
 'GGCC',
 'TICC',
 'WICC',
 'EWCC']

In [None]:
dev_folder

['HCCAB',
 'HUSNS',
 'SAPFS',
 'HRDCC',
 'MACC',
 'CCIAP',
 'CICC',
 'EFCC',
 'FIJI',
 'CCGFS']

In [None]:
len(training_folder)

80

In [None]:
!pip install datasets

Collecting datasets
  Downloading datasets-2.20.0-py3-none-any.whl.metadata (19 kB)
Collecting pyarrow>=15.0.0 (from datasets)
  Downloading pyarrow-17.0.0-cp310-cp310-manylinux_2_28_x86_64.whl.metadata (3.3 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.4.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess (from datasets)
  Downloading multiprocess-0.70.16-py310-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.5.0,>=2023.1.0 (from fsspec[http]<=2024.5.0,>=2023.1.0->datasets)
  Downloading fsspec-2024.5.0-py3-none-any.whl.metadata (11 kB)
Downloading datasets-2.20.0-py3-none-any.whl (547 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m547.8/547.8 kB[0m [31m9.9 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [None]:
from datasets import Dataset
dev_image_caption_pairs = []

for folder in dev_folder:
    folder_path = os.path.join(dev_path, folder)
    print(folder_path)
    if os.path.isdir(folder_path):
        frames_folder = None
        csv_file = None
        for item in os.listdir(folder_path):
            if item.endswith("_frames"):
                frames_folder = item
            elif item == f"{folder}.csv":
                csv_file = item

        if frames_folder and csv_file:
            frames_path = os.path.join(folder_path, frames_folder)
            csv_file_path = os.path.join(folder_path, csv_file)
            print(csv_file_path)

            annotations = pd.read_csv(csv_file_path)

            counter = 1
            for index, row in annotations.iterrows():
              image_file = os.path.join(frames_path, f'{folder}-{counter:03d}.jpg')
              if os.path.exists(image_file) and image_file.endswith('.jpg'):
                    caption = row.iloc[-1]
                    # print(caption)
                    if pd.isna(caption) or caption.endswith(']') or caption.strip() == '':
                        caption = 'NONE'
                    else:
                        caption = caption.lower()

                    label = row['label'] if 'label' in row else 'None'

                    counter += 1
                    dev_image_caption_pairs.append({"image": image_file, "caption": caption, "label": int(label)})

dev_ds = Dataset.from_pandas(pd.DataFrame(dev_image_caption_pairs))


/content/drive/My Drive/Dataset/dev/HCCAB
/content/drive/My Drive/Dataset/dev/HCCAB/HCCAB.csv
/content/drive/My Drive/Dataset/dev/HUSNS
/content/drive/My Drive/Dataset/dev/HUSNS/HUSNS.csv
/content/drive/My Drive/Dataset/dev/SAPFS
/content/drive/My Drive/Dataset/dev/SAPFS/SAPFS.csv
/content/drive/My Drive/Dataset/dev/HRDCC
/content/drive/My Drive/Dataset/dev/HRDCC/HRDCC.csv
/content/drive/My Drive/Dataset/dev/MACC
/content/drive/My Drive/Dataset/dev/MACC/MACC.csv
/content/drive/My Drive/Dataset/dev/CCIAP
/content/drive/My Drive/Dataset/dev/CCIAP/CCIAP.csv
/content/drive/My Drive/Dataset/dev/CICC
/content/drive/My Drive/Dataset/dev/CICC/CICC.csv
/content/drive/My Drive/Dataset/dev/EFCC
/content/drive/My Drive/Dataset/dev/EFCC/EFCC.csv
/content/drive/My Drive/Dataset/dev/FIJI
/content/drive/My Drive/Dataset/dev/FIJI/FIJI.csv
/content/drive/My Drive/Dataset/dev/CCGFS
/content/drive/My Drive/Dataset/dev/CCGFS/CCGFS.csv


In [None]:
dev_image_caption_pairs[0]

{'image': '/content/drive/My Drive/Dataset/dev/HCCAB/HCCAB_frames/HCCAB-001.jpg',
 'caption': 'how climate change affects biodiversity',
 'label': 0}

In [None]:
from collections import Counter

labels = [d['label'] for d in dev_image_caption_pairs]

label_counts = Counter(labels)

print(label_counts)

Counter({1: 204, 2: 130, 0: 83})


In [None]:
from collections import Counter

labels = [d['label'] for d in dev_image_caption_pairs]

label_counts = Counter(labels)

print(label_counts)

Counter({1: 204, 2: 130, 0: 83})


In [None]:
print(len(dev_image_caption_pairs))

417


In [None]:
train_image_caption_pairs = []

for folder in training_folder:
    folder_path = os.path.join(training_path, folder)
    if os.path.isdir(folder_path):
        frames_folder = None
        csv_file = None
        for item in os.listdir(folder_path):
            if item.endswith("_frames"):
                frames_folder = item
            elif item == f"{folder}.csv":
                csv_file = item

        if frames_folder and csv_file:
            frames_path = os.path.join(folder_path, frames_folder)
            csv_file_path = os.path.join(folder_path, csv_file)
            print(csv_file_path)

            annotations = pd.read_csv(csv_file_path)

            counter = 1
            for index, row in annotations.iterrows():
              image_file = os.path.join(frames_path, f'{folder}-{counter:03d}.jpg')
              if os.path.exists(image_file) and image_file.endswith('.jpg'):
                    caption = row.iloc[-1]
                    print(caption)
                    if pd.isna(caption) or caption.endswith(']') or caption.strip() == '':
                        caption = 'NONE'
                    else:
                        caption = caption.lower()

                    label = row['label'] if 'label' in row else 'None'


                    counter += 1
                    train_image_caption_pairs.append({"image": image_file, "caption": caption, "label": int(label)})
                    # train_image_caption_pairs.append((image_file, caption, row[0]))

train_ds = Dataset.from_pandas(pd.DataFrame(train_image_caption_pairs))


/content/drive/My Drive/Dataset/train/HCCIG/HCCIG.csv
How climate change influences geopolitics – Interview with Francesco Femia
climate change can exacerbate the drivers of migration other drivers of migration whether that's you know food or water stress
climate change can exacerbate the drivers of migration other drivers of migration whether that's you know food or water stress
climate change can exacerbate the drivers of migration other drivers of migration whether that's you know food or water stress 
that is a big problem
major migration flows refugee flows can have a significant regional and international security impact 
those are some of the new risks we need to really be worried about 
but climate change also impacts the existing geopolitical environment as we know it 
in the Arctic obviously it's creating a new ocean in the South China Sea we see fish stocks moving north fishing fleets moving into contested waters so you have tension between China and its neighbors and that b

In [None]:
len(train_image_caption_pairs)

3372

In [None]:
from collections import Counter

labels = [d['label'] for d in train_image_caption_pairs]

label_counts = Counter(labels)

print(label_counts)
print(len(train_image_caption_pairs))

Counter({1: 1449, 0: 1036, 2: 887})
3372


In [None]:
from collections import Counter

labels = [d['label'] for d in train_image_caption_pairs]

label_counts = Counter(labels)

print(label_counts)
print(len(train_image_caption_pairs))

Counter({1: 1449, 0: 1036, 2: 887})
3372


In [23]:
test_image_caption_pairs = []

for folder in test_folder:
    folder_path = os.path.join(test_path, folder)
    if os.path.isdir(folder_path):
        frames_folder = None
        csv_file = None
        for item in os.listdir(folder_path):
            if item.endswith("_frames"):
                frames_folder = item
            elif item == f"{folder}.csv":
                csv_file = item

        if frames_folder and csv_file:
            frames_path = os.path.join(folder_path, frames_folder)
            csv_file_path = os.path.join(folder_path, csv_file)
            print(csv_file_path)

            annotations = pd.read_csv(csv_file_path)

            counter = 1
            for index, row in annotations.iterrows():
              image_file = os.path.join(frames_path, f'{folder}-{counter:03d}.jpg')
              if os.path.exists(image_file) and image_file.endswith('.jpg'):
                    caption = row.iloc[-1]
                    if pd.isna(caption) or caption.endswith(']') or caption.strip() == '':
                        caption = 'NONE'
                    else:
                        caption = caption.lower()

                    label = row['label'] if 'label' in row else 'None'

                    counter += 1
                    # test_image_caption_pairs.append((image_file, caption, row.iloc[0]))
                    test_image_caption_pairs.append({"image": image_file, "caption": caption, "label": int(label)})



/content/drive/My Drive/Dataset/test/CCAH/CCAH.csv
/content/drive/My Drive/Dataset/test/ACCFP/ACCFP.csv
/content/drive/My Drive/Dataset/test/CCUIM/CCUIM.csv
/content/drive/My Drive/Dataset/test/CCSAD/CCSAD.csv
/content/drive/My Drive/Dataset/test/SCCC/SCCC.csv
/content/drive/My Drive/Dataset/test/EIB/EIB.csv
/content/drive/My Drive/Dataset/test/GGCC/GGCC.csv
/content/drive/My Drive/Dataset/test/TICC/TICC.csv
/content/drive/My Drive/Dataset/test/WICC/WICC.csv
/content/drive/My Drive/Dataset/test/EWCC/EWCC.csv


In [25]:
from collections import Counter

labels = [d['label'] for d in test_image_caption_pairs]

label_counts = Counter(labels)

print(label_counts)
print(len(test_image_caption_pairs))

Counter({1: 194, 2: 153, 0: 73})
420


In [27]:
print(len(train_image_caption_pairs))
print(len(dev_image_caption_pairs))
print(len(test_image_caption_pairs))

3372
417
420


## Instances count

In [None]:
from collections import Counter

labels = [d['label'] for d in all_image_caption_pairs]

label_counts = Counter(labels)

print(label_counts)

Counter({1: 1890, 0: 1319, 2: 993})


# Check before fine tune

In [None]:
# oppose
local_image_path = "/content/drive/My Drive/Dataset/train/NASA/NASA_frames/NASA-06.jpg"
prompts = [
    # "Instruction: provide an answer to the question. Use the image to answer.\n",
    local_image_path,
    "Climate change is one of the most complex issues facing us today. Question: What's the stance of this frame-text pair towards climate change? Choose one between 0 for neutral, 1 for support and 2 for oppose. Answer:",
]
check_inference(model, processor, prompts, max_new_tokens=5)




"/content/drive/My Drive/Dataset/train/NASA/NASA_frames/NASA-06.jpgClimate change is one of the most complex issues facing us today. Question: What's the stance of this image-text pair towards climate change? Choose one between 0 for neutral, 1 for support and 2 for oppose. Answer: 0. The image"

In [None]:
# support
local_image_path = "/content/drive/My Drive/Dataset/dev/HCCAE/HCCAE_frames/HCCAE-024.jpg"
prompts = [
    # "Instruction: provide an answer to the question. Use the image to answer.\n",
    # "Nick isn't completely sure he wants to commit installing solar panels can be very expensive",
    local_image_path,
    "living healthier and more energy-efficient lives is the first step in doing this. Question: What's the stance of this frame-text pair towards climate change? Choose one between 0 for neutral, 1 for support and 2 for oppose. Answer:",
]
check_inference(model, processor, prompts, max_new_tokens=5)


"/content/drive/My Drive/Dataset/dev/HCCAE/HCCAE_frames/HCCAE-024.jpgliving healthier and more energy-efficient lives is the first step in doing this. Question: What's the stance of this image-text pair towards climate change? Choose one between 0 for neutral, 1 for support and 2 for oppose. Answer: 0.\n\n"

In [None]:
# neutral
local_image_path = "/content/drive/My Drive/Dataset/dev/HCCAE/HCCAE_frames/HCCAE-01.jpg"
prompts = [
    # "Instruction: provide an answer to the question. Use the image to answer.\n",
    # "Nick isn't completely sure he wants to commit installing solar panels can be very expensive",
    local_image_path,
    "How Climate Change Affects the Ecosystem. Question: What's the stance of this frame-text pair towards climate change? Choose one between 0 for neutral, 1 for support and 2 for oppose. Answer:",
]
check_inference(model, processor, prompts, max_new_tokens=5)


"/content/drive/My Drive/Dataset/dev/HCCAE/HCCAE_frames/HCCAE-01.jpgHow Climate Change Affects the Ecosystem. Question: What's the stance of this image-text pair towards climate change? Choose one between 0 for neutral, 1 for support and 2 for oppose. Answer: 0. The image"

In [None]:
test_image_caption_pairs[0]

{'image': '/content/drive/My Drive/Dataset/test/CCAH/CCAH_frames/CCAH-001.jpg',
 'caption': ' the crucial connection: climate change and health | kaiser permanente',
 'label': 1}

In [None]:
import re
def convert_to_rgb(image):
    # if not isinstance(image, Image.Image):
    #     raise TypeError("The provided input is not a PIL Image.")
    if image.mode == "RGB":
        return image
    image_rgba = image.convert("RGBA")
    background = Image.new("RGBA", image_rgba.size, (255, 255, 255))
    alpha_composite = Image.alpha_composite(background, image_rgba)
    return alpha_composite.convert("RGB")


def batch_inference(model, processor, data, max_new_tokens=5):
    image_size = processor.image_processor.image_size
    image_mean = processor.image_processor.image_mean
    image_std = processor.image_processor.image_std

    image_transform = transforms.Compose([
        convert_to_rgb,
        transforms.RandomResizedCrop((image_size, image_size), scale=(0.9, 1.0), interpolation=transforms.InterpolationMode.BICUBIC),
        transforms.ToTensor(),
        transforms.Normalize(mean=image_mean, std=image_std),
    ])

    predictions = []
    true_labels = []


    for image_path, text, true_label in data:
        image = Image.open(image_path)
        prompt = [
            image,
            f"{text}. Question: What's the stance of this frame-text pair towards climate change? Choose one between 0 for neutral, 1 for support, and 2 for oppose. Answer:",
        ]
        result = check_inference(model, processor, prompt, max_new_tokens=5)
        # inputs = processor(prompts, transform=image_transform, return_tensors="pt").to(device)


        try:
            predicted_answer = result.split("Answer:")[1]
            match = re.search(r'\b(0|1|2)\b', predicted_answer)
            stance = match.group() if match else None
        except (IndexError, AttributeError):
            stance = None

        # print(f"Predicted Answer: {predicted_answer.strip()}")
        # print(f"Extracted Stance: {stance}")

        predictions.append(stance)
        true_labels.append(true_label)

    return predictions, true_labels


predictions, true_labels = batch_inference(model, processor, test_image_caption_pairs)


In [None]:
int_pred = [int(item) if isinstance(item, str) else None for item in predictions]
print(int_pred)
print(true_labels)

[0, 1, 1, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 1, 1, 1, None, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 0, 0, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 1, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 1, 0, 

In [None]:
from sklearn.metrics import accuracy_score, f1_score

filtered_true_labels = [true for true, pred in zip(true_labels, int_pred) if true is not None and pred is not None]
filtered_int_pred = [pred for true, pred in zip(true_labels, int_pred) if true is not None and pred is not None]

acc = accuracy_score(filtered_true_labels, filtered_int_pred)
f1 = f1_score(filtered_true_labels, filtered_int_pred, average='weighted')
print(f"Accuracy: {acc}")
print(f"F1 Score: {f1}")

Accuracy: 0.34688995215311
F1 Score: 0.2704995353236151


# Dataset preparation

In [None]:
import random
random.shuffle(train_image_caption_pairs)

In [None]:
# new one
def convert_to_rgb(image):
    if image.mode == "RGB":
        return image

    image_rgba = image.convert("RGBA")
    background = Image.new("RGBA", image_rgba.size, (255, 255, 255))
    alpha_composite = Image.alpha_composite(background, image_rgba)
    alpha_composite = alpha_composite.convert("RGB")
    return alpha_composite

def ds_transforms(example_batch):
    image_size = processor.image_processor.image_size
    image_mean = processor.image_processor.image_mean
    image_std = processor.image_processor.image_std

    image_transform = transforms.Compose([
        convert_to_rgb,
        transforms.RandomResizedCrop((image_size, image_size), scale=(0.9, 1.0), interpolation=transforms.InterpolationMode.BICUBIC),
        transforms.ToTensor(),
        transforms.Normalize(mean=image_mean, std=image_std),
    ])

    prompts = []
    for i, frame_path in enumerate(example_batch["image"]):
        subtitle = example_batch["caption"][i]
        label = example_batch["label"][i]
        frame = Image.open(frame_path)
        # print(subtitle)
        # print(label)
        # print(frame_path)

        prompts.append(
            [
                frame,
                f"{subtitle}. Question: What's the stance of this frame-text pair towards climate change? Choose one between 0 for neutral, 1 for support and 2 for oppose. Answer:{label} </s>",
            ]
        )
        # print(prompts)

    inputs = processor(prompts, transform=image_transform, return_tensors="pt").to(device)

    inputs["labels"] = inputs["input_ids"]

    return inputs


train_ds.set_transform(ds_transforms)
dev_ds.set_transform(ds_transforms)

In [None]:
train_ds[0]

{'input_ids': tensor([    1, 32000, 32001, 32000,   920, 23622,  1735,  7112,  2063,  1737,
         13242,   277,  1199,   785, 15593,   411, 10838,  1111,  4445,   423,
         29889,   894, 29901,  1724, 29915, 29879,   278,   380,   749,   310,
           445,  3515, 29899,   726,  5101,  7113, 23622,  1735, 29973, 14542,
           852,   697,  1546, 29871, 29900,   363, 21104, 29892, 29871, 29896,
           363,  2304,   322, 29871, 29906,   363,  4575,   852, 29889,   673,
         29901, 29900, 29871,     2], device='cuda:0'),
 'attention_mask': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], device='cuda:0'),
 'pixel_values': tensor([[[[-0.8434, -0.7996, -0.7412,  ..., -1.2229, -1.2229, -1.2083],
           [-0.8142, -0.7850, -0.7704,  ..., -1.2521, -1.2521, -1.2375],
           [-0.7850, -0.7850, -0.8142, 

# LoRA
After specifying the low-rank adapters (LoRA) config, we load the PeftModel using the get_peft_model utility function

In [None]:
model_name = checkpoint.split("/")[1]
config = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules=["q_proj", "k_proj", "v_proj"],
    lora_dropout=0.05,
    bias="none",
)
model = get_peft_model(model, config)

In [None]:
model.print_trainable_parameters()

trainable params: 19,750,912 || all params: 8,949,430,544 || trainable%: 0.2207


# Fine tune

### HP-1

In [None]:
training_args = TrainingArguments(
    output_dir=f"{model_name}-multiclimate",
    learning_rate=2e-4,
    fp16=True,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    gradient_accumulation_steps=8,
    dataloader_pin_memory=False,
    save_total_limit=3,
    eval_strategy="steps",
    save_strategy="steps",
    save_steps=100,
    eval_steps=100,
    logging_steps=20,
    max_steps=400,
    remove_unused_columns=False,
    push_to_hub=False,
    label_names=["labels"],
    load_best_model_at_end=True,
    report_to=None,
    optim="paged_adamw_8bit",
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=dev_ds,
)

trainer.train()

max_steps is given, it will override any value given in num_train_epochs


Step,Training Loss,Validation Loss
100,0.941,0.896346
200,0.8634,0.903844
300,0.7766,0.931764


We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.45. Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)


Step,Training Loss,Validation Loss
100,0.941,0.896346
200,0.8634,0.903844
300,0.7766,0.931764
400,0.6628,0.980927


TrainOutput(global_step=400, training_loss=0.8827014517784119, metrics={'train_runtime': 6880.5502, 'train_samples_per_second': 1.86, 'train_steps_per_second': 0.058, 'total_flos': 5.783630887547539e+16, 'train_loss': 0.8827014517784119, 'epoch': 3.795966785290629})

# Check after fine tune

In [None]:
check_inference(model, processor, prompts, max_new_tokens=10)



"/content/WISE-011.jpgNick isn't completely sure he wants to commit installing solar panels can be very expensive. Question: What's the stance of this image-text pair towards climate change? Choose one between 0 for neutral, 1 for support and 2 for oppose. Answer:0 "

In [None]:
# oppose
local_image_path = "/content/drive/My Drive/Dataset/train/NASA/NASA_frames/NASA-06.jpg"
prompts = [
    # "Instruction: provide an answer to the question. Use the image to answer.\n",
    local_image_path,
    "Climate change is one of the most complex issues facing us today. Question: What's the stance of this frame-text pair towards climate change? Choose one between 0 for neutral, 1 for support and 2 for oppose. Answer:",
]
check_inference(model, processor, prompts, max_new_tokens=5)




"/content/drive/My Drive/Dataset/train/NASA/NASA_frames/NASA-06.jpgClimate change is one of the most complex issues facing us today. Question: What's the stance of this frame-text pair towards climate change? Choose one between 0 for neutral, 1 for support and 2 for oppose. Answer:0 "

In [None]:
# support
local_image_path = "/content/drive/My Drive/Dataset/dev/HCCAE/HCCAE_frames/HCCAE-024.jpg"
prompts = [
    # "Instruction: provide an answer to the question. Use the image to answer.\n",
    # "Nick isn't completely sure he wants to commit installing solar panels can be very expensive",
    local_image_path,
    "living healthier and more energy-efficient lives is the first step in doing this. Question: What's the stance of this frame-text pair towards climate change? Choose one between 0 for neutral, 1 for support and 2 for oppose. Answer:",
]
check_inference(model, processor, prompts, max_new_tokens=5)


"/content/drive/My Drive/Dataset/dev/HCCAE/HCCAE_frames/HCCAE-024.jpgliving healthier and more energy-efficient lives is the first step in doing this. Question: What's the stance of this frame-text pair towards climate change? Choose one between 0 for neutral, 1 for support and 2 for oppose. Answer:0 "

In [None]:
# neutral
local_image_path = "/content/drive/My Drive/Dataset/dev/HCCAE/HCCAE_frames/HCCAE-01.jpg"
prompts = [
    # "Instruction: provide an answer to the question. Use the image to answer.\n",
    # "Nick isn't completely sure he wants to commit installing solar panels can be very expensive",
    local_image_path,
    "How Climate Change Affects the Ecosystem. Question: What's the stance of this frame-text pair towards climate change? Choose one between 0 for neutral, 1 for support and 2 for oppose. Answer:",
]
check_inference(model, processor, prompts, max_new_tokens=5)


"/content/drive/My Drive/Dataset/dev/HCCAE/HCCAE_frames/HCCAE-01.jpgHow Climate Change Affects the Ecosystem. Question: What's the stance of this frame-text pair towards climate change? Choose one between 0 for neutral, 1 for support and 2 for oppose. Answer:0 "

In [None]:
import re
def convert_to_rgb(image):
    # if not isinstance(image, Image.Image):
    #     raise TypeError("The provided input is not a PIL Image.")
    if image.mode == "RGB":
        return image
    image_rgba = image.convert("RGBA")
    background = Image.new("RGBA", image_rgba.size, (255, 255, 255))
    alpha_composite = Image.alpha_composite(background, image_rgba)
    return alpha_composite.convert("RGB")


def batch_inference(model, processor, data, max_new_tokens=5):
    image_size = processor.image_processor.image_size
    image_mean = processor.image_processor.image_mean
    image_std = processor.image_processor.image_std

    image_transform = transforms.Compose([
        convert_to_rgb,
        transforms.RandomResizedCrop((image_size, image_size), scale=(0.9, 1.0), interpolation=transforms.InterpolationMode.BICUBIC),
        transforms.ToTensor(),
        transforms.Normalize(mean=image_mean, std=image_std),
    ])

    predictions = []
    true_labels = []

    for image_path, text, true_label in data:
        image = Image.open(image_path)
        prompt = [
            image,
            f"{text}. Question: What's the stance of this image-text pair towards climate change? Choose one between 0 for neutral, 1 for support, and 2 for oppose. Answer:",
        ]
        result = check_inference(model, processor, prompt, max_new_tokens=5)

        try:
            predicted_answer = result.split("Answer:")[1]
            match = re.search(r'\b(0|1|2)\b', predicted_answer)
            stance = match.group() if match else None
        except (IndexError, AttributeError):
            stance = None

        # print(f"Predicted Answer: {predicted_answer.strip()}")
        # print(f"Extracted Stance: {stance}")

        predictions.append(stance)
        true_labels.append(true_label)

    return predictions, true_labels


predictions, true_labels = batch_inference(model, processor, test_image_caption_pairs)


In [None]:
int_pred = [int(item) if isinstance(item, str) else None for item in predictions]
print(int_pred)
print(true_labels)

[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 1, 2, 1, 2, 2, 2, 2, 2, 0, 0, 0, 2, 0, 0, 0, 1, 2, 2, 1, 1, 1, 1, 1, 0, 1, 1, 2, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 2, 2, 0, 1, 1, 1, 2, 2, 2, 2, 2, 1, 2, 2, 2, 1, 2, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 2, 0, 0, 0, 0, 2, 1, 2, 2, 2, 1, 1, 1, 1, 2, 1, 1, 1, 2, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 2, 2, 2, 2, 1, 2, 2, 2, 1, 1, 1, 1, 2, 2, 1, 2, 2, 1, 1, 0, 1, 0, 0, 2, 1, 2, 2, 2, 2, 2, 2, 2, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 0, 0, 1, 0, 0, 0, 1, 1, 2, 1, 1, 1, 1, 2, 1, 2, 2, 2, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 2, 0, 1, 1, 1, 1, 1, 2, 1, 1, 1, 2, 0, 0, 0, 0, 0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2, 2, 2, 2, 1, 1, 2, 2, 0, 0, 0, 2, 2, 0, 2, 1, 1, 2, 2, 2, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 2, 1, 

In [None]:
test_image_caption_pairs[:10]

[['/content/drive/My Drive/Dataset/test/CCAH/CCAH_frames/CCAH-001.jpg',
  ' the crucial connection: climate change and health | kaiser permanente',
  1],
 ['/content/drive/My Drive/Dataset/test/CCAH/CCAH_frames/CCAH-002.jpg',
  ' we know climate-related changes on our planet cause extreme weather.',
  1],
 ['/content/drive/My Drive/Dataset/test/CCAH/CCAH_frames/CCAH-003.jpg',
  ' and extreme weather has life-changing effects.',
  2],
 ['/content/drive/My Drive/Dataset/test/CCAH/CCAH_frames/CCAH-004.jpg',
  ' we see this in the rise of weather-related disasters around the globe.',
  2],
 ['/content/drive/My Drive/Dataset/test/CCAH/CCAH_frames/CCAH-005.jpg',
  ' but there’s one aspect of this topic we don’t hear much about: how does climate change affect our health?',
  0],
 ['/content/drive/My Drive/Dataset/test/CCAH/CCAH_frames/CCAH-006.jpg',
  ' well, as temperatures rise, record-setting heat waves become more frequent.',
  2],
 ['/content/drive/My Drive/Dataset/test/CCAH/CCAH_frames/

In [None]:
from sklearn.metrics import accuracy_score, f1_score

filtered_true_labels = [true for true, pred in zip(true_labels, int_pred) if true is not None and pred is not None]
filtered_int_pred = [pred for true, pred in zip(true_labels, int_pred) if true is not None and pred is not None]

# Calculate accuracy and F1 score
acc = accuracy_score(filtered_true_labels, filtered_int_pred)
f1 = f1_score(filtered_true_labels, filtered_int_pred, average='weighted')
print(f"Accuracy: {acc}")
print(f"F1 Score: {f1}")

Accuracy: 0.6
F1 Score: 0.5905311658071217
