<a href="https://colab.research.google.com/github/tommasocarzaniga/CNM_CycNucMed/blob/main/EMBA_Exam.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Cyclotrons for Nuclear Medicine

###Setup
First, make the necessary imports.
Note that further imports may have to be made in addition to the ones below, if your application uses additional fetures such as loaders and tools. You can find the code for these imports in the respective sections of the tutorial notebooks.

In [1]:
!pip install -q langchain langchain-community langchain-core langchain-openai langchain-huggingface

from google.colab import userdata
from langchain_core.prompts import PromptTemplate
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
from langchain_community.utilities import GoogleSerperAPIWrapper
from langchain_openai import ChatOpenAI
import os
import pprint
import getpass
from IPython.display import Markdown
import IPython.display as ipd
from PIL import Image
import urllib.request
import hashlib

Then, assign the API keys to be able to use OpenAI, Google Serper, Huggingface, etc.

When working with sensitive information like API keys or passwords in Google Colab, it's crucial to handle data securely. As you learnt in the tutorial session, two common approaches for this are using **Colab's Secrets Manager**, which stores and retrieves secrets without exposing them in the notebook, and `getpass`, a Python function that securely prompts users to input secrets during runtime without showing them. Both methods help ensure your sensitive data remains protected.

In [2]:
#You can remove the keys you will not use

#When using Colab Secret Manager
os.environ["OPENAI_API_KEY"] = userdata.get('OPENAI_API_KEY')
#When using getpass
#os.environ['OPENAI_API_KEY'] = getpass.getpass()

#When using Colab Secret Manager
os.environ["SERPER_API_KEY"] = userdata.get('SERPER_API_KEY')
#When using getpass
#os.environ['SERPER_API_KEY'] = getpass.getpass()

#When using Colab Secret Manager
os.environ["HF_TOKEN"] = userdata.get('HF_TOKEN')
#When using getpass
#os.environ['HF_TOKEN'] = getpass.getpass()

Why do I have to install this:

In [3]:
!apt-get update
!apt-get install -y \
  libatk1.0-0 \
  libatk-bridge2.0-0 \
  libcups2 \
  libdrm2 \
  libxkbcommon0 \
  libxcomposite1 \
  libxdamage1 \
  libxfixes3 \
  libxrandr2 \
  libgbm1 \
  libpango-1.0-0 \
  libcairo2 \
  libasound2

0% [Working]            Hit:1 https://cli.github.com/packages stable InRelease
0% [Connecting to archive.ubuntu.com (185.125.190.81)] [Connecting to security.                                                                               Hit:2 https://cloud.r-project.org/bin/linux/ubuntu jammy-cran40/ InRelease
0% [Waiting for headers] [Waiting for headers] [Connected to r2u.stat.illinois.                                                                               Hit:3 http://security.ubuntu.com/ubuntu jammy-security InRelease
Hit:4 http://archive.ubuntu.com/ubuntu jammy InRelease
Hit:5 https://r2u.stat.illinois.edu/ubuntu jammy InRelease
Hit:6 http://archive.ubuntu.com/ubuntu jammy-updates InRelease
Hit:7 http://archive.ubuntu.com/ubuntu jammy-backports InRelease
Hit:8 https://ppa.launchpadcontent.net/deadsnakes/ppa/ubuntu jammy InRelease
Hit:9 https://ppa.launchpadcontent.net/ubuntugis/ppa/ubuntu jammy InRelease
Reading package lists... Done
W: Skipping acquire of configur

Why do I have to install this:

In [4]:
!pip install playwright pandas
!playwright install chromium



What am I doing here:

In [5]:
from playwright.async_api import async_playwright

async def test_playwright():
    async with async_playwright() as p:
        browser = await p.chromium.launch(headless=True)
        page = await browser.new_page()
        await page.goto("https://nucleus.iaea.org/sites/accelerators/Pages/Cyclotron.aspx")
        print(await page.title())
        await browser.close()

await test_playwright()


Pages - Cyclotrons used for Radionuclide Production


Web scraping of the IAEA cyclotron database

This script automatically extracts structured data from the IAEA public web page
listing cyclotron facilities worldwide:
https://nucleus.iaea.org/sites/accelerators/Pages/Cyclotron.aspx

The website is implemented using SharePoint and loads data dynamically across
multiple pages. Because of this, a simple HTTP request is not sufficient and a
browser automation tool (Playwright) is used to simulate real user navigation.

Main features of the script:
- Uses Playwright (async) to control a headless Chromium browser.
- Navigates through all pages of the table by clicking the "Next" button.
- Extracts tabular data from each page.
- Cleans and normalizes rows to handle SharePoint quirks (e.g. header rows
  injected into the table body, extra cells in the first row of each page).
- Deduplicates rows using a hash-based fingerprint.
- Stores the final structured dataset into a CSV file.

Extracted fields:
- Country
- City
- Facility
- Manufacturer
- Model
- Proton energy (MeV)

The final output is saved as:
iaea_cyclotrons_normalized.csv

This approach demonstrates:
- Practical web scraping of JavaScript-heavy websites
- Asynchronous programming in Python
- Robust data cleaning
- Reproducible data extraction for research purposes

In [6]:
# ============================================================
# IAEA Cyclotron List Scraper (SharePoint) — Colab-ready
# ------------------------------------------------------------
# What this does:
# - Opens the IAEA SharePoint list view of cyclotrons
# - Iterates through pages by clicking the "Next" button
# - Extracts the table rows efficiently (one JS call per page)
# - Cleans SharePoint quirks (header labels inside tbody, multi-line / tabbed cells)
# - Deduplicates rows (hash fingerprint)
# - Saves a CSV sorted by Country and City
#
# Output:
#   /content/iaea_cyclotrons.csv
# ============================================================

import asyncio
import pandas as pd
from playwright.async_api import async_playwright
import os
import hashlib
import re

# -----------------------------
# Configuration
# -----------------------------
BASE_URL = "https://nucleus.iaea.org/sites/accelerators/Pages/Cyclotron.aspx"
HASH_PREFIX = "#InplviewHashd5afe566-18ad-4ac0-8aeb-ccf833dbc282="
OUTPUT_CSV = os.path.join(os.getcwd(), "iaea_cyclotrons.csv")

# Expected columns in the IAEA list view
EXPECTED_COLS = 6  # Country, City, Facility, Manufacturer, Model, Proton energy (MeV)

# SharePoint sometimes injects these header labels inside table rows
HEADER_LABELS = [
    "country", "city", "facility", "manufacturer", "model",
    "proton energy (mev)", "proton energy"
]
HEADER_SET = set(HEADER_LABELS)

# IMPORTANT: target the actual SharePoint "list view" table (reduces missing rows)
TABLE_ROW_SELECTOR = "table.ms-listviewtable tbody tr"


# -----------------------------
# Helpers for cleaning/parsing
# -----------------------------
def norm_text(s: str) -> str:
    """Normalize text for comparisons."""
    return " ".join((s or "").split()).strip().lower()


def row_fingerprint(cells):
    """
    Make a stable hash of the cleaned row values.
    Used to deduplicate rows across pages (SharePoint sometimes repeats).
    """
    norm = [" ".join((c or "").split()) for c in cells]
    return hashlib.md5(" | ".join(norm).encode("utf-8")).hexdigest()


def strip_header_prefix(cell: str) -> str:
    """
    Convert 'City: Vienna' -> 'Vienna' (for known header labels).
    Sometimes SharePoint prepends labels inside a cell.
    """
    c = (cell or "").strip()
    c_norm = norm_text(c)
    for h in HEADER_LABELS:
        if re.match(rf"^{re.escape(h)}(\s*[:\-]?\s+)", c_norm):
            return re.sub(rf"(?i)^{re.escape(h)}\s*[:\-]?\s+", "", c).strip()
    return c


def flatten_multiline_cells(raw_cells):
    """
    SharePoint quirks:
    - multiple values in one cell separated by tabs
    - multiple values separated by newlines
    This function splits on tabs/newlines and flattens into a token list.
    """
    tokens = []
    for c in raw_cells:
        if not c:
            continue
        parts = re.split(r"[\t\r\n]+", str(c))
        for part in parts:
            part = part.strip()
            if part:
                tokens.append(part)
    return tokens


def clean_and_align_tokens(raw_cells):
    """
    Convert raw table cells -> exactly 6 cleaned fields.
    Strategy:
    - flatten tabs/newlines
    - remove header-label tokens
    - strip 'Label: value' prefixes
    - if too many tokens remain, pick the best contiguous window of length 6
    """
    tokens = flatten_multiline_cells(raw_cells)

    processed = []
    for t in tokens:
        t_norm = norm_text(t)
        if t_norm in HEADER_SET:
            continue

        t2 = strip_header_prefix(t)
        t2_norm = norm_text(t2)
        if not t2 or t2_norm in HEADER_SET:
            continue

        processed.append(t2.strip())

    # Not enough data to form a row
    if len(processed) < EXPECTED_COLS:
        return None

    # Exactly the correct number of fields
    if len(processed) == EXPECTED_COLS:
        # Drop if it still looks like a header row
        if any(norm_text(x) in HEADER_SET for x in processed):
            return None
        return processed

    # More than 6 tokens => choose the best "slice" of 6 tokens
    def badness(x: str) -> int:
        xn = norm_text(x)
        if xn in HEADER_SET:
            return 100
        if xn.isdigit():  # sometimes SharePoint injects numeric IDs
            return 10
        if any(xn.startswith(h + " ") for h in HEADER_LABELS):  # e.g. "city zurich"
            return 10
        return 0

    best_window = None
    best_score = None

    for start in range(0, len(processed) - EXPECTED_COLS + 1):
        window = processed[start:start + EXPECTED_COLS]
        if any(norm_text(w) in HEADER_SET for w in window):
            continue

        score = sum(badness(w) for w in window)
        if best_score is None or score < best_score:
            best_score = score
            best_window = window

    return best_window


# -----------------------------
# Page navigation helpers
# -----------------------------
async def wait_for_table_refresh(page, prev_first_row_text, timeout_ms=20000):
    """
    After clicking Next, wait until the first row content changes.
    This is more reliable than a fixed sleep on SharePoint pages.
    """
    try:
        await page.wait_for_function(
            """(prev) => {
                const r = document.querySelector('table.ms-listviewtable tbody tr');
                return r && r.innerText && r.innerText !== prev;
            }""",
            arg=prev_first_row_text,
            timeout=timeout_ms
        )
    except:
        # Fallback: small sleep if SharePoint is slow
        await page.wait_for_timeout(1200)


async def robust_click(el):
    """
    SharePoint "Next" can be visible but outside viewport.
    Try:
    1) scroll into view
    2) force-click
    3) JS click as fallback
    """
    try:
        await el.scroll_into_view_if_needed()
    except:
        pass

    try:
        await el.click(force=True, timeout=20000)
        return True
    except:
        try:
            await el.evaluate("node => node.click()")
            return True
        except:
            return False


# -----------------------------
# Main scraper
# -----------------------------
async def scrape_all_pages():
    all_rows = []
    seen_rows = set()

    async with async_playwright() as p:
        # Launch headless browser
        browser = await p.chromium.launch(headless=True)

        # Create a browser context and block heavy resources for speed
        context = await browser.new_context()
        await context.route(
            "**/*",
            lambda route: route.abort()
            if route.request.resource_type in ("image", "font", "media", "stylesheet")
            else route.continue_()
        )

        page = await context.new_page()
        page.set_default_timeout(20000)

        # Open the first page (hash-based view)
        url = BASE_URL + HASH_PREFIX
        print(f"Loading first page: {url}")
        await page.goto(url, timeout=60000, wait_until="domcontentloaded")

        page_index = 0

        while True:
            page_index += 1
            print(f"\n--- Page {page_index} ---")

            # Ensure list view rows exist
            try:
                await page.wait_for_selector(TABLE_ROW_SELECTOR, timeout=20000)
            except:
                print("No list view table rows found → stopping")
                break

            # Extract all rows from the list view in ONE browser call (fast)
            raw_rows = await page.eval_on_selector_all(
                TABLE_ROW_SELECTOR,
                """trs => trs.map(tr =>
                    Array.from(tr.querySelectorAll('td')).map(td => td.innerText)
                )"""
            )

            print(f"Rows extracted from DOM this page: {len(raw_rows)}")

            new_rows_this_page = 0
            kept_after_parsing = 0

            # Parse each extracted row
            for raw_cells in raw_rows:
                cells = clean_and_align_tokens(raw_cells)
                if cells is None or len(cells) != EXPECTED_COLS:
                    continue

                # Safety: never allow header labels as data
                if any(norm_text(x) in HEADER_SET for x in cells):
                    continue

                kept_after_parsing += 1

                # Deduplicate across pages
                fp = row_fingerprint(cells)
                if fp in seen_rows:
                    continue

                seen_rows.add(fp)
                new_rows_this_page += 1

                all_rows.append({
                    "Country": cells[0],
                    "City": cells[1],
                    "Facility": cells[2],
                    "Manufacturer": cells[3],
                    "Model": cells[4],
                    "Proton energy (MeV)": cells[5],
                })

            print(f"Rows kept after parsing this page: {kept_after_parsing}")
            print(f"New unique rows added this page: {new_rows_this_page}")
            print(f"Total unique rows so far: {len(all_rows)}")

            # Capture first row text to detect refresh after clicking Next
            try:
                prev_first = await page.locator(TABLE_ROW_SELECTOR).first.inner_text()
            except:
                prev_first = ""

            # Find "Next" and click it
            next_el = page.locator(
                'a[title="Next"], a[aria-label="Next"], a[title="Next page"], a[aria-label="Next page"], a:has-text("Next")'
            ).first

            # Stop if no next page
            if await next_el.count() == 0 or not await next_el.is_visible() or not await next_el.is_enabled():
                print("No Next button → stopping")
                break

            ok = await robust_click(next_el)
            if not ok:
                print("Could not click Next → stopping")
                break

            # Wait for the table to refresh
            await wait_for_table_refresh(page, prev_first)

        await context.close()
        await browser.close()

    # Build final DataFrame, sort, and save
    df = pd.DataFrame(all_rows)
    df = df.sort_values(["Country", "City"], kind="mergesort").reset_index(drop=True)
    df.to_csv(OUTPUT_CSV, index=False)

    print("\nDONE")
    print(f"Saved {len(df)} unique cyclotron rows to:")
    print(OUTPUT_CSV)


# -----------------------------
# Run in Colab (top-level await is supported in Colab notebooks)
# -----------------------------
await scrape_all_pages()

Loading first page: https://nucleus.iaea.org/sites/accelerators/Pages/Cyclotron.aspx#InplviewHashd5afe566-18ad-4ac0-8aeb-ccf833dbc282=

--- Page 1 ---
Rows extracted from DOM this page: 30
Rows kept after parsing this page: 30
New unique rows added this page: 30
Total unique rows so far: 30

--- Page 2 ---
Rows extracted from DOM this page: 30
Rows kept after parsing this page: 30
New unique rows added this page: 29
Total unique rows so far: 59

--- Page 3 ---
Rows extracted from DOM this page: 30
Rows kept after parsing this page: 30
New unique rows added this page: 30
Total unique rows so far: 89

--- Page 4 ---
Rows extracted from DOM this page: 30
Rows kept after parsing this page: 29
New unique rows added this page: 24
Total unique rows so far: 113

--- Page 5 ---
Rows extracted from DOM this page: 30
Rows kept after parsing this page: 30
New unique rows added this page: 27
Total unique rows so far: 140

--- Page 6 ---
Rows extracted from DOM this page: 30
Rows kept after parsing 

Now the first part of the multimodal

Creation of the function: print_country_report

In [7]:
import pandas as pd
import re

# =========================
# 1) Load your CSV
# =========================
CSV_PATH = "/content/iaea_cyclotrons.csv"   # <- adjust if different
df = pd.read_csv(CSV_PATH)

# Normalize whitespace (helpful for grouping)
for col in ["Country","City","Facility","Manufacturer","Model","Proton energy (MeV)"]:
    if col in df.columns:
        df[col] = df[col].astype(str).str.strip()

# =========================
# 2) Helper: parse energy to numeric (best-effort)
#    Handles: "11", "16.5", "16-18", "16 / 18", etc.
# =========================
def parse_energy_to_float(x):
    if x is None:
        return None
    s = str(x).strip()
    if s == "" or s.lower() in ("nan", "none"):
        return None
    # find numbers in the string
    nums = re.findall(r"\d+(?:\.\d+)?", s)
    if not nums:
        return None
    # if range-like, take the max (or change to avg if you prefer)
    vals = [float(n) for n in nums]
    return max(vals)

df["Energy_num"] = df["Proton energy (MeV)"].apply(parse_energy_to_float)

# =========================
# 3) Country summary function
# =========================
def country_summary(country, top_n=15):
    """
    Return a structured summary for a country:
    - total cyclotrons
    - cities list + counts
    - facilities list + counts
    - manufacturer / model breakdown
    - energy stats
    """
    # case-insensitive match
    sub = df[df["Country"].str.lower() == str(country).strip().lower()].copy()
    if sub.empty:
        # fuzzy suggestion: show close matches by containment
        candidates = df[df["Country"].str.lower().str.contains(str(country).strip().lower(), na=False)]["Country"].unique()
        return {
            "country": country,
            "found": False,
            "message": f"No exact match for '{country}'.",
            "did_you_mean": sorted(candidates)[:20]
        }

    total = len(sub)

    cities = (sub.groupby("City")
                .size()
                .sort_values(ascending=False))

    facilities = (sub.groupby("Facility")
                    .size()
                    .sort_values(ascending=False))

    manufacturers = (sub.groupby("Manufacturer")
                       .size()
                       .sort_values(ascending=False))

    models = (sub.groupby("Model")
                .size()
                .sort_values(ascending=False))

    # manufacturer-model combo
    manu_model = (sub.groupby(["Manufacturer","Model"])
                    .size()
                    .sort_values(ascending=False))

    energy_stats = {
        "count_numeric": int(sub["Energy_num"].notna().sum()),
        "min": float(sub["Energy_num"].min()) if sub["Energy_num"].notna().any() else None,
        "median": float(sub["Energy_num"].median()) if sub["Energy_num"].notna().any() else None,
        "max": float(sub["Energy_num"].max()) if sub["Energy_num"].notna().any() else None,
    }

    return {
        "country": country,
        "found": True,
        "total_cyclotrons": total,
        "cities_top": cities.head(top_n),
        "facilities_top": facilities.head(top_n),
        "manufacturers": manufacturers,
        "models_top": models.head(top_n),
        "manufacturer_model_top": manu_model.head(top_n),
        "energy_stats": energy_stats,
        "all_cities_count": int(cities.shape[0]),
        "all_facilities_count": int(facilities.shape[0]),
    }

# =========================
# 4) Pretty-print function (human readable)
# =========================
def print_country_report(country, top_n=10):
    out = country_summary(country, top_n=top_n)
    if not out["found"]:
        print(out["message"])
        if out.get("did_you_mean"):
            print("Did you mean one of these?")
            for c in out["did_you_mean"]:
                print(" -", c)
        return

    print(f"=== {out['country']} ===")
    print(f"Total cyclotrons: {out['total_cyclotrons']}")
    print(f"Cities covered: {out['all_cities_count']}")
    print(f"Facilities covered: {out['all_facilities_count']}")
    print()

    es = out["energy_stats"]
    print("Energy (MeV) stats (numeric rows only):")
    print(f"  numeric entries: {es['count_numeric']}")
    print(f"  min / median / max: {es['min']} / {es['median']} / {es['max']}")
    print()

    print(f"Top {top_n} cities by count:")
    print(out["cities_top"].to_string())
    print()

    print(f"Top {top_n} facilities by count:")
    print(out["facilities_top"].to_string())
    print()

    print("Manufacturer counts:")
    print(out["manufacturers"].to_string())
    print()

    print(f"Top {top_n} models:")
    print(out["models_top"].to_string())
    print()

    print(f"Top {top_n} (Manufacturer, Model) pairs:")
    print(out["manufacturer_model_top"].to_string())
    print()

# =========================
# 5) Example usage
# =========================
print_country_report("Switzerland", top_n=10)

=== Switzerland ===
Total cyclotrons: 4
Cities covered: 4
Facilities covered: 4

Energy (MeV) stats (numeric rows only):
  numeric entries: 4
  min / median / max: 16.0 / 17.0 / 18.0

Top 10 cities by count:
City
Bern         1
Genève       1
Schlieren    1
Zürich       1

Top 10 facilities by count:
Facility
SWAN / Uni. Bern                                                             1
Universitätsspital Zueurich, Labor Schlieren, Klinik für Onkologie (Wagi)    1
Universitätsspital Zürich (USZ)                                              1
unspecified                                                                  1

Manufacturer counts:
Manufacturer
GE     2
IBA    2

Top 10 models:
Model
PETtrace            2
CYCLONE 18          1
CYCLONE 18/18 HC    1

Top 10 (Manufacturer, Model) pairs:
Manufacturer  Model           
GE            PETtrace            2
IBA           CYCLONE 18          1
              CYCLONE 18/18 HC    1



Ora altro

In [8]:
# Step 1: Set up OpenAI LLM
from langchain_openai import ChatOpenAI
llm = ChatOpenAI(temperature=0.9, model="gpt-4.1-mini")


Convert the structured report into a prompt for the LLM

In [9]:
def country_report_for_llm(country, top_n=10):
    out = country_summary(country, top_n=top_n)

    if not out["found"]:
        return f"No data found for country: {country}"

    # Convert tables into readable text blocks
    cities_text = out["cities_top"].to_string()
    facilities_text = out["facilities_top"].to_string()
    manufacturers_text = out["manufacturers"].to_string()
    models_text = out["models_top"].to_string()

    es = out["energy_stats"]
    energy_text = (
        f"Numeric entries: {es['count_numeric']}\n"
        f"Min energy: {es['min']}\n"
        f"Median energy: {es['median']}\n"
        f"Max energy: {es['max']}"
    )

    # Final report string fed to the LLM
    report = f"""
IAEA Cyclotron Dataset Report for {out['country']}

Total cyclotrons: {out['total_cyclotrons']}
Cities covered: {out['all_cities_count']}
Facilities covered: {out['all_facilities_count']}

Top cities:
{cities_text}

Top facilities:
{facilities_text}

Manufacturers:
{manufacturers_text}

Top models:
{models_text}

Energy statistics:
{energy_text}
"""

    return report.strip()

This cell sends a structured country-level report to an LLM and asks it to generate a concise executive summary.
The model is prompted to focus on market-relevant aspects such as:
- Infrastructure scale
- Geographic concentration
- Major suppliers
- Notable patterns in the data
The output is rendered nicely using Markdown.


In [10]:
from IPython.display import Markdown, display

country = "Italy"

report_text = country_report_for_llm(country)

response = llm.invoke(
    f"""
You are a market analyst in nuclear medicine.
Write a concise executive summary (max 150 words) based on this dataset report.

Focus on:
- Overall infrastructure scale
- Geographic concentration
- Major suppliers
- Any notable patterns

Report:
{report_text}
"""
)

display(Markdown(response.content))

Italy’s cyclotron infrastructure comprises 40 units across 39 facilities in 30 cities, indicating broad but moderately concentrated nuclear medicine capacity. The highest cyclotron density is in Rome and Milan, each with 4 units, followed by Naples with 3, reflecting key urban hubs for radiopharmaceutical production. Major suppliers dominate the market, with GE leading (20 units), followed by IBA (13) and Siemens (4), suggesting a preference for well-established manufacturers. The PETtrace, CYCLONE 18, and MiniTrace models are the most prevalent, accounting for nearly three-quarters of installations, signaling a focus on proven, versatile systems. Energy output ranges from 10 to 19 MeV, with a median of 16 MeV, suitable for diverse isotope production. Notably, a small number of facilities operate multiple cyclotrons, such as Castelfranco Veneto Radiopharmacy with 2 units, underscoring targeted investment in select centers. This distribution supports a robust but regionally focused cyclotron network aligned with clinical and research demands.

Now we should make an analysis of all countries in the IAEA database and create a document with all executive summaries.

Install dependencies

In [11]:
!pip -q install reportlab geopy geopandas shapely pyproj fiona


Check the country name for consistency

In [12]:
!pip -q install country_converter

import os
import json
import country_converter as coco
import pandas as pd

# File paths
IN_CSV  = "/content/iaea_cyclotrons.csv"
OUT_CSV = "/content/iaea_cyclotrons_clean.csv"
CACHE   = "/content/country_fix_cache.json"

df = pd.read_csv(IN_CSV)

# Load or create cache for LLM fixes
if os.path.exists(CACHE):
    with open(CACHE, "r", encoding="utf-8") as f:
        fix_cache = json.load(f)
else:
    fix_cache = {}

# ✅ Allowed SHORT names
allowed = set(coco.CountryConverter().data["name_short"].dropna().astype(str))

def llm_fix_country(raw):
    raw = str(raw).strip()
    if raw in fix_cache:
        return fix_cache[raw]

    prompt = f"""
Map this value to the correct country name (common short form).
Return only the country name, nothing else.

Examples of desired style: Italy, Belarus, North Macedonia, Czechia, Russia.

Input: {raw}
""".strip()

    resp = llm.invoke(prompt)
    candidate = (resp.content or "").strip().strip('"').strip("'")

    if candidate in allowed:
        fixed = candidate
    else:
        fixed = coco.convert(names=candidate, to="name_short", not_found=raw)

    fix_cache[raw] = fixed
    with open(CACHE, "w", encoding="utf-8") as f:
        json.dump(fix_cache, f, ensure_ascii=False, indent=2)

    return fixed

# -----------------------
# Step A: deterministic conversion first (SHORT names)
# -----------------------
unique = df["Country"].dropna().unique()
short_names = coco.convert(names=list(unique), to="name_short", not_found=None)
mapping = dict(zip(unique, short_names))
df["Country_clean"] = df["Country"].map(mapping)

# -----------------------
# Step B: LLM fallback only for missing
# -----------------------
missing = df.loc[df["Country_clean"].isna(), "Country"].unique()
print("LLM resolving:", missing)

for val in missing:
    fixed = llm_fix_country(val)
    df.loc[df["Country"] == val, "Country_clean"] = fixed

# -----------------------
# Step C: overwrite original column
# -----------------------
df["Country"] = df["Country_clean"]
df = df.drop(columns=["Country_clean"])

# -----------------------
# Step C2: ADD ISO3 (use this for maps!)
# -----------------------
df["Country_iso3"] = coco.convert(
    names=df["Country"].fillna(""),
    to="ISO3",
    not_found=None
)

# Quick debug for your two cases
print("\nCheck Bahrain / Belarus:")
print(df[df["Country"].isin(["Bahrain", "Belarus"])][["Country", "Country_iso3"]].drop_duplicates())

# Show any countries still failing ISO3 conversion (these will fail mapping too)
bad = df.loc[df["Country"].notna() & df["Country_iso3"].isna(), "Country"].unique()
print("\nCountries with missing ISO3:", bad)

# -----------------------
# Step D: save cleaned file
# -----------------------
df.to_csv(OUT_CSV, index=False, encoding="utf-8")
print(f"\n✅ Cleaned file saved to: {OUT_CSV}")



LLM resolving: []





Check Bahrain / Belarus:
    Country Country_iso3
36  Bahrain          BHR
39  Belarus          BLR

Countries with missing ISO3: []

✅ Cleaned file saved to: /content/iaea_cyclotrons_clean.csv


In [13]:
print(sorted(df["Country"].unique()))

['Algeria', 'Argentina', 'Armenia', 'Australia', 'Austria', 'Azerbaijan', 'Bahrain', 'Bangladesh', 'Belarus', 'Belgium', 'Bolivia', 'Brazil', 'Brunei Darussalam', 'Bulgaria', 'Canada', 'Chile', 'China', 'Colombia', 'Costa Rica', 'Croatia', 'Cuba', 'Cyprus', 'Czechia', 'Denmark', 'Dominican Republic', 'Dubai', 'Ecuador', 'Egypt', 'Finland', 'France', 'Germany', 'Greece', 'Hungary', 'Iceland', 'India', 'Indonesia', 'Iran', 'Iraq', 'Ireland', 'Israel', 'Italy', 'Jamaica', 'Japan', 'Jordan', 'Kazakhstan', 'Kenya', 'Kuwait', 'Latvia', 'Lebanon', 'Libya', 'Lithuania', 'Malaysia', 'Mexico', 'Morocco', 'Myanmar', 'Netherlands', 'New Zealand', 'Nigeria', 'North Macedonia', 'Northern Ireland', 'Norway', 'Oman', 'Pakistan', 'Panama', 'Peru', 'Philippines', 'Poland', 'Portugal', 'Puerto Rico', 'Qatar', 'Romania', 'Russia', 'Saudi Arabia', 'Singapore', 'Slovakia', 'South Africa', 'South Korea', 'Spain', 'Sweden', 'Switzerland', 'Syria', 'Taiwan', 'Thailand', 'Tunisia', 'Türkiye', 'Ukraine', 'United

Manufacturer cleaner

In [14]:
# ============================================================
# Manufacturer cleaning (LLM-heavy) with:
# - Canon set (persisted)
# - LLM cache (persisted)
# - Manual truth rules (Siemens/CTI, PMB->Avelion, ABT->BCS, ACSI/ASCI->Advanced Cyclotron Systems)
# - Guardrails to prevent "everything maps to one canonical"
# - Single entry point: canonicalize_manufacturers(..., overwrite=True)
# ============================================================

import os, json, re
import pandas as pd

# Optional but recommended for guardrails similarity checks
!pip -q install rapidfuzz
from rapidfuzz import fuzz

# -----------------------
# Paths
# -----------------------
CANON_PATH = "/content/manufacturer_canon_set.json"
CACHE_PATH = "/content/manufacturer_llm_cache.json"

# -----------------------
# Seed canon (your known truths)
# -----------------------
SEED_CANON = {
    "Siemens Healthineers",        # acquired CTI
    "Avelion",                     # PMB -> Avelion (Alcen)
    "Best Cyclotron Systems",      # ABT -> BCS
    "Advanced Cyclotron Systems",  # ACSI/ASCI
    "Rosatom",                     # Rosatom
}

# -----------------------
# Load/init canon + cache
# -----------------------
if os.path.exists(CANON_PATH):
    with open(CANON_PATH, "r", encoding="utf-8") as f:
        canon_set = set(json.load(f))
else:
    canon_set = set(SEED_CANON)

if os.path.exists(CACHE_PATH):
    with open(CACHE_PATH, "r", encoding="utf-8") as f:
        llm_cache = json.load(f)
else:
    llm_cache = {}

def save_canon():
    with open(CANON_PATH, "w", encoding="utf-8") as f:
        json.dump(sorted(canon_set), f, ensure_ascii=False, indent=2)

def save_cache():
    with open(CACHE_PATH, "w", encoding="utf-8") as f:
        json.dump(llm_cache, f, ensure_ascii=False, indent=2)

# -----------------------
# Basic normalization
# -----------------------
def basic_cleanup(s: str) -> str:
    if s is None or (isinstance(s, float) and pd.isna(s)):
        return ""
    s = str(s).strip()
    s = re.sub(r"\s+", " ", s)
    s = re.sub(r"[,\.;:\-]+$", "", s).strip()
    return s

def norm_key(s: str) -> str:
    s = basic_cleanup(s).lower()
    s = re.sub(r"\(([^)]{1,20})\)", " ", s)   # remove short bracket chunks
    s = re.sub(r"[^\w\s]", " ", s)            # remove punctuation
    s = re.sub(r"\s+", " ", s).strip()
    return s

# -----------------------
# Manual truth rules (highest priority)
# -----------------------
def manual_map(raw: str) -> str | None:
    k = norm_key(raw)
    if not k:
        return None

    # Rossatom -> Rosatom (typo)
    if k in {"rossatom", "ross atom", "rosatom"} or "rossatom" in k:
        return "Rosatom"

    # ACSI/ASCI (typo) -> Advanced Cyclotron Systems
    if (
        k in {"acsi", "asci"}
        or k.startswith("acsi ")
        or k.startswith("asci ")
        or "advanced cyclotron systems" in k
    ):
        return "Advanced Cyclotron Systems"

    # Siemens / CTI
    if "siemens" in k or k == "cti" or k.startswith("cti "):
        return "Siemens Healthineers"

    # PMB -> Avelion
    if "pmb" in k:
        return "Avelion"

    # ABT -> Best Cyclotron Systems
    if (
        k == "abt"
        or k.startswith("abt ")
        or "advanced beam technologies" in k
        or "bcs" in k
        or "best cyclotron systems" in k
    ):
        return "Best Cyclotron Systems"

    return None

def looks_like_acsi(raw: str) -> bool:
    k = norm_key(raw)
    return (
        k in {"acsi", "asci"}
        or k.startswith("acsi ")
        or k.startswith("asci ")
        or "advanced cyclotron systems" in k
    )

# -----------------------
# LLM chooser with guardrails
# - LLM must return exact canon OR NEW:<name>
# - If LLM chooses a canon, we validate similarity; otherwise force NEW
# - Special-case: only allow mapping to "Advanced Cyclotron Systems" if it truly looks like ACSI/ASCI
# -----------------------
def llm_choose_canonical_or_new(raw: str, canon_list: list[str]) -> str:
    raw0 = basic_cleanup(raw)
    if not raw0:
        return raw0

    if raw0 in llm_cache:
        return llm_cache[raw0]

    canon_preview = canon_list if len(canon_list) <= 400 else canon_list[:400]

    prompt = f"""
You standardize manufacturer/company names.

Return EXACTLY one of:
1) An EXACT string from the canonical list below (character-for-character), OR
2) NEW:<canonical name> if it is NOT clearly the same company.

CRITICAL RULE:
- Do NOT guess. If you are not highly confident it's the same company, return NEW:<...>.
- Only choose an existing canonical when the input is the same company name, an obvious acronym, or a trivial typo.

Rules:
- Remove legal suffixes (Inc., GmbH, AG, SA, Co., Ltd., Ltd, LLC, BV, SRL, SpA, etc.)
- Resolve acronyms when appropriate
- Prefer the best-known short public brand name
- Output ONLY the chosen canonical OR NEW:<...>

Canonical list:
{chr(10).join("- " + c for c in canon_preview)}

Input: {raw0}
""".strip()

    resp = llm.invoke(prompt)
    ans = (resp.content or "").strip().strip('"').strip("'")

    # Parse NEW:
    if ans.startswith("NEW:"):
        canon = basic_cleanup(ans[4:].strip())
        llm_cache[raw0] = canon
        save_cache()
        return canon

    chosen = basic_cleanup(ans)

    # If LLM chose a canonical, validate.
    if chosen in canon_set:
        # Special case: don't let everything collapse to ACS
        if chosen == "Advanced Cyclotron Systems":
            if looks_like_acsi(raw0):
                llm_cache[raw0] = chosen
                save_cache()
                return chosen
            else:
                # reject this mapping; treat as NEW using cleaned raw
                new_name = basic_cleanup(raw0)
                llm_cache[raw0] = new_name
                save_cache()
                return new_name

        # Generic validation for other canonicals
        score = fuzz.token_sort_ratio(norm_key(raw0), norm_key(chosen))
        if score >= 90:
            llm_cache[raw0] = chosen
            save_cache()
            return chosen
        else:
            # too dissimilar -> force NEW
            new_name = basic_cleanup(raw0)
            llm_cache[raw0] = new_name
            save_cache()
            return new_name

    # If it returned something not in canon without NEW:, accept it as a new canonical candidate
    llm_cache[raw0] = chosen
    save_cache()
    return chosen

# -----------------------
# Main entry point
# -----------------------
def canonicalize_manufacturers(
    df: pd.DataFrame,
    col="Manufacturer",
    out_col="Manufacturer_clean",
    overwrite=False,
    grow_canon=True,
    keep_backup=True
):
    """
    overwrite=False -> creates out_col
    overwrite=True  -> overwrites col and drops out_col
    keep_backup=True -> saves original to f"{col}_raw" before overwrite
    """
    if keep_backup and overwrite:
        raw_col = f"{col}_raw"
        if raw_col not in df.columns:
            df[raw_col] = df[col]

    uniq = sorted(set(basic_cleanup(x) for x in df[col].dropna().astype(str).unique()))
    uniq = [u for u in uniq if u]

    mapping = {}

    # Ensure seeds are in canon_set
    canon_set.update(SEED_CANON)
    save_canon()

    print(f"Unique manufacturers: {len(uniq)}")
    print(f"Canon set size (start): {len(canon_set)}")

    for i, raw in enumerate(uniq, start=1):
        # Manual truth layer first
        m = manual_map(raw)
        if m:
            mapping[raw] = m
            canon_set.add(m)
            continue

        canon_list = sorted(canon_set)
        chosen = llm_choose_canonical_or_new(raw, canon_list)

        mapping[raw] = chosen
        if grow_canon and chosen:
            canon_set.add(chosen)

        if i % 50 == 0:
            print(f"  resolved {i}/{len(uniq)} (canon now {len(canon_set)})")

    save_canon()

    df[out_col] = df[col].map(
        lambda x: mapping.get(basic_cleanup(x), basic_cleanup(x)) if pd.notna(x) else None
    )

    if overwrite:
        df[col] = df[out_col]
        df.drop(columns=[out_col], inplace=True)

    print(f"Canon set size (end): {len(canon_set)}")
    return df, mapping

# -----------------------
# Optional: clear specific bad cached entries (run if needed)
# -----------------------
def clear_manufacturer_cache_keys(keys):
    removed = 0
    for k in keys:
        if k in llm_cache:
            del llm_cache[k]
            removed += 1
    save_cache()
    print(f"Cleared {removed} cached entries.")

# Example use if ACSI/ASCI were cached wrongly earlier:
# clear_manufacturer_cache_keys(["ACSI","ASCI","A.C.S.I.","acsi","asci"])

# -----------------------
# USAGE:
# df, mapping = canonicalize_manufacturers(df, overwrite=True)
# Then use df["Manufacturer"] (cleaned). Backup in df["Manufacturer_raw"] if enabled.
# -----------------------

Usage and test

In [15]:
df, mapping = canonicalize_manufacturers(df, overwrite=True, keep_backup=True)
print(df["Manufacturer"].value_counts().head(30))

Unique manufacturers: 29
Canon set size (start): 20
Canon set size (end): 20
Manufacturer
GE                                       396
IBA                                      237
Siemens Healthineers                     190
Sumitomo                                 161
Advanced Cyclotron Systems                46
Best Cyclotron Systems                    21
Longevous Beamtech                        14
TCC                                        9
Niiefa                                     6
Avelion                                    4
Scanditronix                               4
TRIUMF                                     1
Rosatom                                    1
Lawrence Berkeley National Laboratory      1
W M Brobeck                                1
Ionetix                                    1
Name: count, dtype: int64


Load data + basic cleaning

In [16]:
import os, re, hashlib, math
import pandas as pd

CSV_PATH = "/content/iaea_cyclotrons_clean.csv"  # <-- adjust if needed
df = pd.read_csv(CSV_PATH)

# Clean whitespace
for col in ["Country","City","Facility","Manufacturer","Model","Proton energy (MeV)"]:
    if col in df.columns:
        df[col] = df[col].astype(str).str.strip()

# Standardize "unspecified" to NaN for analysis
def to_nan_if_unspecified(x):
    if x is None:
        return None
    s = str(x).strip()
    if s == "" or s.lower() in ("nan","none","unspecified","n/a","na"):
        return None
    return s

for col in ["Facility","Manufacturer","Model","City","Country","Proton energy (MeV)"]:
    if col in df.columns:
        df[col] = df[col].apply(to_nan_if_unspecified)

# Parse energy numeric (best-effort)
def parse_energy_to_float(x):
    if x is None:
        return None
    s = str(x).strip()
    if s == "":
        return None
    nums = re.findall(r"\d+(?:\.\d+)?", s)
    if not nums:
        return None
    vals = [float(n) for n in nums]
    return max(vals)  # choose max if ranges exist

df["Energy_num"] = df["Proton energy (MeV)"].apply(parse_energy_to_float)

df.head()

Unnamed: 0,Country,City,Facility,Manufacturer,Model,Proton energy (MeV),Country_iso3,Energy_num
0,Algeria,Alger,Clinique Fatema Al Azhare,IBA,CYCLONE KIUBE,18.0,DZA,18.0
1,Algeria,Constantine,,IBA,CYCLONE KIUBE,18.0,DZA,18.0
2,Algeria,Tizi Ouzou,Hôpital Chahids Mahmoudi,GE,PETtrace,16.5,DZA,16.5
3,Argentina,Bariloche,Centro Nacional de Medicina Nuclear y Radioter...,IBA,CYCLONE 18,18.0,ARG,18.0
4,Argentina,Buenos Aires,Fleni,GE,PETtrace,16.0,ARG,16.0


Country report text generator for LLM

In [17]:
def country_summary(country, top_n=15):
    sub = df[df["Country"].str.lower() == str(country).strip().lower()].copy()
    if sub.empty:
        candidates = df[df["Country"].str.lower().str.contains(str(country).strip().lower(), na=False)]["Country"].dropna().unique()
        return {"country": country, "found": False, "did_you_mean": sorted(candidates)[:20]}

    total = len(sub)
    cities = sub.groupby("City").size().sort_values(ascending=False)
    facilities = sub.groupby("Facility").size().sort_values(ascending=False)
    manufacturers = sub.groupby("Manufacturer").size().sort_values(ascending=False)
    models = sub.groupby("Model").size().sort_values(ascending=False)
    manu_model = sub.groupby(["Manufacturer","Model"]).size().sort_values(ascending=False)

    energy_stats = {
        "count_numeric": int(sub["Energy_num"].notna().sum()),
        "min": float(sub["Energy_num"].min()) if sub["Energy_num"].notna().any() else None,
        "median": float(sub["Energy_num"].median()) if sub["Energy_num"].notna().any() else None,
        "max": float(sub["Energy_num"].max()) if sub["Energy_num"].notna().any() else None,
    }

    return {
        "country": country,
        "found": True,
        "total_cyclotrons": total,
        "cities_top": cities.head(top_n),
        "facilities_top": facilities.head(top_n),
        "manufacturers": manufacturers,
        "models_top": models.head(top_n),
        "manufacturer_model_top": manu_model.head(top_n),
        "energy_stats": energy_stats,
        "all_cities_count": int(cities.shape[0]),
        "all_facilities_count": int(facilities.shape[0]),
    }

def country_report_for_llm(country, top_n=10):
    out = country_summary(country, top_n=top_n)
    if not out["found"]:
        return f"No exact match for '{country}'. Suggestions: {out.get('did_you_mean', [])}"

    cities_text = out["cities_top"].to_string()
    facilities_text = out["facilities_top"].to_string()
    manufacturers_text = out["manufacturers"].to_string()
    models_text = out["models_top"].to_string()

    es = out["energy_stats"]
    energy_text = (
        f"Numeric entries: {es['count_numeric']}\n"
        f"Min energy: {es['min']}\n"
        f"Median energy: {es['median']}\n"
        f"Max energy: {es['max']}"
    )

    report = f"""
IAEA Cyclotron Dataset Report for {out['country']}

Total cyclotrons: {out['total_cyclotrons']}
Cities covered: {out['all_cities_count']}
Facilities covered: {out['all_facilities_count']}

Top cities:
{cities_text}

Top facilities:
{facilities_text}

Manufacturers:
{manufacturers_text}

Top models:
{models_text}

Energy statistics:
{energy_text}
"""
    return report.strip()

def llm_exec_summary(country, top_n=10):
    report_text = country_report_for_llm(country, top_n=top_n)
    prompt = f"""
You are a market analyst in nuclear medicine.
Write a concise executive summary (max 150 words) based on this dataset report.

Focus on:
- Overall infrastructure scale
- Geographic concentration
- Major suppliers
- Any notable patterns

Report:
{report_text}
""".strip()
    return llm.invoke(prompt).content.strip()

Geocoding + caching (for map dots)

In [18]:
from geopy.geocoders import Nominatim
from geopy.extra.rate_limiter import RateLimiter
import os
import pandas as pd

GEO_CACHE = "/content/geocode_cache.csv"

# Load existing cache if available
if os.path.exists(GEO_CACHE):
    geo_cache = pd.read_csv(GEO_CACHE)
else:
    geo_cache = pd.DataFrame(columns=["Country_iso3", "Country", "City", "lat", "lon", "display_name"])

# Ensure string columns exist and are strings
for col in ["Country_iso3", "Country", "City"]:
    if col not in geo_cache.columns:
        geo_cache[col] = ""
    geo_cache[col] = geo_cache[col].astype(str)

geolocator = Nominatim(user_agent="iaea_cyclotron_colab_geocoder")
geocode = RateLimiter(geolocator.geocode, min_delay_seconds=1.0, swallow_exceptions=True)

def _norm(x):
    return str(x).strip().lower()

def get_latlon(country, city, country_iso3=None):
    global geo_cache  # must come first

    country = str(country).strip()
    city = str(city).strip()
    country_iso3 = None if country_iso3 is None else str(country_iso3).strip()

    # -----------------------
    # Cache lookup (prefer ISO3 match if available)
    # -----------------------
    if country_iso3:
        hit = geo_cache[
            (geo_cache["Country_iso3"].apply(_norm) == _norm(country_iso3)) &
            (geo_cache["City"].apply(_norm) == _norm(city))
        ]
        if not hit.empty and pd.notna(hit.iloc[0]["lat"]) and pd.notna(hit.iloc[0]["lon"]):
            return float(hit.iloc[0]["lat"]), float(hit.iloc[0]["lon"])

    # fallback cache lookup by country name
    hit = geo_cache[
        (geo_cache["Country"].apply(_norm) == _norm(country)) &
        (geo_cache["City"].apply(_norm) == _norm(city))
    ]
    if not hit.empty and pd.notna(hit.iloc[0]["lat"]) and pd.notna(hit.iloc[0]["lon"]):
        return float(hit.iloc[0]["lat"]), float(hit.iloc[0]["lon"])

    # -----------------------
    # Best-effort geocode with fallback query strategies
    # -----------------------
    queries = []
    queries.append(f"{city}, {country}")                 # normal
    if country_iso3:
        queries.append(f"{city}, {country_iso3}")        # sometimes helps
    queries.append(city)                                  # last resort

    loc = None
    for q in queries:
        loc = geocode(q)
        if loc is not None:
            break

    if loc is None:
        # store miss to avoid retry storms
        new_row = {
            "Country_iso3": country_iso3 or "",
            "Country": country,
            "City": city,
            "lat": None,
            "lon": None,
            "display_name": None
        }
        geo_cache = pd.concat([geo_cache, pd.DataFrame([new_row])], ignore_index=True)
        geo_cache.to_csv(GEO_CACHE, index=False)
        return None, None

    lat, lon = loc.latitude, loc.longitude
    display_name = getattr(loc, "address", None)

    new_row = {
        "Country_iso3": country_iso3 or "",
        "Country": country,
        "City": city,
        "lat": lat,
        "lon": lon,
        "display_name": display_name
    }

    geo_cache = pd.concat([geo_cache, pd.DataFrame([new_row])], ignore_index=True)
    geo_cache.to_csv(GEO_CACHE, index=False)

    return lat, lon

def geocode_country_cities(country):
    # Use the cleaned Country string, but also pass ISO3 where possible
    sub = df[df["Country"].str.lower() == str(country).strip().lower()].copy()
    cities = sorted([c for c in sub["City"].dropna().unique() if str(c).strip()])

    # try to get ISO3 for this country from your df
    iso3 = None
    if "Country_iso3" in sub.columns:
        iso3_vals = sub["Country_iso3"].dropna().unique()
        iso3 = iso3_vals[0] if len(iso3_vals) else None

    pts = []
    for city in cities:
        lat, lon = get_latlon(country, city, country_iso3=iso3)
        if lat is not None and lon is not None:
            pts.append((city, lat, lon))
    return pts

Creates charts + maps

In [19]:
!pip -q install rasterio geopandas shapely

def save_country_map(country, year=2020):
    import geopandas as gpd
    import matplotlib.pyplot as plt
    import numpy as np
    import rasterio
    from rasterio.mask import mask
    from matplotlib.colors import LogNorm
    from shapely.geometry import mapping, Point

    # --- City points (your existing function) ---
    pts = geocode_country_cities(country)
    if not pts:
        return None

    # --- ISO3 for robust polygon match (assumes df has Country_iso3) ---
    sub = df[df["Country"].str.lower() == str(country).strip().lower()].copy()
    iso3 = None
    if "Country_iso3" in sub.columns:
        v = sub["Country_iso3"].dropna().unique()
        iso3 = v[0] if len(v) else None

    # --- Country polygon (Natural Earth) ---
    world = gpd.read_file(
        "https://naturalearth.s3.amazonaws.com/110m_cultural/ne_110m_admin_0_countries.zip"
    )
    if iso3 and "ISO_A3" in world.columns:
        country_poly = world[world["ISO_A3"] == iso3].copy()
    else:
        country_poly = world[world["NAME"].str.lower() == str(country).strip().lower()].copy()

    if country_poly.empty:
        # fallback: just plot points if polygon missing
        country_poly = None

    # --- GPW v4 density raster (2020) ---
    # NOTE: This is a public mirror that links to a .tiff. It can be large.
    # We read it as a remote dataset via /vsicurl/ (no full download required if server supports range requests).
    GPW_URL_2020 = "https://pacific-data.sprep.org/system/files/Global_2020_PopulationDensity30sec_GPWv4.tiff"
    gpw_url = GPW_URL_2020

    # Output
    path = os.path.join(IMG_DIR, f"{country}_map.png")

    # Points GeoDF
    gdf_pts = gpd.GeoDataFrame(
        [{"City": c, "geometry": Point(lon, lat)} for c, lat, lon in pts],
        crs="EPSG:4326"
    )

    plt.figure(figsize=(7, 4.5))
    ax = plt.gca()

    # ---- Open raster remotely ----
    # /vsicurl/ enables HTTP range reads (fast) if the host supports it.
    # If it fails, we fall back to downloading once (see below).
    raster_path = f"/vsicurl/{gpw_url}"

    try:
        with rasterio.Env(GDAL_DISABLE_READDIR_ON_OPEN="EMPTY_DIR"):
            with rasterio.open(raster_path) as src:
                # Reproject polygon to raster CRS (usually EPSG:4326)
                if country_poly is not None:
                    poly_in_crs = country_poly.to_crs(src.crs)
                    geoms = [mapping(geom) for geom in poly_in_crs.geometry if geom is not None]
                else:
                    geoms = None

                if geoms:
                    out_img, out_transform = mask(src, geoms, crop=True, filled=True)
                    data = out_img[0]
                    extent = (
                        out_transform[2],
                        out_transform[2] + out_transform[0] * data.shape[1],
                        out_transform[5] + out_transform[4] * data.shape[0],
                        out_transform[5],
                    )
                else:
                    # No polygon: plot a small window around points
                    # compute bounds in raster CRS
                    pts_in = gdf_pts.to_crs(src.crs)
                    minx, miny, maxx, maxy = pts_in.total_bounds
                    pad_x = (maxx - minx) * 0.25 if maxx > minx else 2e5
                    pad_y = (maxy - miny) * 0.25 if maxy > miny else 2e5
                    window = rasterio.windows.from_bounds(minx - pad_x, miny - pad_y, maxx + pad_x, maxy + pad_y, src.transform)
                    out_img = src.read(1, window=window, masked=True)
                    data = out_img
                    win_transform = src.window_transform(window)
                    extent = (
                        win_transform[2],
                        win_transform[2] + win_transform[0] * data.shape[1],
                        win_transform[5] + win_transform[4] * data.shape[0],
                        win_transform[5],
                    )

    except Exception as e:
        print("⚠️ Remote raster read failed, consider local download. Error:", e)
        return None

    # --- Plot raster with log scale (density is very skewed) ---
    arr = np.array(data)
    arr = np.ma.masked_invalid(arr)
    arr = np.ma.masked_where(arr <= 0, arr)

    if arr.count() == 0:
        print("⚠️ No valid density values found in clipped raster.")
        return None

    vmin = float(np.percentile(arr.compressed(), 5))
    vmax = float(np.percentile(arr.compressed(), 99))
    vmin = max(vmin, 1e-3)

    im = ax.imshow(arr, extent=extent, origin="upper", norm=LogNorm(vmin=vmin, vmax=vmax))
    plt.colorbar(im, ax=ax, shrink=0.75, label="Population density (persons/km², log scale)")

    # --- Plot country outline on top (optional) ---
    if country_poly is not None:
        # outline only so raster is visible
        country_poly.plot(ax=ax, facecolor="none", edgecolor="black", linewidth=0.8, zorder=4)

    # --- Red dots on top ---
    gdf_pts.plot(
        ax=ax,
        color="red",
        markersize=60,
        edgecolor="black",
        linewidth=0.4,
        zorder=6
    )

    plt.title(f"{country}: Cyclotron locations + GPWv4 population density (2020)")
    plt.axis("off")
    plt.tight_layout()
    plt.savefig(path, dpi=220)
    plt.close()
    return path

Comparison + data quality sections (tables)

In [20]:
def global_comparison_tables():
    # Top countries by cyclotron count
    top_countries = df["Country"].dropna().value_counts().head(25)

    # Manufacturer global counts
    top_manu = df["Manufacturer"].dropna().value_counts().head(15)

    # Energy by country (numeric only)
    energy_country = (
        df.dropna(subset=["Country"])
          .groupby("Country")["Energy_num"]
          .agg(["count","min","median","max"])
          .sort_values("count", ascending=False)
          .head(25)
    )

    return top_countries, top_manu, energy_country

def data_quality_summary():
    total = len(df)
    missing = {}
    for col in ["City","Facility","Manufacturer","Model","Proton energy (MeV)","Energy_num"]:
        missing[col] = int(df[col].isna().sum())

    return {
        "total_rows": total,
        "missing_counts": missing,
        "missing_pct": {k: (v/total*100 if total else 0) for k,v in missing.items()}
    }

Multimodal nested call to comment on each map

In [21]:
# OPTIONAL: only if you have a multimodal model client that can accept images
# You can keep it disabled by default.

USE_MAP_LLM = False  # set True only if you have an image-capable model client

def llm_map_insight(country, map_path):
    if not USE_MAP_LLM:
        return None

    # Example placeholder: adapt to your multimodal client API
    # Many multimodal APIs require: [{"type":"text",...}, {"type":"image_url",...}]
    # Here we leave it as a stub so you can plug your provider.
    raise NotImplementedError("Plug in your multimodal LLM client here.")

Build the PDF (ReportLab)

In [30]:
# ============================================================
# IAEA Cyclotron PDF Report — FULL CODE
# - Precomputes LLM summaries + images + facility landscape (LLM JSON -> validated -> markdown)
# - Builds PDF from precomputed artifacts
# - Provides run_test(country=...) and run_full()
# ============================================================

import os
import json
import hashlib
import pandas as pd
import matplotlib.pyplot as plt

from reportlab.platypus import (
    SimpleDocTemplate, Paragraph, Spacer, Image, PageBreak, Table, TableStyle
)
from reportlab.lib.pagesizes import A4
from reportlab.lib.styles import getSampleStyleSheet
from reportlab.lib import colors
from xml.sax.saxutils import escape as xml_escape
from PIL import Image as PILImage

# -----------------------
# CONFIG
# -----------------------
PDF_PATH_FULL = "/content/IAEA_Cyclotron_Country_Executive_Summaries.pdf"
PDF_PATH_TEST = "/content/IAEA_Cyclotron_TEST_ONE_COUNTRY.pdf"
IMG_DIR = "/content/iaea_report_imgs"
os.makedirs(IMG_DIR, exist_ok=True)

styles = getSampleStyleSheet()


# -----------------------
# Helpers
# -----------------------
def escape_paragraph_text(s: str) -> str:
    """Escape &, <, > for ReportLab Paragraph mini-HTML + preserve newlines."""
    if s is None:
        return ""
    s = xml_escape(str(s))
    return s.replace("\n", "<br/>")


def df_to_table(dataframe, max_rows=30):
    df2 = dataframe.copy()
    if len(df2) > max_rows:
        df2 = df2.head(max_rows)

    data = [list(df2.reset_index().columns)] + df2.reset_index().values.tolist()
    t = Table(data, repeatRows=1)
    t.setStyle(TableStyle([
        ("BACKGROUND", (0,0), (-1,0), colors.lightgrey),
        ("GRID", (0,0), (-1,-1), 0.25, colors.grey),
        ("FONTNAME", (0,0), (-1,0), "Helvetica-Bold"),
        ("FONTSIZE", (0,0), (-1,-1), 8),
        ("VALIGN", (0,0), (-1,-1), "TOP"),
        ("LEFTPADDING", (0,0), (-1,-1), 4),
        ("RIGHTPADDING", (0,0), (-1,-1), 4),
        ("TOPPADDING", (0,0), (-1,-1), 3),
        ("BOTTOMPADDING", (0,0), (-1,-1), 3),
    ]))
    return t

def two_column_block(left_title, left_flowable, right_title, right_flowable, gap=12):
    """
    Create a 2-column block (ReportLab Table) with a small title + content in each column.
    left_flowable / right_flowable are usually Tables returned by df_to_table(...)
    """
    left = [Paragraph(left_title, styles["Heading2"]), Spacer(1, 6), left_flowable]
    right = [Paragraph(right_title, styles["Heading2"]), Spacer(1, 6), right_flowable]

    t = Table(
        [[left, right]],
        colWidths=[(A4[0] - 72 - gap) / 2, (A4[0] - 72 - gap) / 2],  # A4 width minus margins (36+36)
    )
    t.setStyle(TableStyle([
        ("VALIGN", (0,0), (-1,-1), "TOP"),
        ("LEFTPADDING", (0,0), (-1,-1), 0),
        ("RIGHTPADDING", (0,0), (-1,-1), 0),
        ("TOPPADDING", (0,0), (-1,-1), 0),
        ("BOTTOMPADDING", (0,0), (-1,-1), 0),
    ]))
    return t

def rl_image_fit(img_path, max_w=480, max_h=300):
    """Fit image into max_w x max_h (points) preserving aspect ratio."""
    with PILImage.open(img_path) as im:
        w, h = im.size
    scale = min(max_w / w, max_h / h)
    return Image(img_path, width=w * scale, height=h * scale)


def plots_2x2_block(imgs, cell_w=255, cell_h=185, pad=6):
    """
    imgs: dict keys: 'city','manufacturer','energy','map'
    Returns a ReportLab Table laid out 2x2.
    """
    def cell(title, path):
        if path and os.path.exists(path):
            return [
                Paragraph(title, styles["Heading4"]),
                rl_image_fit(path, max_w=cell_w, max_h=cell_h)
            ]
        else:
            return [Paragraph(title + " (not available)", styles["Normal"])]

    data = [
        [cell("Top cities", imgs.get("city")), cell("Manufacturers", imgs.get("manufacturer"))],
        [cell("Energy distribution", imgs.get("energy")), cell("Map (red dots)", imgs.get("map"))],
    ]

    t = Table(
        data,
        colWidths=[(cell_w + pad), (cell_w + pad)],
        rowHeights=[cell_h + 30, cell_h + 30],
    )
    t.setStyle(TableStyle([
        ("VALIGN", (0,0), (-1,-1), "TOP"),
        ("LEFTPADDING", (0,0), (-1,-1), pad),
        ("RIGHTPADDING", (0,0), (-1,-1), pad),
        ("TOPPADDING", (0,0), (-1,-1), pad),
        ("BOTTOMPADDING", (0,0), (-1,-1), pad),
    ]))
    return t


# ============================================================
# Facility Landscape (LLM) — Cache + Helpers
# ============================================================

FAC_LANDSCAPE_CACHE = "/content/facility_landscape_cache.json"

# Load cache once
if os.path.exists(FAC_LANDSCAPE_CACHE):
    with open(FAC_LANDSCAPE_CACHE, "r", encoding="utf-8") as f:
        facility_landscape_cache = json.load(f)
else:
    facility_landscape_cache = {}

def _save_facility_landscape_cache():
    with open(FAC_LANDSCAPE_CACHE, "w", encoding="utf-8") as f:
        json.dump(facility_landscape_cache, f, ensure_ascii=False, indent=2)

def _cache_key(country: str, top_n: int, version: str = "v1") -> str:
    raw = f"{version}|{country.strip().lower()}|top_n={top_n}"
    return hashlib.sha256(raw.encode("utf-8")).hexdigest()

def top_facilities_table_df(country: str, top_n: int = 12) -> pd.DataFrame:
    """
    Deterministic table: top facilities by row count for that country.
    If Facility is mostly missing, falls back to top cities.
    """
    sub = df[df["Country"].str.lower() == str(country).strip().lower()].copy()
    if sub.empty:
        return pd.DataFrame(columns=["Facility_or_City", "Count", "Note"])

    fac = (
        sub["Facility"].fillna("").astype(str).str.strip()
        .replace("", pd.NA)
        .dropna()
        .value_counts()
        .head(top_n)
    )

    if not fac.empty:
        out = fac.reset_index()
        out.columns = ["Facility_or_City", "Count"]
        out["Note"] = "Facility"
        return out

    city = sub["City"].dropna().astype(str).str.strip().replace("", pd.NA).dropna().value_counts().head(top_n)
    out = city.reset_index()
    out.columns = ["Facility_or_City", "Count"]
    out["Note"] = "City proxy (Facility missing)"
    return out


def llm_facility_landscape_json(country: str, top_n: int = 12) -> dict:
    """
    Verbose facility landscape.
    Strict schema:
    {
      "overview": "...(multi-paragraph)...",
      "notes": "...",
      "facilities": [
        {
          "facility": "<exact name from input list>",
          "count": <int>,
          "interpretation": "<2–5 sentences>",
          "signals": ["...","..."],
          "typology": "<Hospital/University | Private radiopharmacy | Network/group | State/public institute | Unknown>",
          "confidence": "<High|Medium|Low>",
          "network_hint": "<brand if explicitly in name else —>",
          "caveat": "<1 sentence>"
        }, ...
      ]
    }
    """
    ck = _cache_key(country, top_n, version="v3_verbose")
    if ck in facility_landscape_cache:
        return facility_landscape_cache[ck]

    tdf = top_facilities_table_df(country, top_n=top_n)

    items = []
    for _, r in tdf.iterrows():
        items.append({"name": str(r["Facility_or_City"]), "count": int(r["Count"]), "note": str(r["Note"])})

    prompt = f"""
You are writing a verbose country report section about "key cyclotron facilities".

Country: {country}

Input list (top entries from database; do not add others):
{json.dumps(items, ensure_ascii=False, indent=2)}

CRITICAL RULES (non-hallucinating):
- You MUST NOT invent ownership facts.
- You MUST NOT claim a network/brand unless the facility name explicitly contains it (e.g., contains “Curium”).
- Use only NAME CUES (words in the facility string) and the input “note” field.
- If the entry is a city proxy (note contains "City proxy"), treat it as missing facility info.

TASK:
Return STRICT JSON ONLY (no markdown, no extra text) with this schema:

{{
  "overview": "Write 2–4 paragraphs (verbose) describing the facility landscape in this country. Mention data-quality caveats (missing facility names) if relevant. Explain what the naming patterns suggest about institution types (hospital/university vs private vs public), WITHOUT stating ownership as fact.",
  "notes": "Optional short notes (or empty string).",
  "facilities": [
    {{
      "facility": "<exact name from input list>",
      "count": <integer from input list>,
      "interpretation": "2–5 sentences describing what the name suggests (departmental hospital, university clinic, private company, institute etc.)",
      "signals": ["List 2–6 short name cues you used, e.g. 'University', 'Hospital', 'Ltd', 'Institute', 'Nuclear Medicine', etc. If none, use []"],
      "typology": "<ONE of: Hospital/University | Private radiopharmacy | Network/group | State/public institute | Unknown>",
      "confidence": "<High|Medium|Low>",
      "network_hint": "<If name explicitly contains a network brand (Curium/Siemens/AAA etc.) write it else '—'>",
      "caveat": "1 sentence caveat, e.g., 'Based on name cues only; ownership not verified.'"
    }}
  ]
}}

Output constraints:
- Every facility must be one of the input list names.
- For City proxy entries: typology=Unknown, confidence=Low, network_hint='—', signals=[], interpretation should say facility is missing.
""".strip()

    resp = llm.invoke(prompt)
    raw = (resp.content or "").strip()
    raw = raw.strip().strip("`")
    if raw.lower().startswith("json"):
        raw = raw[4:].strip()

    try:
        obj = json.loads(raw)
    except Exception as e:
        obj = {"overview": f"(Verbose facility landscape JSON parse failed: {e})", "notes": "", "facilities": []}

    allowed_typology = {
        "Hospital/University",
        "Private radiopharmacy",
        "Network/group",
        "State/public institute",
        "Unknown",
    }
    allowed_conf = {"High", "Medium", "Low"}

    allowed_names = {x["name"]: x for x in items}

    cleaned = []
    for r in obj.get("facilities", []):
        fac = str(r.get("facility", "")).strip()
        if fac not in allowed_names:
            continue

        count = int(allowed_names[fac]["count"])
        note = str(allowed_names[fac]["note"])

        typ = str(r.get("typology", "Unknown")).strip()
        conf = str(r.get("confidence", "Low")).strip()
        net = str(r.get("network_hint", "—")).strip() or "—"
        interp = str(r.get("interpretation", "")).strip()
        caveat = str(r.get("caveat", "")).strip()
        signals = r.get("signals", [])
        if not isinstance(signals, list):
            signals = []

        # enforce city-proxy rules
        if "City proxy" in note:
            typ, conf, net = "Unknown", "Low", "—"
            signals = []
            if not interp:
                interp = "Facility name is missing in the dataset; this entry is a city-level proxy."
            if not caveat:
                caveat = "Facility not available; classification not possible."

        if typ not in allowed_typology:
            typ = "Unknown"
        if conf not in allowed_conf:
            conf = "Low"
        if not interp:
            interp = "No strong name cues available; leaving classification as Unknown."
        if not caveat:
            caveat = "Based on name cues only; ownership not verified."

        cleaned.append({
            "facility": fac,
            "count": count,
            "interpretation": interp,
            "signals": [str(s).strip() for s in signals if str(s).strip()][:8],
            "typology": typ,
            "confidence": conf,
            "network_hint": net,
            "caveat": caveat
        })

    out = {
        "overview": str(obj.get("overview", "")).strip() or "—",
        "notes": str(obj.get("notes", "")).strip(),
        "facilities": cleaned
    }

    facility_landscape_cache[ck] = out
    _save_facility_landscape_cache()
    return out


def facility_landscape_to_markdown(obj: dict) -> str:
    """
    Render verbose facility landscape into readable markdown-like text
    (ReportLab Paragraph will display it after escape_paragraph_text).
    """
    overview = obj.get("overview", "—").strip()
    notes = obj.get("notes", "").strip()
    facilities = obj.get("facilities", [])

    lines = []
    lines.append(overview)

    if notes:
        lines.append("")
        lines.append(f"Notes: {notes}")

    lines.append("")
    lines.append("Key facilities (name-cue based interpretation):")

    for f in facilities:
        sig = ", ".join(f.get("signals", [])) if f.get("signals") else "—"
        lines.append("")
        lines.append(f"• {f['facility']} — {int(f['count'])} entries")
        lines.append(f"  Typology: {f['typology']} | Confidence: {f['confidence']} | Network hint: {f['network_hint']}")
        lines.append(f"  Signals: {sig}")
        lines.append(f"  Interpretation: {f['interpretation']}")
        lines.append(f"  Caveat: {f['caveat']}")

    return "\n".join(lines)


def precompute_facility_landscapes(countries, top_n=12):
    """
    Precompute facility landscape markdown per country.
    Uses JSON cache internally -> cheap on reruns.
    """
    out = {}
    for country in countries:
        try:
            obj = llm_facility_landscape_json(country, top_n=top_n)
            out[country] = facility_landscape_to_markdown(obj)
        except Exception as e:
            out[country] = f"(Facility landscape failed: {e})"
    print(f"Precomputed facility landscapes: {len(out)} countries")
    return out


# ============================================================
# Core pipeline functions
# (These assume you already have:
#   - df (loaded)
#   - llm_exec_summary(country, top_n=10)
#   - country_summary(country, top_n=10)
#   - global_comparison_tables()
#   - data_quality_summary()
#   - save_country_map(country)
# ============================================================

def save_city_bar_chart(country, top_n=10):
    out = country_summary(country, top_n=top_n)
    if not out["found"] or out["cities_top"].empty:
        return None
    s = out["cities_top"]
    path = os.path.join(IMG_DIR, f"{country}_cities.png")
    plt.figure()
    s[::-1].plot(kind="barh")
    plt.title(f"{country}: Top {min(top_n, len(s))} cities by cyclotron count")
    plt.xlabel("Count")
    plt.tight_layout()
    plt.savefig(path, dpi=180)
    plt.close()
    return path


def save_manufacturer_bar_chart(country, top_n=10):
    sub = df[df["Country"].str.lower() == str(country).strip().lower()].copy()
    s = sub["Manufacturer"].dropna().value_counts().head(top_n)
    if s.empty:
        return None
    path = os.path.join(IMG_DIR, f"{country}_manufacturers.png")
    plt.figure()
    s[::-1].plot(kind="barh")
    plt.title(f"{country}: Top manufacturers")
    plt.xlabel("Count")
    plt.tight_layout()
    plt.savefig(path, dpi=180)
    plt.close()
    return path


def save_energy_hist(country):
    sub = df[df["Country"].str.lower() == str(country).strip().lower()].copy()
    vals = sub.get("Energy_num", pd.Series(dtype=float)).dropna()
    if vals.empty:
        return None
    path = os.path.join(IMG_DIR, f"{country}_energy.png")
    plt.figure()
    plt.hist(vals, bins=12)
    plt.title(f"{country}: Proton energy distribution (MeV)")
    plt.xlabel("MeV")
    plt.ylabel("Count")
    plt.tight_layout()
    plt.savefig(path, dpi=180)
    plt.close()
    return path


def precompute_summaries(countries, top_n=10):
    summaries_by_country = {}
    for country in countries:
        try:
            summaries_by_country[country] = llm_exec_summary(country, top_n=top_n)
        except Exception as e:
            summaries_by_country[country] = f"(LLM summary failed: {e})"
    print(f"Precomputed LLM summaries: {len(summaries_by_country)} countries")
    return summaries_by_country


def precompute_images(countries, top_n=10):
    images_by_country = {}
    for country in countries:
        imgs = {}
        try: imgs["city"] = save_city_bar_chart(country, top_n=top_n)
        except Exception: imgs["city"] = None
        try: imgs["manufacturer"] = save_manufacturer_bar_chart(country, top_n=top_n)
        except Exception: imgs["manufacturer"] = None
        try: imgs["energy"] = save_energy_hist(country)
        except Exception: imgs["energy"] = None
        try: imgs["map"] = save_country_map(country)
        except Exception: imgs["map"] = None
        images_by_country[country] = imgs
    print(f"Precomputed images for: {len(images_by_country)} countries")
    return images_by_country


def build_pdf(countries, summaries_by_country, images_by_country, facility_landscape_by_country, pdf_path, title_suffix=""):
    doc = SimpleDocTemplate(
        pdf_path,
        pagesize=A4,
        rightMargin=36,
        leftMargin=36,
        topMargin=36,
        bottomMargin=36
    )
    story = []

    title = "IAEA Cyclotron Database — Executive Summaries by Country"
    if title_suffix:
        title += f" ({title_suffix})"
    story.append(Paragraph(title, styles["Title"]))
    story.append(Spacer(1, 10))
    story.append(Paragraph("Automatically generated from the scraped IAEA SharePoint cyclotron list.", styles["Normal"]))
    story.append(PageBreak())

    story.append(Paragraph("Global Comparison", styles["Heading1"]))
    top_countries, top_manu, energy_country = global_comparison_tables()

    # --- Global Comparison (2 columns: countries left, manufacturers right)
    left_tbl  = df_to_table(top_countries.to_frame("count"))
    right_tbl = df_to_table(top_manu.to_frame("count"))

    story.append(two_column_block(
        left_title="Top countries by cyclotron count",
        left_flowable=left_tbl,
        right_title="Top manufacturers (global)",
        right_flowable=right_tbl,
        gap=18
    ))
    story.append(Spacer(1, 14))

    # Keep energy table below (same page if it fits; otherwise it will spill naturally)
    story.append(Paragraph("Energy statistics by country (numeric rows only)", styles["Heading2"]))
    story.append(df_to_table(energy_country))
    story.append(PageBreak())


    dq = data_quality_summary()
    story.append(Paragraph("Data Quality Summary", styles["Heading1"]))
    story.append(Paragraph(f"Total rows: <b>{dq['total_rows']}</b>", styles["Normal"]))
    story.append(Spacer(1, 8))

    dq_table_df = pd.DataFrame({
        "missing_count": dq["missing_counts"],
        "missing_pct": {k: f"{v:.1f}%" for k, v in dq["missing_pct"].items()}
    })
    story.append(df_to_table(dq_table_df))
    story.append(PageBreak())

    story.append(Paragraph("Country Executive Summaries", styles["Heading1"]))
    story.append(Paragraph("Each country section includes an LLM-generated summary plus charts and a map.", styles["Normal"]))
    story.append(PageBreak())

    for idx, country in enumerate(countries, start=1):
        story.append(Paragraph(escape_paragraph_text(country), styles["Heading1"]))
        story.append(Spacer(1, 6))

        summary = summaries_by_country.get(country, "")
        story.append(Paragraph(escape_paragraph_text(summary), styles["Normal"]))
        story.append(Spacer(1, 10))

        imgs = images_by_country.get(country, {})

        # Page 1: plots
        story.append(plots_2x2_block(imgs, cell_w=255, cell_h=185, pad=6))
        story.append(Spacer(1, 10))

        # Page 2: Facility landscape
        story.append(PageBreak())
        story.append(Paragraph("Facility Landscape and Ownership Typology", styles["Heading2"]))
        story.append(Spacer(1, 6))

        tdf = top_facilities_table_df(country, top_n=12)
        story.append(Paragraph("Top facilities (from database)", styles["Heading3"]))
        story.append(df_to_table(tdf.set_index("Facility_or_City")))
        story.append(Spacer(1, 10))

        story.append(Paragraph("LLM-assisted classification (name-cue based, non-hallucinating)", styles["Heading3"]))
        fac_text = facility_landscape_by_country.get(country, "—")
        story.append(Paragraph(escape_paragraph_text(fac_text), styles["Normal"]))
        story.append(Spacer(1, 10))

        if idx != len(countries):
            story.append(PageBreak())

    doc.build(story)
    print(f"✅ PDF written: {pdf_path}")
    return pdf_path


# ============================================================
# RUNNERS — choose what to run each time
# ============================================================

def run_test(country="Belarus", top_n=10):
    all_countries = sorted(df["Country"].dropna().unique())
    if country not in all_countries:
        print(f"⚠️ '{country}' not found in df['Country']. Using first available country instead.")
        country = all_countries[0] if all_countries else country

    countries = [country]
    print("TEST MODE countries:", countries)

    facility_landscape_by_country = precompute_facility_landscapes(countries, top_n=12)
    summaries_by_country = precompute_summaries(countries, top_n=top_n)
    images_by_country = precompute_images(countries, top_n=top_n)

    return build_pdf(
        countries=countries,
        summaries_by_country=summaries_by_country,
        images_by_country=images_by_country,
        facility_landscape_by_country=facility_landscape_by_country,
        pdf_path=PDF_PATH_TEST,
        title_suffix=f"TEST: {country}"
    )


def run_full(top_n=10):
    countries = sorted(df["Country"].dropna().unique())
    print(f"FULL MODE countries: {len(countries)}")

    facility_landscape_by_country = precompute_facility_landscapes(countries, top_n=12)
    summaries_by_country = precompute_summaries(countries, top_n=top_n)
    images_by_country = precompute_images(countries, top_n=top_n)

    return build_pdf(
        countries=countries,
        summaries_by_country=summaries_by_country,
        images_by_country=images_by_country,
        facility_landscape_by_country=facility_landscape_by_country,
        pdf_path=PDF_PATH_FULL,
        title_suffix="FULL"
    )

Run one

In [31]:
# run_test("Bahrain") or Belarus
run_test("Italy")
# run_full()

TEST MODE countries: ['Italy']
Precomputed facility landscapes: 1 countries
Precomputed LLM summaries: 1 countries
Precomputed images for: 1 countries
✅ PDF written: /content/IAEA_Cyclotron_TEST_ONE_COUNTRY.pdf


'/content/IAEA_Cyclotron_TEST_ONE_COUNTRY.pdf'