# Streamlit RAG (Retrievalâ€‘Augmented Generation)
Build an inâ€‘memory RAG app: upload files â†’ chunk â†’ embed â†’ index into FAISS â†’ ask questions â†’ get answers with citations. Supports OpenAI, Gemini, Anthropic, Groq, or local/HF embeddings.


# Installation (commented)

In [None]:
# !pip install streamlit faiss-cpu pypdf chromadb tiktoken
# !pip install sentence-transformers openai anthropic google-generativeai groq huggingface_hub


# Imports

In [None]:
import os
import io
import re
import streamlit as st
from typing import List


# Optional heavy imports inside functions to avoid import costs on app load

In [None]:
def load_pdf_text(file_bytes: bytes) -> str:
    from pypdf import PdfReader
    reader = PdfReader(io.BytesIO(file_bytes))
    return "\n".join([p.extract_text() or '' for p in reader.pages])

def split_text(text: str, chunk_size=800, chunk_overlap=120) -> List[str]:
    chunks = []
    start = 0
    while start < len(text):
        end = min(len(text), start + chunk_size)
        chunks.append(text[start:end])
        start = end - chunk_overlap
        if start < 0:
            start = 0
    return [c.strip() for c in chunks if c.strip()]

def get_embedder(name: str):
    if name == 'OpenAI':
        from openai import OpenAI
        if (k:=st.secrets.get('OPENAI_API_KEY', None) if hasattr(st,'secrets') else None) or os.environ.get('OPENAI_API_KEY'):
            os.environ['OPENAI_API_KEY'] = k or os.environ.get('OPENAI_API_KEY','')
        client = OpenAI()
        def _emb(texts: List[str]):
            resp = client.embeddings.create(model='text-embedding-3-small', input=texts)
            return [d.embedding for d in resp.data]
        return _emb
    if name == 'SentenceTransformers':
        from sentence_transformers import SentenceTransformer
        model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
        return lambda texts: model.encode(texts, show_progress_bar=False, normalize_embeddings=True).tolist()
    if name == 'HuggingFace Inference':
        from huggingface_hub import InferenceClient
        hf = InferenceClient(token=st.secrets.get('HUGGINGFACEHUB_API_TOKEN', os.environ.get('HUGGINGFACEHUB_API_TOKEN')))
        return lambda texts: [hf.embeddings(model='sentence-transformers/all-MiniLM-L6-v2', inputs=[t]).data[0].embedding for t in texts]
    raise ValueError('Unknown embedder')

def build_faiss_index(vectors):
    import numpy as np
    import faiss
    arr = np.array(vectors, dtype='float32')
    dim = arr.shape[1]
    index = faiss.IndexFlatIP(dim)
    # Normalize for cosine similarity
    faiss.normalize_L2(arr)
    index.add(arr)
    return index

def search_index(index, query_vec, k=5):
    import numpy as np
    import faiss
    q = np.array([query_vec], dtype='float32')
    faiss.normalize_L2(q)
    scores, idxs = index.search(q, k)
    return scores[0], idxs[0]

def generate_answer(provider: str, model: str, question: str, contexts: List[str], temperature: float, max_tokens: int) -> str:
    system = (
        "You are a helpful assistant. Answer using only the provided context snippets. "
        "If the answer is not contained in the context, say you don't know."
    )
    context_block = "\n\n".join([f"[Source {i+1}]\n{c}" for i,c in enumerate(contexts)])
    prompt = f"Context:\n{context_block}\n\nQuestion: {question}\nAnswer:"
    if provider == 'OpenAI':
        from openai import OpenAI
        client = OpenAI()
        res = client.chat.completions.create(model=model, temperature=temperature, max_tokens=max_tokens,
                                             messages=[{"role":"system","content":system},{"role":"user","content":prompt}])
        return res.choices[0].message.content
    if provider == 'Gemini':
        import google.generativeai as genai
        genai.configure(api_key=st.secrets.get('GEMINI_API_KEY', os.environ.get('GEMINI_API_KEY')))
        g = genai.GenerativeModel(model)
        out = g.generate_content(f"{system}\n\n{prompt}")
        return getattr(out,'text',str(out))
    if provider == 'Anthropic':
        import anthropic
        a = anthropic.Anthropic()
        out = a.messages.create(model=model, max_tokens=max_tokens, temperature=temperature, system=system,
                                messages=[{"role":"user","content":prompt}])
        return out.content[0].text
    if provider == 'Groq':
        from groq import Groq
        gq = Groq(api_key=st.secrets.get('GROQ_API_KEY', os.environ.get('GROQ_API_KEY')))
        out = gq.chat.completions.create(model=model, temperature=temperature, max_tokens=max_tokens,
                                         messages=[{"role":"system","content":system},{"role":"user","content":prompt}])
        return out.choices[0].message.content
    if provider == 'HuggingFace':
        from huggingface_hub import InferenceClient
        hf = InferenceClient(token=st.secrets.get('HUGGINGFACEHUB_API_TOKEN', os.environ.get('HUGGINGFACEHUB_API_TOKEN')))
        try:
            resp = hf.chat_completion(model=model, messages=[{"role":"system","content":system},{"role":"user","content":prompt}], max_tokens=max_tokens)
            return resp.choices[0].message['content'] if hasattr(resp.choices[0],'message') else resp.choices[0]['message']['content']
        except Exception:
            return hf.text_generation(model=model, inputs=f"{system}\n\n{prompt}", max_new_tokens=max_tokens)
    raise ValueError('Unsupported provider')


# UI

In [None]:
st.set_page_config(page_title="RAG", page_icon="ðŸ“š")
st.title("ðŸ“š RAG â€” Upload, Index, Ask")

with st.sidebar:
    st.header("LLM & Embeddings")
    provider = st.selectbox('LLM Provider', ['OpenAI','Gemini','Anthropic','Groq','HuggingFace'], index=0)
    if provider=='OpenAI':
        model = st.text_input('LLM Model', 'gpt-4o-mini')
    elif provider=='Gemini':
        model = st.text_input('LLM Model', 'gemini-1.5-flash')
    elif provider=='Anthropic':
        model = st.text_input('LLM Model', 'claude-3-5-sonnet-20241022')
    elif provider=='Groq':
        model = st.text_input('LLM Model', 'llama-3.1-8b-instant')
    else:
        model = st.text_input('LLM Model', 'meta-llama/Llama-3.2-1B')
    emb_backend = st.selectbox('Embedding backend', ['SentenceTransformers','OpenAI','HuggingFace Inference'], index=0)
    chunk_size = st.slider('Chunk size', 200, 2000, 800, 50)
    chunk_overlap = st.slider('Chunk overlap', 0, 400, 120, 10)
    temperature = st.slider('temperature', 0.0, 1.5, 0.2, 0.1)
    max_tokens = st.slider('max tokens', 64, 2048, 256, 32)

uploaded = st.file_uploader("Upload files (PDF/TXT/MD)", type=['pdf','txt','md'], accept_multiple_files=True)


# Build index on demand

In [None]:
if st.button('Build index'):
    texts = []
    for f in uploaded or []:
        try:
            if f.name.lower().endswith('.pdf'):
                texts.append(load_pdf_text(f.read()))
            else:
                texts.append(f.read().decode('utf-8', errors='ignore'))
        except Exception as e:
            st.warning(f"Failed to read {f.name}: {e}")
    corpus = "\n\n".join(texts)
    chunks = split_text(corpus, chunk_size=chunk_size, chunk_overlap=chunk_overlap)
    st.session_state['rag_chunks'] = chunks
    embed = get_embedder(emb_backend)
    vectors = embed(chunks)
    index = build_faiss_index(vectors)
    st.session_state['rag_index'] = index
    st.session_state['rag_vectors'] = vectors
    st.success(f"Indexed {len(chunks)} chunks.")


# Ask questions

In [None]:
q = st.text_input("Ask a question about your documents")
if st.button('Answer') and q:
    if 'rag_index' not in st.session_state:
        st.warning('Please build the index first.')
    else:
        embed = get_embedder(emb_backend)
        qvec = embed([q])[0]
        scores, idxs = search_index(st.session_state['rag_index'], qvec, k=5)
        chunks = st.session_state['rag_chunks']
        ctx = [chunks[i] for i in idxs if 0 <= i < len(chunks)]
        try:
            ans = generate_answer(provider, model, q, ctx, temperature, max_tokens)
        except Exception as e:
            ans = f"Error: {e}"
        st.subheader('Answer')
        st.write(ans)
        st.subheader('Citations')
        for i, (c, s) in enumerate(zip(ctx, scores)):
            with st.expander(f"Source {i+1} (score {float(s):.3f})"):
                st.write(c)


# Notes
# - Everything is in-memory; for persistence use Chroma/Milvus etc.
# - For larger docs, prefer background indexing and caching.