In [None]:
import torch
import chromadb
import numpy as np
from transformers import MllamaForCausalLM, AutoTokenizer
from transformers import TextIteratorStreamer
from threading import Thread
import warnings
import gc
import re
import plotly.graph_objects as go
from plotly.subplots import make_subplots

# Suppress warnings
warnings.filterwarnings("ignore")

## Set up LLM
model_id = "meta-llama/Llama-3.2-11B-Vision-Instruct"

model = MllamaForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
print(f"You are running the model on: {model.device}")

## Set up vector database
client = chromadb.PersistentClient('../chroma_prod')
collection = client.get_or_create_collection('clinical_trials_data')


In [2]:
def generate_text_stream(prompt, max_new_tokens=1024):
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    
    try:
        streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
        
        generation_kwargs = dict(
            inputs,
            streamer=streamer,
            max_new_tokens=max_new_tokens,
            do_sample=True,
            top_k=50,
            top_p=0.7,
            temperature=0.2,
            num_return_sequences=1,
            pad_token_id=tokenizer.eos_token_id
        )
        generation_kwargs['eos_token_id'] = tokenizer.encode("</explanation>")[-1]

        # Start the generation in a separate thread
        thread = Thread(target=model.generate, kwargs=generation_kwargs)
        thread.start()

        generated_text = ""
        for new_text in streamer:
            generated_text += new_text

        thread.join()  # Wait for the generation to finish
        
        # Clear CUDA cache
        torch.cuda.empty_cache()
        
        return generated_text
    except Exception as e:
        print(f"\nAn error occurred during text generation: {str(e)}")
        return None
    finally:
        # Ensure we always clear the inputs tensor
        del inputs
        torch.cuda.empty_cache()

In [3]:
def extract_content(text):
    # Extract content within response tags
    response_match = re.search(r'<response>(.*?)</response>', text, re.DOTALL)
    response = response_match.group(1).strip() if response_match else ""
    
    # Extract content within explanation tags
    explanation_match = re.search(r'<explanation>(.*?)</explanation>', text, re.DOTALL)
    explanation = explanation_match.group(1).strip() if explanation_match else ""
    
    # Remove any remaining XML-like tags
    response = re.sub(r'<.*?>', '', response)
    explanation = re.sub(r'<.*?>', '', explanation)
    
    # Remove any repeated content
    response_lines = response.split('\n')
    explanation_lines = explanation.split('\n')
    response = '\n'.join(dict.fromkeys(response_lines))
    explanation = '\n'.join(dict.fromkeys(explanation_lines))
    
    return response, explanation

In [4]:
def generalize_query(query):
    prompt = f"""You are tasked with thoughtfully expanding a user's query to improve recall from a clinical trials database. The goal is to broaden the search terms while maintaining high relevance and preserving the specificity of key medical conditions.

    Original Query:
    <query>
    {query}
    </query>

    Instructions:
    1. Preserve specific medical conditions mentioned in the query.
    2. Include alternative spellings or common abbreviations of medical terms if applicable.
    3. If specific medications are mentioned, include their generic names, but keep the specific medication name in the query.
    4. Do not add related symptoms or comorbidities unless they are direct synonyms of the condition mentioned.
    5. Retain any specific criteria that are central to the query's intent.
    6. Expand only to closely related terms or alternative spellings, without generalizing beyond the specific condition mentioned.

    Provide the expanded query and a brief explanation of your changes using the following format:

    <response>
    [Your expanded query here]
    </response>

    <explanation>
    [Brief explanation of changes made]
    </explanation>
    """

    try:
        # Generate expanded query
        response = generate_text_stream(prompt)
        
        if response:
            expanded_query, explanation = extract_content(response)
            
            # Print the expanded query and explanation
            print("\n" + "="*50)
            print("Original Query:")
            print(query)
            print("\nExpanded Query:")
            print(expanded_query)
            print("\nExplanation:")
            print(explanation)
            print("="*50 + "\n")
            
            return expanded_query, explanation
        else:
            print("Failed to generate expanded query.")
            return None, None
    except Exception as e:
        print(f"\nAn error occurred during query generalization: {str(e)}")
        return None, None
    finally:
        # Clear any remaining GPU memory
        torch.cuda.empty_cache()
        gc.collect()

In [22]:
def find_natural_break(data, min_results=10, threshold_multiplier=3.0):
    if len(data) <= min_results:
        return len(data)
    
    # Calculate differences between consecutive elements
    differences = np.diff(data)
    
    # Calculate the mean and standard deviation of the differences
    mean_diff = np.mean(differences[min_results-1:])  # Consider only differences after min_results
    std_diff = np.std(differences[min_results-1:])
    
    # Set the threshold for a "significant" difference
    threshold = mean_diff + threshold_multiplier * std_diff
    print(threshold)
    
    # Find the first index after min_results where the difference exceeds the threshold
    break_indices = np.where(differences[min_results-1:] > threshold)[0]
    
    if len(break_indices) > 0:
        return min_results + break_indices[0]
    else:
        return len(data)

def tag_relevant_documents(documents, min_results=10):
    
    distances = np.array(documents['distances'][0])
    
    # Find the natural break
    cutoff_index = find_natural_break(distances, min_results)
    
    # Add relevance indicator to each result
    relevance = ['relevant' if i < cutoff_index else 'less_relevant' for i in range(len(distances))]
    
    # Add the relevance indicator to the results
    return {
        'ids': documents['ids'][0],
        'distances': documents['distances'][0],
        'metadatas': documents['metadatas'][0],
        'documents': documents['documents'][0],
        'relevance': relevance
    }
    
def filter_relevant_documents(documents):
        relevant_results = {
        'documents': [],
        'metadatas': [],
        'distances': [],
        'ids': []
    }
    
        for i, relevance in enumerate(documents['relevance']):
            if relevance == 'relevant':
                relevant_results['documents'].append(documents['documents'][i])
                relevant_results['metadatas'].append(documents['metadatas'][i])
                relevant_results['distances'].append(documents['distances'][i])
                relevant_results['ids'].append(documents['ids'][i])
        
        return relevant_results

def plot_search_distances(documents):
    # Extract distances and relevance
    distances = documents['distances']
    relevance = documents['relevance']
    titles = [meta.get('title', f"Document {i}") for i, meta in enumerate(documents['metadatas'])]
    
    # Create indices for x-axis
    indices = list(range(len(distances)))
    
    # Calculate differences
    differences = np.diff(distances)
    diff_indices = indices[1:]  # Differences have one less point
    
    # Separate relevant and less relevant results
    relevant_indices = [i for i, rel in enumerate(relevance) if rel == 'relevant']
    relevant_distances = [distances[i] for i in relevant_indices]
    relevant_titles = [titles[i] for i in relevant_indices]
    
    less_relevant_indices = [i for i, rel in enumerate(relevance) if rel == 'less_relevant']
    less_relevant_distances = [distances[i] for i in less_relevant_indices]
    less_relevant_titles = [titles[i] for i in less_relevant_indices]
    
    # Create the plot with two subplots
    fig = make_subplots(rows=2, cols=1, shared_xaxes=True, vertical_spacing=0.1,
                        subplot_titles=("Search Result Distances", "Differences in Distances"))
    
    # Plot relevant results (distances)
    fig.add_trace(
        go.Scatter(
            x=relevant_indices, 
            y=relevant_distances, 
            mode='markers',
            name='Relevant',
            marker=dict(color='blue', size=10),
            text=relevant_titles,
            hovertemplate="<b>%{text}</b><br>Index: %{x}<br>Distance: %{y:.4f}<extra></extra>"
        ),
        row=1, col=1
    )
    
    # Plot less relevant results (distances)
    fig.add_trace(
        go.Scatter(
            x=less_relevant_indices, 
            y=less_relevant_distances, 
            mode='markers',
            name='Less Relevant',
            marker=dict(color='red', size=10),
            text=less_relevant_titles,
            hovertemplate="<b>%{text}</b><br>Index: %{x}<br>Distance: %{y:.4f}<extra></extra>"
        ),
        row=1, col=1
    )
    
    # Plot differences with consistent color scheme
    fig.add_trace(
        go.Scatter(
            x=diff_indices[:len(relevant_indices)-1],
            y=differences[:len(relevant_indices)-1],
            mode='lines+markers',
            name='Relevant Differences',
            line=dict(color='blue'),
            marker=dict(size=6),
            hovertemplate="Index: %{x}<br>Difference: %{y:.4f}<extra></extra>"
        ),
        row=2, col=1
    )
    
    fig.add_trace(
        go.Scatter(
            x=diff_indices[len(relevant_indices)-1:],
            y=differences[len(relevant_indices)-1:],
            mode='lines+markers',
            name='Less Relevant Differences',
            line=dict(color='red'),
            marker=dict(size=6),
            hovertemplate="Index: %{x}<br>Difference: %{y:.4f}<extra></extra>"
        ),
        row=2, col=1
    )
    
    # Add vertical lines at the break point
    break_point = len(relevant_indices)
    fig.add_vline(x=break_point - 0.5, line_dash="dash", line_color="green", 
                  annotation_text="Relevance Break", row=1, col=1)
    fig.add_vline(x=break_point - 0.5, line_dash="dash", line_color="green", row=2, col=1)
    
    # Customize the plot
    fig.update_layout(
        height=800,  # Increase height to accommodate two subplots
        title_text="Search Result Distances and Differences",
        legend_title="Relevance",
        hovermode='closest'
    )
    
    fig.update_xaxes(title_text="Result Index", row=2, col=1)
    fig.update_yaxes(title_text="Distance", row=1, col=1)
    fig.update_yaxes(title_text="Difference", row=2, col=1)
    
    # Show the plot
    fig.show()

In [25]:
def summarize_trials(trials, query):
    doc_summary_collection = []
    if trials and trials['documents']:
        for i, doc in enumerate(trials['documents']):
            title = trials['metadatas'][i]['title']

            prompt = f"""Summarize the following document in the context of the original query. Focus only on information directly relevant to the query.

            Original Query:
            <query>
            {query}
            </query>

            Document Title: {title}

            Document Content:
            {doc}

            Instructions:
            1. Provide a brief summary (2-3 sentences) highlighting key points relevant to the query.
            2. Mention any specific therapies or treatments discussed in the document that relate to the query.
            3. Include only the most pertinent numerical data, if any.
            4. Use clear, concise language. Avoid repetition.

            Provide the summary and an explanation of the document's relevance using the following format:

            <response>
            [Insert your 2-3 sentence summary here]
            </response>

            <explanation>
            [Explain how this document is relevant to the original query. If not relevant, briefly state why.]
            </explanation>
            """

            # Generate summary
            response = generate_text_stream(prompt)
            
            if response:
                summary, explanation = extract_content(response)
                doc_summary_collection.append({'title': title, 'summary': summary, 'explanation': explanation})
                
                # Print the summary immediately after generation
                print("\n" + "="*50)
                print(f"{i}. Title: {title}")
                print("="*50)
                print("\nSummary:")
                print(summary)
                print("\nExplanation:")
                print(explanation)
                print("\n" + "="*50)
            else:
                print(f"Failed to generate summary for document: {title}")
            
            # Clear any remaining GPU memory
            torch.cuda.empty_cache()
            gc.collect()

    return doc_summary_collection

In [36]:
def generate_overall_summary(trials_summary_collection, query):
    # Combine all summaries into a single string
    all_summaries = "\n\n".join([f"Title: {trial['title']}\nSummary: {trial['summary']}" for trial in trials_summary_collection])

    prompt = f"""Summarize the following document in the context of the original query. Focus only on information directly relevant to the query.

        Original Query:
        <query>
        {query}
        </query>

        Document Content:
        {all_summaries}

        Instructions:
        1. Analyze the document summaries and extract key information relevant to the original query.
        2. Synthesize this information into a coherent, detailed answer addressing the user's question.
        3. Include specific therapies, treatments, or clinical trial details mentioned in the summaries that are directly relevant to the query.
        4. Present any pertinent numerical data or statistics from the summaries.
        5. Ensure the answer is comprehensive yet focused, avoiding irrelevant information.
        6. Use clear, professional language appropriate for discussing medical topics.
        7. If there are conflicting findings or information gaps, mention these briefly.

        Provide your response using the following format:

        <response>
        [Insert your comprehensive summary here. This should be a detailed few paragraphs covering the main points, relevant information, and data that directly address the query.]
        </response>

        <explanation>
        [Explain how you synthesized this information from the document. Discuss any challenges in interpreting the data, conflicting information, or gaps in the available information. Explain how you determined which information was most relevant to include in the summary based on the original query.]
        </explanation>"""

    # Generate overall summary
    response = generate_text_stream(prompt, max_new_tokens=4096)
    
    if response:
        overall_summary, explanation = extract_content(response)
        
        # Print the overall summary immediately after generation
        print("\n" + "="*50)
        print("Overall Summary")
        print("="*50)
        if overall_summary:
            print("\nSummary:")
            print(overall_summary)
        else:
            print("\nWarning: No summary was generated.")
        
        if explanation:
            print("\nExplanation:")
            print(explanation)
        else:
            print("\nWarning: No explanation was provided.")
        print("\n" + "="*50)
    else:
        print("Failed to generate overall summary")
        overall_summary = ""
        explanation = ""
    
    # Clear any remaining GPU memory
    torch.cuda.empty_cache()
    gc.collect()

    return response, overall_summary, explanation

## Generalize the user query

In [None]:
query = "What are the latest therapies for perimenopause?"
# query = "What are the latest therapies for lpa?"
expanded_query, explanation = generalize_query(query)

## Retrieve documents with relevant distance measure and summarize them

In [None]:
min_results = 5
max_results = 100
trials = collection.query(query_texts=expanded_query, n_results=max_results, include=["documents", "metadatas", "distances"])
trials = tag_relevant_documents(trials, min_results=min_results)
plot_search_distances(trials)

In [None]:
relevant_trials = filter_relevant_documents(trials)
trial_summary_collection = summarize_trials(relevant_trials, expanded_query)

In [None]:
response, summary, explanation = generate_overall_summary(trial_summary_collection, expanded_query)