<a href="https://colab.research.google.com/github/wellecks/llmstep/blob/colab/python/colab/llmstep_colab_server.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# [`llmstep`](https://github.com/wellecks/llmstep) server using Colab notebook
In order to use this notebook, follow these instructions:

1. Run all the cells in this colab notebook to start your server.

2. In your local environment, set the environment variable `LLMSTEP_HOST` equal to the url printed out in this notebook (for example, https://04fa-34-125-110-83.ngrok.io/)

3. In your local environment, set the environment variable `LLMSTEP_SERVER=COLAB`.

4. Use `llmstep`.


#### VSCode steps (2) and (3)

To set environment variables in VS Code, go to:

- Settings (`Command` + `,` on Mac)
- Extensions -> Lean 4
- Add environment variables to `Server Env`.
- Then restart the Lean Server (`Command` + `t`, then type `> Lean 4: Restart Server`)


Authors: Rahul Saha, Sean Welleck

## Configuration

First we configure the model and generation parameters.

Below we set default values; modify them if needed for your setup.

In [None]:
# Prompt template for the default model.
# Change this if your model expects a different input format.
def llmstep_prompt(tactic_state, prefix):
  return '[GOAL]%s[PROOFSTEP]%s' % (tactic_state, prefix)


CONFIG = {
    'LLMSTEP_MODEL': 'wellecks/llmstep-mathlib4-pythia2.8b',

    # Sampling temperature(s)
    'LLMSTEP_TEMPERATURES': [0.5],

    # Number of generated suggestions per temperature
    'LLMSTEP_NUM_SAMPLES': 10,

    # Prompt template
    'LLMSTEP_PROMPT': llmstep_prompt
}

### Install required libraries

In [None]:
!pip install pyngrok
!pip install flask
!pip install transformers
!pip install flask_ngrok

### Implementation of server and model utilities

In [None]:
import argparse
import transformers
import torch
import os
import time
import json
from http.server import BaseHTTPRequestHandler, HTTPServer
from pyngrok import ngrok


def load_hf(hf_model):
    print("Loading model...")
    if 'wellecks/llmstep-mathlib4-pythia' in hf_model:
        model = transformers.GPTNeoXForCausalLM.from_pretrained(
            hf_model,
            torch_dtype=torch.float16
        )
        tokenizer = transformers.GPTNeoXTokenizerFast.from_pretrained(hf_model)
    else:
        raise NotImplementedError(hf_model)

    if torch.cuda.is_available():
        model.cuda()
    model.eval()
    print("Done.")
    return model, tokenizer


def hf_generate(
    model,
    tokenizer,
    prompt,
    temperatures,
    num_samples,
    max_new_tokens=128
):
    input_ids = tokenizer.encode(prompt, return_tensors='pt').to(model.device)
    texts = []
    for temp in temperatures:
        out = model.generate(
            input_ids,
            max_new_tokens=max_new_tokens,
            do_sample=temp > 0,
            temperature=temp,
            pad_token_id=tokenizer.eos_token_id,
            num_return_sequences=num_samples if temp > 0 else 1
        )
        output_tokens = out[:, input_ids.shape[1]:]
        texts.extend(tokenizer.batch_decode(
            output_tokens,
            skip_special_tokens=True
        ))
    texts = list(set(texts))
    return texts


class LLMStepServer(HTTPServer):
    def __init__(
        self, model, tokenizer, generate_function, config, *args, **kwargs
    ):
      self.model = model
      self.tokenizer = tokenizer
      self.generate_function = generate_function
      self.config = config
      super().__init__(*args, **kwargs)


class LLMStepRequestHandler(BaseHTTPRequestHandler):
    def process_request(self, tactic_state, prefix):
        prompt = self.server.config['LLMSTEP_PROMPT'](tactic_state, prefix)
        texts = self.server.generate_function(
            model=self.server.model,
            tokenizer=self.server.tokenizer,
            prompt=prompt,
            temperatures=self.server.config['LLMSTEP_TEMPERATURES'],
            num_samples=self.server.config['LLMSTEP_NUM_SAMPLES']
        )
        texts = [prefix + text for text in texts]
        response = {"suggestions": texts}
        return response

    def do_POST(self):
        # Set response headers
        self.send_response(200)
        self.send_header('Content-type', 'application/json')
        self.end_headers()

        # Get the incoming POST data
        content_length = int(self.headers['Content-Length'])
        post_data = self.rfile.read(content_length).decode('utf-8')

        try:
            data = json.loads(post_data)
            result = self.process_request(data['tactic_state'], data['prefix'])
            response = result
            self.wfile.write(json.dumps(response).encode('utf-8'))
        except Exception as e:
            # Handle errors gracefully
            error_response = {'error': str(e)}
            self.wfile.write(json.dumps(error_response).encode('utf-8'))



### Run : load the model and start the server

The cell prints out the public URL, for instance: https://04fa-34-125-110-73.ngrok.io

Add this URL as a `LLMSTEP_HOST` environment variable in your local environment.

In [None]:
# Download and load the model (this takes a few minutes).
model, tokenizer = load_hf(CONFIG['LLMSTEP_MODEL'])
model.cuda();

PORT = 81

# Open a HTTP tunnel
public_url = ngrok.connect(PORT)
print('Your public url is:\n%s\n\nSet LLMSTEP_HOST to this url.' % public_url)

# Create the server
server_address = ('', PORT)
httpd = LLMStepServer(
    model, tokenizer, hf_generate, CONFIG,
    server_address, LLMStepRequestHandler
)

print('Server started')
httpd.serve_forever()