In [None]:
!pip install git+https://github.com/huggingface/diffusers.git@refs/pull/9126/head#egg=diffusers transformers peft accelerate opencv-python sentencepiece
!pip install fastapi uvicorn nest-asyncio pyngrok --quiet

import diffusers
from tqdm import tqdm
import torch
from PIL import Image
import random
from typing import Dict, Tuple
import nest_asyncio
import uvicorn
from fastapi import FastAPI, HTTPException
from pyngrok import ngrok
import os
import asyncio
from concurrent.futures import ThreadPoolExecutor
import numpy as np
from uuid import uuid4
from fastapi.responses import FileResponse
import threading

MEDIUMS = ["painting", "drawing", "photograph", "HD photo", "illustration", "portrait",
           "sketch", "3d render", "digital painting", "concept art", "screenshot",
           "canvas painting", "watercolor art", "print", "mosaic", "sculpture",
           "cartoon", "comic art", "anime"]

SUBJECTS = ["dog", "cat", "horse", "cow", "pig", "sheep", "lion", "elephant",
            "monkey", "bird", "chicken", "eagle", "parrot", "penguin", "fish",
            "shark", "dolphin", "whale", "octopus", "bee", "butterfly", "ant",
            "ladybug", "person", "man", "woman", "child", "baby", "boy", "girl",
            "car", "boat", "airplane", "bicycle", "motorcycle", "train", "building",
            "house", "bridge", "castle", "temple", "monument", "tree", "flower",
            "mountain", "lake", "river", "ocean", "beach", "fruit", "vegetable",
            "meat", "bread", "cake", "soup", "coffee", "toy", "book", "phone",
            "computer", "TV", "camera", "musical instrument", "furniture", "road",
            "park", "garden", "forest", "city", "sunset", "clouds"]

In [2]:
class CLIPSlider:
    def __init__(
            self,
            sd_pipe,
            device: torch.device,
            descriptors: Dict[str, Tuple[str, str]],
            iterations: int = 300,
    ):
        self.device = device
        self.pipe = sd_pipe.to(self.device, torch.float16)
        self.iterations = iterations
        self.descriptors = descriptors
        self.latent_directions = {}
        for descriptor, (target_word, opposite) in descriptors.items():
            avg_diff = self.find_latent_direction(target_word, opposite)
            self.latent_directions[descriptor] = avg_diff

    def find_latent_direction(self, target_word: str, opposite: str):
        with torch.no_grad():
            positives = []
            negatives = []
            for _ in tqdm(range(self.iterations), desc=f"Finding latent direction for '{target_word}' vs '{opposite}'"):
                medium = random.choice(MEDIUMS)
                subject = random.choice(SUBJECTS)
                pos_prompt = f"a {medium} of a {target_word} {subject}"
                neg_prompt = f"a {medium} of a {opposite} {subject}"
                pos_toks = self.pipe.tokenizer(
                    pos_prompt,
                    return_tensors="pt",
                    padding="max_length",
                    truncation=True,
                    max_length=self.pipe.tokenizer.model_max_length
                ).input_ids.to(self.device)
                neg_toks = self.pipe.tokenizer(
                    neg_prompt,
                    return_tensors="pt",
                    padding="max_length",
                    truncation=True,
                    max_length=self.pipe.tokenizer.model_max_length
                ).input_ids.to(self.device)
                pos = self.pipe.text_encoder(pos_toks).pooler_output
                neg = self.pipe.text_encoder(neg_toks).pooler_output
                positives.append(pos)
                negatives.append(neg)

        positives = torch.cat(positives, dim=0)
        negatives = torch.cat(negatives, dim=0)
        diffs = positives - negatives
        avg_diff = diffs.mean(0, keepdim=True)
        return avg_diff

    def generate(self,
        prompt="a photo of a house",
        scales: Dict[str, float] = None,
        seed=15,
        only_pooler=False,
        normalize_scales=False,
        correlation_weight_factor=1.0,
        **pipeline_kwargs
    ):
        scales = scales or {}
        with torch.no_grad():
            toks = self.pipe.tokenizer(
                prompt,
                return_tensors="pt",
                padding="max_length",
                truncation=True,
                max_length=self.pipe.tokenizer.model_max_length
            ).input_ids.to(self.device)
        prompt_embeds = self.pipe.text_encoder(toks).last_hidden_state

        # Normalize scales if required
        if normalize_scales and len(scales) > 0:
            total_scale = sum(abs(s) for s in scales.values())
            scales = {k: v / total_scale for k, v in scales.items()}

        if only_pooler:
            for descriptor, scale in scales.items():
                avg_diff = self.latent_directions[descriptor]
                prompt_embeds[:, toks.argmax()] += avg_diff * scale
        else:
            normed_prompt_embeds = prompt_embeds / prompt_embeds.norm(dim=-1, keepdim=True)
            sims = normed_prompt_embeds[0] @ normed_prompt_embeds[0].T
            weights = sims[toks.argmax(), :][None, :, None].repeat(1, 1, 768)

            standard_weights = torch.ones_like(weights)
            weights = standard_weights + (weights - standard_weights) * correlation_weight_factor

            for descriptor, scale in scales.items():
                avg_diff = self.latent_directions[descriptor]
                prompt_embeds = prompt_embeds + (
                    weights * avg_diff[None, :].repeat(1, self.pipe.tokenizer.model_max_length, 1) * scale
                )

        torch.manual_seed(seed)
        images = self.pipe(prompt_embeds=prompt_embeds, **pipeline_kwargs).images
        return images


In [None]:
from diffusers import StableDiffusionPipeline

# Initialize the Stable Diffusion pipeline
pipe = StableDiffusionPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5",  safety_checker=None, torch_dtype=torch.float16
)

# Set the device (GPU if available)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Define your descriptors and their target and opposite words
descriptors = {
    'color': ('blue', 'red'),
    'size': ('small', 'big'),
    'shape': ('round', 'square'),
    'brightness': ('dark', 'bright'),
    'material': ('metal', 'plastic'),
}

# Create an instance of CLIPSlider
slider = CLIPSlider(
    sd_pipe=pipe,
    device=device,
    descriptors=descriptors,
    iterations=500
)

In [None]:
# generate one image with given scale values

scales = {
    'color': -1.7,
    'size': -1.8,
    'shape': 0.53,
    'brightness': -0.48,
    'material': 0.2,
}

image = slider.generate(
    prompt="a photo of a house",
    scales=scales,
    seed=15,
    num_inference_steps=20,
)

file_name = "image_" + "_".join([f"{key}_{value}" for key, value in scales.items()]) + ".png"
image[0].save(file_name)

In [None]:
!ngrok config add-authtoken $NGROK_TOKEN

# Initialize FastAPI app
app = FastAPI()

# Global variables
server_initialized = False
games = {}
executor = ThreadPoolExecutor(max_workers=4)

# Apply nest_asyncio to allow nested event loops (useful in some environments)
nest_asyncio.apply()
loop = asyncio.get_event_loop()

async def initialize_server():
    global server_initialized
    # Simulate initialization delay
    await asyncio.sleep(1)
    server_initialized = True

# Schedule the server initialization task
asyncio.create_task(initialize_server())

@app.post("/register")
async def register_player():
    global server_initialized
    game_id = str(uuid4())
    if server_initialized:
        games[game_id] = {'images': {}}
        return {"status": "ok", "game_id": game_id}
    else:
        return {"status": "wait", "message": "Server is initializing. Please wait."}

def scale_coordinates(coords):
    scaled = [((c + 1e6) / 2e6) * 4 - 2 for c in coords]
    return scaled

def generate_image(scaled_coords):
    # Assuming `slider` is defined elsewhere in your code
    scales = {
        'color': scaled_coords[0],
        'size': scaled_coords[1],
        'shape': scaled_coords[2],
        'brightness': scaled_coords[3],
        'material': scaled_coords[4],
    }
    images = slider.generate(
        prompt="a photo of a house",
        scales=scales,
        seed=15,
        num_inference_steps=20,
    )
    return images[0]

@app.get("/get_image")
async def get_image(game_id: str, x: int, y: int, z: int, w: int, v: int):
    if game_id not in games:
        raise HTTPException(status_code=404, detail="Game ID not found.")

    coords = (x, y, z, w, v)
    image_key = '_'.join(map(str, coords))
    game = games[game_id]

    if image_key in game['images']:
        image_path = game['images'][image_key]
    else:
        scaled_coords = scale_coordinates(coords)
        print(scaled_coords)
        image = await loop.run_in_executor(executor, generate_image, scaled_coords)
        image_dir = f"images/{game_id}"
        os.makedirs(image_dir, exist_ok=True)
        image_path = f"{image_dir}/{image_key}.png"
        image.save(image_path)
        game['images'][image_key] = image_path

    return FileResponse(image_path, media_type='image/png')

def main():
    # Set your ngrok auth token
    ngrok_auth_token = "2fF9omoV190SaEzgW5y3tCABuLS_2EaMHefkYmq584RpELWHQ"
    ngrok.set_auth_token(ngrok_auth_token)

    # Start ngrok tunnel
    public_url = ngrok.connect(8000)
    print(f"Public URL: {public_url}")

    try:
        # Run the FastAPI app with debug mode enabled
        uvicorn.run(app, host='0.0.0.0', port=8000, log_level="info")
    finally:
        # Ensure ngrok is properly terminated when the app stops
        ngrok.kill()

if __name__ == "__main__":
    main()
