In [None]:
!pip install git+https://github.com/huggingface/diffusers.git@refs/pull/9126/head#egg=diffusers transformers 
!pip install Flask Flask-SQLAlchemy pyngrok Flask-Cors

MEDIUMS = ["painting", "drawing", "photograph", "HD photo", "illustration", "portrait",
           "sketch", "3d render", "digital painting", "concept art", "screenshot",
           "canvas painting", "watercolor art", "print", "mosaic", "sculpture",
           "cartoon", "comic art", "anime"]

SUBJECTS = ["dog", "cat", "horse", "cow", "pig", "sheep", "lion", "elephant",
            "monkey", "bird", "chicken", "eagle", "parrot", "penguin", "fish",
            "shark", "dolphin", "whale", "octopus", "bee", "butterfly", "ant",
            "ladybug", "person", "man", "woman", "child", "baby", "boy", "girl",
            "car", "boat", "airplane", "bicycle", "motorcycle", "train", "building",
            "house", "bridge", "castle", "temple", "monument", "tree", "flower",
            "mountain", "lake", "river", "ocean", "beach", "fruit", "vegetable",
            "meat", "bread", "cake", "soup", "coffee", "toy", "book", "phone",
            "computer", "TV", "camera", "musical instrument", "furniture", "road",
            "park", "garden", "forest", "city", "sunset", "clouds"]

In [2]:
from tqdm import tqdm
import torch
import random
from typing import Dict, Tuple

class CLIPSlider:
    def __init__(
            self,
            sd_pipe,
            device: torch.device,
            descriptors: Dict[str, Tuple[str, str]],
            iterations: int = 300,
    ):
        self.device = device
        self.pipe = sd_pipe.to(self.device, torch.float16)
        self.iterations = iterations
        self.descriptors = descriptors
        self.latent_directions = {}
        for descriptor, (target_word, opposite) in descriptors.items():
            avg_diff = self.find_latent_direction(target_word, opposite)
            self.latent_directions[descriptor] = avg_diff

    def find_latent_direction(self, target_word: str, opposite: str):
        with torch.no_grad():
            positives = []
            negatives = []
            for _ in tqdm(range(self.iterations), desc=f"Finding latent direction for '{target_word}' vs '{opposite}'"):
                medium = random.choice(MEDIUMS)
                subject = random.choice(SUBJECTS)
                pos_prompt = f"a {medium} of a {target_word} {subject}"
                neg_prompt = f"a {medium} of a {opposite} {subject}"
                pos_toks = self.pipe.tokenizer(
                    pos_prompt,
                    return_tensors="pt",
                    padding="max_length",
                    truncation=True,
                    max_length=self.pipe.tokenizer.model_max_length
                ).input_ids.to(self.device)
                neg_toks = self.pipe.tokenizer(
                    neg_prompt,
                    return_tensors="pt",
                    padding="max_length",
                    truncation=True,
                    max_length=self.pipe.tokenizer.model_max_length
                ).input_ids.to(self.device)
                pos = self.pipe.text_encoder(pos_toks).pooler_output
                neg = self.pipe.text_encoder(neg_toks).pooler_output
                positives.append(pos)
                negatives.append(neg)

        positives = torch.cat(positives, dim=0)
        negatives = torch.cat(negatives, dim=0)
        diffs = positives - negatives
        avg_diff = diffs.mean(0, keepdim=True)
        return avg_diff

    def generate(self,
        prompt="a photo of a house",
        scales: Dict[str, float] = None,
        seed=15,
        only_pooler=False,
        normalize_scales=False,
        correlation_weight_factor=1.0,
        **pipeline_kwargs
    ):
        scales = scales or {}
        with torch.no_grad():
            toks = self.pipe.tokenizer(
                prompt,
                return_tensors="pt",
                padding="max_length",
                truncation=True,
                max_length=self.pipe.tokenizer.model_max_length
            ).input_ids.to(self.device)
        prompt_embeds = self.pipe.text_encoder(toks).last_hidden_state

        # Normalize scales if required
        if normalize_scales and len(scales) > 0:
            total_scale = sum(abs(s) for s in scales.values())
            scales = {k: v / total_scale for k, v in scales.items()}

        if only_pooler:
            for descriptor, scale in scales.items():
                avg_diff = self.latent_directions[descriptor]
                prompt_embeds[:, toks.argmax()] += avg_diff * scale
        else:
            normed_prompt_embeds = prompt_embeds / prompt_embeds.norm(dim=-1, keepdim=True)
            sims = normed_prompt_embeds[0] @ normed_prompt_embeds[0].T
            weights = sims[toks.argmax(), :][None, :, None].repeat(1, 1, 768)

            standard_weights = torch.ones_like(weights)
            weights = standard_weights + (weights - standard_weights) * correlation_weight_factor

            for descriptor, scale in scales.items():
                avg_diff = self.latent_directions[descriptor]
                prompt_embeds = prompt_embeds + (
                    weights * avg_diff[None, :].repeat(1, self.pipe.tokenizer.model_max_length, 1) * scale
                )

        torch.manual_seed(seed)
        images = self.pipe(prompt_embeds=prompt_embeds, **pipeline_kwargs).images
        return images


In [None]:
from diffusers import StableDiffusionPipeline, DDIMScheduler

# Initialize the scheduler
scheduler = DDIMScheduler.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="scheduler")

# Initialize the Stable Diffusion pipeline with the scheduler
pipe = StableDiffusionPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5",
    scheduler=scheduler,  # Use the DDIM scheduler
    safety_checker=None,
    torch_dtype=torch.float16
)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Define descriptors and their target and opposite words
descriptors = {
    'color': ('blue', 'red'),
    'size': ('small', 'big'),
    'shape': ('round', 'square'),
    'brightness': ('dark', 'bright'),
    'material': ('metal', 'plastic'),
}

slider = CLIPSlider(
    sd_pipe=pipe,
    device=device,
    descriptors=descriptors,
    iterations=500
)

In [None]:
from google.colab import userdata
from flask import Flask, request, jsonify, send_file
from flask_sqlalchemy import SQLAlchemy
import os
from uuid import uuid4
from pyngrok import ngrok
from flask_cors import CORS
import json
import hashlib  
import traceback
import hashlib

# Scaling factor to map coordinates to desired range
SCALE_FACTOR = 3 / 1_000_000  

# Constant prompt for image generation
prompt = "a photo of a house" 

# Mapping of coordinate indices to descriptors
coordinate_descriptors = list(descriptors.keys())

# Global counter for ongoing image generations
ongoing_generations = 0
max_ongoing_generations = 10

app = Flask(__name__)
CORS(app)

# Configure SQLite database
app.config['SQLALCHEMY_DATABASE_URI'] = 'sqlite:///game_data.db'
app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False

# Add pool settings to prevent connection pool exhaustion
app.config['SQLALCHEMY_ENGINE_OPTIONS'] = {
    'pool_size': 10,
    'max_overflow': 20,
    'pool_timeout': 60
}

db = SQLAlchemy(app)

class GameData(db.Model):
    __tablename__ = 'game_data'
    uuid = db.Column(db.String(36), primary_key=True)  # UUIDs are 36 characters
    data = db.Column(db.JSON)  # Stores the game data as JSON
    textures = db.relationship('Texture', backref='game_data', lazy=True)

class Texture(db.Model):
    __tablename__ = 'textures'
    id = db.Column(db.Integer, primary_key=True)
    game_uuid = db.Column(db.String(36), db.ForeignKey('game_data.uuid'), nullable=False)
    position_key = db.Column(db.String, nullable=False)
    image_filename = db.Column(db.String, nullable=False)

@app.route('/register', methods=['POST'])
def register():
    game_uuid = str(uuid4())
    new_game = GameData(uuid=game_uuid, data={})
    db.session.add(new_game)
    db.session.commit()
    db.session.close()  
    return jsonify({'uuid': game_uuid}), 200

@app.route('/save', methods=['POST'])
def save_game():
    content = request.get_json()
    game_uuid = content.get('uuid')
    game_data = content.get('gameData')
    if not game_uuid or game_data is None:
        return jsonify({'error': 'Missing uuid or gameData'}), 400
    game = GameData.query.get(game_uuid)
    if not game:
        db.session.close() 
        return jsonify({'error': 'Game data not found'}), 404
    game.data = game_data
    db.session.commit()
    db.session.close()  
    return jsonify({'confirmation': 'Game data saved successfully.'}), 200

@app.route('/load', methods=['POST'])
def load_game():
    content = request.get_json()
    game_uuid = content.get('uuid')
    if not game_uuid:
        return jsonify({'error': 'Missing uuid'}), 400
    game = GameData.query.get(game_uuid)
    if not game:
        db.session.close() 
        return jsonify({'error': 'Game data not found'}), 404
    result = jsonify({'gameData': game.data}), 200
    db.session.close()  
    return result

@app.route('/reset', methods=['POST'])
def reset_game():
    content = request.get_json()
    game_uuid = content.get('uuid')
    if not game_uuid:
        return jsonify({'error': 'Missing uuid'}), 400
    game = GameData.query.get(game_uuid)
    if game:
        Texture.query.filter_by(game_uuid=game_uuid).delete()
        db.session.delete(game)
        db.session.commit()
        db.session.close() 
        return jsonify({'confirmation': 'Game data reset successfully.'}), 200
    else:
        db.session.close() 
        return jsonify({'error': 'Game data not found'}), 404

@app.route('/getStarTexture', methods=['POST'])
def get_star_texture():
    global ongoing_generations
    content = request.get_json()
    game_uuid = content.get('uuid')
    position = content.get('position')  # Should be a list or tuple of coordinates
    if not game_uuid or position is None:
        return jsonify({'error': 'Missing uuid or position'}), 400
    print(f"Received texture request for UUID {game_uuid}, position: {position}")

    # Convert position to a JSON string and create a hash for filename
    position_key = json.dumps(position)
    position_hash = hashlib.md5(position_key.encode()).hexdigest()

    game = GameData.query.get(game_uuid)
    if not game:
        db.session.close()
        return jsonify({'error': 'Game data not found'}), 404
    base_image_folder = 'images'
    image_folder = os.path.join(base_image_folder, game_uuid)
    image_filename = f"{position_hash}.jpg"
    image_path = os.path.join(image_folder, image_filename)

    os.makedirs(base_image_folder, exist_ok=True)
    os.makedirs(image_folder, exist_ok=True)
    texture = Texture.query.filter_by(game_uuid=game_uuid, position_key=position_key).first()

    if texture and os.path.exists(image_path):
        db.session.close()
        return send_file(image_path, mimetype='image/jpeg')
    else:
        # Limit the number of concurrent image generations
        if ongoing_generations >= max_ongoing_generations:
            db.session.close()
            return jsonify({'error': 'Server is busy. Please try again later.'}), 503
        ongoing_generations += 1
        try:
            # Scale the coordinates and map them to descriptors
            scales = {}
            for idx, coord in enumerate(position):
                if idx >= len(coordinate_descriptors):
                    break
                descriptor = coordinate_descriptors[idx]
                scaled_coord = coord * SCALE_FACTOR
                scales[descriptor] = scaled_coord
            images = slider.generate(
                prompt=prompt,
                scales=scales,
                seed=15,  
                num_inference_steps=20
            )
            image = images[0]
            image.save(image_path)
            if not texture:
                new_texture = Texture(
                    game_uuid=game_uuid,
                    position_key=position_key,
                    image_filename=os.path.join(game_uuid, image_filename)
                )
                db.session.add(new_texture)
                db.session.commit()
            db.session.close()
            return send_file(image_path, mimetype='image/jpeg')

        except Exception as e:
            db.session.rollback()
            db.session.close()
            print(f"Error during image generation: {e}")
            traceback.print_exc() # for debugging
            return jsonify({'error': 'Image generation failed.'}), 500

        finally:
            ongoing_generations -= 1

# ensures sessions are properly removed
@app.teardown_appcontext
def shutdown_session(exception=None):
    db.session.remove()

if __name__ == '__main__':
    # for running in colab
    ngrok_auth_token = userdata.get('NGROK_TOKEN')
    ngrok.set_auth_token(ngrok_auth_token)

    with app.app_context():
        db.create_all()
    port = 5000
    # Start ngrok when the server starts
    ngrok_tunnel = ngrok.connect(port, domain='caiman-gentle-sunbird.ngrok-free.app')
    public_url = ngrok_tunnel.public_url
    print(f" * ngrok tunnel available at {public_url}")
    app.config["BASE_URL"] = public_url
    try:
        app.run(port=port)
    finally:
        ngrok.kill()
