## Training classifier

Using SmolVLM for extracting text and image embeddings

Can expand to Qwen2.5VM if needed

In [1]:
!nvcc --version

nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2025 NVIDIA Corporation
Built on Fri_Feb_21_20:42:46_Pacific_Standard_Time_2025
Cuda compilation tools, release 12.8, V12.8.93
Build cuda_12.8.r12.8/compiler.35583870_0


In [1]:
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor, AutoModelForImageTextToText
from qwen_vl_utils import process_vision_info
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import TrainingArguments, Trainer
from tqdm import tqdm
import json
import base64
import os

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
data_path = "./data/2. Sorted/Abandoned Trolleys/Other Trolleys/4156788.0.full.jpeg"

In [None]:
data_path = "./data/2. Sorted/Roads & Footprints/Faulty Steetlight/5174732.0.full.jpeg"

In [None]:
data_path = "./data/2. Sorted/Housing/Playground & Fitness Facilities Maintenance/4223509.0.full.jpeg"

## SmolVLM2 Instruct Model

In [2]:
model_path = "HuggingFaceTB/SmolVLM2-2.2B-Instruct"
processor = AutoProcessor.from_pretrained(model_path)
model = AutoModelForImageTextToText.from_pretrained(
    model_path,
    torch_dtype=torch.bfloat16,
    _attn_implementation="flash_attention_2"
).to("cuda")

Fetching 2 files: 100%|██████████| 2/2 [05:38<00:00, 169.25s/it]


ImportError: FlashAttention2 has been toggled on, but it cannot be used due to the following error: the package flash_attn seems to be not installed. Please refer to the documentation of https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-2 to install Flash Attention 2.

## Data Loader

In [None]:
# For example, a flattened version of your label tree:
ALL_CLASSES = [
    "Cold Storage", "FairPrice", "Giant", "Ikea", "Mustafa", "Other Trolleys", "ShengSiong",
    "Bird Issues", "Cat Issues", "Dead Animal", "Dog Issues", "Injured Animal", "Other Animal Issues",
    "Bulky Waste in Common Areas", "Dirty Public Areas", "High-rise Littering", "Overflowing Litter Bin",
    "Construction Noise",
    "Choked Drain or Stagnant Water", "Damaged Drain", "Flooding", "Sewage Smell", "Sewer Choke or Overflow",
    "No Water", "Water Leak", "Water Pressure", "Water Quality",
    "Common Area Maintenance", "HDB Car Park Maintenance", "Lightning Maintenance", "Playground & Fitness Facilities Maintenance",
    "HDB or URA Car Park", "Motorcycle at Void Deck", "Road",
    "Others",  # this is often used as a catch-all category
    "Fallen Tree or Branch", "Other Parks and Greenery Issues", "Overgrown Grass", "Park Facilities Maintenance", "Park Lighting Maintenance",
    "Bee & Hornets", "Cockroaches in Food Establishment", "Mosquitoes", "Rodents in Common Areas", "Rodents in Food Establishment",
    "Covered Linkway Maintenance", "Damaged Road Signs", "Faulty Streetlight", "Footpath Maintenance", "Road Maintenance",
    "Anywheel", "HelloRide", "Other Bicycles",
    "Food Premises", "Other Public Areas", "Parks & Park Connectors"
]

label_to_idx = {label: idx for idx, label in enumerate(ALL_CLASSES)}
NUM_CLASSES = len(ALL_CLASSES)

data_root = "./data/2. Sorted"  # Adjust to your data folder path
label_vector_file = "./label_vectors.json"  # Path to your label vector JSON file

# System prompt for the model
system_prompt = (
    "You are an expert in municipal services issues. Your task is to analyze the provided input, "
    "which may include an image and a description, and categorize the issue into one or more categories "
    "from the predefined list of municipal service issue types. Additionally, assess the severity of the issue "
    "as one of the following: Low, Medium, or High.\n\n"
    "The predefined list of categories is as follows:\n"
    + "\n".join(f"- {category}" for category in ALL_CLASSES) +
    "\n\nYour response should be in the following JSON format:\n"
    "{\n"
    "    \"categories\": [categories],\n"
    "    \"Severity\": severity\n"
    "}\n\n"
    "Ensure that the categories are selected from the provided list of issue types, and the severity is determined "
    "based on the details provided in the input."
)


### Think have to redo this a bit
Thought collate_fn in the finetuning example is the data loader. collate_fn which is kinda what i made here is the batching function. Data loader just returns the dataset raw data

In [None]:
# Define a training dataset
class FixMyStreetDataset(Dataset):
    def __init__(self, data_root, label_vector_file):
        """
        data_root: Folder containing the image and JSON files.
        label_vector_file: Path to the JSON file mapping report IDs to label lists.
        """
        self.data_root = data_root
        
        # Load label mapping (file names without extension as keys)
        with open(label_vector_file, 'r', encoding='utf-8') as f:
            self.label_vector = json.load(f)
            
        self.report_ids = list(self.label_vector.keys())

    def __len__(self):
        return len(self.report_ids)

    def __getitem__(self, idx):
        report_id = self.report_ids[idx]
        
        # Get label vector entry for the current report ID
        label_entry = self.label_vector[report_id]

        image_path = label_entry["image_path"]

        # Load the JSON metadata
        relative_path = os.path.relpath(image_path, start="./data/2. Sorted")  # Adjust to relative path
        json_path = os.path.join(self.data_root, relative_path)
        json_path = os.path.splitext(json_path)[0] + ".json"  # Replace the image extension with .json
        with open(json_path, 'r', encoding="utf-8") as f:
            metadata = json.load(f)

        # Prepare the text content
        text = metadata["description"] + "\n\n"
        text += "Nearby location tags: " + ", ".join([f"{k}: {v}" for tag in metadata["tags"]["nearby"] for k, v in tag.items()]) + "\n\n"
        text += "Enclosing location tags: " + ", ".join([f"{k}: {v}" for tag in metadata["tags"]["enclosing"] for k, v in tag.items()])
        
        # Check if the image exists
        if image_path:
            # Load the image
            image = Image.open(image_path)
            if image.mode != "RGB":
                image = image.convert("RGB")
        else:
            # Placeholder for missing image
            image = torch.zeros((3, 224, 224), dtype=torch.uint8)


        return text, image
        

In [None]:
# Batching function
def collate_fn(data):
    print(data)
    '''
    user_content = [{"type": "text", "text": input_text}]

    messages = [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": user_content},
            {"role": "assistant", "content": answer}
        ]

        # Process the inputs using the processor
        text = processor.apply_chat_template(
            messages,
            add_generation_prompt=False,
        )
    '''

In [None]:
# Create dataset and dataloader
dataset = FixMyStreetDataset(data_root=data_root,
                             label_vector_file=label_vector_file)

data_loader = DataLoader(dataset, batch_size=8, shuffle=True, collate_fn=collate_fn)

In [None]:
dir(dataset)

In [None]:
split_ds = dataset["validation"].train_test_split(test_size=0.5)
train_ds = split_ds["train"]

## Fine-tuning

In [None]:
model_name = model_path.split("/")[-1]

training_args = TrainingArguments(
    num_train_epochs=1,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,
    warmup_steps=50,
    learning_rate=1e-4,
    weight_decay=0.01,
    logging_steps=25,
    save_strategy="steps",
    save_steps=250,
    save_total_limit=1,
    optim="paged_adamw_8bit", # for 8-bit, keep this, else adamw_hf
    bf16=True, # underlying precision for 8bit
    output_dir=f"./{model_name}-vqav2",
    hub_model_id=f"{model_name}-vqav2",
    report_to="tensorboard",
    remove_unused_columns=False,
    gradient_checkpointing=True
)

In [None]:
trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=collate_fn,
    train_dataset=train_ds,
)

## COMPLETE LABEL BULLSHIT

In [None]:
# System prompt for the model
system_prompt = (
    "You are an expert in municipal services issues. Your task is to analyze the provided input, "
    "which may include an image and a description, and categorize the issue into one or more categories "
    "from the predefined list of municipal service issue types. Additionally, assess the severity of the issue "
    "as one of the following: Low, Medium, or High.\n\n"
    "The predefined list of categories is as follows:\n"
    + "\n".join(f"- {category}" for category in ALL_CLASSES) +
    "\n\nYour response should be in the following JSON format:\n"
    "{\n"
    "    \"categories\": [categories],\n"
    "    \"Severity\": severity\n"
    "}\n\n"
    "Ensure that the categories are selected from the provided list of issue types, and the severity is determined "
    "based on the details provided in the input."
)

data_root = "./data/2. Sorted"
label_vector_file = "./label_vectors.json"

# iterate through label_vector_file and get the image path and json path
with open(label_vector_file, 'r', encoding='utf-8') as f:
    label_vector = json.load(f)

    for report_id, label_entry in label_vector.items():
        # Load the JSON metadata
        # get path by accessing image_path and replacing extension with .json
        json_path = os.path.splitext(label_entry["image_path"])[0] + ".json"
        with open(json_path, 'r', encoding="utf-8") as f:
            metadata = json.load(f)

        # Prepare the text content for the processor
        input_text = metadata["description"] + "\n\n"
        input_text += "Nearby location tags: " + ", ".join([f"{k}: {v}" for tag in metadata["tags"]["nearby"] for k, v in tag.items()]) + "\n\n"
        input_text += "Enclosing location tags: " + ", ".join([f"{k}: {v}" for tag in metadata["tags"]["enclosing"] for k, v in tag.items()])
        user_content = [{"type": "text", "text": input_text}]

        # Load the image
        image_path = label_entry["image_path"]
        if image_path:
            image = Image.open(image_path)
            if image.mode != "RGB":
                image = image.convert("RGB")
        else:
            image = torch.zeros((3, 224, 224), dtype=torch.uint8)


## Qwen2.5VM stuff (dont touch for now)

In [None]:
# @title inference function
def inference(image_path, prompt, sys_prompt="You are a helpful assistant.", max_new_tokens=4096, return_input=False):
    image = Image.open(image_path)
    image_local_path = "file://" + image_path
    messages = [
        {"role": "system", "content": sys_prompt},
        {"role": "user", "content": [
                {"type": "text", "text": prompt},
                {"image": image_local_path},
            ]
        },
    ]
    text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    print("text:", text)
    # image_inputs, video_inputs = process_vision_info([messages])
    inputs = processor(text=[text], images=[image], padding=True, return_tensors="pt")
    inputs = inputs.to('cuda')

    output_ids = model.generate(**inputs, max_new_tokens=max_new_tokens)
    generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids in zip(inputs.input_ids, output_ids)]
    output_text = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
    if return_input:
        return output_text[0], inputs
    else:
        return output_text[0]
    
#  base 64 编码格式
def encode_image(image_path):
    with open(image_path, "rb") as image_file:
        return base64.b64encode(image_file.read()).decode("utf-8")

In [None]:
model_path = "Qwen/Qwen2.5-VL-7B-Instruct"
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(model_path, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2",device_map="auto")
processor = AutoProcessor.from_pretrained(model_path)

In [None]:
messages = [
    {
        "role": "user",
        "content": [
            {
                "type": "image",
                "image": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg",
            },
            {"type": "text", "text": "Describe this image."},
        ],
    }
]

# Preparation for inference
text = processor.apply_chat_template(
    messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = processor(
    text=[text],
    images=image_inputs,
    videos=video_inputs,
    padding=True,
    return_tensors="pt",
)
inputs = inputs.to(model.device)

# Inference: Generation of the output
generated_ids = model.generate(**inputs, max_new_tokens=128)
generated_ids_trimmed = [
    out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = processor.batch_decode(
    generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
print(output_text)