# Colabanana — test your 🍌 Banana environment with Colab

## What is this?

This is a simple notebook to test your [Banana serverless](https://docs.banana.dev/banana-docs/core-concepts/inference-server/serverless-framework) environments right from your browser. It is assumed that you already know the core concepts of Banana and have downloaded/cloned the [Colabanana repo](https://github.com/vzakharov/colabanana) from GitHub.

For more information on Banana and its serverless framework, please refer to the [Banana docs](https://docs.banana.dev/banana-docs/).

For information on why this notebook exists, please refer to the [Colabanana repo](https://github.com/vzakharov/colabanana).

## How to use this notebook?

1. Open this notebook in Colab. (Funny because you are already here but hey, you never know.)

2. Go one by one through the cells and either edit them as suggested or enter the required information in the forms attached to them. I tried giving as much information as needed in each cell, so I won’t repeat it here.

3. Once done, go `Runtime` > `Run all` and watch the magic happen.

The notebook is already pre-populated to use the HuggingFace [fill-mask](https://huggingface.co/tasks/fill-mask) model, so you can test it right away.

Note that working with the notebook will involve a lot of copying and pasting between your `.py` files and respective cells in the notebook, but that’s the price you pay for the convenience of using Colab. (I strongly suggest that you do the editing in your `.py` files and then copy-paste the code into the notebook, not the other way around.)

Once everything works perfectly, you can proceed to deploy your model to a serverless environment according to the [Banana docs](https://docs.banana.dev/banana-docs/core-concepts/inference-server/serverless-framework). You shouldn’t need any changes to your `.py` files for that. Well, at least *theoretically*.

In [None]:
#@title requirements.txt
## 👇 Copy the contents of `requirements.txt` between """ and """ in the code below

requirements_txt = """
# Basic, needed for the server to run
flask

# Model-specific
transformers
accelerate
"""

In [None]:
#@title Initial setup
#🚧🚧🚧 Ignore this cell; it's just some service code to make the notebook work

try:
  packages_installed
except NameError:
  packages_installed = []

def pip_install(package):
  
  global packages_installed

  if package not in packages_installed:
    print(f"Installing {package}")
    !pip install {package}
    packages_installed += [ package ]
  else:
    print(f"{package} already installed")

for package in requirements_txt.split("\n"):
  if package and not package.startswith("#"):
    pip_install(package)

pip_install("pyngrok")
ngrok_token = "" #@param {type:"string"}
#@markdown We need [ngrok](https://ngrok.com) to make the server accessible from the outside world via a public URL. If you don’t have an ngrok token, you can get one [here](https://dashboard.ngrok.com/get-started/your-authtoken). Ngrok’s free tier is enough for [most testing purposes](https://ngrok.com/pricing).

assert ngrok_token != "", "You need to enter your ngrok token in the cell above (see https://gyazo.com/c6f0aaf59cd6aee912da7357f2f736db)"

from google.colab import drive
# drive.mount('/content/drive')
# (Uncomment the above line if you want to use Google Drive to store the model weights)
# (Do NOT comment out `from google.colab import drive` because we use this as a flag to determine whether we are running on Colab or not)

In [None]:
#@title Download the model
#@markdown ## 👈 Copy the contents of `download.py` here
#@markdown **Note:** Ideally, you want to rewrite the models so that they are downloaded and later reused from Google Drive. This way, you won't need to download the model every time you start a runtime. However, the code will be specific to the model you are using, so we will leave it as an exercise for the notebook user.

# In this file, we define download_model_weights
# It runs during container build time to get model weights built into the container

# In this example: A Huggingface BERT model

from transformers import pipeline

def download_model_weights():

  # 🚧🚧🚧 Service code to avoid downloading the model every time the cell is run in Colab; ignore until the next 🚧🚧🚧
  try:

    weights_downloaded
    # Hereinafter, this trick allows us to avoid downloading the model whenever the cell is run. Once the model is downloaded for the first time, the variable `weights_downloaded` is set to True, so no error is raised and the model (which is downloaded in the except block) is not downloaded again.

  except NameError:
  # 🛣️🛣️🛣️ End of service code; proceed with your code below
  
    
    # do a dry run of loading the huggingface model, which will download weights
    pipeline('fill-mask', model='bert-base-uncased')


  # 🚧🚧🚧 Ignore all code below this line
    weights_downloaded = True

if __name__ == "__main__":
    download_model_weights()

In [None]:
from transformers import pipeline
import torch

# Init is ran on server startup
# Load your model to GPU as a global variable here using the variable name "model"
def init():
  global model
  
  device = 0 if torch.cuda.is_available() else -1
  model = pipeline('fill-mask', model='bert-base-uncased', device=device)

# Inference is ran for every server call
# Reference your preloaded global model variable here.
def inference(model_inputs:dict) -> dict:
  global model

  # Parse out your arguments
  prompt = model_inputs.get('prompt', None)
  if prompt == None:
    return {'message': "No prompt provided"}
  
  # Run the model
  result = model(prompt)

  # Return the results as a dictionary
  return result

#@title Define model init/inference functions
#@markdown ## 👈 Copy the contents of `app.py` here

In [None]:
#@ Define the test inputs
#@markdown ## 👈 Copy the contents of `test.py` here

# You generally only need to change the `default_model_inputs` variable, which defines the default inputs to your model

default_model_inputs = dict(

  prompt = "Hello I am a [MASK] model."

)

# 🚧🚧🚧 Do not modify the code below this line unless you're sure you know what you're doing 🚧🚧🚧

#@markdown **Notes:**
#@markdown There are two ways to run the test:
#@markdown - Using the /test endpoint after the server is running (see the last cell of this notebook), or
#@markdown - By running `python3 test.py` in the terminal (from anywhere, e.g. your local machine).
#@markdown
#@markdown In the case of the /test endpoint, you can provide the model inputs as query parameters (e.g. `http://<some_id>.ngrok.io/test?prompt=Hello%20I%20am%20a%20[MASK]%20model.`). In the case of the terminal, you can provide the model inputs as command line arguments (e.g. `python3 test.py --prompt "Hello I am a [MASK] model."`).
#@markdown
#@markdown #### Testing environment
#@markdown When running as a script (`python3 test.py`), you can specify which environment to use (dev/prod) by adding the `--env` argument, e.g. `python3 test.py --env prod`. If you don't specify the environment, you will be prompted to enter it or skip it (defaulting to `dev`):
#@markdown - Choosing `prod` will make a call to an already deployed Banana server;
#@markdown - Choosing `dev` will make a call to the server running in the current notebook.
#@markdown
#@markdown In the case of `dev`, you will need to provide the public URL of the server, which you can find in the console after running the last cell of this notebook). We do not store this information locally on purpose, as the URL is likely to change every time you run the notebook.

import json
import os
import requests
import sys

def test_inference(model_inputs={}):

  meta_args = {}
  meta_keys = ['env', 'api_key', 'model_key']

  # If no model_inputs are provided, either use the ones from command line (if any) or the default ones
  if model_inputs == {}:
    # Check if any command line arguments were provided. Keep in mind that the command line arguments is of form "python3 test.py arg1 arg2 arg3"
    if not 'google.colab' in sys.modules and len(sys.argv) > 1:

      print(f"Using command line arguments: {sys.argv[1:]}")

      # The inputs would be provided in the form --[json key] [json value]
      # For example: --prompt "Hello I am a [MASK] model."
      # For meta keys, add them to the meta_args dict instead
      for i in range(1, len(sys.argv), 2):
        key = sys.argv[i].replace('--', '')
        value = sys.argv[i+1]
        dict_to_use = model_inputs if key not in meta_keys else meta_args
        dict_to_use[key] = value

    else:
      # Default inputs
      model_inputs = default_model_inputs
  
  print(f"Using model inputs: {model_inputs}")
  print(f"Using meta args: {meta_args}")

  if 'google.colab' in sys.modules:

    url = ngrok_tunnel.public_url
    res = requests.post(url, json=model_inputs)

  else:

    # Check which environment we're in (dev/prod)
    env = meta_args['env'] if 'env' in meta_args else input("Enter environment (dev/prod) or press Enter to use dev: ") or 'dev'

    if env == 'prod':

      credential_prompts = dict(
        api_key = "Enter your API key (go to https://app.banana.dev/ to get one)",
        model_key = "Enter your model key (from your model's page on https://app.banana.dev/)"
      )

      # Load credentials from 'credentials.json' if it exists
      if os.path.exists('credentials.json'):
        with open('credentials.json', 'r') as f:
          credentials = json.loads(f.read())
          print("Loaded credentials from credentials.json")
      else:
        credentials = {}

      credentials_changed = False
      for key, prompt in credential_prompts.items():
        if key not in meta_args:
          if key in credentials:
            print(f"Using {key} from credentials.json")
          else:
            credentials[key] = input(f"{prompt}: ")
            credentials_changed = True
        else:
          credentials[key] = meta_args[key]
          credentials_changed = True
      
      # Save credentials to 'credentials.json' if they changed
      if credentials_changed:
        with open('credentials.json', 'w') as f:
          f.write(json.dumps(credentials))
          print("Saved credentials to credentials.json")

      res = requests.post("https://api.banana.dev/start/v4/", json=dict(
        apiKey = credentials['api_key'],
        modelKey = credentials['model_key'],
        modelInputs = model_inputs
      )).json()

    else:

      url = input("Enter the public URL of your server (e.g. http://<some_id>.ngrok.io, look for it in the console after running the notebook's last cell): ")
      res = requests.post(url, json=model_inputs)

  print(f"Response: {res}")

  return res

# If not running in Colab, run the test
if not 'google.colab' in sys.modules:
  test_inference()
else:
  print("Running in Colab, skipping test (you will be able to run it later by using the /test endpoint)")

In [None]:
#@title Start the server
#@markdown Usually, you don’t need to change anything in this cell. However, if you change the `server.py` file, double-click on the cell to edit the code and paste its  contents.

import asyncio
import datetime
import time
import subprocess
import sys

import numpy as np
from flask import Flask, request, jsonify
from pyngrok import ngrok

try:
  port += 1
except NameError:
  port = 8000
  # (This is a hack to avoid "address already in use" errors when running the cell multiple times)

if not 'google.colab' in sys.modules:

  import app as user_src
  user_src.init()

else:

  try:
    user_src
    print("Model already initialized")
  except NameError:
    # Define a user_src object which has attributes for init and inference
    class user_src:
      init = init
      inference = inference

    print("Initializing model...")
    user_src.init()

  # Start the ngrok tunnel
  print("Starting tunnel")

  for tunnel in ngrok.get_tunnels():
    ngrok.disconnect(tunnel.public_url)      
  # (We need this to remove the already created tunnel if running the cell multiple times, as free ngrok only allows so many tunnels)

  ngrok_tunnel = ngrok.connect(port)
  print(ngrok_tunnel)
  print("This is the public URL of your server ☝️☝️☝️ (you can also use https)")
  # The public URL will be printed to the console after this line, so look for it there

# We do the model load-to-GPU step on server startup
# so the model object is available globally for reuse

# Create the http server app.
try:
  # First let's kill the server and its server_thread if it's already running (on Colab)
  del server
  del server_thread
  print("Stopped running server")
except ( AssertionError, NameError ):
  pass
  
# Use timestamp in server name for debug purposes
server = Flask(f"server-{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}")

# Healthchecks verify that the environment is correct on Banana Serverless
@server.route('/healthcheck', methods=['GET'])
def healthcheck():
  # dependency free way to check if GPU is visible
  gpu = False
  out = subprocess.run("nvidia-smi", shell=True)
  if out.returncode == 0: # success state on shell command
    gpu = True
  return jsonify(dict(
    state="healthy",
    gpu=gpu,
    timestamp=datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
  ))

# Inference POST handler at '/' is called for every http call from Banana
@server.route('/', methods=['POST'])
def inference():
  model_inputs = request.get_json()
  print(f"Received request: {model_inputs}")

  output = user_src.inference(model_inputs)
  print(f"Sending response: {output}")

  # If the output is an ndarray, convert it to a list (of lists, of lists, etc.)
  if isinstance(output, np.ndarray):
    # Go recursively through each dimension and convert to list
    def convert_to_list(x):
      if isinstance(x, np.ndarray):
        return [convert_to_list(y) for y in x]
      else:
        return float(x)

    output = convert_to_list(output)

  return jsonify(output)

if 'google.colab' in sys.modules:

  # GET /test to call a sample inference using the public URL from ngrok_tunnel
  @server.route('/test', methods=['GET'])
  def test_endpoint():


    global test_inference

    url = ngrok_tunnel.public_url
    print(f"Sending a test inference request to {url}")

    # Take model_inputs from query params
    model_inputs = request.args.to_dict()

    print(f"Request: {model_inputs}")
    print("✂=== Below are logs from the server processing the test request\n")

    res = test_inference(model_inputs)

    print("\n✂=== End of logs from the server processing the test request")
    print(f"Response: {res.json()}")
    return jsonify(res.json())

if __name__ == '__main__':
  # Start the  server in a new thread
  print("Starting server")
  import threading
  server_thread = threading.Thread(target=server.run, kwargs={"host": "0.0.0.0", "port": port})
  server_thread.start()

  if 'google.colab' in sys.modules:
    # Print that the server is started and a test URL after a second (so the server has time to start)
    time.sleep(1)
    print(f"Server started; test: {ngrok_tunnel.public_url}/test")

    # Keep the cell running so we can see the logs
    print("Keeping the cell running so you can see the server logs")
    while True:
      time.sleep(1)

## That’s it!
### Now test your model by either running `python test.py` in your *local* terminal or using the `/test` endpoint in Colab (see the logs in the cell above). Good luck!

For more information on building with and deploying Banana’s serverless framework, see https://docs.banana.dev/banana-docs/core-concepts/inference-server/serverless-framework