## Set Up

In [49]:
!pip install -q git+https://github.com/huggingface/transformers.git
!pip install  -U -q trl datasets bitsandbytes peft accelerate
# Tested with transformers==4.49.0.dev0, trl==0.14.0, datasets==3.2.0, bitsandbytes==0.45.2, peft==0.14.0, accelerate==1.3.0

  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone


In [50]:
!pip install --upgrade huggingface-hub




## Dataset

In [51]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [52]:
import os
import pandas as pd
from datasets import Dataset, Image

PROJECT_ROOT = "/content/drive/MyDrive/dl-project"
annot_root = os.path.join(PROJECT_ROOT, "My-First-Project-2", "train")

df = pd.read_csv(os.path.join(annot_root, "annotations_train.csv"))

df['caption'] = df.apply(
    lambda row: f"max: {row['max']}, min: {row['min']}, range: {row['lower_range']}-{row['upper_range']}, title: {row['title']}, domain: {row['domain']}",
    axis=1
)
image_base_dir_abs = annot_root
df['absolute_image_path'] = df['image_filename'].apply(
    lambda x: os.path.join(image_base_dir_abs, x)
)[:-4]
df = df[['absolute_image_path', 'caption']]

hf_dataset = Dataset.from_pandas(df)
hf_dataset = hf_dataset.cast_column("absolute_image_path", Image())
hf_dataset = hf_dataset.rename_column("absolute_image_path", "image")
dataset = hf_dataset

dataset, dataset['image'][0], dataset['caption'][0]

(Dataset({
     features: ['image', 'caption'],
     num_rows: 141
 }),
 <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=1024x1024>,
 'max: 78.0, min: 12.0, range: 0.0-90.0, title: Sex ratio by age (100 year olds) (UNWPP, 2017) in the Year 2005, domain: healthcare')

In [53]:
PROMPT = f"""
You are given an image of a graph. Your task is to analyze it and extract structured information.

Return your findings as a **valid, minified JSON object** with the following fields.
If any detail cannot be determined from the image, set its value to null (without quotes).

{{
  "maximum": "The highest y-value visible on the graph.",
  "minimum": "The lowest y-value visible on the graph.",
  "range": "The overall span of the y-axis, written as 'min-max'.",
  "title": "The exact title text shown on the graph, if present. If not present, write null (wthout quotes)",
  "domain": "The subject domain of the graph. Choose ONE ONLY from the following: economics, healthcare, politics, environment, technology, entertainment, animal, linguistics, internet, miscellaneous. If none of these options are correct, output null."
}}

**Guidelines:**
1. Base all answers strictly on what is visible in the graph; do not infer or invent data.
2. Include numerical values exactly as they appear (no rounding).
3. Maintain factual, neutral descriptions.
4. Output only the final JSON object — no text, commentary, or markdown.

Output ONLY the JSON object with string values for each aspect.
"""


In [54]:
system_message = """
You are a specialized agent that extracts information from graphs.

Your output must be a **valid, minified JSON object** with exactly the following keys:

{
  "maximum": <number>,
  "minimum": <number>,
  "range": "<min>-<max>",
  "title": "<string_or_null>",
  "domain": "<string_or_null>"
}

Rules:
- Replace <...> with the extracted values.
- "range" must be a single string formatted EXACTLY as "min-max".
- If the title is not present, output null (without quotes).
- Output ONLY the JSON object. No explanation, no markdown.
- The final answer must be valid minified JSON (no spaces or newlines).
"""

def format_data(sample):
    return [
        {
            "role": "system",
            "content": [
                {
                    "type": "text",
                    "text": system_message
                }
            ],
        },
        {
            "role": "user",
            "content": [
                {
                    "type": "image",
                    "image": sample["image"],
                },
                {
                    "type": "text",
                    # "text": sample['question'],
                    "text": PROMPT
                }
            ],
        },
        {
            "role": "assistant",
            "content": [
                {
                    "type": "text",
                    "text": sample["caption"]
                }
            ],
        },
    ]

In [55]:
train_dataset = [format_data(x) for x in dataset]
train_dataset[0]

[{'role': 'system',
  'content': [{'type': 'text',
    'text': '\nYou are a specialized agent that extracts information from graphs.\n\nYour output must be a **valid, minified JSON object** with exactly the following keys:\n\n{\n  "maximum": <number>,\n  "minimum": <number>,\n  "range": "<min>-<max>",\n  "title": "<string_or_null>",\n  "domain": "<string_or_null>"\n}\n\nRules:\n- Replace <...> with the extracted values.\n- "range" must be a single string formatted EXACTLY as "min-max".\n- If the title is not present, output null (without quotes).\n- Output ONLY the JSON object. No explanation, no markdown.\n- The final answer must be valid minified JSON (no spaces or newlines).\n'}]},
 {'role': 'user',
  'content': [{'type': 'image',
    'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=1024x1024>},
   {'type': 'text',
    'text': '\nYou are given an image of a graph. Your task is to analyze it and extract structured information.\n\nReturn your findings as a **valid, mini

In [56]:
## code to generate annotations for test dataset
import csv
import json

jsonl_file_path = os.path.join(PROJECT_ROOT, "My-First-Project-2/test/", "annotations.jsonl")
data = {}



with open(jsonl_file_path, 'r', encoding='utf-8') as f:

    for line in f:
        try:
            json_obj = json.loads(line)
            image_name = json_obj.get("image")
            prefix = json_obj.get("prefix")
            suffix = json_obj.get("suffix")

            if image_name not in data:
                data[image_name] = {}
            data[image_name][prefix] = suffix
        except json.JSONDecodeError:
            print(f"Skipping invalid JSON line: {line.strip()}")
            continue

# print(data)

os.chdir(os.path.join(PROJECT_ROOT,"My-First-Project-2/test/"))

with open("annotations_test.csv", "w", newline="", encoding="utf-8") as f:
    writer = csv.writer(f)
    writer.writerow(["image_filename", "max", "min", "range", "title", "domain"])

    for img, info in data.items():
        writer.writerow([
            img,
            info.get("What is the maximum?", ""),
            info.get("What is the minimum?", ""),
            info.get("What is the range of the y-axis? Format as min-max (No spaces)", ""),
            info.get("What is the title?", ""),
            info.get("What is the domain?", "")
        ])

df_truth = pd.read_csv("annotations_test.csv")

df_truth['caption'] = df_truth.apply(
    lambda row: f"max: {row['max']}, min: {row['min']}, range: {row['range']}, title: {row['title']}, domain: {row['domain']}",
    axis=1
)

image_base_dir_abs = os.path.join(PROJECT_ROOT, "My-First-Project-2", "test")


df_truth['absolute_image_path'] = df_truth['image_filename'].apply(
    lambda x: os.path.join(image_base_dir_abs, x)
)[:-4]



df_truth = df_truth[['absolute_image_path', 'caption']]

hf_dataset_truth = Dataset.from_pandas(df)
hf_dataset_truth = hf_dataset_truth.cast_column("absolute_image_path", Image())
hf_dataset_truth = hf_dataset_truth.rename_column("absolute_image_path", "image")
test_dataset = hf_dataset_truth

test_dataset, test_dataset['image'][0], test_dataset['caption'][0]

test_dataset = [format_data(x) for x in test_dataset]
test_dataset[0]


[{'role': 'system',
  'content': [{'type': 'text',
    'text': '\nYou are a specialized agent that extracts information from graphs.\n\nYour output must be a **valid, minified JSON object** with exactly the following keys:\n\n{\n  "maximum": <number>,\n  "minimum": <number>,\n  "range": "<min>-<max>",\n  "title": "<string_or_null>",\n  "domain": "<string_or_null>"\n}\n\nRules:\n- Replace <...> with the extracted values.\n- "range" must be a single string formatted EXACTLY as "min-max".\n- If the title is not present, output null (without quotes).\n- Output ONLY the JSON object. No explanation, no markdown.\n- The final answer must be valid minified JSON (no spaces or newlines).\n'}]},
 {'role': 'user',
  'content': [{'type': 'image',
    'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=1024x1024>},
   {'type': 'text',
    'text': '\nYou are given an image of a graph. Your task is to analyze it and extract structured information.\n\nReturn your findings as a **valid, mini

## Model

In [57]:
import torch
from transformers import AutoModelForVision2Seq, AutoProcessor

model_id = "ibm-granite/granite-vision-3.3-2b"
processor = AutoProcessor.from_pretrained(model_id)


## Training

In [58]:
from transformers import BitsAndBytesConfig

USE_QLORA = True
USE_LORA = True

if USE_QLORA:
    # BitsAndBytesConfig int-4 config
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="fp4", ## nf4
        bnb_4bit_compute_dtype=torch.bfloat16,
        llm_int8_skip_modules=["vision_tower", "lm_head"],  # Skip problematic modules
        llm_int8_enable_fp32_cpu_offload=True
    )
else:
    bnb_config = None

# Load model and tokenizer
model = AutoModelForVision2Seq.from_pretrained(
    model_id,
    device_map="auto",
    dtype=torch.bfloat16,
    quantization_config=bnb_config,
    _attn_implementation=None,
)
processor = AutoProcessor.from_pretrained(model_id)



Loading weights:   0%|          | 0/815 [00:00<?, ?it/s]

In [59]:
if USE_LORA:
    from peft import LoraConfig, get_peft_model

    # Configure LoRA
    peft_config = LoraConfig(
        r=8,
        lora_alpha=8,
        lora_dropout=0.1,
        target_modules=[name for name, _ in model.named_modules() if 'language_model' in name and '_proj' in name],
        use_dora=True,
        init_lora_weights="gaussian"
    )

    # Apply PEFT model adaptation
    # model = get_peft_model(model, peft_config)
    model.add_adapter(peft_config)
    model.enable_adapters()
    model = get_peft_model(model, peft_config)

    # Print trainable parameters
    model.print_trainable_parameters()

else:
    peft_config = None



trainable params: 15,032,320 || all params: 2,990,429,248 || trainable%: 0.5027


In [60]:
from trl import SFTConfig
output_dir="./checkpoints/granite"

# Configure training arguments using SFTConfig
training_args = SFTConfig(
    output_dir="./checkpoints/granite",
    num_train_epochs=5,
    # max_steps=30,
    per_device_train_batch_size=8,
    gradient_accumulation_steps=2,
    warmup_steps=10,
    learning_rate=1e-4,
    weight_decay=0.01,
    optim="adamw_torch_fused",
    bf16=True,
    push_to_hub=False,
    report_to="none",
    remove_unused_columns=False,
    gradient_checkpointing=True,
    dataset_text_field="",
    dataset_kwargs={"skip_prepare_dataset": True},
    logging_strategy="epoch",
    logging_steps=10,
    save_strategy="epoch",
    save_steps=20,
    save_total_limit=1,
    # evaluation_strategy="epoch",
    # eval_dataset=test_dataset
)

In [61]:
def collate_fn(examples):
    texts = [processor.apply_chat_template(example, tokenize=False) for example in examples]

    image_inputs = []
    for i in range(len(examples)):
      example = examples[i]
      image = example[1]['content'][0]['image']
      if image is None:
        texts.pop(i)
        continue
      if image.mode != 'RGB':
          image = image.convert('RGB')


      image_inputs.append([image])

    batch = processor(text=texts, images=image_inputs, return_tensors="pt", padding=True)



    labels = batch["input_ids"].clone()
    assistant_tokens = processor.tokenizer("<|assistant|>", return_tensors="pt")['input_ids'][0]
    eos_token = processor.tokenizer("<|end_of_text|>", return_tensors="pt")['input_ids'][0]



    for i in range(batch["input_ids"].shape[0]):
        apply_loss = False
        for j in range(batch["input_ids"].shape[1]):
            if not apply_loss:
                labels[i][j] = -100
            if ((j>=len(assistant_tokens)+1) and
                torch.all(batch["input_ids"][i][j+1-len(assistant_tokens):j+1]==assistant_tokens)):
                apply_loss = True
            if batch["input_ids"][i][j]==eos_token:
                apply_loss = False



    batch["labels"] = labels


    return batch

In [62]:
from transformers import TrainerCallback

class LossLoggerCallback(TrainerCallback):
    def __init__(self):
        self.train_losses = []
        self.eval_losses = []

    def on_log(self, args, state, control, logs=None, **kwargs):
        if logs is None:
            return
        if "loss" in logs:
            self.train_losses.append(logs["loss"])
        if "eval_loss" in logs:
            self.eval_losses.append(logs["eval_loss"])


In [63]:
import matplotlib.pyplot as plt
from IPython.display import clear_output
from transformers import TrainerCallback

class LiveLossPlotCallback(TrainerCallback):
    def __init__(self):
        super().__init__()
        self.train_losses = []
        self.eval_losses = []

    def on_log(self, args, state, control, logs=None, **kwargs):
        if logs is None:
            return

        if "loss" in logs:
            self.train_losses.append(logs["loss"])
        if "eval_loss" in logs:
            self.eval_losses.append(logs["eval_loss"])

        # live update plot
        clear_output(wait=True)
        plt.figure(figsize=(8,5))
        if self.train_losses:
            plt.plot(self.train_losses, label="Train Loss")
        if self.eval_losses:
            plt.plot(self.eval_losses, label="Validation Loss")
        plt.xlabel("Logging Steps / Epochs")
        plt.ylabel("Loss")
        plt.title("Training & Validation Loss")
        plt.legend()
        plt.show()



In [64]:
from trl import SFTTrainer

loss_logger = LossLoggerCallback()
loss_plot_callback = LiveLossPlotCallback()


trainer = SFTTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    data_collator=collate_fn,
    peft_config=peft_config,
    processing_class=processor.tokenizer,
    callbacks=[loss_logger],          # attach callback

)
trainer.train()

The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'eos_token_id': 0, 'bos_token_id': 0, 'pad_token_id': 0}.
  return fn(*args, **kwargs)


Step,Training Loss
9,3.542913
18,0.369685
27,0.207489
36,0.16145
45,0.102865


  return fn(*args, **kwargs)
  return fn(*args, **kwargs)
  return fn(*args, **kwargs)
  return fn(*args, **kwargs)


TrainOutput(global_step=45, training_loss=0.8768803238868713, metrics={'train_runtime': 4065.0825, 'train_samples_per_second': 0.173, 'train_steps_per_second': 0.011, 'total_flos': 9.407768244091085e+16, 'train_loss': 0.8768803238868713, 'epoch': 5.0})

In [65]:
torch.cuda.empty_cache()
torch.cuda.synchronize()


print(f"GPU allocated memory: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
print(f"GPU reserved memory: {torch.cuda.memory_reserved() / 1024**3:.2f} GB")



GPU allocated memory: 4.63 GB
GPU reserved memory: 7.19 GB


In [66]:
trainer.save_model(training_args.output_dir)


## Validation

In [67]:
base_model = AutoModelForVision2Seq.from_pretrained(
    training_args.output_dir,
    device_map="auto",
    torch_dtype=torch.bfloat16,
    _attn_implementation=None,
)

processor = AutoProcessor.from_pretrained(model_id)

`torch_dtype` is deprecated! Use `dtype` instead!


Loading weights:   0%|          | 0/815 [00:00<?, ?it/s]

In [68]:
if USE_LORA:
    from peft import PeftModel
    model = PeftModel.from_pretrained(base_model, training_args.output_dir)



In [78]:
from contextlib import nullcontext
def generate_text_from_sample(model, processor, sample, max_new_tokens=100, device="cuda"):
    # Prepare the text input by applying the chat template
    text_input = processor.apply_chat_template(
        sample[:2],  # Use the sample without the assistant response
        add_generation_prompt=True
    )
    truth = sample[2]

    image_inputs = []
    image = sample[1]['content'][0]['image']

    if image is None or image.mode is None:
      return None, None


    if image.mode != 'RGB':
        image = image.convert('RGB')
    image_inputs.append([image])

    # Prepare the inputs for the model
    model_inputs = processor(
        #text=[text_input],
        text=text_input,
        images=image_inputs,
        return_tensors="pt",
    ).to(device)  # Move inputs to the specified device

    # Generate text with the model
    generated_ids = model.generate(**model_inputs, max_new_tokens=max_new_tokens)

    # Trim the generated ids to remove the input ids
    trimmed_generated_ids = [
      out_ids[len(in_ids):] for in_ids, out_ids in zip(model_inputs.input_ids, generated_ids)
    ]

    # Decode the output text
    output_text = processor.batch_decode(
        trimmed_generated_ids,
        skip_special_tokens=True,
        clean_up_tokenization_spaces=False
    )

    return truth, output_text[0]




In [77]:
from tqdm import tqdm
import os



path = os.path.join(PROJECT_ROOT, "results", "ibm-granite")
os.chdir(path)

csv_path = os.path.join(path, "generated_output.csv")
if not os.path.exists(csv_path):
    os.makedirs(path)
    with open(csv_path, "w", encoding="utf-8") as f:
        f.write("id,truth,pred\n")


df = pd.read_csv(csv_path)
# print(df)


# Loop through test dataset with progress bar
def collect_results(csv_path, dataset, df):
  with open(csv_path, "a", encoding="utf-8") as f:
      for i in tqdm(range(len(dataset)), desc="Generating predictions"):
          sample = test_dataset[i]
          # Extract truth from sample

          truth = sample[2]['content'][0]['text']

          if i in df['id']:
              continue

          # Generate prediction
          try:
            truth, pred = generate_text_from_sample(model, processor, sample)
            actual = truth['content'][0]['text']
          except:
            continue

          # Write output
          csv.writer(f).writerow([i, actual, pred])



## Eval

In [95]:
# EVAL on graph type: line, scatter, and bar
# grab annotations
import os
import pandas as pd
from datasets import Dataset, Image
from collections import defaultdict
import json
import csv
PROJECT_ROOT = "/content/drive/MyDrive/dl-project"

jsonl_file_path = os.path.join(PROJECT_ROOT, "My-First-Project-2/valid/", "annotations.jsonl")
data = {}

with open(jsonl_file_path, 'r', encoding='utf-8') as f:

    for line in f:
        try:
            json_obj = json.loads(line)
            image_name = json_obj.get("image")
            prefix = json_obj.get("prefix")
            suffix = json_obj.get("suffix")

            if image_name not in data:
                data[image_name] = {}
            data[image_name][prefix] = suffix
        except json.JSONDecodeError:
            print(f"Skipping invalid JSON line: {line.strip()}")
            continue

# print(data)

os.chdir(os.path.join(PROJECT_ROOT,"My-First-Project-2/valid/"))

with open("annotations_validation.csv", "w", newline="", encoding="utf-8") as f:
    writer = csv.writer(f)
    writer.writerow(["image_filename", "max", "min", "range", "title", "domain"])

    for img, info in data.items():
        writer.writerow([
            img,
            info.get("What is the maximum?", ""),
            info.get("What is the minimum?", ""),
            info.get("What is the range of the y-axis? Format as min-max (No spaces)", ""),
            info.get("What is the title?", ""),
            info.get("What is the domain?", "")
        ])

df_truth = pd.read_csv("annotations_validation.csv")


import os
from collections import defaultdict
import pandas as pd
from datasets import Dataset, Image as HFImage

annot_root = os.path.join(PROJECT_ROOT, "My-First-Project-2", "valid")
df = pd.read_csv(os.path.join(annot_root, "annotations_validation.csv"))

# caption column
df['caption'] = df.apply(
    lambda row: f"max: {row['max']}, min: {row['min']}, range: {row['range']}, title: {row['title']}, domain: {row['domain']}",
    axis=1
)

# --- collect image filenames by type ---
def list_images(path):
    return {
        f for f in os.listdir(path)
        if f.lower().endswith((".png", ".jpg", ".jpeg", ".gif", ".bmp", ".webp"))
    }

scatter_path = os.path.join(annot_root, "scatter")
bar_path     = os.path.join(annot_root, "bar")
line_path    = os.path.join(annot_root, "line")

scatter_images = list_images(scatter_path)
bar_images     = list_images(bar_path)
line_images    = list_images(line_path)

# --- build mapping: filename -> (graph_type, abs_path) ---
image_info = {}

for f in scatter_images:
    image_info[f] = ("scatter", os.path.join(scatter_path, f))

for f in bar_images:
    image_info[f] = ("bar", os.path.join(bar_path, f))

for f in line_images:
    image_info[f] = ("line", os.path.join(line_path, f))

# optional: warn about filenames in CSV not found in any folder
for fname in df["image_filename"].unique():
    if fname not in image_info:
        print("Missing image file:", fname)

# --- add columns in pandas ---
df["graph_type"] = df["image_filename"].apply(
    lambda fn: image_info.get(fn, (None, None))[0]
)
df["abs_image_path"] = df["image_filename"].apply(
    lambda fn: image_info.get(fn, (None, None))[1]
)

# --- convert to HF Dataset ---
hf_dataset = Dataset.from_pandas(df)

# cast & rename image column correctly
hf_dataset = hf_dataset.cast_column("abs_image_path", HFImage())
hf_dataset = hf_dataset.rename_column("abs_image_path", "image")

# --- split by graph_type (simpler than defaultdict) ---
scatter_dataset = hf_dataset.filter(lambda x: x["graph_type"] == "scatter")
bar_dataset     = hf_dataset.filter(lambda x: x["graph_type"] == "bar")
line_dataset    = hf_dataset.filter(lambda x: x["graph_type"] == "line")

scatter_dataset = [format_data(x) for x in scatter_dataset]
bar_dataset = [format_data(x) for x in bar_dataset]
line_dataset = [format_data(x) for x in line_dataset]



Filter:   0%|          | 0/40 [00:00<?, ? examples/s]

Filter:   0%|          | 0/40 [00:00<?, ? examples/s]

Filter:   0%|          | 0/40 [00:00<?, ? examples/s]

In [96]:
results_root = os.path.join(PROJECT_ROOT, "results", "ibm-granite")
os.makedirs(results_root, exist_ok=True)  # ensure directory exists

graph_datasets = {
    "scatter": scatter_dataset,
    "bar": bar_dataset,
    "line": line_dataset,
}

for graph_type, dataset in graph_datasets.items():
    csv_path = os.path.join(results_root, f"base_model_generated_output_{graph_type}.csv")

    if os.path.exists(csv_path):
        df = pd.read_csv(csv_path)
    else:
        # initialize empty CSV with desired columns
        df = pd.DataFrame(columns=["id", "truth", "pred"])
        df.to_csv(csv_path, index=False, encoding="utf-8")

    collect_results(csv_path, dataset, df)

Generating predictions: 100%|██████████| 8/8 [00:53<00:00,  6.70s/it]
Generating predictions: 100%|██████████| 16/16 [01:49<00:00,  6.82s/it]
Generating predictions: 100%|██████████| 16/16 [01:49<00:00,  6.83s/it]


In [86]:
PROJECT_ROOT = "/content/drive/MyDrive/dl-project/synthetic_dataset"

jsonl_file_path = os.path.join(PROJECT_ROOT, "annotations.jsonl")
data = {}

with open(jsonl_file_path, 'r', encoding='utf-8') as f:

    for line in f:
        try:
            json_obj = json.loads(line)
            image_name = json_obj.get("image")
            prefix = json_obj.get("prefix")
            suffix = json_obj.get("suffix")

            if image_name not in data:
                data[image_name] = {}
            data[image_name][prefix] = suffix
        except json.JSONDecodeError:
            print(f"Skipping invalid JSON line: {line.strip()}")
            continue

# print(data)


with open("syn_annotations.csv", "w", newline="", encoding="utf-8") as f:
    writer = csv.writer(f)
    writer.writerow(["image_filename", "max", "min", "range", "title", "domain"])

    for img, info in data.items():
        writer.writerow([
            img,
            info.get("What is the maximum?", ""),
            info.get("What is the minimum?", ""),
            info.get("What is the range of the y-axis? Format as min-max (No spaces)", ""),
            info.get("What is the title?", ""),
            info.get("What is the domain?", "")
        ])

df = pd.read_csv("syn_annotations.csv")


df['caption'] = df.apply(
    lambda row: f"max: {row['max']}, min: {row['min']}, range: {row['range']}, title: {row['title']}, domain: {row['domain']}",
    axis=1
)
image_base_dir_abs = PROJECT_ROOT
df['absolute_image_path'] = df['image_filename'].apply(
    lambda x: os.path.join(image_base_dir_abs, x)
)[:-4]
df = df[['absolute_image_path', 'caption']]

hf_dataset = Dataset.from_pandas(df)
hf_dataset = hf_dataset.cast_column("absolute_image_path", Image())
hf_dataset = hf_dataset.rename_column("absolute_image_path", "image")
syn_dataset = hf_dataset





{'image': <PIL.PngImagePlugin.PngImageFile image mode=RGBA size=501x501>,
 'caption': 'max: 3.101652, min: -1.750746, range: -2-4, title: Steps walked per day, domain: Society'}

In [87]:
csv_path = os.path.join(results_root, f"generated_output_synthetic.csv")

if os.path.exists(csv_path):
  df = pd.read_csv(csv_path)
else:
  # initialize empty CSV with desired columns
  df = pd.DataFrame(columns=["id", "truth", "pred"])
  df.to_csv(csv_path, index=False, encoding="utf-8")

collect_results(csv_path, syn_dataset, df)


Generating predictions: 100%|██████████| 60/60 [12:18<00:00, 12.30s/it]


In [91]:
model = AutoModelForVision2Seq.from_pretrained(
    model_id,
    device_map="auto",
    dtype=torch.bfloat16,
    quantization_config=bnb_config,
    _attn_implementation=None,
)



Loading weights:   0%|          | 0/815 [00:00<?, ?it/s]

In [90]:
csv_path = os.path.join(results_root, f"base_model_generated_output_synthetic.csv")

if os.path.exists(csv_path):
  df = pd.read_csv(csv_path)
else:
  # initialize empty CSV with desired columns
  df = pd.DataFrame(columns=["id", "truth", "pred"])
  df.to_csv(csv_path, index=False, encoding="utf-8")

collect_results(csv_path, syn_dataset, df)

Generating predictions: 100%|██████████| 60/60 [06:17<00:00,  6.30s/it]
