<div align="center" dir="auto">
<p dir="auto">

<a href="https://colab.research.google.com/github/write-with-neurl/deepgram-content/blob/main/MultimodalAI/Finetuning_Idefics_9B.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

</p>



# ⚡ [Deepgram] Finetuning IDEFICS 9B LLM for image description of Pokemon Go images

In this example, we fine-tune IDEFICS on a Pokémon Go dataset from Hugging Face. [IDEFICS](https://huggingface.co/docs/transformers/main/en/tasks/idefics) is an open-access vision and language model based on [Flamingo](https://huggingface.co/papers/2204.14198), a state-of-the-art visual language model initially developed by DeepMind. The model accepts arbitrary sequences of image and text inputs, generating coherent text as output. It can answer questions about images, describe visual content, and create stories grounded in multiple images.


> Check out the accompanying article for this notebook "Multimodal AI in Action" on the [Deepgram blog](https://deepgram.com/learn/article).

## 🧑‍💻 Installations and Set Up

In [None]:
!pip install -q datasets
!pip install -q git+https://github.com/huggingface/transformers
!pip install -q bitsandbytes sentencepiece accelerate loralib
!pip install -q -U git+https://github.com/huggingface/peft.git

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m510.5/510.5 kB[0m [31m6.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m11.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.8/134.8 kB[0m [31m11.5 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
  Building wheel for transformers (pyproject.toml) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m102.2/102.2 MB[0m [31m16.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m280.0/280.0 kB[0m [31m33.0 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadat

In [None]:
import torch
from datasets import load_dataset
from peft import LoraConfig, get_peft_model
from PIL import Image
from transformers import IdeficsForVisionText2Text, AutoProcessor, Trainer, TrainingArguments, BitsAndBytesConfig
import torchvision.transforms as transforms

## ⏳ Loading and Quantization of Idefics 9B
Since high-memory GPU availability may be a concern, we opt to load the quantized version of the model. To load the model and the processor in 4-bit precision, we pass a BitsAndBytesConfig to the `from_pretrained` method, and the model will compress on the fly during loading.

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
checkpoint = "HuggingFaceM4/idefics-9b"

In [None]:
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
    llm_int8_skip_modules=["lm_head", "embed_tokens"]
)

In [None]:
processor = AutoProcessor.from_pretrained(checkpoint)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


preprocessor_config.json:   0%|          | 0.00/281 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/1.36k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.84M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/61.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/181 [00:00<?, ?B/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [None]:
model = IdeficsForVisionText2Text.from_pretrained(checkpoint, quantization_config=bnb_config, device_map="auto")

config.json:   0%|          | 0.00/1.41k [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/99.2k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/19 [00:00<?, ?it/s]

model-00001-of-00019.safetensors:   0%|          | 0.00/2.00G [00:00<?, ?B/s]

model-00002-of-00019.safetensors:   0%|          | 0.00/1.82G [00:00<?, ?B/s]

model-00003-of-00019.safetensors:   0%|          | 0.00/1.98G [00:00<?, ?B/s]

model-00004-of-00019.safetensors:   0%|          | 0.00/1.93G [00:00<?, ?B/s]

model-00005-of-00019.safetensors:   0%|          | 0.00/1.93G [00:00<?, ?B/s]

model-00006-of-00019.safetensors:   0%|          | 0.00/1.98G [00:00<?, ?B/s]

model-00007-of-00019.safetensors:   0%|          | 0.00/1.89G [00:00<?, ?B/s]

model-00008-of-00019.safetensors:   0%|          | 0.00/1.98G [00:00<?, ?B/s]

model-00009-of-00019.safetensors:   0%|          | 0.00/1.93G [00:00<?, ?B/s]

model-00010-of-00019.safetensors:   0%|          | 0.00/1.93G [00:00<?, ?B/s]

model-00011-of-00019.safetensors:   0%|          | 0.00/1.98G [00:00<?, ?B/s]

model-00012-of-00019.safetensors:   0%|          | 0.00/1.89G [00:00<?, ?B/s]

model-00013-of-00019.safetensors:   0%|          | 0.00/1.98G [00:00<?, ?B/s]

model-00014-of-00019.safetensors:   0%|          | 0.00/1.93G [00:00<?, ?B/s]

model-00015-of-00019.safetensors:   0%|          | 0.00/1.93G [00:00<?, ?B/s]

model-00016-of-00019.safetensors:   0%|          | 0.00/1.97G [00:00<?, ?B/s]

model-00017-of-00019.safetensors:   0%|          | 0.00/1.98G [00:00<?, ?B/s]

model-00018-of-00019.safetensors:   0%|          | 0.00/1.97G [00:00<?, ?B/s]

model-00019-of-00019.safetensors:   0%|          | 0.00/705M [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/19 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/137 [00:00<?, ?B/s]

In [None]:
model

IdeficsForVisionText2Text(
  (model): IdeficsModel(
    (embed_tokens): IdeficsDecoupledEmbedding(
      num_embeddings=32000, num_additional_embeddings=2, embedding_dim=4096, partially_freeze=False
      (additional_embedding): Embedding(2, 4096)
    )
    (vision_model): IdeficsVisionTransformer(
      (embeddings): IdeficsVisionEmbeddings(
        (patch_embedding): Conv2d(3, 1280, kernel_size=(14, 14), stride=(14, 14), bias=False)
        (position_embedding): Embedding(257, 1280)
      )
      (pre_layrnorm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
      (encoder): IdeficsVisionEncoder(
        (layers): ModuleList(
          (0-31): 32 x IdeficsVisionEncoderLayer(
            (self_attn): IdeficsVisionAttention(
              (k_proj): Linear4bit(in_features=1280, out_features=1280, bias=True)
              (v_proj): Linear4bit(in_features=1280, out_features=1280, bias=True)
              (q_proj): Linear4bit(in_features=1280, out_features=1280, bias=True)
        

## Inference Function for Generating Responses
This function, `do_inference`, carries out the inference process for the 4-bit quantized Idefics 9B model and processor with specified prompts. It initializes a tokenizer from the processor to encode prompts and specify tokens that are not desirable in the output (bad words). If any such tokens are listed, their IDs are retrieved. The function also defines an end-of-sequence token to signal the model when to stop generating further tokens. It then encodes the prompts into input tensors compatible with the model and generates a sequence of IDs representing the continuation of the prompts, avoiding the bad words, and stopping after a maximum number of new tokens or when the end-of-sequence token is reached. Finally, it decodes the generated IDs back into text and prints the result, ensuring any special tokens are not included in the output.

In [None]:
# Inference
def do_inference(model, processor, prompts, max_new_tokens=50):
    tokenizer = processor.tokenizer
    bad_words = ["<image>", "<fake_token_around_image>"]
    if len(bad_words) > 0:
        bad_words_ids = tokenizer(bad_words, add_special_tokens=False).input_ids
    eos_token = "</s>"
    eos_token_id = tokenizer.convert_tokens_to_ids(eos_token)

    inputs = processor(prompts, return_tensors="pt").to(device)
    generated_ids = model.generate(
        **inputs,
        eos_token_id=[eos_token_id],
        bad_words_ids=bad_words_ids,
        max_new_tokens=max_new_tokens,
        early_stopping=True
    )
    generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
    print(generated_text)

In [None]:
url = "https://hips.hearstapps.com/hmg-prod/images/cute-photos-of-cats-in-grass-1593184777.jpg"
prompts = [
    url,
    "Question: What's on the picture? Answer:",
]

In [None]:
do_inference(model, processor, prompts, max_new_tokens=50)



Question: What's on the picture? Answer: Two kittens playing in the grass.


# 🛠 Finetuning Setup

## Preprocessing of Pokemon Go dataset

The function `convert_to_rgb` ensures that any given image is in the RGB color space. If the image is already in RGB, it is returned as is; if not, it converts the image to RGBA to blend it with a white background, effectively removing any transparency, before converting it to RGB.



In [None]:
def convert_to_rgb(image):
  if image.mode == "RGB":
    return image

  image_rgba = image.convert("RGBA")
  background = Image.new("RGBA", image_rgba.size, (255,255,255))
  alpha_composite = Image.alpha_composite(background, image_rgba)
  alpha_composite = alpha_composite.convert("RGB")
  return alpha_composite

`ds_transforms` is a preprocessing function that takes a batch of examples, each consisting of image URLs, captions, and names. It defines a series of image transformations, including conversion to RGB, random cropping, and normalization using predefined parameters from a processor. For each example, it extracts a prompt from the caption and forms a question-answer pair that includes the name and a statement about the image. These prompts are then processed along with the transformed images into tensors ready for input into a model, also setting the input IDs as labels for supervised training. The function concludes by returning the processed inputs and labels, prepared for the model.

In [None]:
def ds_transforms(example_batch):
  image_size = processor.image_processor.image_size
  image_mean = processor.image_processor.image_mean
  image_std = processor.image_processor.image_std

  image_transform = transforms.Compose([
      convert_to_rgb,
      transforms.RandomResizedCrop((image_size, image_size), scale=(0.9, 1.0), interpolation=transforms.InterpolationMode.BICUBIC),
      transforms.ToTensor(),
      transforms.Normalize(mean=image_mean, std=image_std)
  ])

  prompts = []
  print(example_batch)
  for i in range(len(example_batch['caption'])):
    caption = example_batch['caption'][i].split(".")[0]
    prompts.append(
        [
            example_batch['image_url'][i],
            f"Question: What's on the picture? Answer: This is {example_batch['name']}. {caption}",
        ],
    )
  inputs = processor(prompts, transform=image_transform, return_tensors="pt").to(device)
  inputs["labels"] = inputs["input_ids"]
  return inputs

## Load and prepare the data

In [None]:
ds = load_dataset("TheFusion21/PokemonCards")
ds = ds["train"].train_test_split(test_size=0.002)
train_ds = ds["train"]
eval_ds = ds["test"]
train_ds.set_transform(ds_transforms)
eval_ds.set_transform(ds_transforms)

In [None]:
model_name = checkpoint.split("/")[1]
config = LoraConfig(
    r = 16,
    lora_alpha = 32,
    target_modules = ["q_proj", "k_proj", "v_proj"],
    lora_dropout = 0.05,
    bias="none"
)

In [None]:
model = get_peft_model(model, config)

In [None]:
model.print_trainable_parameters()

trainable params: 19,750,912 || all params: 8,949,430,544 || trainable%: 0.2206946230030432


In [None]:
training_args = TrainingArguments(
    output_dir = f"{model_name}-PokemonCards",
    learning_rate = 2e-4,
    fp16 = True,
    per_device_train_batch_size = 2,
    per_device_eval_batch_size = 2,
    gradient_accumulation_steps = 8,
    dataloader_pin_memory = False,
    save_total_limit = 3,
    evaluation_strategy ="steps",
    save_strategy = "steps",
    eval_steps = 10,
    save_steps = 25,
    max_steps = 25,
    logging_steps = 5,
    remove_unused_columns = False,
    push_to_hub=False,
    label_names = ["labels"],
    load_best_model_at_end = False,
    report_to = "none",
    optim = "paged_adamw_8bit",
)

In [None]:
trainer = Trainer(
    model = model,
    args = training_args,
    train_dataset = train_ds,
    eval_dataset = eval_ds
)

In [None]:
trainer.train()

{'id': ['ex10-14', 'ex10-51'], 'image_url': ['https://images.pokemontcg.io/ex10/14_hires.png', 'https://images.pokemontcg.io/ex10/51_hires.png'], 'caption': ["A Stage 1 Pokemon Card of type Psychic with the title Slowking and 70 HP of rarity Rare Holo evolved from Slowpoke from the set Unseen Forces.  It has the attack Aftermath with the cost Psychic, Colorless, the energy cost 2 and the damage of 20+ with the description: Does 20 damage plus 10 more damage for each Pokemon Tool card in your discard pile. You can't add more than 60 damage in this way. It has the ability Item Search with the description: Once during your turn (before your attack), you may search your deck for a Pokemon Tool card, show it to your opponent, and put it into your hand. Shuffle your deck afterward. This power can't be used if Slowking is affected by a Special Condition. It has weakness against Grass 2. ", 'A Basic Pokemon Card of type Grass with the title Chikorita and 50 HP of rarity Common from the set Uns

Step,Training Loss,Validation Loss
10,1.5975,1.175423
20,0.8949,0.887494


{'id': ['xy2-11', 'dp1-72'], 'image_url': ['https://images.pokemontcg.io/xy2/11_hires.png', 'https://images.pokemontcg.io/dp1/72_hires.png'], 'caption': ['A Basic, EX Pokemon Card of type Fire with the title Charizard-EX and 180 HP of rarity Rare Holo EX from the set Flashfire.  It has the attack Stoke with the cost Colorless, the energy cost 1 with the description: Flip a coin. If heads, search your deck for up to 3 basic Energy cards and attach them to this Pokemon. Shuffle your deck afterward. It has the attack Fire Blast with the cost Fire, Colorless, Colorless, Colorless, the energy cost 4 and the damage of 120 with the description: Discard an Energy attached to this Pokemon. It has weakness against Water 2. ', 'A Basic Pokemon Card of type Water with the title Buizel and 60 HP of rarity Common from the set Diamond & Pearl and the flavor text: It has a flotation sac that is like an inflatable collar. It floats on water with its head out. It has the attack Splash About with the cos

TrainOutput(global_step=25, training_loss=1.4273568725585937, metrics={'train_runtime': 545.5301, 'train_samples_per_second': 0.733, 'train_steps_per_second': 0.046, 'total_flos': 1878661821522816.0, 'train_loss': 1.4273568725585937, 'epoch': 0.03})

## Inference on fine-tuned model
In this section, we test the deployed finetuned model sending a url of a pokemon card and text prompt and receiving the generated text in return.

In [None]:
url = "https://images.pokemontcg.io/pop6/2_hires.png"

In [None]:
prompts = [
    url,
    "Question: What's on the picture? Answer:",
]

In [None]:
do_inference(model, processor, prompts, max_new_tokens=100)



Question: What's on the picture? Answer: This is ['Lucario-GX', 'Lucario']. A Basic Pokemon Card of type Fire with the title Lucario-GX and 90 HP of rarity Rare Holo from the set Unbound Legends and the flavor text: It's a Pokemon that can use its tail as a weapon. It's a Pokemon that can use its tail as a weapon. It evolves from Lucario when it is traded with
