# Train, Inference, Evaluate (WIP)

## Creating our working directory

For our experiments, we'll use the following folder to save the model, training artifacts, and our working configs.

In [None]:
## TODO: Oumi -> return_full_text field so the user can deactivate the model
# responding with the input text as well.
# see https://huggingface.co/docs/transformers/v4.17.0/main_classes/pipelines

## Add reference: https://github.com/QwenLM/Qwen2-VL

In [None]:
# Gor the following warning:
# /home/gcpuser/miniconda3/lib/python3.10/site-packages/tqdm/auto.py:21:
# TqdmWarning: IProgress not found. Please update jupyter and ipywidgets.
# See https://ipywidgets.readthedocs.io/en/stable/user_install.html
#   from .autonotebook import tqdm as notebook_tqdm

!pip install ipywidgets

In [1]:
from pathlib import Path

tutorial_dir = "vision_language_tutorial"

Path(tutorial_dir).mkdir(parents=True, exist_ok=True)

In [2]:
import os

os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [None]:
from oumi.builders import build_tokenizer
from oumi.core.configs import ModelParams
from oumi.datasets.vision_language.vqav2_small import Vqav2SmallDataset

# Initialize the dataset, build the tokenizer
model_name = "Qwen/Qwen2-VL-2B-Instruct"
tokenizer = build_tokenizer(ModelParams(model_name=model_name))
dataset = Vqav2SmallDataset(
    tokenizer=tokenizer,
    processor_name=model_name,
    # limit=1000,  # Limit the number of examples to load for demonstration purposes
)
print("Examples included:", len(dataset))

In [None]:
import io

from PIL import Image

from oumi.core.types.conversation import Type

# Print a few examples
for i in range(2):
    conversation = dataset.conversation(i)
    print(f"Example {i + 1}:")
    for message in conversation.messages:
        ## More pythonic way to display image below???
        if message.role == "user":  # User poses a question, regarding an image
            img_content = message.content[0]
            assert (
                img_content.type == Type.IMAGE_BINARY
            ), "Oumi encodes image content in binary."
            image = Image.open(io.BytesIO(img_content.binary))
            display(image.resize((256, 256)))  # Resize for display

        print(f"{message.role}: {message.content[:100]}...")  # Truncate for brevity
    print("\n")

In [None]:
## If you want to see directly the data
dataset.data.head()  # Display the first few rows of the dataset

In [None]:
%%writefile $tutorial_dir/infer.yaml

model:
  model_name: "Qwen/Qwen2-VL-2B-Instruct"
  torch_dtype_str: "bfloat16" # Assumes your GPU supports bfloat16 (Ampere or newer)
  chat_template: "qwen2-vl-instruct"
  model_max_length: 4096
  trust_remote_code: True
  
generation:
  max_new_tokens: 64
  batch_size: 1
  
engine: NATIVE # We are using a native engine for inference, consider VLLM if available for much faster inference

In [None]:
from oumi.core.configs import InferenceConfig
from oumi.infer import infer

config = InferenceConfig.from_yaml(str(Path(tutorial_dir) / "infer.yaml"))

# Use the data of the ith conversation as input
conversation_id = 1
query_img = dataset.conversation(conversation_id).messages[0].image_content_items
query_text = dataset.conversation(conversation_id).messages[0].text_content_items

print(query_text)

results = infer(
    config=config,
    inputs=[str(query_text[0])],
    # inputs=["Desrcibe the image"],
    input_image_bytes=query_img[0].binary,
)

In [None]:
x = results[0]
x

In [None]:
from PIL import ImageDraw

# Load the image
query_img_bytes = query_img[0].binary
image = Image.open(io.BytesIO(query_img_bytes))

# Define bounding box coordinates based on the given format
top_left = (101, 39)  # (X_top_left, Y_top_left)
bottom_right = (341, 694)  # (X_bottom_right, Y_bottom_right)

# Draw the bounding box
draw = ImageDraw.Draw(image)
draw.rectangle([top_left, bottom_right], outline="red", width=1)

# Show the image with bounding box
image.show()

In [None]:
%%writefile $tutorial_dir/train.yaml

model:
  model_name: "Qwen/Qwen2-VL-2B-Instruct"
  torch_dtype_str: "bfloat16"
  model_max_length: 4096
  trust_remote_code: True
  attn_implementation: "sdpa"
  chat_template: "qwen2-vl-instruct"
  freeze_layers:
    - "visual"     # Let's train only the language component of the model for faster training

data:
  train:
    collator_name: "vision_language_with_padding" # simple padding collator
    datasets:
      - dataset_name: "merve/vqav2-small"
        split: "validation" # This dataset has only a validation split
        shuffle: True
        seed: 42
        transform_num_workers: "auto"
        dataset_kwargs:
          processor_name: "Qwen/Qwen2-VL-2B-Instruct" # i.e., the default for our model
          limit: 4096
          return_tensors: True      

training:
  output_dir: "vision_language_tutorial"
  trainer_type: "TRL_SFT"
  enable_gradient_checkpointing: True
  per_device_train_batch_size: 1 # Must be 1: the model generates variable-sized image features.
  gradient_accumulation_steps: 32
  
  # ***NOTE***
  # We set it to 10 steps to first verify that it works
  # Swap to num_train_epochs: 1 to get more meaningful results.
  # Note: 1 training epoch will take XXX hours on a single A100-40GB GPU.
  # max_steps: 20
  num_train_epochs: 1

  gradient_checkpointing_kwargs:
    # Reentrant docs: https://pytorch.org/docs/stable/checkpoint.html#torch.utils.checkpoint.checkpoint
    use_reentrant: False
  ddp_find_unused_parameters: False
  empty_device_cache_steps: 1

  optimizer: "adamw_torch_fused"
  learning_rate: 2e-5
  warmup_ratio: 0.03
  weight_decay: 0.0
  lr_scheduler_type: "cosine"

  logging_steps: 5
  save_steps: 0
  dataloader_main_process_only: False
  dataloader_num_workers: 2
  dataloader_prefetch_factor: 8
  include_performance_metrics: True
  enable_wandb: True # Set to False if you don't want to use Weights & Biases

In [None]:
!oumi train -c "$tutorial_dir/train.yaml"

## Use the Fine-tuned Model

Once we're happy with the results, we can serve the fine-tuned model for interactive inference:

In [None]:
%%writefile $tutorial_dir/trained_infer.yaml

model:
  model_name: "vision_language_tutorial"  
  torch_dtype_str: "bfloat16" # Assumes your GPU supports bfloat16 (Ampere or newer)
  chat_template: "qwen2-vl-instruct"
  model_max_length: 4096
  trust_remote_code: True

generation:
  max_new_tokens: 64
  batch_size: 1
  
engine: NATIVE 

In [None]:
config = InferenceConfig.from_yaml(str(Path(tutorial_dir) / "trained_infer.yaml"))

# Use the data of the first conversation as input
query_img = dataset.conversation(0).messages[0].image_content_items
query_text = dataset.conversation(0).messages[0].text_content_items

print(query_text)

results = infer(
    config=config,
    inputs=[str(query_text[0])],
    # inputs=["Desrcibe the image"],
    input_image_bytes=query_img[0].binary,
)

results[0]

In [None]:
# TODO. New dataset: Loading the dataset
# dataset_id = "HuggingFaceM4/the_cauldron"
# subset = "geomverse"
# dataset = load_dataset(dataset_id, subset, split="train")

# # Selecting a subset of 3K samples for fine-tuning
# dataset = dataset.select(range(3000))
# print(f"Using a sample size of {len(dataset)} for fine-tuning.")
# print(dataset)