In [None]:
# Install and configure ngrok
!apt install curl
!curl -s https://ngrok-agent.s3.amazonaws.com/ngrok.asc | tee /etc/apt/trusted.gpg.d/ngrok.asc >/dev/null && echo "deb https://ngrok-agent.s3.amazonaws.com buster main" | tee /etc/apt/sources.list.d/ngrok.list && apt update && apt install ngrok
!ngrok config add-authtoken #token

In [None]:
# Copy the flask repository
!git clone https://github.com/r0zh/Visionhub-flask-models

In [None]:
cd Visionhub-flask-models

In [None]:
# Install the requirements
!pip install -r requirements.txt

In [None]:
# Copy the shap-e repository
!git clone https://github.com/yashasvi-ranawat/shap-e

In [None]:
# Install the shap-e package
!pip install -e ./shap-e

In [None]:
cd shap-e/

In [None]:
import torch

from shap_e.diffusion.sample import sample_latents
from shap_e.diffusion.gaussian_diffusion import diffusion_from_config
from shap_e.models.download import load_model, load_config
from shap_e.util.notebooks import create_pan_cameras, decode_latent_images, gif_widget

from pyngrok import ngrok

In [None]:
# Try to use the GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
# Load the models
xm = load_model('transmitter', device=device)
model = load_model('text300M', device=device)
diffusion = diffusion_from_config(load_config('diffusion'))

In [None]:
# FLASK SERVER
from flask import Flask, send_file, request
from shap_e.util.notebooks import decode_latent_mesh
from flask_cors import CORS
import os
import tempfile
import zipfile


# Initialize the Flask application
app = Flask(__name__)
port = "5000"
CORS(app)

# Open a ngrok tunnel to the HTTP server
public_url = ngrok.connect(port).public_url
print(f' * ngrok tunnel \"{public_url}\" -> \"http://127.0.0.1:{port}\"')

# Update any base URLs to use the public ngrok URL
app.config["BASE_URL"] = public_url

@app.route('/generate', methods=['POST'])
def get_photo():
    guidance_scale = 15.0

    # To get the best result, you should remove the background and show only the object of interest to the model.
    prompt = request.get_json()['prompt']
    batch_size = int(request.get_json()['batchSize'])
    print(prompt)


    latents = sample_latents(
      batch_size=batch_size,
      model=model,
      diffusion=diffusion,
      guidance_scale=guidance_scale,
      model_kwargs=dict(texts=[prompt] * batch_size),
      progress=True,
      clip_denoised=True,
      use_fp16=True,
      use_karras=True,
      karras_steps=64,
      sigma_min=1e-3,
      sigma_max=160,
      s_churn=0,
    )

    with tempfile.TemporaryDirectory() as temp_dir:
        files = []
        for i, latent in enumerate(latents):
            t = decode_latent_mesh(xm, latent).tri_mesh()
            file_path = os.path.join(temp_dir, f'model_{i}.obj')
            with open(file_path, 'w') as f:
                t.write_obj(f)
            files.append(file_path)

        zip_file_path = os.path.join(temp_dir, 'models.zip')
        with zipfile.ZipFile(zip_file_path, 'w') as zip_file:
            for file in files:
                zip_file.write(file, os.path.basename(file))

        return send_file(zip_file_path, as_attachment=True, download_name='models.zip')


    return "Error in generator", 500

# Start the Flask server
if __name__ == "__main__":
    app.run(debug=True, use_reloader=False)