In [1]:
import torch
from transformers import LlamaTokenizer, LlamaForCausalLM
from peft import PeftModel, PeftConfig
from utils import prepare_prompt


Welcome to bitsandbytes. For bug reports, please submit your error trace to: https://github.com/TimDettmers/bitsandbytes/issues
CUDA SETUP: CUDA runtime path found: /mnt/appl/software/CUDA/11.4.1/lib/libcudart.so
CUDA SETUP: Highest compute capability among GPUs detected: 8.0
CUDA SETUP: Detected CUDA version 114
CUDA SETUP: Loading binary /home/halamvac/venvs/venv39/lib/python3.9/site-packages/bitsandbytes/libbitsandbytes_cuda114.so...


In [2]:
class AskRedditModel:
    def __init__(self, model_path):
        config = PeftConfig.from_pretrained(model_path)
        model = LlamaForCausalLM.from_pretrained(config.base_model_name_or_path,
                                                 load_in_8bit=True,
                                                 torch_dtype=torch.float16,
                                                 device_map='auto')
        self.model = PeftModel.from_pretrained(model, model_path)
        self.tokenizer = LlamaTokenizer.from_pretrained(model_path)
        
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.model.to(self.device)
            
    def __call__(self, question, min_length=20):
        prompt = prepare_prompt(question)
        inp = self.tokenizer(prompt, return_tensors='pt')['input_ids']
        generated = self.model.generate(input_ids=inp.to(self.device),
                                        no_repeat_ngram_size=3,
                                        num_beams=4,
                                        max_new_tokens=128,
                                        min_new_tokens=min_length,
                                        early_stopping=True)
        response = self.tokenizer.decode(generated[0])
        
        # Clean the output
        return response.split("Response:")[1][:-4]
        

In [3]:
model = AskRedditModel("askreddit_v1")

Loading checkpoint shards:   0%|          | 0/33 [00:00<?, ?it/s]

In [4]:
resp = model("What is something that is way more dangerous than people think it is?", min_length=20)
print(resp)



Mosquitoes. They've killed more people than any other animal in the history of the planet.


In [5]:
resp = model("What will always be dirty no matter how often it's cleaned?", min_length=20)
print(resp)

The inside of a toilet.

You can clean it, but it'll still be dirty.


In [6]:
resp = model("What do people take way too seriously?", min_length=0)
print(resp)

Politics


In [7]:
resp = model("What is something that people don't take seriously enough?", min_length=0)
print(resp)

Mental health


In [8]:
resp = model("What does japan do better than the rest of the world?", min_length=10)
print(resp)

Japanese toilet paper is the best
