# Steam Palette Extractor
Reference: https://github.com/woctezuma/steam-palette-extractor

## Install Python packages

In [None]:
%pip install -qq img2dataset mediapy tqdm

## Download images from Steam (only the first time)

In [None]:
IMG_NAME = 'capsule_616x353.jpg'

def get_image_url(app_id, img_name=IMG_NAME):
  return f'https://cdn.cloudflare.steamstatic.com/steam/apps/{app_id}/{img_name}'

def write_to_text_file(app_ids, fname, img_name=IMG_NAME):
  with open(fname, 'w') as f:
    for app_id in app_ids:
      url = get_image_url(app_id, img_name)
      f.write(f'{url}\n')

In [None]:
import json

APPID_FNAME = "appids.json"

def get_app_ids():
  with open(APPID_FNAME) as f:
    return [str(app_id) for app_id in json.load(f)]

In [None]:
!curl -OL https://github.com/woctezuma/steam-palette-extractor/releases/download/games/{APPID_FNAME}

In [None]:
app_ids = get_app_ids()
write_to_text_file(app_ids, fname='myimglist.txt')

In [None]:
# The download process took ~ 30 minutes.
# Out of 95,800 images, 92,249 were successfully downloaded.
# The output folder uses ~ 8 GB of disk space.
!echo img2dataset --url_list=myimglist.txt --output_folder=steam_images --resize_mode=no
!echo zip -r steam_images.zip steam_images

## Check the content of the image folder

In [None]:
FILTERED_INDICES_FNAME = 'filtered_indices.json'
FILTERED_APP_IDS_FNAME = 'filtered_appids.json'

In [None]:
import glob

IMG_FOLDER = "steam_images"

def get_test_fnames(image_folder, file_ext = '.jpg'):
  return sorted(glob.glob(image_folder +'/*' + file_ext))

test_fnames = get_test_fnames(f'{IMG_FOLDER}/*')
print(f'#images = {len(test_fnames)}')

In [None]:
from pathlib import Path

app_ids = get_app_ids()

filtered_indices = []
filtered_app_ids = []

for fname in test_fnames:
  index = int(Path(fname).stem)

  filtered_indices.append(index)
  filtered_app_ids.append(app_ids[index])

with open(FILTERED_INDICES_FNAME, 'w') as f:
  json.dump(filtered_indices, f)

with open(FILTERED_APP_IDS_FNAME, 'w') as f:
  json.dump(filtered_app_ids, f)

## Utils

In [None]:
# Reference: https://stackoverflow.com/questions/3241929/python-find-dominant-most-common-color-in-an-image/61730849#61730849

def get_dominant_colors(pil_img, palette_size=16, num_colors=10):
    # Resize image to speed up processing
    img = pil_img.copy()
    img.thumbnail((100, 100))

    # Reduce colors (uses k-means internally)
    paletted = img.convert('P', palette=Image.ADAPTIVE, colors=palette_size)

    # Find the color that occurs most often
    palette = paletted.getpalette()
    color_counts = sorted(paletted.getcolors(), reverse=True)

    dominant_colors = []
    for i in range(num_colors):
      try:
        palette_index = color_counts[i][1]
        colors = palette[palette_index*3:palette_index*3+3]
      except IndexError:
        colors = [0, 0, 0]

      dominant_colors.append(colors)

    return dominant_colors

In [None]:
import mediapy as media

from PIL import Image

def extract_colors(path_or_url, num_colors=10):
  img = media.read_image(path_or_url)
  pil_img = Image.fromarray(img)
  return get_dominant_colors(pil_img, num_colors=num_colors)

In [None]:
from PIL import ImageColor, ImageDraw

# Reference: https://stackoverflow.com/questions/54165439/what-are-the-exact-color-names-available-in-pils-imagedraw

def show_colors(c):
  n = len( c )

  cols        = NUM_COLORS
  rows        = ((n-1) // cols) +1
  cellHeight  = 30
  cellWidth   = 170
  imgHeight   = cellHeight * rows
  imgWidth    = cellWidth * cols

  i = Image.new( "RGB", (imgWidth,imgHeight), (0,0,0) )
  a = ImageDraw.Draw( i )

  for idx, rgb in enumerate( c ):
      y0 = cellHeight * (idx // cols)
      y1 = y0 + cellHeight
      x0 = cellWidth * (idx % cols)
      x1 = x0 + (cellWidth / 1)

      a.rectangle( [ x0, y0, x1, y1 ], fill=tuple(rgb), outline='black' )

  media.show_image(i)

In [None]:
from colorsys import rgb_to_hsv

def to_hsv(r, g, b):
  # Reference: https://stackoverflow.com/a/37656972/376454

  h, s, v = rgb_to_hsv(r / 255, g / 255, b / 255)
  return int(h * 255), int(s * 255), int(v * 255)

In [None]:
import torch

def to_linear_hsv(dominant_colors, change_coordinates=True):
  v = torch.tensor([to_hsv(*rgb) for rgb in dominant_colors])

  # Caveat: convert the HSV values before computing the distance!
  # https://stackoverflow.com/a/39113477/376454

  if change_coordinates:
    v = v.float() / 255

    theta = 2 * torch.pi * v[:, 0]
    radius = v[:, 1]

    x = radius * torch.cos(theta)
    y = radius * torch.sin(theta)

    v[:, 0] = x
    v[:, 1] = y

  return v

## Compute the palette for each Steam game

In [None]:
NUM_COLORS = 8
PALETTE_FNAME = f"steam_palette_{NUM_COLORS}.pth"

In [None]:
from tqdm import tqdm

extract_from_scratch = False

if extract_from_scratch:
  # This extraction process takes ~ 20 minutes.

  d = torch.zeros(len(test_fnames), NUM_COLORS, len("RGB"), dtype=int)

  for i, fname in tqdm(enumerate(test_fnames)):
    if torch.all(d[i]==0):
      dominant_colors = extract_colors(fname, num_colors=NUM_COLORS)
      d[i] = torch.tensor(dominant_colors)

  torch.save(d, PALETTE_FNAME)

## Load pre-computed data

In [None]:
!curl -OL https://github.com/woctezuma/steam-palette-extractor/releases/download/colors/{FILTERED_APP_IDS_FNAME}
!curl -OL https://github.com/woctezuma/steam-palette-extractor/releases/download/colors/{PALETTE_FNAME}

In [None]:
import json

with open(FILTERED_APP_IDS_FNAME) as f:
  filtered_app_ids = json.load(f)

In [None]:
import torch

d = torch.load(PALETTE_FNAME)

## Load data intended to evaluate the results

In [None]:
SOLUTIONS_FNAME = "egs_solutions.json"
POPULAR_APPIDS_FNAME = "popular_appids.json"

In [None]:
!curl -OL https://github.com/woctezuma/steam-palette-extractor/releases/download/solutions/{SOLUTIONS_FNAME}
!curl -o {POPULAR_APPIDS_FNAME} -L https://github.com/woctezuma/steam-popular-appids/releases/download/data/{APPID_FNAME}

In [None]:
import json

with open(SOLUTIONS_FNAME) as f:
  egs_solutions = json.load(f)

with open(POPULAR_APPIDS_FNAME) as f:
  popular_appids = [str(app_id) for app_id in json.load(f)]

In [None]:
def to_steam_url(app_id):
  elements = egs_solutions["image"]["steam"]
  return f"{elements['url']}{app_id}{elements['suffix']}"

def to_egs_url(index, md5):
  elements = egs_solutions["image"]["egs"]
  keyword = elements["keyword"][0] if index == 1 else elements["keyword"][1]
  return f"{elements['url']}{index}{keyword}{elements['resolution']}{md5}"

In [None]:
def from_gift_to_steam_url(gift, app_id_index=0):
  app_id = gift["appids"][app_id_index]
  return to_steam_url(app_id)

def from_gift_to_egs_url(gift):
  index = gift["index"]
  md5 = gift["md5"]
  return to_egs_url(index, md5)

gift_index = 13
gift = egs_solutions["gift"][gift_index]
print(from_gift_to_steam_url(gift))
print(from_gift_to_egs_url(gift))

## Run the workflow

In [None]:
def to_score(minimal_distances, indices, exponent):
  return (minimal_distances * (1+indices)**exponent).sum()

def compute_distance_between_palettes(v, w, top_k = 3, max_components=8, exponent = 1.00):
  pairwise_distances = torch.cdist(
      v.float(),
      w.float(),
      )

  # The first score
  minimal_distances, indices = pairwise_distances.min(
      dim=0, # caveat
      )
  score_for_w = to_score(minimal_distances, indices, exponent)

  # The second score, in order to make the distance symmetrical
  minimal_distances, indices = pairwise_distances.min(
      dim=1, # caveat
      )
  score_for_v = to_score(minimal_distances, indices, exponent)

  return score_for_w + score_for_v

In [None]:
# The algorithm if much slower when using HSV.
use_HSV = True
change_coordinates = True

max_components = 8
top_k = min(max_components, 6)

In [None]:
exponent = 1.0 # TODO
# Exponent = 0.25 ---> rank 59 for Ghostrunner
# Exponent = 0.50 ---> rank 41 for Ghostrunner :D
# Exponent = 0.75 ---> rank 47 for Ghostrunner
# Exponent = 1.00 ---> rank 81 for Ghostrunner
# Exponent = 1.50 ---> rank 166 for Ghostrunner

# Exponent = 0.25 ---> rank ?? for Escape Academy
# Exponent = 0.50 ---> rank 105 for Escape Academy
# Exponent = 0.75 ---> rank 62 for Escape Academy
# Exponent = 1.00 ---> rank 57 for Escape Academy
# Exponent = 1.50 ---> rank ?? for Escape Academy

### Define the target

In [None]:
gift_index = 14
gift = egs_solutions["gift"][gift_index]
path_or_url = from_gift_to_egs_url(gift)

dominant_colors = extract_colors(path_or_url, num_colors=NUM_COLORS)
if use_HSV:
  reference_colors = to_linear_hsv(dominant_colors, change_coordinates)
else:
  reference_colors = torch.tensor(dominant_colors)

media.show_image(media.read_image(path_or_url), width=616)
show_colors(dominant_colors)

### Check the ground truth

In [None]:
# The first appID in the list is used here.
# If there were another edition (e.g. GOTY), you may want to try another appID.
app_id_index = 0
ground_truth_app_id = gift["appids"][app_id_index]

path_or_url = get_image_url(ground_truth_app_id)
dominant_colors = extract_colors(path_or_url, num_colors=NUM_COLORS)
if use_HSV:
  ground_truth_colors = to_linear_hsv(dominant_colors, change_coordinates)
else:
  ground_truth_colors = torch.tensor(dominant_colors)

distance = compute_distance_between_palettes(
    reference_colors,
    ground_truth_colors,
    top_k,
    max_components,
    exponent,
    )

print(f'\tappID: {ground_truth_app_id} ; distance: {distance:.2f}')
show_colors(dominant_colors)
media.show_image(media.read_image(path_or_url))

### Check all

In [None]:
import torch

from tqdm import tqdm

best_app_id = None
best_distance = None
distance_dict = {}

for i, app_id in tqdm(enumerate(filtered_app_ids)):
  dominant_colors = d[i]

  if use_HSV:
    test_colors = to_linear_hsv(dominant_colors)
  else:
    test_colors = torch.tensor(dominant_colors)

  distance = compute_distance_between_palettes(
      reference_colors,
      test_colors,
      top_k,
      max_components,
      exponent,
      )

  distance_dict[app_id] = distance

  if best_distance is None or distance < best_distance:
    best_app_id = app_id
    best_distance = distance

    path_or_url = get_image_url(app_id)
    print(f'\tappID: {app_id} ; distance: {distance:.2f} ; url: {path_or_url}')
    media.show_image(media.read_image(path_or_url))

In [None]:
most_similar_app_ids = sorted(distance_dict, key=lambda x: distance_dict[x])

if ground_truth_app_id:
  ground_truth_rank = most_similar_app_ids.index(str(ground_truth_app_id))
  print(f"Ground truth (appID = {ground_truth_app_id}) is ranked n°{ground_truth_rank}.")

In [None]:
MAX_NUM_DISPLAYED_IMAGES = 25
DISPLAYED_IMAGE_WIDTH = 300

for i, app_id in enumerate(most_similar_app_ids[:MAX_NUM_DISPLAYED_IMAGES], start=1):
  distance = distance_dict[app_id]

  path_or_url = get_image_url(app_id)
  print(f'\t{i}) appID: {app_id} ; distance: {distance:.2f} ; url: {path_or_url}')
  media.show_image(media.read_image(path_or_url), width=DISPLAYED_IMAGE_WIDTH)