In [1]:
import os
import torch
import torch.nn as nn
from os.path import expanduser  # pylint: disable=import-outside-toplevel
from urllib.request import urlretrieve  # pylint: disable=import-outside-toplevel
def get_aesthetic_model(clip_model="vit_l_14"):
    """load the aethetic model"""
    home = expanduser("~")
    cache_folder = home + "/.cache/emb_reader"
    path_to_model = cache_folder + "/sa_0_4_"+clip_model+"_linear.pth"
    if not os.path.exists(path_to_model):
        os.makedirs(cache_folder, exist_ok=True)
        url_model = (
            "https://github.com/LAION-AI/aesthetic-predictor/blob/main/sa_0_4_"+clip_model+"_linear.pth?raw=true"
        )
        urlretrieve(url_model, path_to_model)
    if clip_model == "vit_l_14":
        m = nn.Linear(768, 1)
    elif clip_model == "vit_b_32":
        m = nn.Linear(512, 1)
    else:
        raise ValueError()
    s = torch.load(path_to_model)
    m.load_state_dict(s)
    m.eval()
    return m

In [2]:
import torchvision.transforms as transforms
from PIL import Image
from torchvision.transforms.functional import InterpolationMode
import clip

# Load the aesthetic model
aesthetic_model = get_aesthetic_model()

# Load CLIP model
device = "cuda" if torch.cuda.is_available() else "cpu"
clip_model, preprocess = clip.load("ViT-L/14", device=device)

# Function to preprocess image
def preprocess_image(image_path):
    image = Image.open(image_path).convert("RGB")
    transform = transforms.Compose([
        transforms.Resize((224, 224), interpolation=InterpolationMode.BICUBIC),
        transforms.ToTensor(),
        transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
    ])
    return transform(image).unsqueeze(0)

# Function to get aesthetic score
def get_aesthetic_score(image_path):
    image = preprocess_image(image_path).to(device)
    with torch.no_grad():
        image_features = clip_model.encode_image(image)
        aesthetic_score = aesthetic_model(image_features.float()).item()
    return aesthetic_score


  Referenced from: <CFED5F8E-EC3F-36FD-AAA3-2C6C7F8D3DD9> /opt/anaconda3/lib/python3.11/site-packages/torchvision/image.so
  warn(
100%|███████████████████████████████████████| 890M/890M [00:26<00:00, 35.1MiB/s]


In [4]:
# List of image paths
image_paths = ["/Users/petrsushko/Desktop/local_photobench/clean/benchmarks/benchmark_images/input/jy4v49.jpg",
               "/Users/petrsushko/Desktop/local_photobench/clean/benchmarks/benchmark_images/magic_brush_output/jy4v49.jpg",
               "/Users/petrsushko/Desktop/local_photobench/clean/benchmarks/benchmark_images/reddit_output/gd10gef_2.png"]

# Get aesthetic scores for all images
for image_path in image_paths:
    score = get_aesthetic_score(image_path)
    rounded_score = round(score, 1)
    print(f"Aesthetic score for {image_path}: {rounded_score}")

Aesthetic score for /Users/petrsushko/Desktop/local_photobench/clean/benchmarks/benchmark_images/input/jy4v49.jpg: 42.4
Aesthetic score for /Users/petrsushko/Desktop/local_photobench/clean/benchmarks/benchmark_images/magic_brush_output/jy4v49.jpg: 28.4
Aesthetic score for /Users/petrsushko/Desktop/local_photobench/clean/benchmarks/benchmark_images/reddit_output/gd10gef_2.png: 44.4
