In [18]:
!pip install tf-keras

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.0.1[0m[39;49m -> [0m[32;49m24.3.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [19]:
import os
import requests
import logging

import gradio as gr
from pymongo import MongoClient
from qdrant_client import QdrantClient
from sentence_transformers import SentenceTransformer
import torch

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Connect to Qdrant (adjust host/port if different)
qdrant_client = QdrantClient(host='rag_qdrant', port=6333)
collection_name = 'rag_collection'

# Embedding model
embedding_model = SentenceTransformer('all-MiniLM-L6-v2')


INFO:sentence_transformers.SentenceTransformer:Use pytorch device_name: cpu
INFO:sentence_transformers.SentenceTransformer:Load pretrained SentenceTransformer: all-MiniLM-L6-v2


In [29]:
import json
def retrieve_relevant_chunks(query_embedding, top_k=5):
    search_results = qdrant_client.search(
        collection_name=collection_name,
        query_vector=query_embedding,
        limit=top_k
    )
    retrieved_texts = [result.payload['chunk'] for result in search_results]
    return retrieved_texts

def call_ollama(prompt, model="llama2"):
    payload = {
        "model": model,
        "prompt": prompt,
        "options": {
            "stream": False  # Request a full response, not streaming
        }
    }

    url = os.getenv("OLLAMA_HOST", "http://host.docker.internal:11434")
    try:
        # Send the POST request
        response = requests.post(f"{url}/api/generate", json=payload, timeout=30)  # Add timeout
        response.raise_for_status()

        # Handle raw response in case of multiple JSON objects
        raw_response = response.text.strip()  # Raw response as text
        print("Raw Response:", raw_response)

        # Combine 'response' fields from multiple JSON objects
        responses = []
        for line in raw_response.splitlines():
            try:
                data = json.loads(line)
                if 'response' in data:
                    responses.append(data['response'])
            except json.JSONDecodeError as e:
                logger.error(f"Error parsing line as JSON: {e} - Line: {line}")

        # Join all parts of the response
        full_response = ''.join(responses)
        return full_response if full_response else "No response generated."
    except requests.exceptions.RequestException as e:
        logger.error(f"Error calling Ollama API: {e}")
        return "Sorry, I'm having trouble connecting to the language model."


In [21]:
def generate_response(query):
    if not query.strip():
        return "Please enter a query."
    
    # Embed the user query
    query_embedding = embedding_model.encode(query)

    # Retrieve relevant chunks from Qdrant
    retrieved_texts = retrieve_relevant_chunks(query_embedding, top_k=5)

    # Construct the prompt
    context = "\n".join(retrieved_texts)
    prompt = f"Context:\n{context}\n\nQuestion:\n{query}\n\nAnswer:"

    answer = call_ollama(prompt, model="llama2").strip()
    if not answer:
        answer = "I'm not sure how to answer that."
    return answer


In [30]:
test_query = "How do I navigate to a specific pose in ROS2?"
print(generate_response(test_query))


Batches:   0%|          | 0/1 [00:00<?, ?it/s]

INFO:httpx:HTTP Request: POST http://rag_qdrant:6333/collections/rag_collection/points/search "HTTP/1.1 200 OK"


Raw Response: {"model":"llama2","created_at":"2024-12-05T23:03:07.413905Z","response":"To","done":false}
{"model":"llama2","created_at":"2024-12-05T23:03:07.450304Z","response":" navigate","done":false}
{"model":"llama2","created_at":"2024-12-05T23:03:07.486667Z","response":" to","done":false}
{"model":"llama2","created_at":"2024-12-05T23:03:07.523052Z","response":" a","done":false}
{"model":"llama2","created_at":"2024-12-05T23:03:07.559369Z","response":" specific","done":false}
{"model":"llama2","created_at":"2024-12-05T23:03:07.595412Z","response":" pose","done":false}
{"model":"llama2","created_at":"2024-12-05T23:03:07.631034Z","response":" in","done":false}
{"model":"llama2","created_at":"2024-12-05T23:03:07.667145Z","response":" R","done":false}
{"model":"llama2","created_at":"2024-12-05T23:03:07.703449Z","response":"OS","done":false}
{"model":"llama2","created_at":"2024-12-05T23:03:07.739351Z","response":"2","done":false}
{"model":"llama2","created_at":"2024-12-05T23:03:07.775048

In [23]:
import requests

# Replace with the endpoint you want to test
ollama_url = "http://ollama:11434/api/generate"

# Example payload for the Ollama API
payload = {
    "prompt": "Generate a creative story",
    "parameters": {"max_tokens": 50}
}

try:
    response = requests.post(ollama_url, json=payload)
    if response.status_code == 200:
        print("Connection successful!")
        print("Response:", response.json())
    else:
        print(f"Connection failed with status code {response.status_code}")
        print("Response:", response.text)
except requests.exceptions.RequestException as e:
    print("Error connecting to Ollama:", e)


Connection failed with status code 403
Response: 
