In [None]:
from dotenv import load_dotenv
from pathlib import Path
import sys


sys.path.append(Path("..").resolve().as_posix())
_ = load_dotenv()

In [None]:
from pydantic import BaseModel, HttpUrl
from typing import List, Optional
import requests
from datasets import Dataset, Image
from collections import defaultdict
import torch
import json

In [None]:
raw_tweets_path = Path("fashion_twitter_raw/raw_fashion_twitter.jsonl")
images_path = Path("fashion_twitter_raw/images").resolve()

## Create the dataset

In [None]:
class Recommendations(BaseModel):
    problems: List[str]
    fixes: List[str]
    positives: Optional[List[str]] = []
    advice: str

class Tweet(BaseModel):
    id: str
    image_url: HttpUrl
    description: str
    recommendations: Recommendations

In [None]:
with raw_tweets_path.open("r") as f:
    tweets = [Tweet.model_validate_json(line) for line in f.readlines()]

tweets[0]

In [None]:
tweets[0].model_dump_json(include=["description", "recommendations"])

In [None]:
images_path.mkdir(exist_ok=True, parents=True)

# Download each image
for tweet in tweets:
    response = requests.get(tweet.image_url)
    if response.status_code == 200:
        with images_path.joinpath(f"{tweet.id}.jpg").open("wb") as f:
            f.write(response.content)
        print(f"Downloaded {tweet.id}.jpg")
    else:
        print(f"Failed to download image from {tweet.image_url}")

print("All images downloaded.")

In [None]:
dataset_dict = defaultdict(list)

for tweet in tweets:
    dataset_dict["image"].append(images_path.joinpath(f"{tweet.id}.jpg").as_posix())
    dataset_dict["json"].append(tweet.model_dump_json(include=["description", "recommendations"]))

dataset_dict

In [None]:
dataset = Dataset.from_dict(dataset_dict).cast_column("image", Image())
dataset.save_to_disk("fashion_twitter")

## Create a preset

In [None]:
PROMPT = "extract JSON."

class ImageJSONCollatorWithPadding:

    def __init__(self, processor):
        self.processor = processor

    def __call__(self, examples):
        json_dicts = [json.loads(example["json"]) for example in examples]
        labels = [self.json2token(json_dict) for json_dict in json_dicts]

        images = [example["image"] for example in examples]

        images = [
            torch.cat([image, image, image], dim=0) if image.shape[0] == 1 else image
            for image in images
        ]

        texts = [PROMPT for _ in range(len(examples))]

        tokens = self.processor(
            text=texts,
            images=images,
            suffix=labels,
            return_tensors="pt",
            padding="longest",
        )
        return tokens

    def json2token(self, obj, sort_json_key: bool = True):
        """
        Convert an ordered JSON object into a token sequence
        """
        if type(obj) == dict:
            if len(obj) == 1 and "text_sequence" in obj:
                return obj["text_sequence"]
            else:
                output = ""
                if sort_json_key:
                    keys = sorted(obj.keys(), reverse=True)
                else:
                    keys = obj.keys()
                for k in keys:
                    output += rf"" + self.json2token(obj[k], sort_json_key) + rf""
                return output
        elif type(obj) == list:
            return r"".join([self.json2token(item, sort_json_key) for item in obj])
        else:
            obj = str(obj)
            return obj

In [None]:
from src.core.types import DataPreset

json_image_preset = DataPreset(
    train_test_split=0.2,
    path = "fashion_twitter",
    collator_cls=ImageJSONCollatorWithPadding,
)

## Train

In [None]:
from src.model_presets import paligemma_preset
from src.train_builder import build_trainer

In [None]:
trainer = build_trainer(**json_image_preset.as_kwargs(), **paligemma_preset.as_kwargs())

In [None]:
trainer.train()