# AestheticSorter

This notebook lets you sort a folder of images by aesthetic rating. You can choose from a few pretrained aesthetic models.
* laion_aesthetic by [LAION](https://x.com/laion_ai)
* laion_aesthetic_improved by [christophschuhmann](https://github.com/christophschuhmann)
* aesthetika by [pharmapsychotic](https://x.com/pharmapsychotic)

In [1]:
#@title Check GPU
!nvidia-smi -L

GPU 0: Tesla T4 (UUID: GPU-f5156b3d-9a0c-ecbf-01bb-1cf4de9307e9)


In [2]:
#@title Mount Google Drive
from google.colab import drive
drive.mount('/content/gdrive')

Mounted at /content/gdrive


In [None]:
#@title Install dependencies
!pip install requests safetensors tqdm
!pip install git+https://github.com/openai/CLIP.git

In [4]:
#@title Code definitions
import hashlib
import os
import shutil
from pathlib import Path
from typing import Optional

import clip
import requests
import torch
import torch.nn as nn
from PIL import Image
from safetensors import safe_open
from tqdm import tqdm

device = "cuda" if torch.cuda.is_available() else "cpu"
clip_model, preprocess = None, None

class BaseAestheticPredictor(nn.Module):
    URL: str
    HASH: Optional[str] = None

    def __init__(self, cache_dir: str = ".cache"):
        super().__init__()
        self.cache_dir = Path(cache_dir)
        self.cache_dir.mkdir(exist_ok=True)

    def _download_if_needed(self) -> Path:
        url_path = self.URL.split("?")[0]  # remove query params
        fname = self.cache_dir / Path(url_path).name.replace("+", "_")

        if not fname.exists():
            response = requests.get(self.URL, stream=True)
            response.raise_for_status()

            with open(fname, "wb") as f:
                for chunk in response.iter_content(chunk_size=8192):
                    f.write(chunk)

        # verify hash
        with open(fname, "rb") as f:
            file_hash = hashlib.sha256(f.read()).hexdigest()
        if file_hash != self.HASH:
            fname.unlink()
            raise ValueError(f"hash verification failed for {fname}")

        return fname

class Aesthetika(BaseAestheticPredictor):
    URL = "https://pharmapsychotic.com/models/aesthetika_20240424.safetensors"
    HASH = "fad60cfdfb7857c22c8f1c99c05b463f5aded3766124bf92aa21b7633c035982"

    def __init__(self, cache_dir: str = ".cache"):
        super().__init__(cache_dir)
        self.layers = nn.Sequential(
            nn.Linear(768, 1024),
            nn.ReLU(),
            nn.Identity(),
            nn.Linear(1024, 1),
        )
        state_dict = {}
        model_path = self._download_if_needed()
        with safe_open(model_path, framework="pt", device="cpu") as f:
            for key in f.keys():
                state_dict[key] = f.get_tensor(key)
        self.load_state_dict(state_dict)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.layers(x)

class LaionAesthetic(BaseAestheticPredictor):
    URL = "https://github.com/LAION-AI/aesthetic-predictor/blob/main/sa_0_4_vit_l_14_linear.pth?raw=true"
    HASH = "2cd4e60f4f24ae3bcd57b847b13c1f3ba27edc28cc1a7f9ce74ee9f421243cba"

    def __init__(self, cache_dir: str = ".cache"):
        super().__init__(cache_dir)
        self.layers = nn.Sequential(nn.Linear(768, 1))
        weights = torch.load(self._download_if_needed(), weights_only=True)
        weights = {f"layers.0.{k}": v for k, v in weights.items()}
        self.load_state_dict(weights)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.layers(x)

class LaionAestheticImproved(BaseAestheticPredictor):
    URL = "https://github.com/christophschuhmann/improved-aesthetic-predictor/blob/main/sac+logos+ava1-l14-linearMSE.pth?raw=true"
    HASH = "21dd590f3ccdc646f0d53120778b296013b096a035a2718c9cb0d511bff0f1e0"

    def __init__(self, cache_dir: str = ".cache"):
        super().__init__(cache_dir)
        self.layers = nn.Sequential(
            nn.Linear(768, 1024),
            nn.Dropout(0.2),
            nn.Linear(1024, 128),
            nn.Dropout(0.2),
            nn.Linear(128, 64),
            nn.Dropout(0.1),
            nn.Linear(64, 16),
            nn.Linear(16, 1)
        )
        weights = torch.load(self._download_if_needed(), weights_only=True)
        self.load_state_dict(weights)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.layers(x)

def create_model(name: str, **kwargs) -> BaseAestheticPredictor:
    models = {
        "aesthetika": Aesthetika,
        "laion_aesthetic": LaionAesthetic,
        "laion_aesthetic_improved": LaionAestheticImproved
    }
    if name not in models:
        raise ValueError(f"Invalid model {name}")
    return models[name](**kwargs)

def predict_rating(model, image_path):
    image = Image.open(image_path).convert("RGB")
    image = preprocess(image).unsqueeze(0).to(device)
    with torch.inference_mode():
        image_embed = clip_model.encode_image(image)
        aesthetic_rating = model(image_embed.float())
    return aesthetic_rating.item()

In [None]:
#@title Let's sort!
#@markdown Click Files folder on left then right click a folder and choose
#@markdown "Copy path" and paste that into `images_folder`. You can choose an
#@markdown existing path in your Google Drive or upload to a new folder.
#@markdown <br><br>

images_folder = "/content/gdrive/MyDrive/my_images" #@param {type:"string"}
sorted_folder = "/content/gdrive/MyDrive/my_images/sorted" #@param {type:"string"}
aesthetic_model = 'laion_aesthetic_improved' #@param ['aesthetika', 'laion_aesthetic', 'laion_aesthetic_improved']

files = [file for file in os.listdir(images_folder) if os.path.splitext(file)[1] in ('.png', '.jpg', '.jpeg', '.webp')]
if not len(files):
    raise Exception(f"No image files found in {images_folder}")

if clip_model is None:
    print("Loading CLIP model...")
    clip_model, preprocess = clip.load("ViT-L/14", device=device)

# calculate aesthetic ratings for all images
model = create_model(aesthetic_model)
model.eval()
model.to(device)
ratings = []
for file in tqdm(files, desc="Calculating aesthetic ratings"):
    rating = predict_rating(model, os.path.join(images_folder, file))
    ratings.append(rating)

# sort by aesthetic rating
ordering = sorted(range(len(ratings)), key=lambda i: ratings[i], reverse=True)

print(f"Saving in sorted order to {sorted_folder}...")
if not os.path.exists(sorted_folder):
    os.makedirs(sorted_folder)
for i in range(len(files)):
    file = files[ordering[i]]
    rating = ratings[ordering[i]]
    base, ext = os.path.splitext(file)
    dest = os.path.join(sorted_folder, f"{i:04d}_{rating:.2f}_{base}.{ext}")
    if os.path.exists(dest):
        os.remove(dest)
    shutil.copyfile(os.path.join(images_folder, file), dest)

print("Sorted order:")
print([os.path.splitext(files[ordering[i]])[0] for i in range(len(ordering))])


Calculating aesthetic ratings: 100%|██████████| 8/8 [00:00<00:00, 12.01it/s]


Saving in sorted order to /content/gdrive/MyDrive/my_images/sorted...
Sorted order:
['0004', '0005', '0003', '0008', '0006', '0007', '0009', '0002']
