## Testing Salesforce XGen model

- XGen blog: https://blog.salesforceairesearch.com/xgen/
- XGen 7B : https://huggingface.co/Salesforce/xgen-7b-8k-base?ref=blog.salesforceairesearch.com
- XGen 7B instruct (research only) : https://huggingface.co/Salesforce/xgen-7b-8k-inst
- vLLM is not support XGen yet

In [None]:
!pip install -q transformers accelerate sentencepiece bitsandbytes tiktoken

In [None]:
from huggingface_hub import snapshot_download
from pathlib import Path
import os

local_model_path = Path("./pretrained-models")
local_model_path.mkdir(exist_ok=True)
# model_name = "Salesforce/xgen-7b-8k-base"
model_name = "Salesforce/xgen-7b-8k-inst"
allow_patterns = ["*.json", "*.pt", "*.bin", "*.txt", "*.model", "*.py"]

model_download_path = snapshot_download(
    repo_id=model_name,
    cache_dir=local_model_path,
    allow_patterns=allow_patterns,
)

In [None]:
print(f"Local model download path: {model_download_path}")

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

In [None]:
tokenizer = AutoTokenizer.from_pretrained(
    model_download_path, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    model_download_path,
    torch_dtype=torch.bfloat16,
    device_map='auto',
    low_cpu_mem_usage=True,
)
model.eval()

In [None]:
# # Base model
# prompt = "What is the best food in the world?"

In [None]:
# Instruct model
instruction = f"How to implement arbitrage bot for cryptocurrency? please explain step by step."

prompt = f"""
A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.

### Human: {instruction}

### Assiatant: 
"""
print(prompt)

In [None]:
eos_token = "<|endoftext|>"
eos_token_id = tokenizer(eos_token)["input_ids"][0]
print(eos_token_id)

In [None]:
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

In [None]:
%%time
sample = model.generate(
    **inputs,
    temperature=0.7,
    top_p=0.8,
    max_length=512,
    eos_token_id=eos_token_id
)

In [None]:
raw_output = tokenizer.decode(sample[0])
result = raw_output[len(prompt):-len(eos_token)]

In [None]:
print(result)