# Steam Palette Extractor
Reference: https://github.com/woctezuma/steam-palette-extractor

## Install Python packages

In [None]:
%pip install -qq img2dataset mediapy tqdm

## Download images from Steam (only the first time)

In [None]:
IMG_NAME = 'capsule_616x353.jpg'

def get_image_url(app_id, img_name=IMG_NAME):
  return f'https://cdn.cloudflare.steamstatic.com/steam/apps/{app_id}/{img_name}'

def write_to_text_file(app_ids, fname, img_name=IMG_NAME):
  with open(fname, 'w') as f:
    for app_id in app_ids:
      url = get_image_url(app_id, img_name)
      f.write(f'{url}\n')

In [None]:
import json

APPID_FNAME = "appids.json"

def get_app_ids():
  with open(APPID_FNAME) as f:
    return [str(app_id) for app_id in json.load(f)]

In [None]:
!curl -OL https://github.com/woctezuma/steam-palette-extractor/releases/download/games/{APPID_FNAME}

In [None]:
app_ids = get_app_ids()
write_to_text_file(app_ids, fname='myimglist.txt')

In [None]:
# The download process took ~ 30 minutes.
# Out of 95,800 images, 92,249 were successfully downloaded.
# The output folder uses ~ 8 GB of disk space.
!echo img2dataset --url_list=myimglist.txt --output_folder=steam_images --resize_mode=no
!echo zip -r steam_images.zip steam_images

## Check the content of the image folder

In [None]:
import glob

IMG_FOLDER = "steam_images"

def get_test_fnames(image_folder, file_ext = '.jpg'):
  return sorted(glob.glob(image_folder +'/*' + file_ext))

test_fnames = get_test_fnames(f'{IMG_FOLDER}/*')
print(f'#images = {len(test_fnames)}')

## Utils

In [None]:
# Reference: https://stackoverflow.com/questions/3241929/python-find-dominant-most-common-color-in-an-image/61730849#61730849

def get_dominant_colors(pil_img, palette_size=16, num_colors=10):
    # Resize image to speed up processing
    img = pil_img.copy()
    img.thumbnail((100, 100))

    # Reduce colors (uses k-means internally)
    paletted = img.convert('P', palette=Image.ADAPTIVE, colors=palette_size)

    # Find the color that occurs most often
    palette = paletted.getpalette()
    color_counts = sorted(paletted.getcolors(), reverse=True)

    dominant_colors = []
    for i in range(num_colors):
      try:
        palette_index = color_counts[i][1]
        colors = palette[palette_index*3:palette_index*3+3]
      except IndexError:
        colors = [0, 0, 0]

      dominant_colors.append(colors)

    return dominant_colors

In [None]:
import mediapy as media

from PIL import Image

def extract_colors(path_or_url, num_colors=10):
  img = media.read_image(path_or_url)
  pil_img = Image.fromarray(img)
  return get_dominant_colors(pil_img, num_colors=num_colors)

In [None]:
import mediapy as media

from PIL import ImageColor, Image, ImageDraw

# Reference: https://stackoverflow.com/questions/54165439/what-are-the-exact-color-names-available-in-pils-imagedraw

def show_colors(c):
  n = len( c )

  cols        = NUM_COLORS
  rows        = ((n-1) // cols) +1
  cellHeight  = 30
  cellWidth   = 170
  imgHeight   = cellHeight * rows
  imgWidth    = cellWidth * cols

  i = Image.new( "RGB", (imgWidth,imgHeight), (0,0,0) )
  a = ImageDraw.Draw( i )

  for idx, rgb in enumerate( c ):
      y0 = cellHeight * (idx // cols)
      y1 = y0 + cellHeight
      x0 = cellWidth * (idx % cols)
      x1 = x0 + (cellWidth / 1)

      a.rectangle( [ x0, y0, x1, y1 ], fill=tuple(rgb), outline='black' )

  media.show_image(i)

In [None]:
from colorsys import rgb_to_hsv

def to_hsv(r, g, b):
  # Reference: https://stackoverflow.com/a/37656972/376454

  h, s, v = rgb_to_hsv(r / 255, g / 255, b / 255)
  return int(h * 255), int(s * 255), int(v * 255)

## Run

In [None]:
import torch

NUM_COLORS = 8

path_or_url = "https://cdn1.epicgames.com/offer/d5241c76f178492ea1540fce45616757/Free-Game-3-teaser_1920x1080-56b6434f5564766a6dc7a1d7fada8c18"
dominant_colors = extract_colors(path_or_url, num_colors=NUM_COLORS)
show_colors(dominant_colors)
v = torch.tensor([to_hsv(*rgb) for rgb in dominant_colors])
print(v)

path_or_url = get_image_url(1267910)
dominant_colors = extract_colors(path_or_url, num_colors=NUM_COLORS)
show_colors(dominant_colors)
w = torch.tensor([to_hsv(*rgb) for rgb in dominant_colors])
print(w)

In [None]:
# Caveat: convert the HSV values before computing the distance!
# https://stackoverflow.com/a/39113477/376454

x = torch.cos(v[:, 0]) * v[:, 1]
y = torch.sin(v[:, 0]) * v[:, 1]
v[:, 0] = x
v[:, 1] = y

x = torch.cos(w[:, 0]) * w[:, 1]
y = torch.sin(w[:, 0]) * w[:, 1]
w[:, 0] = x
w[:, 1] = y

distance = torch.cdist(v.float(), w.float())
distance

top_k = 3
score = torch.topk(distance.min(dim=1)[0], k=top_k, largest=False)[0].mean()
score

In [None]:
from pathlib import Path

app_ids = get_app_ids()

filtered_indices = []
filtered_app_ids = []

for fname in test_fnames:
  index = int(Path(fname).stem)

  filtered_indices.append(index)
  filtered_app_ids.append(app_ids[index])

with open('filtered_indices.json', 'w') as f:
  json.dump(filtered_indices, f)

with open('filtered_appids.json', 'w') as f:
  json.dump(filtered_app_ids, f)

In [None]:
from tqdm import tqdm

# This extraction process takes ~ 20 minutes.

# d = torch.zeros(len(test_fnames), NUM_COLORS, len("RGB"), dtype=int)

for i, fname in tqdm(enumerate(test_fnames)):
  if torch.all(d[i]==0):
    dominant_colors = extract_colors(fname, num_colors=NUM_COLORS)
    d[i] = torch.tensor(dominant_colors)

In [None]:
torch.save(d, f"steam_palette_{NUM_COLORS}.pth")