In [1]:
import torch
from peft import LoraConfig, PeftModel
from transformers import (AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig)
import threading
from flask import Flask, request, jsonify, render_template
from transformers import StoppingCriteria, StoppingCriteriaList
import requests

In [2]:
app = Flask(__name__)
@app.route('/')
def home():
    return render_template('index.html')

In [3]:
model_identifier = "D:/chat model/model"
enable_4bit = True
compute_dtype_bnb = "float16"
quant_type_bnb = "nf4"
double_quant_flag = False
device_assignment = {"": 0}

In [4]:
dtype_computation = getattr(torch, compute_dtype_bnb)

# BitsAndBytes configuration for model quantization
bnb_setup = BitsAndBytesConfig(load_in_4bit=enable_4bit,
                               bnb_4bit_quant_type=quant_type_bnb,
                               bnb_4bit_use_double_quant=double_quant_flag,
                               bnb_4bit_compute_dtype=dtype_computation)

In [5]:

llama_model = AutoModelForCausalLM.from_pretrained(model_identifier, quantization_config=bnb_setup, device_map=device_assignment)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [6]:
adapters_path = 'D:/chat model/adapter'
model = PeftModel.from_pretrained(llama_model, adapters_path)

In [7]:
llama_tokenizer = AutoTokenizer.from_pretrained(model_identifier, trust_remote_code=True)
llama_tokenizer.pad_token = llama_tokenizer.eos_token
llama_tokenizer.padding_side = "right"

In [8]:
class StopAtPunctuation(StoppingCriteria):
    def __init__(self, stop_token_ids):
        self.stop_token_ids = stop_token_ids

    def __call__(self, input_ids, scores, **kwargs):
        last_token_id = input_ids[0, -1].item()
        return last_token_id in self.stop_token_ids


In [9]:
def generate_text(prompt, model, tokenizer, max_length=1000):
    inputs = tokenizer(prompt, return_tensors="pt").to('cuda')
    stop_tokens = [tokenizer.encode("#")[0], tokenizer.encode("!")[0], tokenizer.encode("?")[0]]
    stop_criteria = StoppingCriteriaList([StopAtPunctuation(stop_tokens)])
    with torch.no_grad():
        outputs = model.generate(**inputs, max_length=max_length,max_new_tokens=200, pad_token_id=tokenizer.eos_token_id,stopping_criteria=stop_criteria )
    return tokenizer.decode(outputs[0], skip_special_tokens=True)
        

In [10]:
@app.route('/predict', methods=['POST'])
def predict():
    user_prompt = request.json.get('prompt')
    generated_text = generate_text(f"{user_prompt}", model, llama_tokenizer)
    return jsonify(generated_text)


In [11]:
@app.route('/submit_prompt', methods=['POST'])
def submit_prompt():
    user_prompt = request.form['prompt']
    url = 'http://192.168.29.193:5000/predict'  # Ensure this points to the correct LLM API
    try:
        response = requests.post(url, json={'prompt': user_prompt})
        
        if response.status_code == 200:
            generated_text = response.text.replace('\\n', '<br>')
        else:
            generated_text = f"Error: Received status code {response.status_code}"

    except requests.exceptions.RequestException as e:
        generated_text = f"Error connecting to the LLM API: {e}"

    return render_template('index.html', prompt=user_prompt, result=generated_text)

In [12]:
def run_flask():
    app.run(host='0.0.0.0', port=5000, use_reloader=False)

In [None]:
flask_thread = threading.Thread(target=run_flask)
flask_thread.start()

 * Serving Flask app '__main__'
 * Debug mode: off


 * Running on all addresses (0.0.0.0)
 * Running on http://127.0.0.1:5000
 * Running on http://192.168.29.193:5000
Press CTRL+C to quit
192.168.29.13 - - [24/Sep/2024 21:58:37] "GET /favicon.ico HTTP/1.1" 404 -
192.168.29.13 - - [24/Sep/2024 21:58:44] "GET / HTTP/1.1" 200 -
Both `max_new_tokens` (=200) and `max_length`(=1000) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)
You are not running the flash-attention implementation, expect numerical differences.
192.168.29.193 - - [24/Sep/2024 21:59:10] "POST /predict HTTP/1.1" 200 -
192.168.29.13 - - [24/Sep/2024 21:59:10] "POST /submit_prompt HTTP/1.1" 200 -
Both `max_new_tokens` (=200) and `max_length`(=1000) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/tex

In [14]:
flask_thread.join(timeout=1)
