In [None]:
pip install git+https://github.com/huggingface/transformers@v4.49.0-Gemma-3

In [None]:
pip show transformers

In [None]:
pip install faiss-cpu langchain

In [None]:
pip install -U langchain-community

In [None]:
from huggingface_hub import login
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()

hf_token = user_secrets.get_secret("HUGGINGFACE_TOKEN")
login(token = hf_token)

In [None]:
import json
import torch
import re
import time
from transformers import Gemma3ForConditionalGeneration, AutoProcessor, AutoModel, AutoTokenizer
from tqdm import tqdm
import numpy as np
import faiss

In [None]:
# Load data
def load_data(file_path):
    with open(file_path, 'r', encoding='utf-8') as f:
        data = json.load(f)
    return data

In [None]:
# Improved embedding function using E5 models (better for multilingual retrieval)
class E5Embedder:
    def __init__(self, model_name="intfloat/multilingual-e5-large", device=None):
        if device is None:
            self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        else:
            self.device = device
        
        print(f"Loading embedding model {model_name} on {self.device}...")
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModel.from_pretrained(model_name).to(self.device)
        self.model.eval()
        print("Embedding model loaded")
    
    def _average_pool(self, last_hidden_states, attention_mask):
        # Take attention mask into account for averaging
        last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
        return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
    
    def encode(self, texts, batch_size=8, show_progress_bar=True):
        # Prepare storage for embeddings
        all_embeddings = []
        
        # Process in batches to avoid OOM
        for i in tqdm(range(0, len(texts), batch_size), disable=not show_progress_bar):
            batch_texts = texts[i:i+batch_size]
            
            # For E5 models, add prefix for better retrieval performance
            processed_texts = [f"passage: {text}" for text in batch_texts]
            
            # Tokenize
            inputs = self.tokenizer(
                processed_texts,
                padding=True,
                truncation=True,
                max_length=512,
                return_tensors="pt"
            ).to(self.device)
            
            # Get embeddings
            with torch.no_grad():
                outputs = self.model(**inputs)
                embeddings = self._average_pool(outputs.last_hidden_state, inputs['attention_mask'])
                
                # Normalize embeddings
                embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
                all_embeddings.append(embeddings.cpu().numpy())
        
        # Concatenate all embeddings
        return np.vstack(all_embeddings)
    
    def encode_queries(self, queries, batch_size=8):
        # Similar to encode but with "query: " prefix instead of "passage: "
        if isinstance(queries, str):
            queries = [queries]
            
        processed_queries = [f"query: {query}" for query in queries]
        
        inputs = self.tokenizer(
            processed_queries,
            padding=True,
            truncation=True,
            max_length=512,
            return_tensors="pt"
        ).to(self.device)
        
        with torch.no_grad():
            outputs = self.model(**inputs)
            embeddings = self._average_pool(outputs.last_hidden_state, inputs['attention_mask'])
            embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
        
        return embeddings.cpu().numpy()

In [None]:
# Enhanced chunking function for better retrieval
def chunk_documents(legal_data, max_chunk_size=256, overlap=50):
    chunks = []
    references = []
    
    for item in legal_data:
        text = item["Текст"]
        reference = item["Ссылка"]
        
        # For very short texts, keep them as is
        if len(text.split()) <= max_chunk_size:
            chunks.append(text)
            references.append(reference)
            continue
        
        # Split longer texts into chunks with overlap
        words = text.split()
        current_position = 0
        
        while current_position < len(words):
            end_position = min(current_position + max_chunk_size, len(words))
            chunk = " ".join(words[current_position:end_position])
            
            # Add extra information about the source to each chunk
            # This helps maintain context even in chunks
            ref_info = f"{reference} - Фрагмент {current_position//max_chunk_size + 1}"
            
            chunks.append(chunk)
            references.append(ref_info)
            
            # Move with overlap
            current_position += max_chunk_size - overlap
    
    return chunks, references

In [None]:
# Create improved FAISS index
def create_improved_faiss_index(legal_data, embedder, chunk_size=256, overlap=50):
    print("Chunking documents...")
    chunks, references = chunk_documents(legal_data, max_chunk_size=chunk_size, overlap=overlap)
    print(f"Created {len(chunks)} chunks from {len(legal_data)} documents")
    
    # Create embeddings
    print("Creating embeddings...")
    embeddings = embedder.encode(chunks)
    
    # Create FAISS index
    print("Building FAISS index...")
    dimension = embeddings.shape[1]
    
    # Using Flat index for better accuracy
    index = faiss.IndexFlatIP(dimension)
    index.add(embeddings)
    
    return index, chunks, references

In [None]:
# Improved document search with better scoring
def search_documents(query, index, chunks, references, embedder, top_k=8, min_score=0.2):
    # Create query embedding
    query_embedding = embedder.encode_queries(query)
    
    # Search the index
    scores, indices = index.search(query_embedding, top_k)
    
    # Collect results
    results = []
    for i, idx in enumerate(indices[0]):
        if idx < len(chunks) and scores[0][i] > min_score:
            results.append({
                "text": chunks[idx],
                "reference": references[idx],
                "score": float(scores[0][i])
            })
    
    # If we found few or no results, try with query expansion
    if len(results) < 3:
        expanded_query = expand_query(query)
        if expanded_query != query:
            exp_embedding = embedder.encode_queries(expanded_query)
            exp_scores, exp_indices = index.search(exp_embedding, top_k)
            
            # Add expanded results
            for i, idx in enumerate(exp_indices[0]):
                if idx < len(chunks) and exp_scores[0][i] > min_score:
                    # Check if this document is already in results
                    is_duplicate = False
                    for existing in results:
                        if existing["reference"] == references[idx]:
                            is_duplicate = True
                            break
                    
                    if not is_duplicate:
                        results.append({
                            "text": chunks[idx],
                            "reference": references[idx],
                            "score": float(exp_scores[0][i])
                        })
    
    # Sort by score
    results.sort(key=lambda x: x["score"], reverse=True)
    
    # Deduplicate by reference while keeping highest scores
    seen_references = set()
    deduplicated_results = []
    
    for result in results:
        # Extract the base reference by removing the "- Фрагмент X" part
        base_ref = result["reference"].split(" - Фрагмент")[0]
        
        if base_ref not in seen_references:
            seen_references.add(base_ref)
            deduplicated_results.append(result)
    
    return deduplicated_results[:top_k]

In [None]:
# Query expansion to improve retrieval
def expand_query(query):
    # Simple keyword extraction and expansion
    legal_keywords = {
        "налог": ["налогообложение", "налоговый вычет", "налоговая ставка"],
        "квартир": ["недвижимость", "жилье", "собственность", "жилая площадь"],
        "земл": ["земельный участок", "кадастр", "земельный налог"],
        "наследств": ["наследование", "наследодатель", "наследник"],
        "договор": ["сделка", "соглашение", "контракт"],
        "дар": ["дарение", "дарственная", "даритель", "одаряемый"],
        "пошлин": ["государственная пошлина", "сбор", "платеж"],
        "регистрац": ["регистрация", "росреестр", "оформление"],
        "опек": ["опекунство", "попечительство", "несовершеннолетний"]
    }
    
    lower_query = query.lower()
    expanded = query
    
    for keyword, expansions in legal_keywords.items():
        if keyword in lower_query:
            # Add one random expansion from the list
            import random
            expanded += f" {random.choice(expansions)}"
    
    return expanded

In [None]:
# Improved legal question detection
def is_legal_question(query):
    # Extended list of legal keywords and patterns
    legal_keywords = [
        'закон', 'право', 'суд', 'иск', 'договор', 'налог', 'кодекс', 'льгот', 
        'штраф', 'юрист', 'адвокат', 'нотариус', 'наследств', 'имуществ',
        'обязательств', 'ответственност', 'регистрац', 'доверенност', 'лиценз',
        'патент', 'собственност', 'аренд', 'залог', 'ипотек', 'кредит', 
        'страхов', 'компенсац', 'возмещени', 'претензи', 'банкротств', 
        'увольнени', 'оформить', 'заявлени', 'документ', 'выплат', 'пени', 
        'кадастр', 'недвижимост', 'земл', 'квартир', 'дом', 'участ', 'гражданск',
        'паспорт', 'снилс', 'инн', 'удостоверени', 'срок', 'пошлин', 'выписк',
        'свидетельств', 'доход', 'пенси', 'социальн', 'пособи', 'льгот',
        'дарени', 'продаж', 'покупк', 'наследов', 'завещани', 'брак', 'развод',
        'алимент', 'опек', 'попечит', 'усыновле', 'гражданств', 'вид на жительство',
        'миграци', 'прописк', 'регистрац', 'юридическ', 'физическ', 'лиц'
    ]
    
    # Convert query to lowercase
    query_lower = query.lower()
    
    # Check for legal keywords
    for keyword in legal_keywords:
        if keyword in query_lower:
            return True
    
    # Check for question formulations about rights and obligations
    legal_patterns = [
        r'как\s+.*\s+(оформить|получить|подать|зарегистрировать)',
        r'что\s+.*\s+(делать|нужно|требуется)\s+.*\s+(если|для|при)',
        r'когда\s+.*\s+(необходимо|нужно|следует|можно|обязан)',
        r'где\s+.*\s+(оформ|получ|зарегистр|подать)',
        r'сколько\s+.*\s+(стоит|платить|налог|штраф|пошлина|срок)',
        r'какие\s+.*\s+(документы|права|обязанности)',
        r'кто\s+.*\s+(имеет право|должен|обязан)',
        r'можно ли\s+.*',
        r'обязан ли\s+.*',
        r'нужно ли\s+.*',
        r'(дарение|налог|договор|пошлина|срок)\s+.*\?'
    ]
    
    for pattern in legal_patterns:
        if re.search(pattern, query_lower):
            return True
    
    # Be more lenient - assume it's legal if it ends with a question mark
    if query_lower.strip().endswith('?'):
        return True
    
    return False

In [None]:
# Improved answer generation
def generate_answer(query, context_documents, model, processor, max_new_tokens=384, temperature=0.7, is_legal=True):
    if not is_legal:
        return "Извините, я могу отвечать только на юридические вопросы. Пожалуйста, задайте вопрос о законодательстве, правах, налогах, документах или иных юридических аспектах."
    
    # Format prompt based on context
    if not context_documents or len(context_documents) == 0:
        messages = [
            {
                "role": "system",
                "content": [{"type": "text", "text": "Ты профессиональный юрист, который отвечает на вопросы о российском законодательстве. Давай точные и информативные ответы, основанные на актуальном законодательстве РФ. Если ты не уверен, честно скажи об этом и дай общую информацию по теме. Твои ответы должны быть краткими, структурированными и понятными для обычного человека."}]
            },
            {
                "role": "user",
                "content": [{"type": "text", "text": f"Вопрос: {query}"}]
            }
        ]
    else:
        # Prepare context, sorted by relevance
        context_docs = sorted(context_documents, key=lambda x: x["score"], reverse=True)
        
        # Limit the number of documents to fit within context window
        max_docs = min(5, len(context_docs))
        context = ""
        for i, doc in enumerate(context_docs[:max_docs]):
            context += f"Документ {i+1}:\n{doc['text']}\nИсточник: {doc['reference']}\n\n"
        
        messages = [
            {
                "role": "system",
                "content": [{"type": "text", "text": "Ты профессиональный юрист, который отвечает на вопросы о российском законодательстве. Используй предоставленные фрагменты законодательства для составления ответа. Не упоминай в ответе фразы типа 'на основании предоставленной информации' или 'в предоставленных документах'. Просто дай юридически точный ответ, ссылаясь на соответствующие нормы закона. Если информации недостаточно, дай полезный ответ на основе общих знаний о законодательстве РФ, но отметь, что это общая информация. Твои ответы должны быть структурированными, краткими и понятными для обычного человека."}]
            },
            {
                "role": "user",
                "content": [{"type": "text", "text": f"Вопрос: {query}\n\nФрагменты законодательства:\n{context}"}]
            }
        ]
    
    # Optimize memory - free CUDA cache before generation
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    
    # Tokenize messages with processor
    inputs = processor.apply_chat_template(
        messages, 
        add_generation_prompt=True, 
        tokenize=True,
        return_dict=True, 
        return_tensors="pt"
    ).to(model.device, dtype=torch.bfloat16)
    
    input_len = inputs["input_ids"].shape[-1]
    
    # Measure response time
    start_time = time.time()
    
    try:
        # Generate response with parameters for speed
        with torch.inference_mode():
            generation = model.generate(
                **inputs,
                max_new_tokens=max_new_tokens,
                temperature=temperature,
                do_sample=True,
                top_p=0.9,
                repetition_penalty=1.1,
                num_beams=1
            )
            
            # Get only new tokens (excluding input)
            generation = generation[0][input_len:]
    except RuntimeError as e:
        if "CUDA out of memory" in str(e):
            # In case of memory shortage, try with lower settings
            torch.cuda.empty_cache()
            # Reduce input context to save memory
            if context_documents:
                context = ""
                for i, doc in enumerate(context_docs[:2]):  # Take only 2 documents
                    context += f"Документ {i+1}:\n{doc['text'][:300]}...\nИсточник: {doc['reference']}\n\n"
                
                messages[1]["content"][0]["text"] = f"Вопрос: {query}\n\nФрагменты законодательства (сокращенные):\n{context}"
                
                inputs = processor.apply_chat_template(
                    messages, 
                    add_generation_prompt=True, 
                    tokenize=True,
                    return_dict=True, 
                    return_tensors="pt"
                ).to(model.device, dtype=torch.bfloat16)
                
                input_len = inputs["input_ids"].shape[-1]
                
                with torch.inference_mode():
                    generation = model.generate(
                        **inputs,
                        max_new_tokens=256,
                        temperature=0.7,
                        do_sample=True,
                        num_beams=1
                    )
                    
                    generation = generation[0][input_len:]
            else:
                return "Извините, произошла ошибка при генерации ответа. Попробуйте упростить вопрос или задать его позже."
        else:
            raise e

    # Decode response
    response = processor.decode(generation, skip_special_tokens=True)
    
    # Measure generation time
    generation_time = time.time() - start_time
    print(f"Generation time: {generation_time:.2f} seconds")
    
    # Post-process response for better readability
    response = response.strip()
    
    # Replace some template phrases for better user experience
    response = response.replace("На основании предоставленной информации", "Согласно законодательству РФ")
    response = response.replace("в предоставленных документах", "в российском законодательстве")
    response = response.replace("документах найти невозможно", "нормативных актах следует отметить")
    
    return response

In [None]:
# Format final answer with sources
def format_answer_with_sources(answer, context_documents):
    if not context_documents:
        return answer
    
    full_response = answer + "\n\nИсточники:\n"
    
    # Add only unique sources
    unique_sources = {}
    for doc in context_documents:
        # Extract the base reference without the "- Фрагмент X" part
        ref = doc['reference'].split(" - Фрагмент")[0]
        score = doc['score']
        if ref not in unique_sources or score > unique_sources[ref]['score']:
            unique_sources[ref] = {'score': score}
    
    # Sort sources by relevance
    sorted_sources = sorted(unique_sources.items(), key=lambda x: x[1]['score'], reverse=True)
    
    for i, (ref, info) in enumerate(sorted_sources):
        full_response += f"{i+1}. {ref} (релевантность: {info['score']:.2f})\n"
    
    return full_response

In [None]:
# Initialize chatbot with better error handling and memory optimization
def run_improved_legal_chatbot(model_name="google/gemma-3-4b-it", embedding_model_name="intfloat/multilingual-e5-large", legal_data_path=None, quantization=True):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    # CUDA memory management
    if torch.cuda.is_available():
        # Try to free unused memory
        torch.cuda.empty_cache()
        
        # Check available memory
        try:
            total_mem = torch.cuda.get_device_properties(0).total_memory / (1024**3)  # in GB
            allocated_mem = torch.cuda.memory_allocated(0) / (1024**3)  # in GB
            free_mem = total_mem - allocated_mem
            print(f"GPU memory: total {total_mem:.2f} GB, free {free_mem:.2f} GB")
        except RuntimeError:
            print("Unable to check GPU memory")
            free_mem = 0
    else:
        free_mem = 0
    
    try:
        # Load data
        print("Loading legal data...")
        if not legal_data_path:
            print("ERROR: No path to legal data file specified")
            return None, None, None, None, None, None
        legal_data = load_data(legal_data_path)
        print(f"Loaded {len(legal_data)} legal documents")
        
        # Choose appropriate embedding model based on available resources
        if free_mem < 4.0 or device.type == "cpu":
            # For limited resources, use a smaller model
            if embedding_model_name == "intfloat/multilingual-e5-large":
                embedding_model_name = "intfloat/multilingual-e5-small"
                print(f"Switching to smaller embedding model due to memory constraints: {embedding_model_name}")
        
        # Initialize embedding model
        embedder = E5Embedder(model_name=embedding_model_name, device=device)
        
        # Create improved FAISS index with optimal chunking
        index, chunks, references = create_improved_faiss_index(legal_data, embedder)
        
        # Memory management - move embedder to CPU if memory is limited
        if torch.cuda.is_available() and free_mem < 4.0:
            embedder.model = embedder.model.cpu()
            torch.cuda.empty_cache()
            print("Embedding model moved to CPU to save CUDA memory")
        
        # Load LLM model with optimizations
        print("\nLoading language model...")
        if quantization:
            # Try 8-bit quantization to save memory
            try:
                model = Gemma3ForConditionalGeneration.from_pretrained(
                    model_name, 
                    device_map="auto",
                    load_in_8bit=True
                ).eval()
                print("Model loaded with 8-bit quantization")
            except:
                # If that fails, try standard variant
                print("Failed to load model with 8-bit quantization, using standard variant")
                model = Gemma3ForConditionalGeneration.from_pretrained(
                    model_name, 
                    device_map="auto",
                    torch_dtype=torch.bfloat16
                ).eval()
        else:
            model = Gemma3ForConditionalGeneration.from_pretrained(
                model_name, 
                device_map="auto",
                torch_dtype=torch.bfloat16
            ).eval()
        
        processor = AutoProcessor.from_pretrained(model_name)
        
        return model, processor, embedder, index, chunks, references
        
    except Exception as e:
        print(f"Error during initialization: {e}")
        import traceback
        traceback.print_exc()
        return None, None, None, None, None, None

In [None]:
# Main function to start chatbot
def start_improved_chatbot(model, processor, embedder, index, chunks, references, interactive=True, test_queries=None):
    if test_queries:
        # Process test queries
        print("\n" + "="*50)
        print("TEST QUERIES")
        print("="*50)
        
        for query in test_queries:
            print("\n" + "="*50)
            print(f"Question: {query}")
            
            # Check if question is legal
            legal_question = is_legal_question(query)
            if not legal_question:
                print("Chatbot response:")
                print("Извините, я могу отвечать только на юридические вопросы. Пожалуйста, задайте вопрос о законодательстве, правах, налогах, документах или иных юридических аспектах.")
                continue
            
            # Search for relevant documents
            context_documents = search_documents(query, index, chunks, references, embedder, top_k=5, min_score=0.2)
            
            # Output found documents for debugging
            print(f"Found relevant documents: {len(context_documents)}")
            for i, doc in enumerate(context_documents[:3]):  # Show only first 3 for brevity
                print(f"Document {i+1} (relevance: {doc['score']:.4f}):")
                print(f"Reference: {doc['reference']}")
                print(f"Fragment: {doc['text'][:100]}...")
            
            # Generate answer
            answer = generate_answer(query, context_documents, model, processor, is_legal=legal_question)
            
            # Format full response with sources
            full_response = format_answer_with_sources(answer, context_documents)
            
            print("\nChatbot response:")
            print(full_response)
    
    if interactive:
        # Interactive mode
        print("\n" + "="*50)
        print("Interactive mode (type 'exit' to end):")
        
        while True:
            user_query = input("\nYour question: ")
            
            if user_query.lower() in ['выход', 'exit', 'quit']:
                break
            
            # Check if question is legal
            legal_question = is_legal_question(user_query)
            
            if not legal_question:
                print("Chatbot response:")
                print("Извините, я могу отвечать только на юридические вопросы. Пожалуйста, задайте вопрос о законодательстве, правах, налогах, документах или иных юридических аспектах.")
                continue
            
            # Search for relevant documents
            context_documents = search_documents(user_query, index, chunks, references, embedder, top_k=5, min_score=0.2)
            
            # Generate answer
            answer = generate_answer(user_query, context_documents, model, processor, is_legal=legal_question)
            
            # Format full response with sources
            full_response = format_answer_with_sources(answer, context_documents)
            
            print("Chatbot response:")
            print(full_response)

# Example usage
if __name__ == "__main__":
    legal_data_path = "/kaggle/input/legal-house-json/legal_documents.json"  # Path to your JSON file
    
    # Initialize models and index
    model, processor, embedder, index, chunks, references = run_improved_legal_chatbot(
        model_name="google/gemma-3-4b-it",
        embedding_model_name="intfloat/multilingual-e5-small",  # More memory efficient model
        legal_data_path=legal_data_path,
        quantization=True
    )
    
    if model is None:
        print("Failed to initialize chatbot. Check errors above.")
        exit(1)
    
    # Test queries
    test_queries = [
        "Сколько в год налог на землю?",
        "Как можно подарить квартиру?",
        "Для чего нужен кадастровый номер?",
        "Как рассчитать налог на квартиру?",
        "Что обязательно должно быть в договоре купли продажи?"
    ]
    
    # Start chatbot with test queries
    start_improved_chatbot(model, processor, embedder, index, chunks, references, 
                  interactive=True, test_queries=test_queries)