In [None]:
!pip install tf-keras

In [2]:
# ZenML
from zenml import pipeline, step

# General
import os
import numpy as np

# MongoDB
from pymongo import MongoClient

# Qdrant
from qdrant_client import QdrantClient
from qdrant_client.http.models import PointStruct, VectorParams, Distance

# Embedding Model
from sentence_transformers import SentenceTransformer

# For logging
import logging

logger = logging.getLogger(__name__)


[1;35mgenerated new fontManager[0m


In [3]:
@step
def load_data_from_mongodb() -> list:
    logger.info("Loading data from MongoDB...")
    
    client = MongoClient('mongodb://rag_mongodb:27017/')
    db = client['rag_db']
    collection = db['raw_data']
    
    documents = list(collection.find())
    logger.info(f"Loaded {len(documents)} documents from MongoDB.")
    
    return documents


In [4]:
@step
def categorize_and_preprocess_data(documents: list) -> list:
    logger.info("Categorizing and preprocessing data...")
    
    processed_data = []
    
    for doc in documents:
        content = doc.get('content', '')
        file_path = doc.get('path', '')
        source = doc.get('source', 'unknown')
        url = doc.get('url', '')

        # Determine the category
        if source == 'github':
            if file_path.endswith('.md') or file_path.endswith('.rst'):
                category = 'article'
                # Additional preprocessing for articles if needed
            elif file_path.endswith('.py'):
                category = 'code'
                if is_valid_python_code(content):
                    content = remove_comments_and_docstrings(content)
                else:
                    logger.warning(f"Skipping invalid Python file: {file_path}")
                    continue
            else:
                category = 'other'
        elif source == 'web':
            category = 'article'
        elif source == 'youtube':
            category = 'article'
        else:
            category = 'unknown'

        processed_data.append({
            'url': url,
            'path': file_path,
            'repository': doc.get('repository', ''),
            'branch': doc.get('branch', ''),
            'content': content,
            'source': source,
            'category': category
        })
    
    logger.info(f"Categorized and processed {len(processed_data)} documents.")
    return processed_data


In [5]:
import ast

def is_valid_python_code(source):
    """
    Validate if the source code is likely valid Python.
    """
    try:
        ast.parse(source)
        return True
    except SyntaxError:
        return False
    except Exception:
        return False


def remove_comments_and_docstrings(source):
    """
    Remove comments and docstrings from Python source code.
    """
    try:
        # Parse the source code into an AST
        parsed = ast.parse(source)
        for node in ast.walk(parsed):
            if isinstance(node, (ast.FunctionDef, ast.ClassDef, ast.AsyncFunctionDef, ast.Module)):
                # Remove docstrings
                if node.body and isinstance(node.body[0], ast.Expr):
                    if hasattr(node.body[0], 'value') and isinstance(node.body[0].value, ast.Str):
                        node.body = node.body[1:]
        return ast.unparse(parsed)
    except SyntaxError as e:
        logger.warning(f"Syntax error in Python code: {e}")
        return source  # Return the original source if parsing fails
    except Exception as e:
        logger.warning(f"Unexpected error parsing Python code: {e}")
        return source  # Return the original source if other errors occur

In [6]:
@step
def chunk_data(processed_data: list) -> list:
    logger.info("Chunking data...")
    
    chunked_data = []
    max_chunk_size = 500  # Adjust based on your embedding model's max input length
    
    for doc in processed_data:
        content = doc['content']
        # Create a unique ID for the document
        doc_id = doc.get('path', '') or doc.get('url', '').replace('/', '_')
        if not doc_id:
            doc_id = f"{doc['source']}_{len(chunked_data)}"

        # Split content into chunks
        content_length = len(content)
        chunks = [content[i:i+max_chunk_size] for i in range(0, content_length, max_chunk_size)]
        
        for idx, chunk in enumerate(chunks):
            chunked_data.append({
                'doc_id': str(doc_id),
                'chunk_id': f"{str(doc_id)}_{idx}",
                'chunk': chunk,
                'metadata': {
                    'url': doc.get('url', ''),
                    'path': doc.get('path', ''),
                    'repository': doc.get('repository', ''),
                    'branch': doc.get('branch', ''),
                    'source': doc.get('source', ''),
                    'category': doc.get('category', '')
                }
            })
    
    logger.info(f"Created {len(chunked_data)} chunks from documents.")
    return chunked_data


In [7]:
@step
def generate_embeddings(chunked_data: list) -> list:
    logger.info("Generating embeddings...")
    
    if not chunked_data:
        logger.warning("No data to generate embeddings for!")
        return []
    
    model = SentenceTransformer('all-MiniLM-L6-v2')
    batch_size = 32  # Adjust based on your hardware capabilities
    
    embeddings = []
    total_batches = (len(chunked_data) + batch_size - 1) // batch_size
    
    for i in range(0, len(chunked_data), batch_size):
        batch = chunked_data[i:i+batch_size]
        texts = [item['chunk'] for item in batch]
        try:
            batch_embeddings = model.encode(texts)
            for idx, item in enumerate(batch):
                item['embedding'] = batch_embeddings[idx].tolist()
            embeddings.extend(batch)
            logger.info(f"Processed batch {(i//batch_size)+1}/{total_batches}")
        except Exception as e:
            logger.error(f"Error generating embeddings for batch {i//batch_size}: {str(e)}")
            continue
    
    logger.info(f"Generated embeddings for {len(embeddings)} chunks.")
    return embeddings

In [8]:
@step
def store_embeddings_in_qdrant(chunked_data: list):
    logger.info("Storing embeddings in Qdrant...")
    
    try:
        client = QdrantClient(host='rag_qdrant', port=6333)
        
        # Define collection parameters
        collection_name = 'rag_collection'
        vector_size = len(chunked_data[0]['embedding'])
        
        # Check if collection exists and recreate
        try:
            client.get_collection(collection_name)
            client.delete_collection(collection_name)
            logger.info(f"Deleted existing collection: {collection_name}")
        except Exception:
            pass
        
        client.create_collection(
            collection_name=collection_name,
            vectors_config=VectorParams(size=vector_size, distance=Distance.COSINE)
        )
        
        # Prepare data for Qdrant
        batch_size = 100  # Adjust based on your needs
        for i in range(0, len(chunked_data), batch_size):
            batch = chunked_data[i:i+batch_size]
            points = []
            for idx, item in enumerate(batch):
                # Generate a positive integer ID using the position in the dataset
                point_id = i * batch_size + idx + 1  # Ensures positive, unique IDs starting from 1
                
                point = PointStruct(
                    id=point_id,  # Use the positive integer ID
                    vector=item['embedding'],
                    payload={
                        **item['metadata'],
                        'chunk_id': item['chunk_id'],
                        'doc_id': item['doc_id'],
                        'chunk': item['chunk']
                    }
                )
                points.append(point)
            
            try:
                # Upload batch to Qdrant
                client.upsert(
                    collection_name=collection_name,
                    points=points
                )
                logger.info(f"Uploaded batch of {len(points)} embeddings to Qdrant (IDs {points[0].id} to {points[-1].id})")
            except Exception as e:
                logger.error(f"Error uploading batch: {str(e)}")
                # Log the first failing point for debugging
                if points:
                    logger.error(f"First point in failing batch - ID: {points[0].id}")
                raise
        
        logger.info(f"Successfully stored all {len(chunked_data)} embeddings in Qdrant.")
    except Exception as e:
        logger.error(f"Error storing embeddings in Qdrant: {str(e)}")
        raise

In [None]:
# First, create a global variable to store the results
pipeline_results = None

@pipeline
def featurization_pipeline():
    try:
        documents = load_data_from_mongodb()
        if not documents:
            logger.error("No documents loaded from MongoDB!")
            return None
            
        processed_data = categorize_and_preprocess_data(documents)
        if not processed_data:
            logger.error("No documents after preprocessing!")
            return None
            
        chunked_data = chunk_data(processed_data)
        if not chunked_data:
            logger.error("No chunks created!")
            return None
            
        chunked_data_with_embeddings = generate_embeddings(chunked_data)
        if not chunked_data_with_embeddings:
            logger.error("No embeddings generated!")
            return None
            
        store_embeddings_in_qdrant(chunked_data_with_embeddings)
        logger.info("Pipeline completed successfully!")
        
        # Return the results
        return chunked_data_with_embeddings
    except Exception as e:
        logger.error(f"Pipeline failed: {str(e)}")
        raise

# Run the pipeline and store results
pipeline_instance = featurization_pipeline()



In [None]:
# Now test the results
def search_qdrant(query_text: str, limit: int = 3):
    try:
        # Connect to Qdrant
        qdrant_client = QdrantClient(host='rag_qdrant', port=6333)
        
        # Load the same model used in the pipeline
        model = SentenceTransformer('all-MiniLM-L6-v2')
        
        # Generate embedding for the query
        query_vector = model.encode(query_text)
        
        # Ensure the query vector is a list of floats
        if not isinstance(query_vector, list):
            query_vector = query_vector.tolist()
        
        # Log the query vector for debugging
        logger.debug(f"Query vector: {query_vector[:10]}...")  # Log first 10 elements
        
        # Search
        search_results = qdrant_client.search(
            collection_name='rag_collection',
            query_vector=query_vector,
            limit=limit
        )
        
        # Display results
        print(f"\nSearch Results for: '{query_text}'")
        print("-" * 50)
        for result in search_results:
            print(f"Score: {result.score:.4f}")
            print(f"Repository: {result.payload.get('repository')}")
            print(f"Path: {result.payload.get('path')}")
            print(f"Category: {result.payload.get('category')}")
            print(f"Chunk: {result.payload.get('chunk')[:200]}...")  # Show first 200 chars
            print("-" * 50)
            
    except Exception as e:
        logger.error(f"Error searching Qdrant: {str(e)}")
        raise

# Example usage:
search_qdrant("What is ROS?")