In [1]:
!pip install --upgrade transformers -q
!pip install accelerate -q

# GPTQ Dependencies
!pip install --upgrade optimum -q
!pip install --upgrade auto-gptq -q

# RAG Dependencies
!pip install langchain -q
!pip install -U sentence-transformers -q
!pip install faiss-cpu -q

# BERT
!pip install bert-extractive-summarizer -q

# Hosting Deps
!pip install pyngrok -q

In [2]:
from langchain.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.document_loaders import TextLoader
from langchain.text_splitter import CharacterTextSplitter
from summarizer import Summarizer
from pathlib import Path
import nest_asyncio

from auto_gptq import exllama_set_max_input_length

from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
import sys, json, re

from flask import Flask, jsonify, request
from pyngrok import ngrok
import os
import threading

PATH="/content/game.txt"
model_id = "TheBloke/Mistral-7B-Instruct-v0.2-GPTQ"
revision = "gptq-4bit-32g-actorder_True"
adapters_name = 'SyedTalha/Mistral-7B-Instruct-v0.2-PEFT-adapters-v2'

In [3]:
tokenizer = AutoTokenizer.from_pretrained(
    model_id,
    padding=True,
    padding_side = "left",
    use_fast=True
)
tokenizer.pad_token = tokenizer.eos_token

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


In [4]:
# we use this fine-tuned model for story generation
model_finetuned = AutoModelForCausalLM.from_pretrained(
    adapters_name,
    device_map="auto",
    trust_remote_code=True,
    revision="main"
)
model_finetuned = exllama_set_max_input_length(model_finetuned, 8192)



In [5]:
# we use this model to generate branching narratives from the generated story
# or to process any other query from the user through RAG
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="auto",
    trust_remote_code=False,
    revision=revision
)
model = exllama_set_max_input_length(model, 8192)

In [6]:
# Create a pipeline
pipe = pipeline(
    task='text-generation',
    model=model,
    tokenizer=tokenizer,
    max_new_tokens=512,
    do_sample=True, # creative generation by discouraging greedy decoding
    temperature=1,
    top_p=0.95,
    top_k=40,
    repetition_penalty=1.1,
    return_full_text = False  # Only return the current output instead of returning complete prompt
)

# Create a separate pipeline for story generation
pipe_finetuned = pipeline(
    task='text-generation',
    model=model_finetuned,
    tokenizer=tokenizer,
    max_new_tokens=512,
    do_sample=True, # creative generation by discouraging greedy decoding
    temperature=1,
    top_p=0.95,
    top_k=40,
    repetition_penalty=1.1,
    return_full_text = False  # Only return the current output instead of returning complete prompt
)

In [7]:
bert_model = Summarizer()

In [8]:
def make_vdb(path):
  loader = TextLoader(path)
  doc=loader.load()

  # Chunk text
  text_splitter = CharacterTextSplitter(chunk_size=10, chunk_overlap=0)
  chunked_documents = text_splitter.split_documents(doc)

  embeddings = HuggingFaceEmbeddings(model_name='sentence-transformers/all-mpnet-base-v2')

  # Load chunked documents into the FAISS index
  db = FAISS.from_documents(chunked_documents, embeddings)

  # Connect query to FAISS index using a retriever
  retriever = db.as_retriever(
      search_type="similarity",
      search_kwargs={'k': 3}
  )

  folder_path = Path("/content/faiss_index")

  if folder_path.exists():
    old_db = FAISS.load_local("/content/faiss_index", embeddings,allow_dangerous_deserialization=True)
    db.merge_from(old_db)
    db.save_local("/content/faiss_index")
    return db
  else:
    db.save_local("/content/faiss_index")
    return db

def make_rag_query(query, db):
  docs = db.similarity_search(query)
  result=docs[0].page_content
  return result

def save_to_txt(content):
    filename="/content/game.txt"
    with open(filename, 'w') as file:
        file.write(content)

def rag_optimize(context,query):
    print("context: ", context, type(context))
    system = f"""
    You are excellent at creating simplified statemnet about the context given to you. Use this simplified statement to answer the question [{query}]. The answer must be concise.
    [{context}]

    """

    # one-shot prompting
    chat = [
      {"role": "user", "content": system}
    ]

    # prepare the prompt using the chat template
    prompt = tokenizer.apply_chat_template(chat, tokenize=False)
    # run the pipeline to generate the model output
    outputs = pipe(prompt)
    output = outputs[0]["generated_text"].strip()
    return output

def summarize_scenerio(scenerio):
  bert_summary = ''.join(bert_model(scenerio, min_length=10))
  return bert_summary

def extract_first_json(text):
    # Define a regex pattern to match the first JSON object
    pattern = r'{\s*".*?"\s*:\s*{.*?}\s*}'

    # Use re.search to find the first match of the pattern in the text
    match = re.search(pattern, text, re.DOTALL)

    if match:
        # Extract and return the matched JSON object
        return match.group()
    else:
        return None

# This function takes in the malformed json and creates a prompt for the model to convert it into a valid json as per the given valid json schema.
def json_fixer(malformed_json, valid_json_schema):
    prompt = f"""Generate valid JSON from the malformed JSON fixing missing commas, quotes and brackets.

valid JSON should strictly follow the following "json-valid" schema:
```json-valid
{valid_json_schema}
```

Here is a malformed json:
```json-malformed
{malformed_json}
```

Here is a fixed JSON, with fixed missing commas, quotes and brackets:
```json"""

    return prompt

In [9]:
os.environ["FLASK_DEBUG"] = "development"
app = Flask(__name__)
port = 5000
ngrok.set_auth_token("2D9yNuGDT1dVBQnv20D5rMFPz38_6KYfPQZssNFuUvZcng55M")
public_url = ngrok.connect(port).public_url
print(public_url)
app.config["BASE_URL"] = public_url

https://c2ca-34-83-116-245.ngrok-free.app


In [10]:
@app.route('/generate', methods=['POST'])
def game():
  data = request.json
  story = data["story"]
  rag_query = data["rag_query"]

  if not rag_query:
    # prompt for story generation
    story_user = f"""
    You are an AI dungeon master that provides any kind of roleplaying game content.

    Instructions:

    - Be specific, descriptive, and creative.
    - Avoid repetition and avoid summarization.
    - Generally use second person (like this: 'He looks at you.'). But use third person if that's what the story seems to follow.
    - Never decide or write for the user. If the input ends mid sentence, continue where it left off.
    - Make sure you always give responses continuing mid sentence even if it stops partway through.

    Continue the story below:

    {story}

    """

    chat = [
        {"role": "user", "content": story_user}
    ]

    # run the pipeline to generate the choices and associated damage score
    story_prompt = tokenizer.apply_chat_template(chat, tokenize=False)

    # run the pipeline to generate the narrative based on the story so far
    outputs = pipe_finetuned(story_prompt)
    narrative = outputs[0]["generated_text"].strip()

    # used for one-shot prompting to let the model know the output structure
    option_assistant = f"""
      {{
        "option1": {{
          "text": "This is choice 1",
          "outcome": "This is the narrative for choice 1",
          "damage": 0
        }},
        "option2": {{
          "text": "This is choice 2",
          "outcome": "This is the narrative for choice 2",
          "damage": 5
        }},
        "option3": {{
          "text": "This is choice 3",
          "outcome": "This is the narrative for choice 3",
          "damage": 0
        }}
      }}
      """

    # prompt for generating branching narratives and the choices along with the associated damage score based on the most recent narrative
    option_user = f"""
    You are an expert interactive fiction writer who specializes in crafting short and creative branching narratvies.\
    Create three branching narratives for the story excerpt provided below.\
    The generated narrative should be one-liner sentences with less than 20 words referencing the keywords from the story and highlighting the key details.\
    Also, present each narrative in the form of user-visible choice as well. The choices must be of a few words capturing the essence of the resulting narrative.\
    Additionally, associate damage score with each choice: 5 damage if selecting this choice can bring damage, otherwise 0 damage.
    Generate only one JSON object containing the narratives and corresponding choices using the following json schema.

    ```json
    {option_assistant}

    Story Excerpt:

    {narrative}

    """

    chat = [
        {"role": "user", "content": option_user},
        {"role": "assistant", "content": option_assistant},
        {"role": "user", "content": "```json"},
    ]

    # run the pipeline to generate the choices and associated damage score
    option_prompt = tokenizer.apply_chat_template(chat, tokenize=False)
    outputs = pipe(option_prompt)
    output = outputs[0]["generated_text"].strip()

    try:
      data = json.loads(extract_first_json(output))
    except:
      data = None

    if data is None:
      print("fixing JSON!!!") # log this instead of printing
      prompt = json_fixer(output, option_assistant)
      outputs = pipe(prompt)
      output = outputs[0]["generated_text"].strip()
      data = json.loads(extract_first_json(output))

    option1_text = data['option1']['text']
    option1_outcome = data['option1']['outcome']
    option1_damage = data['option1']['damage']

    option2_text = data['option2']['text']
    option2_outcome = data['option2']['outcome']
    option2_damage = data['option2']['damage']

    option3_text = data['option3']['text']
    option3_outcome = data['option3']['outcome']
    option3_damage = data['option3']['damage']

    options = [option1_text, option2_text, option3_text]
    outcomes = [option1_outcome, option2_outcome, option3_outcome]
    points = [option1_damage, option2_damage, option3_damage]

    image_prompt = summarize_scenerio(narrative)

    # create and store vector db for the current narrative
    save_to_txt(narrative)
    db = make_vdb(PATH)

    return jsonify({"story": narrative, "options": options, "outcomes": outcomes, "points": points, "image_prompt": image_prompt, "rag_response": None})

  else:
    # handle rag query
    save_to_txt("dummy text")
    db = make_vdb(PATH)

    result = make_rag_query(rag_query, db)
    rag_response = rag_optimize(result, rag_query)

    return jsonify({"story": None, "options": None, "outcomes": None, "points": None, "image_prompt": None, "rag_response": rag_response})

In [11]:
threading.Thread(target=app.run, kwargs={"use_reloader": False}).start()