In [1]:
# RAG Application with Groq API for Google Colab
# Install required packages (warnings about opentelemetry are normal and can be ignored)
import sys
import subprocess

def install_packages():
    packages = [
        'groq',
        'chromadb',
        'sentence-transformers',
        'pypdf2',
        'openpyxl',
        'python-docx',
        'gradio'
    ]
    print("üì¶ Installing dependencies... This may take a minute.")
    for pkg in packages:
        subprocess.run([sys.executable, '-m', 'pip', 'install', '-q', pkg],
                      capture_output=True)
    print("‚úÖ Installation complete! Starting application...\n")

install_packages()

import os
import gradio as gr
from groq import Groq
import chromadb
from chromadb.config import Settings
from sentence_transformers import SentenceTransformer
import PyPDF2
import openpyxl
import docx
from typing import List, Tuple
import io
import warnings
warnings.filterwarnings('ignore')

class RAGApplication:
    def __init__(self):
        self.client = None
        self.api_key = None
        self.embedding_model = None
        self.chroma_client = chromadb.Client(Settings(
            anonymized_telemetry=False,
            allow_reset=True
        ))
        self.collection = None
        self.documents = []
        self.doc_count = 0

    def verify_api_key(self, api_key: str) -> str:
        """Verify the Groq API key"""
        if not api_key or not api_key.strip():
            return "‚ùå Please enter an API key"

        try:
            self.api_key = api_key.strip()
            self.client = Groq(api_key=self.api_key)

            # Test the API key with a simple request
            response = self.client.chat.completions.create(
                messages=[{"role": "user", "content": "test"}],
                model="llama-3.3-70b-versatile",
                max_tokens=5
            )

            # Initialize embedding model after successful API verification
            if self.embedding_model is None:
                print("Loading embedding model...")
                self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
                print("‚úÖ Embedding model loaded")

            return "‚úÖ API Key verified successfully! You can now upload documents."
        except Exception as e:
            self.client = None
            self.api_key = None
            error_msg = str(e)
            if "invalid" in error_msg.lower() or "unauthorized" in error_msg.lower():
                return "‚ùå Invalid API Key. Please check your Groq API key and try again."
            return f"‚ùå API Key verification failed: {error_msg}"

    def extract_text_from_pdf(self, file_content: bytes) -> str:
        """Extract text from PDF file"""
        try:
            pdf_reader = PyPDF2.PdfReader(io.BytesIO(file_content))
            text = ""
            for page in pdf_reader.pages:
                extracted = page.extract_text()
                if extracted:
                    text += extracted + "\n"
            return text.strip()
        except Exception as e:
            raise Exception(f"Error reading PDF: {str(e)}")

    def extract_text_from_excel(self, file_content: bytes) -> str:
        """Extract text from Excel file"""
        try:
            workbook = openpyxl.load_workbook(io.BytesIO(file_content), data_only=True)
            text = ""
            for sheet in workbook.worksheets:
                text += f"\n=== Sheet: {sheet.title} ===\n"
                for row in sheet.iter_rows(values_only=True):
                    row_text = " | ".join([str(cell) if cell is not None else "" for cell in row])
                    if row_text.strip():
                        text += row_text + "\n"
            return text.strip()
        except Exception as e:
            raise Exception(f"Error reading Excel: {str(e)}")

    def extract_text_from_docx(self, file_content: bytes) -> str:
        """Extract text from Word document"""
        try:
            doc = docx.Document(io.BytesIO(file_content))
            text = ""
            for para in doc.paragraphs:
                if para.text.strip():
                    text += para.text + "\n"
            return text.strip()
        except Exception as e:
            raise Exception(f"Error reading Word document: {str(e)}")

    def extract_text_from_txt(self, file_content: bytes) -> str:
        """Extract text from text file"""
        try:
            return file_content.decode('utf-8', errors='ignore').strip()
        except Exception as e:
            raise Exception(f"Error reading text file: {str(e)}")

    def process_documents(self, files) -> str:
        """Process uploaded documents"""
        if not files:
            return "‚ùå No files uploaded. Please select at least one document."

        if self.client is None:
            return "‚ùå Please verify your API key first before uploading documents."

        if self.embedding_model is None:
            return "‚ùå Embedding model not loaded. Please verify API key again."

        # Reset collection
        try:
            self.chroma_client.delete_collection("documents")
        except:
            pass

        self.collection = self.chroma_client.create_collection(
            name="documents",
            metadata={"description": "Document collection for RAG"}
        )

        self.documents = []
        processed_files = []
        failed_files = []

        for file in files:
            file_name = "unknown_file"
            file_content = None

            try:
                # Handle filepath (Gradio passes file paths as strings)
                if isinstance(file, str):
                    file_name = os.path.basename(file)
                    with open(file, 'rb') as f:
                        file_content = f.read()
                # Handle file object
                elif hasattr(file, 'read'):
                    file_name = getattr(file, 'name', 'uploaded_file')
                    try:
                        file.seek(0)
                    except:
                        pass
                    file_content = file.read()
                else:
                    file_name = str(file)
                    failed_files.append(f"{file_name} (unknown file type)")
                    continue

                if not file_content:
                    failed_files.append(f"{file_name} (empty file)")
                    continue

                # Extract text based on file type
                text = ""
                if file_name.lower().endswith('.pdf'):
                    text = self.extract_text_from_pdf(file_content)
                elif file_name.lower().endswith(('.xlsx', '.xls')):
                    text = self.extract_text_from_excel(file_content)
                elif file_name.lower().endswith('.docx'):
                    text = self.extract_text_from_docx(file_content)
                elif file_name.lower().endswith('.txt'):
                    text = self.extract_text_from_txt(file_content)
                else:
                    failed_files.append(f"{file_name} (unsupported format)")
                    continue

                if not text or len(text.strip()) < 10:
                    failed_files.append(f"{file_name} (no extractable text)")
                    continue

                # Create chunks
                chunks = self.create_chunks(text, file_name)
                if chunks:
                    self.documents.extend(chunks)
                    processed_files.append(file_name)
                else:
                    failed_files.append(f"{file_name} (could not create chunks)")

            except Exception as e:
                error_msg = str(e)
                failed_files.append(f"{file_name} (error: {error_msg})")

        # Create embeddings and store in ChromaDB
        if self.documents:
            try:
                print(f"Creating embeddings for {len(self.documents)} chunks...")
                texts = [doc['text'] for doc in self.documents]
                embeddings = self.embedding_model.encode(
                    texts,
                    show_progress_bar=True,
                    batch_size=32
                ).tolist()

                self.collection.add(
                    embeddings=embeddings,
                    documents=texts,
                    metadatas=[{
                        "source": doc['source'],
                        "chunk_id": doc['chunk_id']
                    } for doc in self.documents],
                    ids=[f"doc_{i}" for i in range(len(self.documents))]
                )

                self.doc_count = len(processed_files)

                result_msg = f"‚úÖ Successfully processed {len(processed_files)} document(s) with {len(self.documents)} chunks\n\n"
                result_msg += "Processed files:\n" + "\n".join([f"  ‚Ä¢ {f}" for f in processed_files])

                if failed_files:
                    result_msg += "\n\n‚ö†Ô∏è Failed to process:\n" + "\n".join([f"  ‚Ä¢ {f}" for f in failed_files])

                result_msg += "\n\nüîç You can now ask questions about these documents!"
                return result_msg

            except Exception as e:
                return f"‚ùå Error creating embeddings: {str(e)}"
        else:
            if failed_files:
                return f"‚ùå No text could be extracted from any files.\n\nFailed files:\n" + "\n".join([f"  ‚Ä¢ {f}" for f in failed_files])
            else:
                return "‚ùå No valid files were processed."

    def create_chunks(self, text: str, source: str, chunk_size: int = 500, overlap: int = 50) -> List[dict]:
        """Split text into overlapping chunks"""
        chunks = []
        words = text.split()

        if len(words) == 0:
            return chunks

        for i in range(0, len(words), chunk_size - overlap):
            chunk_text = ' '.join(words[i:i + chunk_size])
            if chunk_text.strip() and len(chunk_text.strip()) > 20:
                chunks.append({
                    'text': chunk_text,
                    'source': source,
                    'chunk_id': len(chunks)
                })

        return chunks

    def retrieve_relevant_chunks(self, query: str, n_results: int = 3) -> List[str]:
        """Retrieve relevant document chunks"""
        if not self.collection:
            return []

        try:
            query_embedding = self.embedding_model.encode([query]).tolist()
            results = self.collection.query(
                query_embeddings=query_embedding,
                n_results=min(n_results, len(self.documents))
            )

            return results['documents'][0] if results['documents'] else []
        except Exception as e:
            print(f"Error retrieving chunks: {e}")
            return []

    def answer_question(self, question: str) -> str:
        """Answer question using RAG"""
        if not self.client:
            return "‚ùå Please verify your API key first."

        if not self.documents:
            return "‚ùå Please upload documents first. No documents are currently loaded."

        if not question.strip():
            return "‚ùå Please enter a question."

        try:
            # Retrieve relevant chunks
            relevant_chunks = self.retrieve_relevant_chunks(question, n_results=4)

            if not relevant_chunks:
                return "‚ùå No relevant information found in the documents for this question."

            # Create context from chunks
            context = "\n\n---\n\n".join(relevant_chunks)

            # Create prompt
            prompt = f"""You are a helpful assistant that answers questions based ONLY on the provided context from documents.

Context from documents:
{context}

Question: {question}

Instructions:
- Answer the question using ONLY information from the context above
- If the answer cannot be found in the context, respond with "I cannot find this information in the provided documents."
- Be concise and accurate
- Do not make up or infer information that is not explicitly stated in the context

Answer:"""

            # Get response from Groq
            response = self.client.chat.completions.create(
                messages=[{"role": "user", "content": prompt}],
                model="llama-3.3-70b-versatile",
                max_tokens=1024,
                temperature=0.2
            )

            answer = response.choices[0].message.content
            return f"üìù {answer}"

        except Exception as e:
            return f"‚ùå Error generating answer: {str(e)}\n\nPlease try again or rephrase your question."

# Initialize RAG application
print("Initializing RAG Application...")
rag_app = RAGApplication()
print("‚úÖ Application initialized!\n")

# Create Gradio interface
with gr.Blocks(title="RAG Application", theme=gr.themes.Soft()) as demo:
    gr.Markdown("""
    # üìö RAG Application with Groq API
    ### Upload documents and ask questions about their content
    """)

    with gr.Row():
        with gr.Column(scale=3):
            api_key_input = gr.Textbox(
                label="üîë Groq API Key",
                type="password",
                placeholder="Enter your Groq API key (starts with gsk_...)"
            )
        with gr.Column(scale=1):
            verify_btn = gr.Button("Verify API Key", variant="primary", size="lg")

    api_status = gr.Textbox(label="Status", interactive=False, show_label=False)

    gr.Markdown("---")

    file_upload = gr.File(
        label="üìÑ Upload Documents",
        file_count="multiple",
        file_types=[".pdf", ".xlsx", ".xls", ".docx", ".txt"],
        type="filepath"
    )

    upload_status = gr.Textbox(label="Upload Status", interactive=False, lines=8)

    gr.Markdown("---")

    with gr.Row():
        question_input = gr.Textbox(
            label="‚ùì Ask a Question",
            placeholder="What would you like to know about the documents?",
            lines=2,
            scale=4
        )
        submit_btn = gr.Button("Get Answer", variant="primary", size="lg", scale=1)

    answer_output = gr.Textbox(
        label="üí¨ Answer",
        lines=10,
        interactive=False
    )

    # Event handlers
    verify_btn.click(
        fn=rag_app.verify_api_key,
        inputs=[api_key_input],
        outputs=[api_status]
    )

    file_upload.change(
        fn=rag_app.process_documents,
        inputs=[file_upload],
        outputs=[upload_status]
    )

    submit_btn.click(
        fn=rag_app.answer_question,
        inputs=[question_input],
        outputs=[answer_output]
    )

    question_input.submit(
        fn=rag_app.answer_question,
        inputs=[question_input],
        outputs=[answer_output]
    )

    gr.Markdown("""
    ---
    ### üìñ How to Use:
    1. **Get a Groq API Key**: Visit [console.groq.com](https://console.groq.com) to get your free API key
    2. **Enter API Key**: Paste your key above and click "Verify API Key"
    3. **Upload Documents**: Upload PDF, Excel, Word, or text files (multiple files supported)
    4. **Wait for Processing**: The system will chunk and embed your documents
    5. **Ask Questions**: Type questions about your documents and get AI-powered answers

    **Note**: The application only answers questions based on the uploaded documents. If information isn't in the documents, it will let you know.
    """)

# Launch the interface
print("üöÄ Launching Gradio interface...")
demo.launch(debug=False, share=False)


üì¶ Installing dependencies... This may take a minute.
‚úÖ Installation complete! Starting application...





Initializing RAG Application...
‚úÖ Application initialized!

üöÄ Launching Gradio interface...
Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
Note: opening Chrome Inspector may crash demo inside Colab notebooks.
* To create a public link, set `share=True` in `launch()`.


<IPython.core.display.Javascript object>

