In [None]:
!python -m pip install --upgrade pip
!python -m pip uninstall -y torch torchvision torchaudio
!python -m pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124
!pip install faiss-gpu-cu12 datasets sentence-transformers scikit-learn

In [None]:
from huggingface_hub import login
login(token="HF_token")

In [None]:
import numpy as np
import sqlite3
import torch

from transformers import AutoTokenizer, AutoModelForCausalLM
from sklearn.metrics.pairwise import cosine_similarity
from sentence_transformers import SentenceTransformer
from data_sql import example_data

In [None]:
def create_example_db(db_path='example_db.db'):
    """Create SQLite database and documents table."""
    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()
    cursor.execute('''
        CREATE TABLE IF NOT EXISTS documents (
            id INTEGER PRIMARY KEY AUTOINCREMENT,
            topic TEXT NOT NULL,
            text_column TEXT NOT NULL)''')
    cursor.execute('DELETE FROM documents')
    cursor.executemany('INSERT INTO documents (topic, text_column) VALUES (?, ?)', example_data)
    conn.commit()
    conn.close()

def fetch_data_from_db(db_path='example_db.db'):
    """Fetch all rows from the documents table."""
    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()
    cursor.execute("SELECT id, topic, text_column FROM documents")
    rows = cursor.fetchall()
    conn.close()
    return rows

In [None]:
def load_embedding_model():
    """Load the embedding model (embeddinggemma-300m) with SentenceTransformer."""
    return SentenceTransformer("google/embeddinggemma-300m")

def generate_embeddings(model, texts):
    """Generate embeddings for a list of texts using the provided model."""
    return model.encode(texts, convert_to_tensor=False)

def retrieve(query_embedding, doc_embeddings, top_k=3):
    """Retrieve top_k most similar documents to the query by cosine similarity."""
    sims = cosine_similarity([query_embedding], doc_embeddings)[0]
    top_indices = np.argsort(sims)[-top_k:][::-1]
    return [(i, sims[i]) for i in top_indices]

In [None]:
def load_gemma_2b():
    """Load the model (gemma-2-2b-it) for causal language modeling (answer generation)."""
    tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b-it")
    model = AutoModelForCausalLM.from_pretrained("google/gemma-2-2b-it", torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32)
    if torch.cuda.is_available():
        model.to('cuda')
    model.eval()
    return tokenizer, model

def build_prompt(question, context_chunks):
    """Build a prompt string that instructs to use the given context chunks to answer the question."""
    context = "\n\n".join(context_chunks)
    return f"Use the context to answer the question.\n\nContext:\n{context}\n\nQuestion: {question}\nAnswer:"

def generate_answer(tokenizer, model, question, context, max_length=256):
    """Generate an answer for the question based on the provided context chunks."""
    prompt = build_prompt(question, context)
    inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=512)
    if torch.cuda.is_available():
        inputs = {k: v.to('cuda') for k, v in inputs.items()}
    with torch.no_grad():
        out = model.generate(**inputs, max_new_tokens=max_length, do_sample=True, top_p=0.9, temperature=0.7)
    answer = tokenizer.decode(out[0], skip_special_tokens=True).replace(prompt, "").strip()
    return answer

In [None]:
def main():
    # Create and prepare the example database
    create_example_db()
    rows = fetch_data_from_db()

    # Load embedding model and embed all documents
    embed_model = load_embedding_model()
    texts = [row[2] for row in rows]
    doc_embeddings = generate_embeddings(embed_model, texts)

    # Example user query
    question = "Type your query here"
    question_embed = generate_embeddings(embed_model, [question])[0]

    # Retrieve top-3 similar documents for context
    hits = retrieve(question_embed, doc_embeddings, top_k=3)
    ctx_chunks = [rows[i][2] for i, _ in hits]
    
    # Load Gemma 2B model and tokenizer for answer generation
    tokenizer, gen_model = load_gemma_2b()

    # Generate answer based on user question
    answer = generate_answer(tokenizer, gen_model, question, ctx_chunks)
    
    print("Answer:", answer)

if __name__ == "__main__":
    main()