In [None]:
# One-shot long-context Q&A and summary with AI21 Jamba Reasoning 3B (no RAG)
import os
import requests, os, pathlib
from pathlib import Path

MODEL_NAME = "ai21labs/AI21-Jamba-Reasoning-3B"
GEN_MAX_NEW_TOKENS = 4096  # reserve generation budget

print("Model:", MODEL_NAME)

PDF_URL = "https://www.oecd.org/content/dam/oecd/en/publications/reports/2025/09/key-findings-and-integration-strategies-on-the-impact-of-digital-technologies-on-students-learning_fad2ee0b/ab309c32-en.pdf"   # <- put your link here
PDF_PATH = "/kaggle/working/doc.pdf"

with requests.get(PDF_URL, stream=True) as r:
    r.raise_for_status()
    with open(PDF_PATH, "wb") as f:
        for chunk in r.iter_content(1024 * 64):
            f.write(chunk)

print("Saved:", PDF_PATH, pathlib.Path(PDF_PATH).stat().st_size, "bytes")


In [None]:
# Load PDF as a single text blob using PyPDFLoader and merge pages
from langchain_community.document_loaders import PyPDFLoader
from langchain.schema import Document

loader = PyPDFLoader(PDF_PATH)
page_docs = loader.load()  # one Document per page
print(f"Loaded {len(page_docs)} pages")

full_text = "\n\n".join(d.page_content for d in page_docs)
merged_doc = Document(page_content=full_text, metadata={"source": PDF_PATH})

print("Merged characters:", len(merged_doc.page_content))



In [None]:
# Load tokenizer/model on GPU and report tokenizer/model limits
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

print("CUDA available:", torch.cuda.is_available())
if not torch.cuda.is_available():
    raise RuntimeError("CUDA is required to test long-context for Jamba")

# Load tokenizer first to inspect max length
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
model_max_len = getattr(tokenizer, "model_max_length", None)
print("Tokenizer model_max_length:", model_max_len)

# Load model
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    trust_remote_code=True,
    dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
)
model.to("cuda")
# Enable fast kernels if available
if hasattr(model.config, "use_mamba_kernels"):
    model.config.use_mamba_kernels = True
model.eval()

print("Model loaded on:", next(model.parameters()).device)



In [None]:
# Helper: apply chat template and generate without truncating the document
from typing import List, Dict

def build_chat_with_document(doc_text: str, user_query: str | None = None) -> List[Dict[str, str]]:
    messages = []
    system_prompt = (
        "You are a helpful assistant. You will receive the full document content in one message. "
        "Answer questions or summarize based strictly on the provided content."
    )
    messages.append({"role": "system", "content": system_prompt})

    if user_query is None:
        user_content = (
            "Read the following document and be ready to answer questions about it.\n\n" + doc_text
        )
    else:
        user_content = (
            "Here is the full document context:\n\n" + doc_text + "\n\nQuestion: " + user_query
        )
    messages.append({"role": "user", "content": user_content})
    return messages


def count_tokens(text: str) -> int:
    return len(tokenizer(text).input_ids)


def generate_with_long_context(messages: List[Dict[str, str]], max_new_tokens: int = GEN_MAX_NEW_TOKENS) -> str:
    try:
        formatted = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
    except Exception:
        # Fallback if chat template not provided
        formatted = "\n\n".join([f"{m['role'].upper()}: {m['content']}" for m in messages])

    # Convert to tokens; we do NOT set truncation=True because we must not drop context
    inputs = tokenizer(formatted, return_tensors="pt")

    # Safety: if too long for the model, raise with actionable hint
    input_len = inputs["input_ids"].shape[-1]
    limit = getattr(tokenizer, "model_max_length", None)
    if limit is not None and input_len + max_new_tokens > limit:
        raise ValueError(
            f"Prompt too long for model context window: input={input_len}, gen={max_new_tokens}, limit={limit}. "
            "Reduce document size or increase context window (different model)."
        )

    inputs = {k: v.to("cuda") for k, v in inputs.items()}
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            temperature=0.6,
            top_p=0.9,
            do_sample=True,
            pad_token_id=tokenizer.eos_token_id,
        )
    return tokenizer.decode(outputs[0], skip_special_tokens=True)



In [None]:
# Inspect token count of the merged document to ensure it fits
merged_tokens = count_tokens(merged_doc.page_content)
print({
    "merged_characters": len(merged_doc.page_content),
    "merged_tokens": merged_tokens,
    "tokenizer_model_max_length": getattr(tokenizer, "model_max_length", None),
    "gen_tokens_budget": GEN_MAX_NEW_TOKENS,
})

In [None]:
def ask(question: str, hide_thinking_trace: bool = True):
    messages = build_chat_with_document(merged_doc.page_content, user_query=question)
    answer = generate_with_long_context(messages, max_new_tokens=GEN_MAX_NEW_TOKENS)
    print("Q:", question)
    print("A:\n", answer.rsplit("</think>")[-1] if hide_thinking_trace else answer)

In [None]:
def summarize(prompt: str = "Summarize the following document concisely, preserving key metrics and conclusions", hide_thinking_trace: bool = True):
    # One-shot summary over the full document (no RAG)
    messages_sum = build_chat_with_document(merged_doc.page_content, user_query=None)
    # Replace the last user content to explicitly ask for a concise summary
    messages_sum[-1]["content"] = (
        f"{prompt}\n\n"
        + merged_doc.page_content
    )
    summary = generate_with_long_context(messages_sum, max_new_tokens=GEN_MAX_NEW_TOKENS)
    summary = summary.rsplit("</think>")[-1] if hide_thinking_trace else summary
    print(summary)



# Q&A and Document Summary


In [None]:
ask("What chapter is devoted to cyberbullying, and what types of interventions does it recommend?")

In [None]:
ask("List all the tables and figures included in the report")

In [None]:
ask('Which table mentions the challenge of "digital equity" and what example studies does it cite as evidence for this challenge?')

In [None]:
summarize()