In [None]:
import weaviate
import weaviate.classes as wvc
import os
from dotenv import load_dotenv
from weaviate.classes.query import MetadataQuery
from transformers import BitsAndBytesConfig, AutoModel, AutoTokenizer, BitsAndBytesConfig
from torch.nn.functional import Tensor
import torch, gc

import torch.nn.functional as F

cwd = os.getcwd()
parent_dir = os.path.dirname(cwd)
os.chdir(parent_dir)

load_dotenv()
hf_token = os.getenv("HUGGINGFACE_TOKEN")

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

model = AutoModel.from_pretrained(
    'Salesforce/SFR-Embedding-Mistral',
    trust_remote_code=True,
    device_map='auto',
    torch_dtype=torch.bfloat16,
    quantization_config=bnb_config
)

def last_token_pool(last_hidden_states: Tensor,
                 attention_mask: Tensor) -> Tensor:
    left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
    if left_padding:
        return last_hidden_states[:, -1]
    else:
        sequence_lengths = attention_mask.sum(dim=1) - 1
        batch_size = last_hidden_states.shape[0]
    return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]

def convert_text_to_tokens(text:str, tokenizer, max_length):

    batch_dict = tokenizer(text, max_length=max_length, padding=True, truncation=True, return_tensors="pt").to('cuda')
    output = model(**batch_dict)
    embeddings = last_token_pool(output.last_hidden_state, batch_dict['attention_mask'])[0].float().cpu().detach().numpy()
    return embeddings

openai_api_key = os.getenv("OPENAI_KEY")

client = weaviate.connect_to_local(
    port=8080,
    grpc_port=50051,
    additional_config=weaviate.config.AdditionalConfig(timeout=(60, 180)),
    headers={
        "X-OpenAI-Api-Key": openai_api_key  
    }
)

In [None]:
tokenizer = AutoTokenizer.from_pretrained('Salesforce/SFR-Embedding-Mistral')
max_length = 4096
question = f"What do I do if my neighbour is having a party"
collection = client.collections.get("citizens_info_docs")
question_embeddings = convert_text_to_tokens(question, tokenizer, max_length)

response = collection.query.near_vector(
    near_vector=question_embeddings.tolist(),  # Pass the list of vectors
    target_vector='default', 
    return_properties=['body', 'title'],
    limit=2,
    return_metadata=MetadataQuery(distance=True)
)

for o in response.objects:
    print(o.properties)
    print(o.metadata.distance)

model.cpu()
del model, tokenizer
torch.cuda.empty_cache()
gc.collect()

In [None]:
# mistralai/Mistral-7B-Instruct-v0.2
from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, TextStreamer

model_name_or_path = "mistralai/Mistral-7B-Instruct-v0.2"

config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True)
config.max_position_embeddings = 8096
quantization_config = BitsAndBytesConfig(
llm_int8_enable_fp32_cpu_offload=True,
bnb_4bit_quant_type='nf4',
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=torch.bfloat16,
load_in_4bit=True
)

tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True)
model = AutoModelForCausalLM.from_pretrained(
model_name_or_path,
config=config,
trust_remote_code=True,
quantization_config=quantization_config,
device_map="cuda",
offload_folder="./offload"
)

# tokenizer = AutoTokenizer.from_pretrained('mistralai/Mistral-7B-Instruct-v0.2', token=hf_token)

In [None]:
context = o.properties['body']
prompt = [
    {"role": "user", "content": f"Based on the following context {context}, can you provide an answer to this {question}. If the information is not clear say I don't know but don't make up any information"},
]

encodeds = tokenizer.apply_chat_template(prompt, return_tensors="pt")

generated_ids = model.generate(encodeds, max_new_tokens=1000, do_sample=True, pad_token_id=tokenizer.eos_token_id)  # Use encodeds directly
decoded = tokenizer.batch_decode(generated_ids)

# Extract the answer after the [/INST] token
start_token = "[/INST]"
start_index = decoded[0].find(start_token)

print(f'{question}: \n')
if start_index != -1:
    start_index += len(start_token)
    answer = decoded[0][start_index:].strip()
    print(answer)
else:
    print("Token not found in the generated text.")