In [1]:
import sys
import torch
import pandas as pd
from IPython.display import display

sys.path.append("../lib")

device = torch.device("cuda")
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

### Step 1: Loading the model

We prepared our custom Steam dataset with over 80,000 games and corresponding images for each of them. We can read this dataset from the `data` directory.

In [2]:
import msgspec
from steam_model import SteamGame
from pathlib import Path

data_root = Path("./data/steam")
steam_games = data_root / "games.json"

games = msgspec.json.decode(steam_games.read_text(), type=list[SteamGame])

excluded_genres = set([
    "Animation & Modeling",
    "Audio Production",
    "Design & Illustration",
    "Photo Editing",
    "Software Training",
    "Utilities",
    "Video Production",
    "Web Publishing"
])

dataset = [
    {
        "id": game.id,
        "title": game.page_information.title,
        "image": image,
        "genres": ",".join(game.page_information.genres)
    }
    for game
    in games[:10000]
    if len(game.page_information.images) > 0
        and game.popularity_rank < 5000
        and len(game.page_information.genres) > 0
        and next((genre for genre in game.page_information.genres if genre in excluded_genres), None) is None
    for image in game.page_information.images
]
df = pd.DataFrame.from_records(dataset)
df.head()

Unnamed: 0,id,title,image,genres
0,553850,HELLDIVERS™ 2,images/553850/ss_0c79f56fc7be1bd0102f2ca1c92c8...,Action
1,553850,HELLDIVERS™ 2,images/553850/ss_33e684e9cb2517af1599f0ca2b57d...,Action
2,553850,HELLDIVERS™ 2,images/553850/ss_8949ed7dd24a02d5ea13b08fc5c04...,Action
3,553850,HELLDIVERS™ 2,images/553850/ss_50afbbc4d811c38fe9f64c1fc8d7e...,Action
4,553850,HELLDIVERS™ 2,images/553850/ss_cb276fe9f0b09683bdbc496f82b40...,Action


Let's define a few convenience functions to easily get an image for a given game:

In [3]:
import typing
from transformers import CLIPProcessor, CLIPModel

model_id = "openai/clip-vit-large-patch14"
processor_model_id = model_id

processor = typing.cast(CLIPProcessor, CLIPProcessor.from_pretrained(processor_model_id))
model = typing.cast(CLIPModel, CLIPModel.from_pretrained(model_id))
model

CLIPModel(
  (text_model): CLIPTextTransformer(
    (embeddings): CLIPTextEmbeddings(
      (token_embedding): Embedding(49408, 768)
      (position_embedding): Embedding(77, 768)
    )
    (encoder): CLIPEncoder(
      (layers): ModuleList(
        (0-11): 12 x CLIPEncoderLayer(
          (self_attn): CLIPAttention(
            (k_proj): Linear(in_features=768, out_features=768, bias=True)
            (v_proj): Linear(in_features=768, out_features=768, bias=True)
            (q_proj): Linear(in_features=768, out_features=768, bias=True)
            (out_proj): Linear(in_features=768, out_features=768, bias=True)
          )
          (layer_norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (mlp): CLIPMLP(
            (activation_fn): QuickGELUActivation()
            (fc1): Linear(in_features=768, out_features=3072, bias=True)
            (fc2): Linear(in_features=3072, out_features=768, bias=True)
          )
          (layer_norm2): LayerNorm((768,), eps=1e-05,

In [4]:
import PIL.Image
from typing import TypedDict
from torch.utils.data import Dataset

class SteamGenreItem(TypedDict):
    image: PIL.Image.Image
    genres: str

class SteamGenresDataset(Dataset):
    def __init__(self, root: Path, dataset: pd.DataFrame, processor: CLIPProcessor):
        self.root = root
        self.dataset = dataset
        self.processor = processor

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx: int) -> SteamGenreItem:
        item = self.dataset.iloc[idx]
        image_path = self.root / item["image"]
        image = PIL.Image.open(image_path)
        genres = item["genres"]

        return { "image": image, "genres": genres, }

def collate_fn(batch: list[SteamGenreItem]):
    def toinp(text: str) -> torch.Tensor:
        input_ids = processor.tokenizer(text, padding="max_length", truncation=True).input_ids
        return input_ids

    def topix(image: PIL.Image.Image) -> torch.Tensor:
        pixel_values = processor.image_processor(image, return_tensors="pt").pixel_values
        return pixel_values.squeeze()

    return {
        'pixel_values': torch.stack([topix(x["image"]) for x in batch]),
        'input_ids': torch.tensor([toinp(x["genres"]) for x in batch]),
        "return_loss": True
    }

In [5]:
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split

# Split our data into training and validation sets
train_df, eval_df = train_test_split(df, test_size=0.2, random_state=42)

# Create and verify that the dataset works as we expect
train_dataset = SteamGenresDataset(root=data_root, dataset=train_df, processor=processor)
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=1)
eval_dataset = SteamGenresDataset(root=data_root, dataset=eval_df, processor=processor)
eval_dataloader = DataLoader(eval_dataset, shuffle=True, batch_size=1)

In [6]:
collate_fn([train_dataset[0]])

{'pixel_values': tensor([[[[-1.7631, -1.7923, -1.7923,  ..., -1.6171, -1.6171, -1.6317],
           [-1.7923, -1.7923, -1.7923,  ..., -1.6171, -1.6171, -1.6317],
           [-1.7923, -1.7923, -1.7923,  ..., -1.6171, -1.6171, -1.6317],
           ...,
           [-1.7631, -1.7631, -1.7485,  ..., -1.7193, -1.7193, -1.7193],
           [-1.7631, -1.7777, -1.7777,  ..., -1.7193, -1.7193, -1.7193],
           [-1.7339, -1.7777, -1.7777,  ..., -1.7339, -1.7339, -1.7193]],
 
          [[-1.1668, -1.1818, -1.1668,  ..., -1.3769, -1.3769, -1.3919],
           [-1.1818, -1.1518, -1.1368,  ..., -1.3769, -1.3769, -1.3919],
           [-1.1668, -1.1218, -1.1518,  ..., -1.3769, -1.3769, -1.3919],
           ...,
           [ 0.1539,  0.2890,  0.2589,  ..., -1.6621, -1.6621, -1.6621],
           [ 0.1989,  0.3340,  0.3190,  ..., -1.6621, -1.6621, -1.6621],
           [ 0.1839,  0.2740,  0.2439,  ..., -1.6771, -1.6771, -1.6621]],
 
          [[-0.5559, -0.5275, -0.5133,  ..., -1.0110, -1.0110, -1.0252

In [7]:
from transformers import TrainingArguments, Trainer, IntervalStrategy

training_args = TrainingArguments(
    output_dir="steam_genre_classifier",
    remove_unused_columns=False,
    evaluation_strategy=IntervalStrategy.STEPS,
    save_strategy=IntervalStrategy.STEPS,
    save_total_limit=10,
    eval_steps=250,
    save_steps=250,
    learning_rate=5e-5,
    weight_decay=0.01,
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    gradient_accumulation_steps=4,
    num_train_epochs=3,
    warmup_ratio=0.1,
    logging_steps=10,
    tf32=True,
    bf16=True,
    optim="adamw_bnb_8bit",
    # Comment these out when testing new models to check for CUDA OOM errors quicker, but for debugging, it just makes things more difficult
    torch_compile=True,
    torch_compile_backend="inductor",
    torch_compile_mode="reduce-overhead"
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    data_collator=collate_fn,
    tokenizer=processor.image_processor
)


In [None]:
trainer.train()