# Import images and fine-tune PaliGemma

In [1]:
from dotenv import load_dotenv
from pathlib import Path
import sys


sys.path.append(Path("..").resolve().as_posix())
_ = load_dotenv()

In [None]:
target_images_path = Path("fashion_twitter_raw/images").resolve()
target_annotations_path = Path("fashion_twitter_raw/fashion_twitter_converted.jsonl")

## Convert annotations

In [None]:
from pydantic import BaseModel, HttpUrl, Field
from typing import List, Optional, Union
import requests
import json

In [None]:
raw_tweets_path = Path("fashion_twitter_raw/fashion_twitter_raw.jsonl")


class Recommendations(BaseModel):
    problems: List[str] = Field(default_factory=list)
    fixes: List[str]= Field(default_factory=list)
    positives: List[str] = Field(default_factory=list)
    advice: Union[str, List[str]] = Field(default_factory=list)
    item_description: Union[str, List[str]] = Field(default_factory=list)
    wearing_suggestions: Union[str, List[str]] = Field(default_factory=list)


class Tweet(BaseModel):
    id: str
    image_url: HttpUrl
    description: Optional[str]
    recommendations: Recommendations


with raw_tweets_path.open("r") as f:
    tweets = [Tweet.model_validate_json(line) for line in f.readlines()]

In [None]:
for tweet in tweets:
    if type(tweet.recommendations.advice) is not list:
        tweet.recommendations.advice = [tweet.recommendations.advice]
    if type(tweet.recommendations.item_description) is not list:
        tweet.recommendations.item_description = [tweet.recommendations.item_description]
    if type(tweet.recommendations.wearing_suggestions) is not list:
        tweet.recommendations.wearing_suggestions = [tweet.recommendations.wearing_suggestions]

In [None]:
# 1. merge advice, fixes, wearing suggestions
# 2. merge description and item description
# 3. convert recommendations into description, strengths, flaws, advice


class TargetLabel(BaseModel):
    descriptions: List[str] = Field(default_factory=list)
    strengths: List[str] = Field(default_factory=list)
    flaws: List[str] = Field(default_factory=list)
    advice: List[str] = Field(default_factory=list)

    @classmethod
    def from_tweet(cls, tweet: Tweet) -> "TargetLabel":
        return cls(
            descriptions=[tweet.description] + tweet.recommendations.item_description,
            strengths=tweet.recommendations.positives,
            flaws=tweet.recommendations.problems,
            advice=tweet.recommendations.fixes
            + tweet.recommendations.advice
            + tweet.recommendations.wearing_suggestions,
        )

labels = [TargetLabel.from_tweet(tweet) for tweet in tweets]
len(labels)

In [None]:
target_images_path.mkdir(exist_ok=True, parents=True)

# Download each image
for tweet in tweets:
    response = requests.get(tweet.image_url)
    if response.status_code == 200:
        with target_images_path.joinpath(f"{tweet.id}.jpg").open("wb") as f:
            f.write(response.content)
        print(f"Downloaded {tweet.id}.jpg")
    else:
        print(f"Failed to download image from {tweet.image_url}")

print("All images downloaded.")

In [None]:
with target_annotations_path.open("w") as f:
    for tweet, label in zip(tweets, labels):
        json_dict = {
            "image": target_images_path.joinpath(f"{tweet.id}.jpg").as_posix(),
            "json": label.model_dump_json(),
        }
        f.write(f"{json.dumps(json_dict)}\n")

## Import dataset

In [None]:
from training_toolkit import ImageJSONImporter

In [None]:
image_importer = ImageJSONImporter()
dataset = image_importer("fashion_twitter", target_images_path.as_posix(), target_annotations_path.as_posix())
dataset.save_to_disk("fashion_twitter")

## Train the model

In [2]:
from training_toolkit import build_trainer, paligemma_image_preset, image_json_preset

A preset is a Pydantic model instance that contains default parameters for training.
We can access those parameters directly as properties in order to change them.

In order to check what exactly goes into the trainer, we can install Rich.

Set up the trainer by passing necessary arguments into the `build_trainer` function.

In [None]:
paligemma_image_preset.training_args["num_train_epochs"] = 8
paligemma_image_preset.training_args["eval_strategy"] = "no"


trainer = build_trainer(
    **image_json_preset.with_path("fashion_twitter").as_kwargs(apply_train_test_split=False),
    **paligemma_image_preset.as_kwargs()
)

In [None]:
trainer.train()