In [1]:
#!pip install git+https://github.com/huggingface/transformers.git
#!pip install datasets
#!pip install faiss-cpu

In [1]:
import os
import json
import faiss
import numpy as np
from transformers import RagTokenizer, RagRetriever, RagTokenForGeneration, AutoTokenizer
import torch
import faiss

#source new_env/bin/activate
model = RagTokenForGeneration.from_pretrained_question_encoder_generator("facebook/dpr-question_encoder-single-nq-base", "facebook/bart-large")
question_encoder_tokenizer = AutoTokenizer.from_pretrained("facebook/dpr-question_encoder-single-nq-base")
generator_tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large")

tokenizer = RagTokenizer(question_encoder_tokenizer, generator_tokenizer)
model.config.use_dummy_dataset = False
model.config.index_name = "exact"
retriever = RagRetriever(model.config, question_encoder_tokenizer, generator_tokenizer)

model.save_pretrained("./")
tokenizer.save_pretrained("./")
retriever.save_pretrained("./")


  from .autonotebook import tqdm as notebook_tqdm
Some weights of the model checkpoint at facebook/dpr-question_encoder-single-nq-base were not used when initializing DPRQuestionEncoder: ['question_encoder.bert_model.pooler.dense.bias', 'question_encoder.bert_model.pooler.dense.weight']
- This IS expected if you are initializing DPRQuestionEncoder from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DPRQuestionEncoder from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


: 

In [None]:
embedding_model = AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
embedding_tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")

if 'embedding_model' not in globals():
    embedding_model = AutoModel.from_pretrained("sentence-transformers/all-mpnet-base-v2")


In [None]:
# Function to embed text for retrieval
def embed_text(text):
    inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
    embeddings = embedding_model(**inputs).last_hidden_state.mean(dim=1).detach().numpy()
    return embeddings

# Initialize FAISS index
dimension = 768  # Dimension of embeddings
index = faiss.IndexFlatL2(dimension)
metadata = []

# Directory containing JSON files
json_dir = 'json'

# Process and index each JSON file
for json_file in os.listdir(json_dir):
    with open(os.path.join(json_dir, json_file), 'r') as f:
        case_data = json.load(f)
    
    # Extract and split passages from case body text
    for section in case_data['casebody']['opinions']:
        text = section['text']
        # Split text into smaller passages
        passages = [text[i:i+300] for i in range(0, len(text), 300)]
        
        for passage in passages:
            # Embed passage and add to FAISS index
            embedding = embed_text(passage)
            index.add(embedding)
            metadata.append({'file': json_file, 'text': passage})

# Save index and metadata
faiss.write_index(index, 'legal_cases_index.faiss')
with open('metadata.json', 'w') as f:
    json.dump(metadata, f)

# Query and Retrieve Relevant Passages
query = "What is the ruling on minors' ability to enlist in the navy?"
query_embedding = embed_text(query)

# Perform retrieval
_, I = index.search(query_embedding, k=5)  # Retrieve top 5 results

# Retrieve passages based on FAISS indices
retrieved_passages = [metadata[i]['text'] for i in I[0]]

# Use RAG to generate summary
input_ids = rag_tokenizer(query, return_tensors="pt").input_ids
generated_ids = rag_model.generate(input_ids, context_input_ids=retrieved_passages)
summary = rag_tokenizer.batch_decode(generated_ids, skip_special_tokens=True)

print("Summary:", summary[0])

In [5]:
# Function to generate automated queries
def generate_queries(metadata, templates, num_queries=10):
    queries = []
    for _ in range(num_queries):
        case_info = random.choice(metadata)
        template = random.choice(templates)
        
        # Use case name or a random topic for the query
        if case_info["topics"]:
            topic = random.choice(case_info["topics"])
        else:
            topic = case_info["name"]
        
        # Generate query
        query = template.format(topic)
        queries.append(query)
    
    return queries

# Generate sample queries
sample_queries = generate_queries(metadata, query_templates, num_queries=10)

# Print or save the queries
print("Generated Queries:")
for i, query in enumerate(sample_queries, 1):
    print(f"{i}. {query}")


Generated Queries:
1. How does this case interpret 31 Stat. 356?
2. Provide a summary of 53 Pac. 536.
3. What rights were discussed regarding 81 Fed. 211?
4. What was the ruling on 90 Fed. 673?
5. Provide a summary of 31 Stat. 383.
6. What are the legal principles established in 5 Or. 438?
7. Provide a summary of THE CATHERINE SUDDEN.
8. What was the ruling on 31 Stat. 528?
9. How does this case interpret 23 Stat. 24?
10. How does this case interpret 25 L. Ed. 435?


In [None]:
# Assuming you have a function `query_rag_system` that takes a query and returns a response
def query_rag_system(query):
    # This function should connect to your RAG system and return a response
    # Placeholder for an actual call to RAG model
    response = rag_system.generate_response(query)  # Replace with actual API or function call
    return response

# Run the queries and collect responses
responses = []
for query in sample_queries:
    response = query_rag_system(query)
    responses.append({"query": query, "response": response})

# Optionally, save responses for analysis
with open("rag_responses.json", "w") as f:
    json.dump(responses, f)
