In [None]:
!pip install git+https://github.com/huggingface/diffusers.git@refs/pull/9126/head#egg=diffusers transformers safetensors
!pip install Flask Flask-SQLAlchemy pyngrok Flask-Cors

MEDIUMS = ["painting", "drawing", "photograph", "HD photo", "illustration", "portrait",
           "sketch", "3d render", "digital painting", "concept art", "screenshot",
           "canvas painting", "watercolor art", "print", "mosaic", "sculpture",
           "cartoon", "comic art", "anime"]

SUBJECTS = ["dog", "cat", "horse", "cow", "pig", "sheep", "lion", "elephant",
            "monkey", "bird", "chicken", "eagle", "parrot", "penguin", "fish",
            "shark", "dolphin", "whale", "octopus", "bee", "butterfly", "ant",
            "ladybug", "person", "man", "woman", "child", "baby", "boy", "girl",
            "car", "boat", "airplane", "bicycle", "motorcycle", "train", "building",
            "house", "bridge", "castle", "temple", "monument", "tree", "flower",
            "mountain", "lake", "river", "ocean", "beach", "fruit", "vegetable",
            "meat", "bread", "cake", "soup", "coffee", "toy", "book", "phone",
            "computer", "TV", "camera", "musical instrument", "furniture", "road",
            "park", "garden", "forest", "city", "sunset", "clouds"]

In [13]:
import os
import torch
from typing import List, Tuple, Dict
from tqdm import tqdm
from safetensors.torch import save_file, load_file
import random

# Ensure that MEDIUMS and SUBJECTS are defined somewhere in your code
# Example:
# MEDIUMS = ['painting', 'drawing', 'sketch', 'photograph']
# SUBJECTS = ['cat', 'dog', 'tree', 'mountain']

class CLIPSlider:
    def __init__(
            self,
            sd_pipe,
            device: torch.device,
            descriptors: List[Tuple[str, str]],
            iterations: int = 300,
    ):
        self.device = device
        self.pipe = sd_pipe.to(self.device, torch.float16)
        self.iterations = iterations
        self.descriptors = descriptors  # List of tuples
        self.latent_directions = {}
        self.current_progress = 0  # For progress reporting

    def find_latent_direction(self, negative_word: str, positive_word: str):
        positives = []
        negatives = []
        with torch.no_grad():
            for _ in tqdm(
                range(self.iterations),
                desc=f"Finding latent direction for '{negative_word}' vs '{positive_word}'"
            ):
                medium = random.choice(MEDIUMS)
                subject = random.choice(SUBJECTS)
                neg_prompt = f"a {medium} of a {negative_word} {subject}"
                pos_prompt = f"a {medium} of a {positive_word} {subject}"

                neg_toks = self.pipe.tokenizer(
                    neg_prompt,
                    return_tensors="pt",
                    padding="max_length",
                    truncation=True,
                    max_length=self.pipe.tokenizer.model_max_length
                ).input_ids.to(self.device)
                pos_toks = self.pipe.tokenizer(
                    pos_prompt,
                    return_tensors="pt",
                    padding="max_length",
                    truncation=True,
                    max_length=self.pipe.tokenizer.model_max_length
                ).input_ids.to(self.device)

                neg = self.pipe.text_encoder(neg_toks).pooler_output
                pos = self.pipe.text_encoder(pos_toks).pooler_output
                negatives.append(neg)
                positives.append(pos)

        negatives = torch.cat(negatives, dim=0)
        positives = torch.cat(positives, dim=0)
        diffs = positives - negatives
        avg_diff = diffs.mean(0, keepdim=True)
        return avg_diff

    def compute_latent_directions(self):
        for descriptor_pair in self.descriptors:
            negative_word, positive_word = descriptor_pair
            avg_diff = self.find_latent_direction(negative_word, positive_word)
            self.latent_directions[descriptor_pair] = avg_diff

    def save_latent_directions(self, file_path: str):
        tensors_to_save = {}
        for descriptor_pair, tensor in self.latent_directions.items():
            key = f"{descriptor_pair[0]}_{descriptor_pair[1]}"
            tensors_to_save[key] = tensor.cpu()
        save_file(tensors_to_save, file_path)
        print(f"Latent directions saved to {file_path}")

    def load_latent_directions(self, file_path: str):
        if not os.path.exists(file_path):
            raise FileNotFoundError(f"Latent directions file not found: {file_path}")
        loaded_tensors = load_file(file_path)
        for key in loaded_tensors:
            negative_word, positive_word = key.split('_')
            self.latent_directions[(negative_word, positive_word)] = loaded_tensors[key].to(self.device)
        print(f"Latent directions loaded from {file_path}")

    def generate(
        self,
        prompt="a photo of a house",
        scales: Dict[Tuple[str, str], float] = None,
        seed=15,
        only_pooler=False,
        normalize_scales=False,
        correlation_weight_factor=1.0,
        **pipeline_kwargs
    ):
        scales = scales or {}
        with torch.no_grad():
            toks = self.pipe.tokenizer(
                prompt,
                return_tensors="pt",
                padding="max_length",
                truncation=True,
                max_length=self.pipe.tokenizer.model_max_length
            ).input_ids.to(self.device)
            prompt_embeds = self.pipe.text_encoder(toks).last_hidden_state

        # Normalize scales if required
        if normalize_scales and len(scales) > 0:
            total_scale = sum(abs(s) for s in scales.values())
            scales = {k: v / total_scale for k, v in scales.items()}

        if only_pooler:
            for descriptor_pair, scale in scales.items():
                avg_diff = self.latent_directions[descriptor_pair]
                prompt_embeds[:, toks.argmax()] += avg_diff * scale
        else:
            normed_prompt_embeds = prompt_embeds / prompt_embeds.norm(dim=-1, keepdim=True)
            sims = normed_prompt_embeds[0] @ normed_prompt_embeds[0].T
            weights = sims[toks.argmax(), :][None, :, None].repeat(1, 1, prompt_embeds.shape[-1])

            standard_weights = torch.ones_like(weights)
            weights = standard_weights + (weights - standard_weights) * correlation_weight_factor

            for descriptor_pair, scale in scales.items():
                avg_diff = self.latent_directions[descriptor_pair]
                prompt_embeds = prompt_embeds + (
                    weights * avg_diff[None, :].repeat(1, self.pipe.tokenizer.model_max_length, 1) * scale
                )

        torch.manual_seed(seed)
        images = self.pipe(prompt_embeds=prompt_embeds, **pipeline_kwargs).images
        return images


In [None]:
from diffusers import StableDiffusionPipeline, DDIMScheduler

# Initialize the scheduler
scheduler = DDIMScheduler.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="scheduler")

# Initialize the Stable Diffusion pipeline with the scheduler
pipe = StableDiffusionPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5",
    scheduler=scheduler,  # Use the DDIM scheduler
    safety_checker=None,
    torch_dtype=torch.float16
)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Initialize the default CLIPSlider
default_descriptors = [
    ('blue', 'red'),
    ('small', 'big'),
    ('round', 'square'),
    ('dark', 'bright'),
    ('metal', 'plastic'),
]

DEFAULT_ITERATIONS = 10  # Adjust as needed

default_slider = CLIPSlider(
    sd_pipe=pipe,
    device=device,
    descriptors=default_descriptors,
    iterations=DEFAULT_ITERATIONS
)
default_slider.compute_latent_directions()


In [None]:
from tqdm import tqdm
import torch
import random
from typing import List, Tuple, Dict
from diffusers import StableDiffusionPipeline, DDIMScheduler
from safetensors.torch import save_file, load_file
import os
from flask import Flask, request, jsonify, send_file, Response, stream_with_context
from flask_sqlalchemy import SQLAlchemy
from flask_cors import CORS
from google.colab import userdata
from uuid import uuid4
from pyngrok import ngrok
import json
import hashlib
import traceback

current_prompt = "a photo of a house"  # Default prompt

# Scaling factor to map coordinates to desired range
SCALE_FACTOR = 3 / 1_000_000

DEFAULT_ITERATIONS = 10

# Global counter for ongoing image generations
ongoing_generations = 0
max_ongoing_generations = 10

app = Flask(__name__)
CORS(app)

# Configure SQLite database
app.config['SQLALCHEMY_DATABASE_URI'] = 'sqlite:///game_data.db'
app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False

# Add pool settings to prevent connection pool exhaustion
app.config['SQLALCHEMY_ENGINE_OPTIONS'] = {
    'pool_size': 10,
    'max_overflow': 20,
    'pool_timeout': 60
}

db = SQLAlchemy(app)

class GameData(db.Model):
    __tablename__ = 'game_data'
    uuid = db.Column(db.String(36), primary_key=True)  # UUIDs are 36 characters
    prompt = db.Column(db.String, nullable=True)  # Stores the current prompt
    descriptors = db.Column(db.JSON, nullable=True)  # Stores descriptors as a list of tuples
    data = db.Column(db.JSON, nullable=True)  # Stores additional game data as JSON
    textures = db.relationship('Texture', backref='game_data', lazy=True, cascade="all, delete-orphan")

class Texture(db.Model):
    __tablename__ = 'textures'
    id = db.Column(db.Integer, primary_key=True)
    game_uuid = db.Column(db.String(36), db.ForeignKey('game_data.uuid'), nullable=False)
    position_key = db.Column(db.String, nullable=False)
    image_filename = db.Column(db.String, nullable=False)

# Global dictionary to store CLIPSlider instances per game UUID
game_sliders = {}

@app.route('/register', methods=['POST'])
def register():
    game_uuid = str(uuid4())
    new_game = GameData(uuid=game_uuid, data={})
    db.session.add(new_game)
    db.session.commit()
    db.session.close()
    return jsonify({'uuid': game_uuid}), 200

@app.route('/save', methods=['POST'])
def save_game():
    content = request.get_json()
    game_uuid = content.get('uuid')
    game_data = content.get('gameData')
    if not game_uuid or game_data is None:
        return jsonify({'error': 'Missing uuid or gameData'}), 400
    game = GameData.query.get(game_uuid)
    if not game:
        db.session.close()
        return jsonify({'error': 'Game data not found'}), 404
    game.data = game_data
    # Save the latent directions under the folder of the game UUID
    latent_dir = os.path.join('latent_directions', game_uuid)
    os.makedirs(latent_dir, exist_ok=True)
    latent_file = os.path.join(latent_dir, 'directions.safetensors')
    slider = game_sliders.get(game_uuid)
    if slider:
        slider.save_latent_directions(latent_file)
    db.session.commit()
    db.session.close()
    return jsonify({'confirmation': 'Game data and latent directions saved successfully.'}), 200

@app.route('/load', methods=['POST'])
def load_game():
    global game_sliders
    content = request.get_json()
    game_uuid = content.get('uuid')
    if not game_uuid:
        return jsonify({'error': 'Missing uuid'}), 400
    game = GameData.query.get(game_uuid)
    if not game:
        db.session.close()
        return jsonify({'error': 'Game data not found.'}), 404

    # Load the latent directions for the game
    latent_dir = os.path.join('latent_directions', game_uuid)
    latent_file = os.path.join(latent_dir, 'directions.safetensors')
    descriptors_file = os.path.join(latent_dir, 'descriptors.json')

    if os.path.exists(latent_file) and os.path.exists(descriptors_file):
        if game_uuid not in game_sliders:
            with open(descriptors_file, 'r') as f:
                descriptors = json.load(f)
            slider = CLIPSlider(
                sd_pipe=pipe,
                device=device,
                descriptors=descriptors,
                iterations=DEFAULT_ITERATIONS
            )
            slider.load_latent_directions(latent_file)
            game_sliders[game_uuid] = slider

    result = jsonify({'gameData': game.data}), 200
    db.session.close()
    return result

@app.route('/reset', methods=['POST'])
def reset_game():
    global game_sliders
    content = request.get_json()
    game_uuid = content.get('uuid')
    if not game_uuid:
        return jsonify({'error': 'Missing uuid'}), 400
    game = GameData.query.get(game_uuid)
    if game:
        Texture.query.filter_by(game_uuid=game_uuid).delete()
        # Remove latent directions folder for this game
        latent_dir = os.path.join('latent_directions', game_uuid)
        if os.path.exists(latent_dir):
            import shutil
            shutil.rmtree(latent_dir)
        # Remove the slider from game_sliders
        if game_uuid in game_sliders:
            del game_sliders[game_uuid]
        db.session.delete(game)
        db.session.commit()
        db.session.close()
        return jsonify({'confirmation': 'Game data reset successfully.'}), 200
    else:
        db.session.close()
        return jsonify({'error': 'Game data not found'}), 404

@app.route('/getStarTexture', methods=['POST'])
def get_star_texture():
    global ongoing_generations, game_sliders, default_slider
    content = request.get_json()
    game_uuid = content.get('uuid')
    position = content.get('position')  # Should be a list or tuple of coordinates
    if not game_uuid or position is None:
        return jsonify({'error': 'Missing uuid or position'}), 400
    print(f"Received texture request for UUID {game_uuid}, position: {position}")

    # Convert position to a JSON string and create a hash for filename
    position_key = json.dumps(position)
    position_hash = hashlib.md5(position_key.encode()).hexdigest()

    game = GameData.query.get(game_uuid)
    if not game:
        db.session.close()
        return jsonify({'error': 'Game data not found.'}), 404

    # Retrieve the saved prompt; fallback to current_prompt if not set
    prompt = game.prompt if game.prompt else current_prompt

    base_image_folder = 'images'
    image_folder = os.path.join(base_image_folder, game_uuid)
    image_filename = f"{position_hash}.jpg"
    image_path = os.path.join(image_folder, image_filename)

    os.makedirs(base_image_folder, exist_ok=True)
    os.makedirs(image_folder, exist_ok=True)
    texture = Texture.query.filter_by(game_uuid=game_uuid, position_key=position_key).first()

    if texture and os.path.exists(image_path):
        db.session.close()
        return send_file(image_path, mimetype='image/jpeg')
    else:
        # Limit the number of concurrent image generations
        if ongoing_generations >= max_ongoing_generations:
            db.session.close()
            return jsonify({'error': 'Server is busy. Please try again later.'}), 503
        ongoing_generations += 1
        try:
            # Retrieve or load the CLIPSlider instance
            slider = game_sliders.get(game_uuid)
            if not slider:
                latent_dir = os.path.join('latent_directions', game_uuid)
                latent_file = os.path.join(latent_dir, 'directions.safetensors')
                descriptors_file = os.path.join(latent_dir, 'descriptors.json')
                if os.path.exists(latent_file) and os.path.exists(descriptors_file):
                    with open(descriptors_file, 'r') as f:
                        descriptors = json.load(f)
                    slider = CLIPSlider(
                        sd_pipe=pipe,
                        device=device,
                        descriptors=descriptors,
                        iterations=DEFAULT_ITERATIONS
                    )
                    slider.load_latent_directions(latent_file)
                    game_sliders[game_uuid] = slider
                else:
                    # Use the default slider if no custom descriptors are set
                    slider = default_slider

            # Scale the coordinates and map them to descriptors
            scales = {}
            for idx, coord in enumerate(position):
                if idx >= len(slider.descriptors):
                    break
                descriptor_pair = slider.descriptors[idx]
                scaled_coord = coord * SCALE_FACTOR
                scales[descriptor_pair] = scaled_coord

            # Generate the image using the 'generate' method
            images = slider.generate(
                prompt=prompt,
                scales=scales,
                seed=15,
                num_inference_steps=20
            )
            image = images[0]
            image.save(image_path)
            if not texture:
                new_texture = Texture(
                    game_uuid=game_uuid,
                    position_key=position_key,
                    image_filename=os.path.join(game_uuid, image_filename)
                )
                db.session.add(new_texture)
                db.session.commit()
            db.session.close()

            return send_file(image_path, mimetype='image/jpeg')

        except Exception as e:
            db.session.rollback()
            db.session.close()
            print(f"Error during image generation: {e}")
            traceback.print_exc()
            return jsonify({'error': 'Image generation failed.'}), 500

        finally:
            ongoing_generations -= 1

@app.route('/setDirections', methods=['POST'])
def set_directions():
    global game_sliders
    content = request.get_json()
    descriptors_list = content.get('descriptors')
    if not descriptors_list or len(descriptors_list) < 4:
        return jsonify({'error': 'At least 4 descriptors are required.'}), 400
    descriptors = [tuple(desc) for desc in descriptors_list]
    iterations = content.get('iterations', DEFAULT_ITERATIONS)

    game_uuid = content.get('uuid')
    if not game_uuid:
        return jsonify({'error': 'Missing uuid'}), 400

    # Save descriptors to the GameData model
    game = GameData.query.get(game_uuid)
    if not game:
        db.session.close()
        return jsonify({'error': 'Game data not found.'}), 404
    game.descriptors = descriptors
    db.session.commit()

    latent_dir = os.path.join('latent_directions', game_uuid)
    os.makedirs(latent_dir, exist_ok=True)
    descriptors_file = os.path.join(latent_dir, 'descriptors.json')
    latent_file = os.path.join(latent_dir, 'directions.safetensors')

    with open(descriptors_file, 'w') as f:
        json.dump(descriptors, f)

    def generate_progress():
        try:
            slider = CLIPSlider(
                sd_pipe=pipe,
                device=device,
                descriptors=descriptors,
                iterations=iterations
            )

            total_descriptors = len(descriptors)
            completed_descriptors = 0

            for descriptor_pair in descriptors:
                negative_word, positive_word = descriptor_pair
                # Compute latent direction for each descriptor pair
                avg_diff = slider.find_latent_direction(negative_word, positive_word)
                slider.latent_directions[descriptor_pair] = avg_diff

                completed_descriptors += 1
                progress = int((completed_descriptors / total_descriptors) * 100)
                yield f"data:{progress}\n\n"

            # Save the latent directions
            slider.save_latent_directions(latent_file)

            # Store the slider instance for future use
            game_sliders[game_uuid] = slider

        except Exception as e:
            print(f"Error in setDirections: {e}")
            traceback.print_exc()
            yield "data:error\n\n"

    return Response(stream_with_context(generate_progress()), mimetype='text/event-stream')


@app.route('/getDirections', methods=['GET'])
def get_directions():
    # Retrieve the 'uuid' from query parameters
    game_uuid = request.args.get('uuid')
    if not game_uuid:
        return jsonify({'error': 'Missing uuid parameter.'}), 400

    # Check if the game exists in the database
    game = GameData.query.get(game_uuid)
    if not game:
        return jsonify({'error': 'Game data not found.'}), 404

    # Retrieve descriptors from the GameData model
    descriptors = game.descriptors
    if not descriptors:
        return jsonify({'error': 'Descriptors not set for this game.'}), 404

    return jsonify({'descriptors': descriptors}), 200

@app.route('/setPrompt', methods=['POST'])
def set_prompt():
    global game_sliders
    content = request.get_json()
    prompt = content.get('prompt')
    game_uuid = content.get('uuid')  # Expecting 'uuid' in the request

    if not prompt:
        return jsonify({'error': 'Prompt is required.'}), 400
    if not game_uuid:
        return jsonify({'error': 'Game UUID is required.'}), 400

    # Fetch the game from the database
    game = GameData.query.get(game_uuid)
    if not game:
        return jsonify({'error': 'Game data not found.'}), 404

    # Update the prompt in the game's prompt field
    game.prompt = prompt

    # Define the image directory for the specific game UUID
    image_folder = os.path.join('images', game_uuid)

    # Delete all images in the image_folder
    if os.path.exists(image_folder):
        try:
            import shutil
            shutil.rmtree(image_folder)
            print(f"Deleted images in {image_folder}")
        except Exception as e:
            print(f"Error deleting images for UUID {game_uuid}: {e}")
            return jsonify({'error': 'Failed to delete existing images.'}), 500

    # Delete all Texture records associated with the game UUID
    try:
        textures = Texture.query.filter_by(game_uuid=game_uuid).all()
        for texture in textures:
            db.session.delete(texture)
        db.session.commit()
        print(f"Deleted Texture records for UUID {game_uuid}")
    except Exception as e:
        db.session.rollback()
        print(f"Error deleting Texture records for UUID {game_uuid}: {e}")
        return jsonify({'error': 'Failed to delete texture records.'}), 500
    finally:
        db.session.close()

    # Remove the slider from game_sliders to force reloading with updated prompt
    if game_uuid in game_sliders:
        del game_sliders[game_uuid]

    return jsonify({'confirmation': 'Prompt set successfully and existing images deleted.'}), 200

@app.route('/getPrompt', methods=['POST'])
def get_prompt():
    # Retrieve JSON data from the request body
    content = request.get_json()
    if not content:
        return jsonify({'error': 'Missing JSON payload.'}), 400

    game_uuid = content.get('uuid')
    if not game_uuid:
        return jsonify({'error': 'UUID is required in the payload.'}), 400

    # Check if the game exists in the database
    game = GameData.query.get(game_uuid)
    if not game:
        return jsonify({'error': 'Game data not found.'}), 404

    # Retrieve the prompt from the GameData model
    prompt = game.prompt
    if not prompt:
        # If no prompt is set for the game, use the global current_prompt
        prompt = current_prompt

    return jsonify({'prompt': prompt}), 200

# ensures sessions are properly removed
@app.teardown_appcontext
def shutdown_session(exception=None):
    db.session.remove()

if __name__ == '__main__':
    # for running in colab
    ngrok_auth_token = userdata.get('NGROK_TOKEN')
    ngrok.set_auth_token(ngrok_auth_token)

    with app.app_context():
        db.create_all()
    port = 5000
    # Start ngrok when the server starts
    ngrok_tunnel = ngrok.connect(port, domain='caiman-gentle-sunbird.ngrok-free.app')
    public_url = ngrok_tunnel.public_url
    print(f" * ngrok tunnel available at {public_url}")
    app.config["BASE_URL"] = public_url
    try:
        app.run(port=port)
    finally:
        ngrok.kill()