# Abstract Query Tool

Here, we take your abstract text, find the top K most similar abstracts based on SentenceTransformers embeddings,
and generate a summary of the literature using a huggingface LLM (by default LLama 3.2 3B)

## Step 0 — Configuration
Set MongoDB connection info, base model, LoRA weights path, AWS region, and generation parameters.

In [1]:
app_path = '../'
use_adapted_model = False
use_s3 = True
s3_bucket = "watspeed-data-gr-project"
s3_prefix = "models"
local_model_path = "models"
base_model_name = "unsloth/Meta-Llama-3.1-8B" 
adapter_path = None #'unsloth_Llama-3.2-3B_20250812_090718/lora_weights' # path is relative to local_model_path or s3_prefix"


mongo_uri = "mongodb://localhost:27017/"
mongo_db_name = "biorxiv"
mongo_db_collection = "abstracts"


start_date = '2025-07-01'
end_date = None

top_k           = 5
model_max_length = 2048
max_new_tokens = 256
temperature     = 0.5
dtype = None
load_in_4bit = True

abstract_text = """Glioblastomas harbor diverse cell populations, including rare glioblastoma stem cells (GSCs)
                that drive tumorigenesis. To characterize functional diversity within this population, we performed 
                single-cell RNA sequencing on >69,000 GSCs cultured from the tumors of 26 patients. We observed a high 
                degree of inter- and intra-GSC transcriptional heterogeneity that could not be fully explained by DNA 
                somatic alterations. Instead, we found that GSCs mapped along a transcriptional gradient spanning two 
                cellular states reminiscent of normal neural development and inflammatory wound response. 
                Genome-wide CRISPR–Cas9 dropout screens independently recapitulated this observation, with each state 
                characterized by unique essential genes. Further single-cell RNA sequencing of >56,000 malignant cells 
                from primary tumors found that the majority organize along an orthogonal astrocyte maturation gradient 
                yet retain expression of founder GSC transcriptional programs. We propose that glioblastomas grow out 
                of a fundamental GSC-based neural wound response transcriptional program, which is a promising target 
                for new therapy development."""


In [2]:
import json
from functools import partial
import os
import sys
import gc
import warnings
from datetime import datetime
from tqdm import tqdm
from pymongo import MongoClient


import numpy as np
from sklearn.metrics.pairwise import cosine_similarity

import torch
from trl import SFTTrainer
from transformers import TrainingArguments
from unsloth import is_bfloat16_supported
from unsloth import FastLanguageModel # FastLanguageModel for LLMs
from peft import prepare_model_for_kbit_training
from sentence_transformers import SentenceTransformer

  from .autonotebook import tqdm as notebook_tqdm

Please restructure your imports with 'import unsloth' at the top of your file.
  from unsloth import is_bfloat16_supported


🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
🦥 Unsloth Zoo will now patch everything to make training faster!


## Step 1 — Load Models and Data

In [3]:
os.chdir(app_path)

In [4]:
from utils.aws import get_boto3_client
if use_s3:
    s3 = get_boto3_client("s3")

Loaded .env — assuming local environment


In [5]:
if use_adapted_model:
    # if use_s3, download the adapted model from S3 from specified, bucket, prefix and path
    assert adapter_path is not None, "Adapter path must be specified when using adapted model."
    if use_s3:
        # assert s3 handler exists
        assert s3 is not None, "S3 client is not initialized."
        s3_model_path = f"{s3_prefix}/{adapter_path}"
        full_local_model_path = os.path.join(local_model_path, adapter_path)
        # Wipe local directory if it exists
        # if os.path.exists(full_model_local_path):
        #     os.rmdir(full_model_local_path)
        os.makedirs(full_local_model_path, exist_ok=True)
        # List all objects under the prefix
        paginator = s3.get_paginator('list_objects_v2')
        for page in paginator.paginate(Bucket=s3_bucket, Prefix=s3_model_path):
            for obj in page.get('Contents', []):
                key = obj['Key']
                if key.endswith('/'):  # Skip folders
                    continue
                # Determine local file path
                rel_path = os.path.basename(key)
                local_path = os.path.join(full_local_model_path, rel_path)
                os.makedirs(os.path.dirname(local_path), exist_ok=True)
    
                print(f"Downloading {key} to {local_path}")
                s3.download_file(s3_bucket, key, local_path)
    else:
        full_local_model_path = os.path.join(local_model_path, adapter_path)

    model, tokenizer = FastLanguageModel.from_pretrained(
        model_name = full_local_model_path,
        max_seq_length = model_max_length,
        dtype = dtype,
        load_in_4bit = load_in_4bit
        #
    )
else:
    model, tokenizer = FastLanguageModel.from_pretrained(
        model_name = base_model_name,
        max_seq_length = model_max_length,
        dtype = dtype,
        load_in_4bit = load_in_4bit,
        # token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf
    )
    # num_layers = model.config.num_hidden_layers
    # model = FastLanguageModel.get_peft_model(
    #             model,
    #             r = 8, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
    #             target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
    #                               "gate_proj", "up_proj", "down_proj",],
    #             # layers_to_transform=[num_layers - 1],
    #             lora_alpha = 16,
    #             lora_dropout = 0, # Supports any, but = 0 is optimized
    #             bias = "none",    # Supports any, but = "none" is optimized
    #             # [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes!
    #             use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
    #             random_state = 3407,
    #             use_rslora = False,  # We support rank stabilized LoRA
    #             loftq_config = None, # And LoftQ
    #         )

==((====))==  Unsloth 2025.8.4: Fast Llama patching. Transformers: 4.55.0.
   \\   /|    NVIDIA GeForce RTX 4060 Laptop GPU. Num GPUs = 1. Max memory: 7.996 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.7.1+cu126. CUDA: 8.9. CUDA Toolkit: 12.6. Triton: 3.3.1
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.31.post1. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!




In [6]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
try:
    print("PEFT adapters:", getattr(model, "peft_config", None))
except Exception as e:
    print("PEFT not attached?", e)

print("CUDA available:", torch.cuda.is_available(), "| device:", device)

PEFT adapters: None
CUDA available: True | device: cuda


## Step 2 — Load Corpus given specified date range

In [7]:
# --- Mongo → load docs (light filter) ---
col = MongoClient(mongo_uri)[mongo_db_name][mongo_db_collection]
if start_date is None:
    min_doc = col.find_one({"abstract": {"$ne": ""}}, sort=[("date", 1)])
    start_date = min_doc["date"] if min_doc else None
print("Start Date: {}".format(start_date))
if end_date is None:
    max_doc = col.find_one({"abstract": {"$ne": ""}}, sort=[("date", -1)])
    end_date = max_doc["date"] if max_doc else None
print("End Date: {}".format(end_date))
# Step 2: Build the query with date range
query = {
    "abstract": {"$ne": ""},
    "date": {"$gte": start_date, "$lte": end_date}
}
docs = list(col.find(query, {"_id": 1, "title": 1, "abstract": 1}))
assert docs, "No docs found in Mongo. Check DB/collection."
print(f"Loaded {len(docs)} abstracts from Mongo.")



Start Date: 2025-07-01
End Date: 2025-08-14
Loaded 3030 abstracts from Mongo.


# STEP3: Calculate Embeddings and determine top K abstracts

In [8]:
embedder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
# Assuming 'docs' is your list of MongoDB documents
abstracts = [doc['abstract'] for doc in docs if 'abstract' in doc]
# Generate embeddings
embeddings = embedder.encode(abstracts, show_progress_bar=True)

Batches: 100%|██████████████████████████████████████████████████████████████████████████| 95/95 [00:04<00:00, 20.10it/s]


In [9]:
query_embedding = embedder.encode([abstract_text])
similarities = cosine_similarity(query_embedding, embeddings)[0]  # shape: (num_abstracts,)

# Step 3: Get top k indexes
top_k = 5  # or whatever number you want
top_k_indices = np.argsort(similarities)[-top_k:][::-1]  # sorted in descending order

## Step 3 — Use Llama for Summarization
Cleans abstracts, truncates them to safe length, rebuilds embeddings, constructs summarization prompt, and generates output with retry logic to avoid boilerplate.

In [10]:
def summarize_literature(
    model,
    tokenizer,
    query_abstract,
    top_k_abstracts,
    max_new_tokens=256,
    temperature=0.7,
    repetition_penalty=1.0
):
    current_summary = ""

    for i, doc in enumerate(top_k_abstracts):
        abstract_i = doc.get("abstract", "").strip()
        if not abstract_i:
            continue

        # Build prompt as a simple text completion task for a base model
        # The prompt is a single string with no special chat tokens.
        prompt = (
            "You are a scientific summarization assistant.\n\n"
            "Your task is to update a literature summary based on a new abstract and a query abstract.\n\n"
            f"Inputs:\n"
            f"- Current summary: '{current_summary}'\n"
            f"- Query abstract: '{query_abstract}'\n"
            f"- New abstract (ranked {i+1}th most similar to query): '{abstract_i}'\n\n"
            "Instructions:\n"
            "1. If the current summary is empty, initialize it by summarizing how the new abstract relates to the query abstract.\n"
            "2. If the current summary is not empty, update it by integrating any new insights or findings from the new abstract that are relevant to the query abstract.\n\n"
            "Write the updated summary below:\nSummary:"
        )

        # Tokenize prompt and check length
        prompt_tokens = tokenizer(prompt, return_tensors="pt", truncation=False)["input_ids"][0]
        prompt_length = len(prompt_tokens)
        
        # A simple model might have a smaller max length, so this check is crucial.
        model_max_length = model.config.max_position_embeddings
        buffer_tokens = 32

        total_length = prompt_length + max_new_tokens + buffer_tokens
        
        if total_length > model_max_length:
            excess = total_length - model_max_length
            print(f"⚠️ Prompt too long by {excess} tokens. Trimming current summary.")
            
            # Trim current_summary from the head
            summary_tokens = tokenizer(current_summary)["input_ids"]
            trimmed_summary_tokens = summary_tokens[excess:] if excess < len(summary_tokens) else []
            current_summary = tokenizer.decode(trimmed_summary_tokens, skip_special_tokens=True)

            # Rebuild prompt with trimmed summary
            prompt = (
                "You are a scientific summarization assistant. Your task is to update an existing literature summary by integrating new information from a related abstract. "
                "Your output should be a concise, accurate combined summary that reflects both the original and new findings. Use clear scientific language and avoid redundancy.\n\n"
            
                "Here are several examples of how to perform this task:\n\n"
            
                "--- Example 1 ---\n"
                "Query Abstract: This study investigates the anti-inflammatory effects of drug X-123 in murine models, demonstrating significant reductions in cytokine levels and improved recovery times.\n"
                "Current Summary: Drug X-123 reduces inflammation in mice.\n"
                "New Abstract: A related study found that X-123 also enhances tissue regeneration in rats by modulating macrophage activity and suppressing pro-inflammatory signaling pathways.\n"
                "Combined Summary: Drug X-123 reduces inflammation in mice and enhances tissue regeneration in rats by modulating macrophage activity and suppressing pro-inflammatory signaling pathways.\n\n"
            
                "--- Example 2 ---\n"
                "Query Abstract: Researchers evaluated compound Y-456 for its effects on cognitive decline in elderly patients, noting improvements in memory retention and executive function over a 12-week trial.\n"
                "Current Summary: Compound Y-456 improves cognitive function in elderly patients.\n"
                "New Abstract: A follow-up study revealed that Y-456 also reduces oxidative stress in brain tissue and increases synaptic density in the hippocampus.\n"
                "Combined Summary: Compound Y-456 improves cognitive function in elderly patients, reduces oxidative stress in brain tissue, and increases synaptic density in the hippocampus.\n\n"
            
                "--- Example 3 ---\n"
                "Query Abstract: The paper explores the role of protein Z in regulating insulin sensitivity in diabetic mice, showing enhanced glucose tolerance and reduced insulin resistance.\n"
                "Current Summary: Protein Z regulates insulin sensitivity in diabetic mice.\n"
                "New Abstract: Additional research shows that protein Z also promotes glucose uptake in muscle cells and downregulates inflammatory markers associated with metabolic syndrome.\n"
                "Combined Summary: Protein Z regulates insulin sensitivity in diabetic mice, promotes glucose uptake in muscle cells, and downregulates inflammatory markers associated with metabolic syndrome.\n\n"
            
                "Now, perform the task with the following inputs:\n"
                f"Query Abstract: {query_abstract}\n"
                f"Current Summary: {current_summary}\n"
                f"New Abstract (ranked {i+1}th most similar to query): {abstract_i}\n\n"
                "Combined Summary:"
)
            prompt_tokens = tokenizer(prompt, return_tensors="pt", truncation=False)["input_ids"][0]
        
        # Final tokenization
        inputs = tokenizer(prompt, return_tensors="pt", truncation=True).to(model.device)

        # Generate output
        # EOS token ID should be the standard tokenizer EOS for non-instruct models
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            repetition_penalty=repetition_penalty,
            do_sample=True,
            eos_token_id=tokenizer.eos_token_id
        )

        # Decode and clean output
        generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
        if generated_text.startswith(prompt):
            generated_text = generated_text[len(prompt):].strip()

        # Log for debugging
        print(f"\n--- Iteration {i+1} ---")
        print("Prompt token count:", len(prompt_tokens))
        print("Generated token count:", len(tokenizer(generated_text)["input_ids"]))
        print(generated_text)
        
        # Update summary
        current_summary = generated_text.strip()

    return current_summary


In [11]:
top_k_abstracts = [docs[i] for i in top_k_indices]
final_summary = summarize_literature(
    query_abstract =  abstract_text,
    top_k_abstracts = top_k_abstracts,
    model = model,
    tokenizer = tokenizer,
    max_new_tokens = 1014,
    temperature = 0.7,
)

AttributeError: 'Tensor' object has no attribute 'input_ids'

In [None]:
print(final_summary)