In [1]:
import os
import torch
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    HfArgumentParser,
    TrainingArguments,
    pipeline,
    logging,
)
from peft import LoraConfig, PeftModel
from trl import SFTTrainer
import pandas as pd

In [2]:
from datasets import Dataset
import pyarrow as pa
import pyarrow.dataset as ds

In [3]:
from huggingface_hub import login
import numpy as np


In [4]:
def remove_chars(s):
    return s[2:-2]

df = pd.read_csv('/home/chats/data/July_llama2_finetune.csv')
df = df.drop('Unnamed: 0', axis=1)
df = df.rename(columns={'0': 'conversation'})
df = df[df['conversation']!= '[]']
#df[df['0']!='[]']
df = df.dropna()
# To reset the indices
df = df.reset_index(drop=True)
column_to_clean = 'conversation'
df[column_to_clean] = df[column_to_clean].apply(remove_chars)
string_to_remove = 'we are facing a technical issue please bear with us for sometime'

# Remove rows containing the specified string
df = df[~df['conversation'].str.contains(string_to_remove)]

In [5]:
dataset = Dataset(pa.Table.from_pandas(df))

In [None]:
login()

Token will not been saved to git credential helper. Pass `add_to_git_credential=True` if you want to set the git credential as well.
Token is valid (permission: write).
Your token has been saved to /home/chats/.cache/huggingface/token
Login successful


In [7]:
# Activate 4-bit precision base model loading
use_4bit = True
# Compute dtype for 4-bit base models
bnb_4bit_compute_dtype = "float16"
# Quantization type (fp4 or nf4)
bnb_4bit_quant_type = "nf4"
# Activate nested quantization for 4-bit base models (double quantization)
use_double_nested_quant = False
# Get the type
compute_dtype = getattr(torch, bnb_4bit_compute_dtype)

# BitsAndBytesConfig int-4 config
bnb_config = BitsAndBytesConfig(
    load_in_4bit=use_4bit,
    bnb_4bit_use_double_quant=use_double_nested_quant,
    bnb_4bit_quant_type=bnb_4bit_quant_type,
    bnb_4bit_compute_dtype=compute_dtype
)
     

In [8]:
#tokenizer = AutoTokenizer.from_pretrained("sharpbai/Llama-2-7b-chat")
#model = AutoModelForCausalLM.from_pretrained("sharpbai/Llama-2-7b-chat")

In [9]:
device_map = {"": 0}
model_id ="sharpbai/Llama-2-7b-chat"

model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=bnb_config, use_cache = False, device_map=device_map)
model.config.pretraining_tp = 1

# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

Downloading (‚Ä¶)lve/main/config.json:   0%|          | 0.00/625 [00:00<?, ?B/s]

Downloading (‚Ä¶)model.bin.index.json:   0%|          | 0.00/26.8k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/34 [00:00<?, ?it/s]

Downloading (‚Ä¶)l-00001-of-00034.bin:   0%|          | 0.00/396M [00:00<?, ?B/s]

Downloading (‚Ä¶)l-00002-of-00034.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (‚Ä¶)l-00003-of-00034.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (‚Ä¶)l-00004-of-00034.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (‚Ä¶)l-00005-of-00034.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (‚Ä¶)l-00006-of-00034.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (‚Ä¶)l-00007-of-00034.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (‚Ä¶)l-00008-of-00034.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (‚Ä¶)l-00009-of-00034.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (‚Ä¶)l-00010-of-00034.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (‚Ä¶)l-00011-of-00034.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (‚Ä¶)l-00012-of-00034.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (‚Ä¶)l-00013-of-00034.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (‚Ä¶)l-00014-of-00034.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (‚Ä¶)l-00015-of-00034.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (‚Ä¶)l-00016-of-00034.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (‚Ä¶)l-00017-of-00034.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (‚Ä¶)l-00018-of-00034.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (‚Ä¶)l-00019-of-00034.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (‚Ä¶)l-00020-of-00034.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (‚Ä¶)l-00021-of-00034.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (‚Ä¶)l-00022-of-00034.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (‚Ä¶)l-00023-of-00034.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (‚Ä¶)l-00024-of-00034.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (‚Ä¶)l-00025-of-00034.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (‚Ä¶)l-00026-of-00034.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (‚Ä¶)l-00027-of-00034.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (‚Ä¶)l-00028-of-00034.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (‚Ä¶)l-00029-of-00034.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (‚Ä¶)l-00030-of-00034.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (‚Ä¶)l-00031-of-00034.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (‚Ä¶)l-00032-of-00034.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (‚Ä¶)l-00033-of-00034.bin:   0%|          | 0.00/271M [00:00<?, ?B/s]

Downloading (‚Ä¶)l-00034-of-00034.bin:   0%|          | 0.00/262M [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/34 [00:00<?, ?it/s]

Downloading (‚Ä¶)neration_config.json:   0%|          | 0.00/192 [00:00<?, ?B/s]

Downloading (‚Ä¶)okenizer_config.json:   0%|          | 0.00/695 [00:00<?, ?B/s]

Downloading tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

Downloading (‚Ä¶)/main/tokenizer.json:   0%|          | 0.00/1.84M [00:00<?, ?B/s]

Downloading (‚Ä¶)cial_tokens_map.json:   0%|          | 0.00/435 [00:00<?, ?B/s]

In [10]:
# LoRA attention dimension
lora_r = 64
# Alpha parameter for LoRA scaling
lora_alpha = 16
# Dropout probability for LoRA layers
lora_dropout = 0.1

# LoRA config based on QLoRA paper
peft_config = LoraConfig(
        lora_alpha=lora_alpha,
        lora_dropout=lora_dropout,
        r=lora_r,
        bias="none",
        task_type="CAUSAL_LM",
        target_modules=[
            "gate_proj",
            "down_proj",
            "k_proj",
            "up_proj",
            "o_proj",
            "q_proj",
            "v_proj"
            ]
)

In [11]:
output_dir = "./results"
per_device_train_batch_size = 4
gradient_accumulation_steps = 4
optim = "paged_adamw_32bit"
save_steps = 10
logging_steps = 10
learning_rate = 1e-4
max_grad_norm = 0.3
max_steps = 250
warmup_ratio = 0.03
lr_scheduler_type = "constant"

# Ste training parameters
training_arguments = TrainingArguments(
    output_dir=output_dir,
    per_device_train_batch_size=per_device_train_batch_size,
    gradient_accumulation_steps=gradient_accumulation_steps,
    optim=optim,
    save_steps=save_steps,
    logging_steps=logging_steps,
    learning_rate=learning_rate,
    fp16=True,
    max_grad_norm=max_grad_norm,
    max_steps=max_steps,
    warmup_ratio=warmup_ratio,
    group_by_length=True,
    lr_scheduler_type=lr_scheduler_type,
    report_to=None
)


In [12]:
max_seq_length = 1024

trainer = SFTTrainer(
    model=model,
    train_dataset=dataset,
    peft_config=peft_config,
    dataset_text_field="conversation",
    max_seq_length=max_seq_length,
    tokenizer=tokenizer,
    args=training_arguments,
)



Map:   0%|          | 0/12642 [00:00<?, ? examples/s]

In [13]:
# Upcasting the layer norms to have for stable training
for name, module in trainer.model.named_modules():
    if "norm" in name:
        module = module.to(torch.float32)

In [14]:
trainer.train(resume_from_checkpoint=True) # there will not be a progress bar since tqdm is disabled


[34m[1mwandb[0m: Currently logged in as: [33manurag-pal[0m ([33myulu-bikes[0m). Use [1m`wandb login --relogin`[0m to force relogin


You're using a LlamaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Step,Training Loss
70,0.4115
80,0.3663
90,0.3207
100,0.2359
110,0.3903
120,0.3614
130,0.2957
140,0.2394
150,0.2108
160,0.3927


TrainOutput(global_step=250, training_loss=0.2271800227165222, metrics={'train_runtime': 7120.3351, 'train_samples_per_second': 0.562, 'train_steps_per_second': 0.035, 'total_flos': 5.056123191194419e+16, 'train_loss': 0.2271800227165222, 'epoch': 0.32})

In [15]:
trainer.save_model()


In [16]:
# Empty VRAM
del model
del trainer
import gc
gc.collect()
gc.collect()
torch.cuda.empty_cache() # PyTorch thing


In [17]:
gc.collect()


0

In [19]:
from peft import AutoPeftModelForCausalLM

new_model = AutoPeftModelForCausalLM.from_pretrained(
    output_dir,
    low_cpu_mem_usage=True,
    return_dict=True,
    torch_dtype=torch.float16,
    device_map=device_map,
)

# Merge LoRA and base model
merged_model = new_model.merge_and_unload()

# Save the merged model
merged_model.save_pretrained("merged_model",safe_serialization=True)
tokenizer.save_pretrained("merged_model")


Loading checkpoint shards:   0%|          | 0/34 [00:00<?, ?it/s]

('merged_model/tokenizer_config.json',
 'merged_model/special_tokens_map.json',
 'merged_model/tokenizer.json')

In [1]:
from peft import AutoPeftModelForCausalLM
from transformers import AutoTokenizer


In [None]:
from huggingface_hub import login
from huggingface_hub import HfApi

login()
api = HfApi()

# Upload all the content from the local folder to your remote Space.
# By default, files are uploaded at the root of the repo
api.upload_folder(
    folder_path="./merged_model",
    repo_id="YULU-BIKE/LLAMA_YULU",
    repo_type="model",
)

Token will not been saved to git credential helper. Pass `add_to_git_credential=True` if you want to set the git credential as well.
Token is valid (permission: write).
Your token has been saved to /home/chats/.cache/huggingface/token
Login successful


model-00001-of-00002.safetensors:   0%|          | 0.00/9.98G [00:00<?, ?B/s]

Upload 2 LFS files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/3.50G [00:00<?, ?B/s]

'https://huggingface.co/YULU-BIKE/LLAMA_YULU/tree/main/'

In [None]:
trainer.model.push_to_hub("YULU-BIKE/LLAMA-Shared-Ride")


In [4]:
torch.cuda.empty_cache()

In [16]:
from peft import AutoPeftModelForCausalLM
from transformers import AutoTokenizer

output_dir = "./results/checkpoint-60/"
device_map = {"": 0}
import torch
new_model = AutoPeftModelForCausalLM.from_pretrained(
    output_dir,
    low_cpu_mem_usage=True,
    return_dict=True,
    torch_dtype=torch.float16,
    device_map=device_map,
)

# Merge LoRA and base model
merged_model = new_model.merge_and_unload()

# Save the merged model
merged_model.save_pretrained("merged_model",safe_serialization=True)

tokenizer = AutoTokenizer.from_pretrained('sharpbai/Llama-2-7b-chat', trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"
tokenizer.save_pretrained("merged_model")


Downloading (‚Ä¶)lve/main/config.json:   0%|          | 0.00/625 [00:00<?, ?B/s]

Downloading (‚Ä¶)model.bin.index.json:   0%|          | 0.00/26.8k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/34 [00:00<?, ?it/s]

Downloading (‚Ä¶)l-00001-of-00034.bin:   0%|          | 0.00/396M [00:00<?, ?B/s]

Downloading (‚Ä¶)l-00002-of-00034.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (‚Ä¶)l-00003-of-00034.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (‚Ä¶)l-00004-of-00034.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (‚Ä¶)l-00005-of-00034.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (‚Ä¶)l-00006-of-00034.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (‚Ä¶)l-00007-of-00034.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (‚Ä¶)l-00008-of-00034.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (‚Ä¶)l-00009-of-00034.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (‚Ä¶)l-00010-of-00034.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (‚Ä¶)l-00011-of-00034.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (‚Ä¶)l-00012-of-00034.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (‚Ä¶)l-00013-of-00034.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (‚Ä¶)l-00014-of-00034.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (‚Ä¶)l-00015-of-00034.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (‚Ä¶)l-00016-of-00034.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (‚Ä¶)l-00017-of-00034.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (‚Ä¶)l-00018-of-00034.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (‚Ä¶)l-00019-of-00034.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (‚Ä¶)l-00020-of-00034.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

Downloading (‚Ä¶)l-00021-of-00034.bin:   0%|          | 0.00/405M [00:00<?, ?B/s]

KeyboardInterrupt: 

In [6]:
from huggingface_hub import login


In [None]:
login()

Token will not been saved to git credential helper. Pass `add_to_git_credential=True` if you want to set the git credential as well.
Token is valid (permission: write).
Your token has been saved to /home/chats/.cache/huggingface/token
Login successful


In [None]:
merged_model.push_to_hub("YULU-BIKE/LLAMA-Shared-Ride")
tokenizer.push_to_hub("YULU-BIKE/LLAMA-Shared-Ride")

In [31]:
sys = 'You are a Customer Support Agent for Yulu a Microbility company.Help Cusotmers regarding their queries in polite manner.If a question does not make any sense, or is not factually coherent,explain why instead of answering something not correct. If you donesnot know the answer to a question, please do not share false information.'


In [42]:
#sample = dataset[randrange(len(dataset))]

prompt = f"""<s>[INST]<<SYS>>You are a Customer Support Agent for Yulu a Microbility company.Help Cusotmers regarding their queries in polite manner.If a question does not make any sense, or is not factually coherent,explain why instead of answering something not correct. 
                If you do not know the answer to a question, please do not share false information.Refund 20 ruppees if it's bike issue an 50 ruppess if it's battery issue.<</SYS>> 
            ### Human: my battery is fluctuating[/INST]</s>"""

input_ids = tokenizer(prompt, return_tensors="pt", truncation=True).input_ids.cuda()
# with torch.inference_mode():
outputs = merged_model.generate(input_ids=input_ids, max_new_tokens=60, do_sample=True, top_p=0.9,temperature=0.5)

print(f"Prompt:\n{prompt}\n")
print(f"\nGenerated instruction:\n{tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True)[0][len(prompt):]}")
#print(f"\nGround truth:\n{sample['output']}")

Prompt:
<s>[INST]<<SYS>>You are a Customer Support Agent for Yulu a Microbility company.Help Cusotmers regarding their queries in polite manner.If a question does not make any sense, or is not factually coherent,explain why instead of answering something not correct. 
                If you do not know the answer to a question, please do not share false information.Refund 20 ruppees if it's bike issue an 50 ruppess if it's battery issue.<</SYS>> 
            ### Human: my battery is fluctuating[/INST]</s>


Generated instruction:
eetings, thank you for contacting yulu support. we understand your concern. please share the current location to check the issue. [/INST] looks like you have stepped away, we are closing this chat for now.\nplease feel free to reopen this chat anytime to


tensor([[    1,     1, 29961, 25580, 29962,  9314, 14816, 29903,  6778,  3492,
           526,   263, 21886, 18601, 28330,   363,   612, 21528,   263, 20140,
         29890,  1793,  5001, 29889, 29648,   315,   375,   327, 13269, 11211,
          1009,  9365,   297,  1248,   568,  8214, 29889,  3644,   263,  1139,
           947,   451,  1207,   738,  4060, 29892,   470,   338,   451,  2114,
          1474, 16165,   261,   296, 29892,  4548,  7420,  2020,  2012,   310,
         22862,  1554,   451,  1959, 29889, 29871,    13, 18884,   960,   366,
           437,   451,  1073,   278,  1234,   304,   263,  1139, 29892,  3113,
           437,   451,  6232,  2089,  2472, 19423,   829, 14816, 29903,  6778,
         29871,    13,  9651,   835, 12968, 29901,   306,   505,   337, 25389,
           287,   363, 29871, 29906, 29900, 29900,  5796,   412,   267,   541,
           372, 29915, 29879,   451,  9432,   292,   297,   278,   623, 29961,
         29914, 25580, 29962,    13,  9651,   835,  

In [55]:
get_prompt('my battery is dead',[("Alice", "28"),("Bob", "35")],sys)

'<s>[INST] <<SYS>>You are a Customer Support Agent for Yulu a Microbility company.Help Cusotmers regarding their queries in polite manner.If a question does not make any sense, or is not factually coherent,explain why instead of answering something not correct. If you donesnot know the answer to a question, please do not share false information.<</SYS>>Alice [/INST] 28 </s><s>[INST]Bob [/INST] 35 </s><s>[INST]my battery is dead [/INST]'

In [13]:
model=merged_model

In [56]:
from threading import Thread
from typing import Iterator
from transformers import TextIteratorStreamer
def get_prompt(message: str, chat_history: list[tuple[str, str]],
               system_prompt: str) -> str:
    texts = [f'<s>[INST] <<SYS>>{system_prompt}<</SYS>>']
    # The first user input is _not_ stripped
    do_strip = False
    for user_input, response in chat_history:
        user_input = user_input.strip() if do_strip else user_input
        do_strip = True
        texts.append(f'{user_input} [/INST] {response.strip()} </s><s>[INST]')
    message = message.strip() if do_strip else message
    texts.append(f'{message} [/INST]</s>')
    return ''.join(texts)


def get_input_token_length(message: str, chat_history: list[tuple[str, str]], system_prompt: str) -> int:
    prompt = get_prompt(message, chat_history, system_prompt)
    input_ids = tokenizer([prompt], return_tensors='np', add_special_tokens=False)['input_ids']
    return input_ids.shape[-1]


def run(message: str,
        chat_history: list[tuple[str, str]],
        system_prompt: str,
        max_new_tokens: int = 1024,
        temperature: float = 0.8,
        top_p: float = 0.95,
        top_k: int = 50) -> Iterator[str]:
    prompt = get_prompt(message, chat_history, system_prompt)
    inputs = tokenizer([prompt], return_tensors='pt', add_special_tokens=False).to('cuda')

    streamer = TextIteratorStreamer(tokenizer,
                                    timeout=10.,
                                    skip_prompt=True,
                                    skip_special_tokens=True)
    generate_kwargs = dict(
        inputs,
        streamer=streamer,
        max_new_tokens=max_new_tokens,
        do_sample=True,
        top_p=top_p,
        top_k=top_k,
        temperature=temperature,
        num_beams=1,
    )
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    outputs = []
    for text in streamer:
        outputs.append(text)
        yield ''.join(outputs)

In [15]:
DEFAULT_SYSTEM_PROMPT = "You are a Customer Support Agent for Yulu a Microbility company.Help Cusotmers regarding their queries in polite manner.If a question does not make any sense, or is not factually coherent,explain why instead of answering something not correct. If you do not know the answer to a question, please do not share false information."

In [27]:
MAX_MAX_NEW_TOKENS = 2048
DEFAULT_MAX_NEW_TOKENS = 50
MAX_INPUT_TOKEN_LENGTH = 1024


In [17]:
DESCRIPTION = """
Yulu Chat bot MVP (LLAMA 7b version (50 epochs))
"""
LICENSE = """
<p/>
---
As a derivate work of [Llama-2-7b-chat](https://huggingface.co/meta-llama/Llama-2-7b-chat) by Meta,
this demo is governed by the original [license](https://huggingface.co/spaces/huggingface-projects/llama-2-7b-chat/blob/main/LICENSE.txt) and [acceptable use policy](https://huggingface.co/spaces/huggingface-projects/llama-2-7b-chat/blob/main/USE_POLICY.md).
"""


In [18]:
def clear_and_save_textbox(message: str) -> tuple[str, str]:
    return '', message


def display_input(message: str,
                  history: list[tuple[str, str]]) -> list[tuple[str, str]]:
    history.append((message, ''))
    return history


def delete_prev_fn(
        history: list[tuple[str, str]]) -> tuple[list[tuple[str, str]], str]:
    try:
        message, _ = history.pop()
    except IndexError:
        message = ''
    return history, message or ''


def generate(
    message: str,
    history_with_input: list[tuple[str, str]],
    system_prompt: str,
    max_new_tokens: int,
    temperature: float,
    top_p: float,
    top_k: int,
) -> Iterator[list[tuple[str, str]]]:
    if max_new_tokens > MAX_MAX_NEW_TOKENS:
        raise ValueError

    history = history_with_input[:-1]
    generator = run(message, history, system_prompt, max_new_tokens, temperature, top_p, top_k)
    try:
        first_response = next(generator)
        yield history + [(message, first_response)]
    except StopIteration:
        yield history + [(message, '')]
    for response in generator:
        yield history + [(message, response)]


def process_example(message: str) -> tuple[str, list[tuple[str, str]]]:
    generator = generate(message, [], DEFAULT_SYSTEM_PROMPT, 1024, 1, 0.95, 50)
    for x in generator:
        pass
    return '', x


def check_input_token_length(message: str, chat_history: list[tuple[str, str]], system_prompt: str) -> None:
    input_token_length = get_input_token_length(message, chat_history, system_prompt)
    if input_token_length > MAX_INPUT_TOKEN_LENGTH:
        raise gr.Error(f'The accumulated input is too long ({input_token_length} > {MAX_INPUT_TOKEN_LENGTH}). Clear your chat history and try again.')


In [19]:
import gradio as gr


In [57]:
import gradio as gr
with gr.Blocks(css='style.css') as demo:
    gr.Markdown(DESCRIPTION)
    gr.DuplicateButton(value='Duplicate Space for private use',
                       elem_id='duplicate-button')

    with gr.Group():
        chatbot = gr.Chatbot(label='Chatbot')
        with gr.Row():
            textbox = gr.Textbox(
                container=False,
                show_label=False,
                placeholder='Type a message...',
                scale=10,
            )
            submit_button = gr.Button('Submit',
                                      variant='primary',
                                      scale=1,
                                      min_width=0)
    with gr.Row():
        retry_button = gr.Button('üîÑ  Retry', variant='secondary')
        undo_button = gr.Button('‚Ü©Ô∏è Undo', variant='secondary')
        clear_button = gr.Button('üóëÔ∏è  Clear', variant='secondary')

    saved_input = gr.State()

    with gr.Accordion(label='Advanced options', open=False):
        system_prompt = gr.Textbox(label='System prompt',
                                   value=DEFAULT_SYSTEM_PROMPT,
                                   lines=6)
        max_new_tokens = gr.Slider(
            label='Max new tokens',
            minimum=1,
            maximum=MAX_MAX_NEW_TOKENS,
            step=1,
            value=DEFAULT_MAX_NEW_TOKENS,
        )
        temperature = gr.Slider(
            label='Temperature',
            minimum=0.1,
            maximum=4.0,
            step=0.1,
            value=1.0,
        )
        top_p = gr.Slider(
            label='Top-p (nucleus sampling)',
            minimum=0.05,
            maximum=1.0,
            step=0.05,
            value=0.95,
        )
        top_k = gr.Slider(
            label='Top-k',
            minimum=1,
            maximum=1000,
            step=1,
            value=50,
        )


    gr.Markdown(LICENSE)

    textbox.submit(
        fn=clear_and_save_textbox,
        inputs=textbox,
        outputs=[textbox, saved_input],
        api_name=False,
        queue=False,
    ).then(
        fn=display_input,
        inputs=[saved_input, chatbot],
        outputs=chatbot,
        api_name=False,
        queue=False,
    ).then(
        fn=check_input_token_length,
        inputs=[saved_input, chatbot, system_prompt],
        api_name=False,
        queue=False,
    ).success(
        fn=generate,
        inputs=[
            saved_input,
            chatbot,
            system_prompt,
            max_new_tokens,
            temperature,
            top_p,
            top_k,
        ],
        outputs=chatbot,
        api_name=False,
    )

    button_event_preprocess = submit_button.click(
        fn=clear_and_save_textbox,
        inputs=textbox,
        outputs=[textbox, saved_input],
        api_name=False,
        queue=False,
    ).then(
        fn=display_input,
        inputs=[saved_input, chatbot],
        outputs=chatbot,
        api_name=False,
        queue=False,
    ).then(
        fn=check_input_token_length,
        inputs=[saved_input, chatbot, system_prompt],
        api_name=False,
        queue=False,
    ).success(
        fn=generate,
        inputs=[
            saved_input,
            chatbot,
            system_prompt,
            max_new_tokens,
            temperature,
            top_p,
            top_k,
        ],
        outputs=chatbot,
        api_name=False,
    )

    retry_button.click(
        fn=delete_prev_fn,
        inputs=chatbot,
        outputs=[chatbot, saved_input],
        api_name=False,
        queue=False,
    ).then(
        fn=display_input,
        inputs=[saved_input, chatbot],
        outputs=chatbot,
        api_name=False,
        queue=False,
    ).then(
        fn=generate,
        inputs=[
            saved_input,
            chatbot,
            system_prompt,
            max_new_tokens,
            temperature,
            top_p,
            top_k,
        ],
        outputs=chatbot,
        api_name=False,
    )

    undo_button.click(
        fn=delete_prev_fn,
        inputs=chatbot,
        outputs=[chatbot, saved_input],
        api_name=False,
        queue=False,
    ).then(
        fn=lambda x: x,
        inputs=[saved_input],
        outputs=textbox,
        api_name=False,
        queue=False,
    )

    clear_button.click(
        fn=lambda: ([], ''),
        outputs=[chatbot, saved_input],
        queue=False,
        api_name=False,
    )

demo.queue(max_size=20).launch(share=True)

Running on local URL:  http://127.0.0.1:7861
Running on public URL: https://04e7ae6c7ae3ff6660.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)




In [34]:
demo.close()

Closing server running on port: 7861
