**Welcome to the KoboldAI Colab Service GPT-J-6B Notebook!**<br/>
*Note: This colab is intended to be used with the KoboldAI Client, [which can be downloaded from GitHub here](https://github.com/KoboldAI/KoboldAI-Client).*

In [None]:
#@title <b>Step 1 - Install Dependencies</b>
#@markdown Press the Play button and wait for the script to finish.
from IPython.display import clear_output
from termcolor import colored
import os

!pip install flask-ngrok
!pip install termcolor
!pip install flask_cloudflared
!apt install zstd
if not os.path.isdir("step_383500"):
   !time wget https://the-eye.eu/public/AI/GPT-J-6B/step_383500_slim.tar.zstd
   !time tar -I zstd -xf step_383500_slim.tar.zstd
!git clone https://github.com/kingoflolz/mesh-transformer-jax.git
!pip install -r mesh-transformer-jax/requirements.txt
!pip install mesh-transformer-jax/ jax==0.2.12
clear_output()
print(colored("Installing DONE!", "green"))

In [None]:
#@title <b>Step 2 - Adjust Your Settings</b>
#@markdown 1. Connect via Ngrok or Cloudflare?
connect_method = "Cloudflare" #@param ["Ngrok", "Cloudflare"]
#@markdown 2. Press Play button to lock in settings <b>(Do not skip!)</b>

In [None]:
#@title <b>Step 3 - Initialize Model</b> { display-mode: "form" }
#@markdown Press the Play button. Wait for the model to complete
#@markdown initialization. This can take 5+ minutes.</br>
#@markdown When the word DONE! is displayed, you can move on to
#@markdown the next Step.</br></br>
#@markdown <b>>> If you get an error when running this cell, run it again! <<</b>

from flask import Flask, redirect, url_for, request
import json
import torch
import requests
import subprocess
import tarfile
from jax.config import config
import time

# Sometimes the next step errors for some reason, just run it again
import jax
from jax.experimental import maps
import numpy as np
import optax
import transformers
from mesh_transformer.checkpoint import read_ckpt
from mesh_transformer.sampling import nucleaus_sample
from mesh_transformer.transformer_shard import CausalTransformer

# Initialize the model
print(colored("Initializing model, please wait...", "magenta"))

colab_tpu_addr = os.environ['COLAB_TPU_ADDR'].split(':')[0]
url = f'http://{colab_tpu_addr}:8475/requestversion/tpu_driver0.1_dev20210607'
requests.post(url)
config.FLAGS.jax_xla_backend = "tpu_driver"
config.FLAGS.jax_backend_target = "grpc://" + os.environ['COLAB_TPU_ADDR']

params = {
  "layers": 28,
  "d_model": 4096,
  "n_heads": 16,
  "n_vocab": 50400,
  "norm": "layernorm",
  "pe": "rotary",
  "pe_rotary_dims": 64,

  "seq": 2048,
  "cores_per_replica": 8,
  "per_replica_batch": 1,
}

per_replica_batch = params["per_replica_batch"]
cores_per_replica = params["cores_per_replica"]
seq = params["seq"]


params["sampler"] = nucleaus_sample
params["optimizer"] = optax.scale(0)

mesh_shape = (jax.device_count() // cores_per_replica, cores_per_replica)
devices = np.array(jax.devices()).reshape(mesh_shape)

maps.thread_resources.env = maps.ResourceEnv(maps.Mesh(devices, ('dp', 'mp')))
tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2')
total_batch = per_replica_batch * jax.device_count() // cores_per_replica
print(colored("Creating CasualTransformer instance...", "magenta"))
network = CausalTransformer(params)
print(colored("Reading checkpoint...", "magenta"))
network.state = read_ckpt(network.state, "step_383500/", devices.shape[1])
print(colored("Calling move_xmap...", "magenta"))
network.state = network.move_xmap(network.state, np.zeros(cores_per_replica))

clear_output()
print(colored("DONE!", "green"))

In [None]:
#@title <b>Step 4 - Run Web Service</b> { display-mode: "form" }
#@markdown Press the Play button. Flask will start and give you an 
#@markdown Ngrok address which looks like this:<br/>
#@markdown <i>https://\<unique id\>.trycloudflare.com/</i><br/>
#@markdown You will need to right-click this and copy the address.
#@markdown Start the KoboldAI Client on your computer and choose 
#@markdown Google Colab as the model. You will be asked to paste 
#@markdown the Ngrok address into the terminal.<br/><br/>
#@markdown If your session is interrupted, you can just restart
#@markdown this cell to get a new address without reinitializing
#@markdown the model.</br></br>
#@markdown <b>The first generation takes around a minute due to 
#@markdown compilation, but after that it should only take about 
#@markdown 10 seconds per sample.</b>

tenv = maps.ResourceEnv(maps.Mesh(devices, ('dp', 'mp')))

if connect_method == "Cloudflare":
   from flask_cloudflared import run_with_cloudflared
elif connect_method == "Ngrok":
   from flask_ngrok import run_with_ngrok

app = Flask(__name__)

if connect_method == "Cloudflare":
   run_with_cloudflared(app)
elif connect_method == "Ngrok":
   run_with_ngrok(app)

@app.route("/")
def home():
    return "<h1>KoboldAI Colab Service Running!</h1>"

@app.route('/request',methods = ['POST'])
def koboldrequest():
   if request.method == 'POST':
      try:
        clear_output()
        js      = request.json
        txt     = js["text"]
        min     = js["min"]
        max     = js["max"]
        rep_pen = js["rep_pen"]
        temp    = js["temperature"]
        top_p   = js["top_p"]

        gen_len = max - (min - 1)

        print(colored("Received Data: {0}".format(txt), "yellow"))
        print(colored("Generating text, please wait...", "green"))

        # env has to be redefined with each call for some reason, else a threading error is produced
        maps.thread_resources.env = tenv
        
        tokens = tokenizer.encode(txt)
        provided_ctx = len(tokens)
        pad_amount = seq - provided_ctx
        padded_tokens = np.pad(tokens, ((pad_amount, 0),)).astype(np.uint32)
        batched_tokens = np.array([padded_tokens] * total_batch)
        length = np.ones(total_batch, dtype=np.uint32) * len(tokens)
        output = network.generate(batched_tokens, length, gen_len, {"top_p": np.ones(total_batch) * top_p, "temp": np.ones(total_batch) * temp})
        samples = []
        decoded_tokens = output[1][0]

        for o in decoded_tokens[:, :, 0]:
          samples.append(tokenizer.decode(o))

        genout = samples[0]

        print(colored("Generated Text: {0}".format(genout), "cyan"))
        response = app.response_class(
           response=json.dumps({"data": {"seqs": [genout]}}),
           status=200,
           mimetype='application/json'
        )
        
        js         = {}
        genout     = ""
        
        return response

      except Exception as e:
        print(colored("[ERROR] Something went wrong during generation!", "red"))
        print(colored("{0}".format(e), "red"))
        response = app.response_class(
          response=json.dumps({"error": {"extensions": {"code": "Something went wrong during generation! {0}".format(e)}}}),
          status=400,
          mimetype='application/json'
        )

print(colored("Starup complete! Running web service.", "green"))
app.run()