# 準備資料

In [1]:
from datasets import load_dataset

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
dataset_id = "AlignmentLab-AI/gpt4vsent4"
train_dataset = load_dataset(dataset_id,split="train")


In [3]:
train_dataset

Dataset({
    features: ['caption', 'id', 'url', 'similarity', 'prompt_3', 'response_3', 'idx', 'prompt_1', 'prompt_2', 'prompt_4', 'response_1', 'response_2', 'response_4'],
    num_rows: 1227
})

In [4]:
from PIL import Image
import requests
from io import BytesIO

In [32]:
from PIL import Image


def convert_to_rgb(image):
    """Convert image to RGB format if not already in RGB."""
    if image.mode == "RGB":
        return image
    image_rgba = image.convert("RGBA")
    background = Image.new("RGBA", image_rgba.size, (255, 255, 255))
    alpha_composite = Image.alpha_composite(background, image_rgba)
    return alpha_composite.convert("RGB")


def reduce_image_size(image, scale=0.5):
    """Reduce image size by a given scale."""
    original_width, original_height = image.size
    new_width = int(original_width * scale)
    new_height = int(original_height * scale)
    return image.resize((new_width, new_height))

def download_image(url):
    image = None
    if isinstance(url, str) and url.startswith("http"):
        try:
            response = requests.get(url, timeout=5)
            response.raise_for_status()
            image = Image.open(BytesIO(response.content)).convert("RGB")
        except Exception as e:
            print(f"❌ Failed to download image: {url}\n  Reason: {e}")
            return None  # Skip this sample if the image fails to load
    return image  
def format_data(sample):
    image = download_image(sample["url"])
    if image:
        image = convert_to_rgb(image)  
        image = reduce_image_size(image)
    return {
        "messages": [
            {"role": "system", "content": "You are chatGPT, a large language model trained by OpenAI."},
            {
                "role": "user",
                "content": [
                    {
                        "type": "text",
                        "text": sample["prompt_1"],
                    },
                    {
                        "type": "image",
                        "image": image,  
                    }
                ],
            },
            {"role": "assistant", "content": sample["response_1"]},
        ],
    }
def process_vision_info(messages: list[dict]) -> list[Image.Image]:
    image_inputs = []
    # Iterate through each conversation
    
    for msg in messages:
        # Get content (ensure it's a list)
        content = msg.get("content", [])
        if not isinstance(content, list):
            content = [content]

        # Check each content element for images
        for element in content:
            if isinstance(element, dict) and (
                "image" in element or element.get("type") == "image"
            ):
                # Get the image and convert to RGB
                if "image" in element:
                    image = element["image"]
                else:
                    image = element
                image_inputs.append(image.convert("RGB"))
    return image_inputs

def dataset_clean_none_image(dataset):
    return [sample for sample in dataset if sample["messages"][0]["content"][1]["image"] is not None]
    

In [7]:
# Convert dataset to OAI messages
# need to use list comprehension to keep Pil.Image type, .mape convert image to bytes
dataset = [format_data(sample) for sample in train_dataset]
dataset = dataset_clean_none_image(dataset)

print(dataset[105]["messages"])

❌ Failed to download image: https://www.exelmansgalerie.be/wp-content/uploads/2014/10/noyelle-jos-exelmans-galerie-kunstgalerie-beeldentuin-belgie-462x490.jpg
  Reason: 403 Client Error: Forbidden for url: https://www.exelmansgalerie.be/wp-content/uploads/2014/10/noyelle-jos-exelmans-galerie-kunstgalerie-beeldentuin-belgie-462x490.jpg
❌ Failed to download image: https://cdn.xxl.thumbs.canstockphoto.com/canstock26184371.jpg
  Reason: HTTPSConnectionPool(host='cdn.xxl.thumbs.canstockphoto.com', port=443): Max retries exceeded with url: /canstock26184371.jpg (Caused by SSLError(SSLCertVerificationError(1, '[SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1007)')))
❌ Failed to download image: https://media.istockphoto.com/photos/smiling-woman-sitting-in-chair-picture-id184321060
  Reason: 400 Client Error: Bad Request for url: https://media.istockphoto.com/photos/smiling-woman-sitting-in-chair-picture-id184321060
❌ Failed to download image: https:

TypeError: string indices must be integers

In [34]:
dataset[3]["messages"]

[{'role': 'system',
  'content': 'You are chatGPT, a large language model trained by OpenAI.'},
 {'role': 'user',
  'content': [{'type': 'text',
    'text': 'Describe this image in up to two paragraphs? Specify any objects within the image backgrounds scenery interactions and gestures or poses. If they are multiple of any object please specify how many. Is there text in the image and if so what does it say? If there is any lighting in the image can you identify where it is and what it looks like? What style is the image? If there are people or characters in the image what emotions are they conveying? Please keep your descriptions terse but complete. DO NOT add any unnecessary speculation about the things that are not part of the image such as DO NOT add things as these descriptions are interpretations and not a part of the image itself.'},
   {'type': 'image', 'image': <PIL.Image.Image image mode=RGB size=300x450>}]},
 {'role': 'assistant',
  'content': "The image displays a close-up p

In [9]:
import torch
from transformers import AutoProcessor, AutoModelForImageTextToText, BitsAndBytesConfig

In [10]:
# Hugging Face model id
model_id = "google/gemma-3-4b-pt" 

# Check if GPU benefits from bfloat16
if torch.cuda.get_device_capability()[0] < 8:
    raise ValueError("GPU does not support bfloat16, please use a GPU that supports bfloat16.")

# Define model init arguments
model_kwargs = dict(
    attn_implementation="eager", # Use "flash_attention_2" when running on Ampere or newer GPU
    torch_dtype=torch.bfloat16, # What torch dtype to use, defaults to auto
    device_map="auto", # Let torch decide how to load the model
)

# BitsAndBytesConfig int-4 config
model_kwargs["quantization_config"] = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=model_kwargs["torch_dtype"],
    bnb_4bit_quant_storage=model_kwargs["torch_dtype"],
)

# Load model and tokenizer
model = AutoModelForImageTextToText.from_pretrained(model_id, **model_kwargs)
processor = AutoProcessor.from_pretrained("google/gemma-3-4b-it",use_fast=True)

Loading checkpoint shards: 100%|██████████| 2/2 [00:03<00:00,  1.61s/it]


In [29]:
from trl import SFTConfig

args = SFTConfig(
    output_dir="gemma-product-description",     # directory to save and repository id
    num_train_epochs=1,                         # number of training epochs
    per_device_train_batch_size=1,              # batch size per device during training
    gradient_accumulation_steps=4,              # number of steps before performing a backward/update pass
    gradient_checkpointing=True,                # use gradient checkpointing to save memory
    optim="adamw_torch_fused",                  # use fused adamw optimizer
    logging_steps=5,                            # log every 5 steps
    save_strategy="epoch",                      # save checkpoint every epoch
    learning_rate=2e-4,                         # learning rate, based on QLoRA paper
    bf16=True,                                  # use bfloat16 precision
    max_grad_norm=0.3,                          # max gradient norm based on QLoRA paper
    warmup_ratio=0.03,                          # warmup ratio based on QLoRA paper
    lr_scheduler_type="constant",               # use constant learning rate scheduler
    push_to_hub=False,                           # push model to hub
    report_to="tensorboard",                    # report metrics to tensorboard
    gradient_checkpointing_kwargs={
        "use_reentrant": False
    },  # use reentrant checkpointing
    dataset_text_field="",                      # need a dummy field for collator
    dataset_kwargs={"skip_prepare_dataset": True},  # important for collator
)
args.remove_unused_columns = False # important for collator

# Create a data collator to encode text and image pairs
def collate_fn(examples):
    texts = []
    images = []
    for example in examples:
        
        image_inputs = process_vision_info(example["messages"])
        text = processor.apply_chat_template(
            example["messages"], add_generation_prompt=False, tokenize=False
        )
        texts.append(text.strip())
        images.append(image_inputs)

    # Tokenize the texts and process the images
    batch = processor(text=texts, images=images, return_tensors="pt", padding=True)

    # The labels are the input_ids, and we mask the padding tokens and image tokens in the loss computation
    labels = batch["input_ids"].clone()

    # Mask image tokens
    image_token_id = [
        processor.tokenizer.convert_tokens_to_ids(
            processor.tokenizer.special_tokens_map["boi_token"]
        )
    ]
    # Mask tokens for not being used in the loss computation
    labels[labels == processor.tokenizer.pad_token_id] = -100
    labels[labels == image_token_id] = -100
    labels[labels == 262144] = -100

    batch["labels"] = labels
    return batch


In [12]:
from peft import LoraConfig

peft_config = LoraConfig(
    lora_alpha=16,
    lora_dropout=0.05,
    r=16,
    bias="none",
    target_modules="all-linear",
    task_type="CAUSAL_LM",
    modules_to_save=[
        "lm_head",
        "embed_tokens",
    ],
)

In [30]:
from trl import SFTTrainer

trainer = SFTTrainer(
    model=model,
    args=args,
    train_dataset=dataset,
    peft_config=peft_config,
    processing_class=processor,
    data_collator=collate_fn,
)

No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


In [31]:
# Start training, the model will be automatically saved to the Hub and the output directory
trainer.train()

# Save the final model again to the Hugging Face Hub
trainer.save_model()

You are chatGPT, a large language model trained by OpenAI.
[{'type': 'text', 'text': 'Describe this image in up to two paragraphs? Specify any objects within the image backgrounds scenery interactions and gestures or poses. If they are multiple of any object please specify how many. Is there text in the image and if so what does it say? If there is any lighting in the image can you identify where it is and what it looks like? What style is the image? If there are people or characters in the image what emotions are they conveying? Please keep your descriptions terse but complete. DO NOT add any unnecessary speculation about the things that are not part of the image such as DO NOT add things as these descriptions are interpretations and not a part of the image itself.'}, {'type': 'image', 'image': <PIL.Image.Image image mode=RGB size=225x160 at 0x7FDB104518D0>}]
The image features a single adult person presented in a profile view against a plain, light-colored background. The person has 

AttributeError: 'NoneType' object has no attribute 'convert'

In [60]:
del model
#del trainer
torch.cuda.empty_cache()

In [6]:
import torch

# Load Model with PEFT adapter
model = AutoModelForImageTextToText.from_pretrained(
  args.output_dir,
  device_map="cuda:0",
  torch_dtype=torch.bfloat16,
  attn_implementation="eager",
)
processor = AutoProcessor.from_pretrained(args.output_dir)

Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00,  1.93it/s]


In [17]:
from transformers import GenerationConfig
generation_config = GenerationConfig.from_pretrained("google/gemma-3-4b-pt")
generation_config.cache_implementation = "dynamic"

In [54]:

img = Image.open(requests.get("https://m.media-amazon.com/images/I/81+7Up7IWyL._AC_SY300_SX300_.jpg", stream=True).raw).convert("RGB")
messages = [
        {
            "role": "system",
            "content": [{"type": "text", "text": "Describe this image in up to two paragraphs? Specify any objects within the image backgrounds scenery interactions and gestures or poses. If they are multiple of any object please specify how many. Is there text in the image and if so what does it say? If there is any lighting in the image can you identify where it is and what it looks like? What style is the image? If there are people or characters in the image what emotions are they conveying? Please keep your descriptions terse but complete. DO NOT add any unnecessary speculation about the things that are not part of the image such as DO NOT add things as these descriptions are interpretations and not a part of the image itself."}],
        },
        {
          "role": "user",
          "content": [
                {"type": "image", "image": img},

            ],
        },
    ]
text = processor.apply_chat_template(
    messages, tokenize=False, add_generation_prompt=True
)
# Process the image and text
image_inputs = process_vision_info(messages)
# Tokenize the text and process the images
inputs = processor(
    text=[text],
    images=image_inputs,
    padding=True,
    return_tensors="pt",
)
# Move the inputs to the device
inputs = inputs.to(model.device)

# Generate the output
stop_token_ids = [processor.tokenizer.eos_token_id, processor.tokenizer.convert_tokens_to_ids("<end_of_turn>")]
generated_ids = model.generate(
    **inputs,
    max_new_tokens=100,
    temperature=0.9,
    top_p=0.95,
    do_sample=True,
    pad_token_id=processor.tokenizer.pad_token_id,
    eos_token_id=processor.tokenizer.eos_token_id,
    generation_config=generation_config,
)
input_len = inputs["input_ids"].shape[-1]
generation = generated_ids[0][input_len:]
decoded = processor.decode(generation, skip_special_tokens=True)
print(decoded)




In [65]:
messages = [
{
    "role": "system",
    "content": [{"type": "text", "text":"what can you see"}],
},
{
  "role": "user",
  "content": [
        {"type": "image", "image": "https://m.media-amazon.com/images/I/81+7Up7IWyL._AC_SY300_SX300_.jpg"},
    ],
},
]
inputs = processor.apply_chat_template(
    messages, add_generation_prompt=True, tokenize=True,
    return_dict=True, return_tensors="pt"
).to(model.device, dtype=torch.bfloat16)

with torch.inference_mode():
    outputs = model.generate(
        **inputs,
        max_new_tokens=1024,
        temperature=0.3,
        top_p=0.95,
        do_sample=True,
        pad_token_id=processor.tokenizer.pad_token_id,
        eos_token_id=processor.tokenizer.eos_token_id
    )
print(processor.tokenizer.decode(outputs[0], skip_special_tokens=True))


user
what can you see






model





























































































































































































































































































































































































































































































user


































