In [1]:
from geopy.geocoders import Nominatim
from geopy.extra.rate_limiter import RateLimiter
import hashlib, json, os
from IPython.display import HTML
from collections import defaultdict
import requests, time
import folium, json, hashlib
from folium.plugins import MarkerCluster
import datetime
from trials_to_graph import convert_and_render

[NbConvertApp] Converting notebook trials_to_graph.ipynb to script
[NbConvertApp] Writing 22171 bytes to trials_to_graph.py


In [2]:
run_date = datetime.date.today().strftime("%Y-%m-%d")

In [3]:
_GEOCODER = None
_CACHE_PATH = os.path.expanduser("~/.ctgov_geocode_cache.json")
try:
    with open(_CACHE_PATH, "r") as f:
        _GEO_CACHE = json.load(f)
except:
    _GEO_CACHE = {}

def _save_cache():
    try:
        with open(_CACHE_PATH, "w") as f:
            json.dump(_GEO_CACHE, f)
    except:
        pass

def _geocoder():
    global _GEOCODER
    if _GEOCODER is None:
        # Nominatim requires a unique user_agent; use something identifiable
        _GEOCODER = Nominatim(user_agent="ctgov_mapper_tim_robinson/1.0")
        # be gentle: 1 call per second
        _GEOCODER.rate_limiter = RateLimiter(_GEOCODER.geocode, min_delay_seconds=1)
    return _GEOCODER

def geocode_addr(addr_dict):
    """
    addr_dict: {"facility": "...", "city":"...", "state":"...", "country":"..."}
    Returns (lat, lon) or (None, None)
    """
    parts = [
        addr_dict.get("facility"),
        addr_dict.get("city"),
        addr_dict.get("state"),
        addr_dict.get("country"),
    ]
    q = ", ".join([p for p in parts if p])
    if not q:
        return None, None

    key = hashlib.sha1(q.encode("utf-8")).hexdigest()
    if key in _GEO_CACHE:
        return _GEO_CACHE[key]

    try:
        loc = _geocoder().rate_limiter(q)
        if loc:
            res = (loc.latitude, loc.longitude)
        else:
            res = (None, None)
        _GEO_CACHE[key] = res
        _save_cache()
        return res
    except Exception:
        return None, None


In [4]:
BASE = "https://clinicaltrials.gov/api/v2/studies"

# keep this list tight; add fields only as you use them
FIELDS = [
  "protocolSection.identificationModule.nctId",
  "protocolSection.identificationModule.briefTitle",
  "protocolSection.statusModule.overallStatus",
  "protocolSection.designModule.studyType",
  "protocolSection.designModule.phases",
  "protocolSection.statusModule.startDateStruct",
  "protocolSection.statusModule.completionDateStruct",
  "protocolSection.conditionsModule.conditions",
  "protocolSection.armsInterventionsModule.interventions",
  "protocolSection.sponsorCollaboratorsModule.leadSponsor",
  "protocolSection.contactsLocationsModule.locations",
  "protocolSection.eligibilityModule.eligibilityCriteria"
]

def fetch_studies_by_condition(cond_term: str, page_size=100, max_pages=50):
  """
  Use the v2 structured condition filter (query.cond),
  so only studies that *list* this disease as a condition are returned.
  """
  token = None
  for _ in range(max_pages):
    params = {
      "format": "json",
      "pageSize": page_size,
      "fields": ",".join(FIELDS),
      "query.cond": cond_term,     # <- key change here
      "countTotal": "true"
    }
    if token:
      params["pageToken"] = token

    r = requests.get(BASE, params=params, timeout=30)
    r.raise_for_status()
    data = r.json()

    for study in data.get("studies", []):
      yield study

    token = data.get("nextPageToken")
    if not token:
      break
    time.sleep(0.2)  # be polite

def _extract_addr_geo(loc: dict):
  """Resilient extraction for address + geoPoint from a locations item."""
  if not isinstance(loc, dict):
    return {}, {}
  facility = loc.get("facility")
  location = facility.get("location") if isinstance(facility, dict) else None
  if not isinstance(location, dict):
    location = loc.get("location") if isinstance(loc.get("location"), dict) else {}
  addr = location.get("address") if isinstance(location.get("address"), dict) else {}
  geo  = location.get("geoPoint") if isinstance(location.get("geoPoint"), dict) else {}
  if not geo and isinstance(loc.get("geoPoint"), dict):
    geo = loc["geoPoint"]
  return addr, geo

def to_features(study, enable_geocoding=True):
    ps   = study.get("protocolSection", {}) or {}
    idm  = ps.get("identificationModule", {}) or {}
    stat = ps.get("statusModule", {}) or {}
    des  = ps.get("designModule", {}) or {}
    cond = ps.get("conditionsModule", {}) or {}
    arms = ps.get("armsInterventionsModule", {}) or {}
    spon = ps.get("sponsorCollaboratorsModule", {}) or {}
    locs = (ps.get("contactsLocationsModule", {}) or {}).get("locations") or []

    nct    = idm.get("nctId")
    title  = idm.get("briefTitle") or idm.get("officialTitle")
    status = stat.get("overallStatus")
    phases = des.get("phases") or []
    conds  = cond.get("conditions") or []
    inters = arms.get("interventions") or []
    sponsor = (spon.get("leadSponsor") or {}).get("name")

    for loc in locs:
        if not isinstance(loc, dict):
            continue
        # address + geo extraction
        facility = loc.get("facility")
        location = facility.get("location") if isinstance(facility, dict) else None
        if not isinstance(location, dict):
            location = loc.get("location") if isinstance(loc.get("location"), dict) else {}
        addr = location.get("address") if isinstance(location.get("address"), dict) else {}
        geo  = location.get("geoPoint") if isinstance(location.get("geoPoint"), dict) else {}
        if not geo and isinstance(loc.get("geoPoint"), dict):
            geo = loc["geoPoint"]

        lat, lon = geo.get("lat"), geo.get("lon")

        # if no geoPoint, try geocoding from address (facility/city/state/country)
        addr_out = {
            "facility": (loc.get("facility") or {}).get("name") if isinstance(loc.get("facility"), dict) else None,
            "city": addr.get("city"),
            "state": addr.get("state"),
            "country": addr.get("country"),
        }
        if (lat is None or lon is None) and enable_geocoding:
            lat, lon = geocode_addr(addr_out)

        if lat is None or lon is None:
            continue  # still nothing, skip

        yield {
            "type": "Feature",
            "geometry": {"type": "Point", "coordinates": [float(lon), float(lat)]},
            "properties": {
                "nctId": nct,
                "title": title,
                "status": status,
                "phases": phases,
                "conditions": conds,
                "interventions": inters,
                "sponsor": sponsor,
                "address": addr_out,
                "url": f"https://clinicaltrials.gov/study/{nct}",
            },
        }


def run(disease_term: str):
    disease_lc = disease_term.lower()
    features = []
    allowed_status = {"NOT_YET_RECRUITING", "RECRUITING", "ACTIVE_NOT_RECRUITING"}

    for s in fetch_studies_by_condition(disease_term):
        ps = s.get("protocolSection", {}) or {}

        # --- Filter: must be Interventional ---
        design = ps.get("designModule", {}) or {}
        if design.get("studyType", "").upper() != "INTERVENTIONAL":
            continue

        # --- Filter: trial status ---
        status = (ps.get("statusModule", {}) or {}).get("overallStatus", "").upper()
        if status not in allowed_status:
            continue

        # --- Filter: disease must be in structured conditions ---
        conds = (ps.get("conditionsModule") or {}).get("conditions") or []
        if not any(disease_lc in (c or "").lower() for c in conds):
            continue

        # --- Optional extra guard against exclusion-only mentions ---
        elig_text = ((ps.get("eligibilityModule") or {}).get("eligibilityCriteria") or "").lower()
        if "exclusion" in elig_text and disease_lc in elig_text and not conds:
            continue

        features.extend(list(to_features(s)))

    return {"type": "FeatureCollection", "features": features}


def run_multi(disease_terms):
    """
    Build one FeatureCollection for multiple diseases.
    - Calls your existing `run(disease)` for each term
    - Dedupes sites by (nctId, lon, lat)
    - Adds `properties["diseases"]` listing which disease(s) the site belongs to
    """
    per = []
    for d in disease_terms:
        fc = run(d)  # <-- uses your current run() unchanged
        per.append((d, fc))

    by_key = {}  # (nctId, lon, lat) -> feature
    for disease, fc in per:
        for feat in fc["features"]:
            props = feat["properties"]
            geom = feat["geometry"]
            nct = props.get("nctId")
            lon, lat = geom["coordinates"]
            key = (nct, float(lon), float(lat))
            if key not in by_key:
                # clone + add diseases list
                new_props = dict(props)
                new_props["diseases"] = [disease]
                by_key[key] = {
                    "type": "Feature",
                    "geometry": geom,
                    "properties": new_props
                }
            else:
                # append disease (dedup)
                dl = by_key[key]["properties"].setdefault("diseases", [])
                if disease not in dl:
                    dl.append(disease)

    return {"type": "FeatureCollection", "features": list(by_key.values())}


def nct_color(nctId: str) -> str:
    """Return a hex color string derived from the NCT ID."""
    if not nctId:
        return "#000000"  # default black
    # Hash the ID → use first 6 hex digits as color
    h = hashlib.md5(nctId.encode("utf-8")).hexdigest()
    return "#" + h[:6]



In [5]:
import folium, json, hashlib, datetime
from folium.plugins import MarkerCluster

def nct_color(nctId: str) -> str:
    if not nctId:
        return "#000000"
    h = hashlib.md5(nctId.encode("utf-8")).hexdigest()
    return "#" + h[:6]

def _intervention_names(interventions):
    names = []
    if isinstance(interventions, list):
        for it in interventions:
            if isinstance(it, dict):
                name = it.get("name")
                if name:
                    names.append(name.strip())
            elif isinstance(it, str):
                names.append(it.strip())
    seen, uniq = set(), []
    for n in names:
        if n and n.lower() not in seen:
            uniq.append(n); seen.add(n.lower())
    return "; ".join(uniq) if uniq else "Intervention: n/a"

def make_map_with_cluster_chart(feature_collection, outfile="trials_map.html"):
    feats = feature_collection.get("features", [])
    if not feats:
        raise ValueError("No geocoded locations.")

    # Payload the browser will render/control entirely
    points = []          # [{lat,lng,nct,color,title,status,phases,interventions,sponsor,address,url,diseases:[]}]
    nct_color_map = {}   # keep for consistent coloring per NCT

    for f in feats:
        lon, lat = f["geometry"]["coordinates"]
        p = f["properties"]
        nct = p["nctId"]
        color = nct_color_map.setdefault(nct, nct_color(nct))
        points.append({
            "lat": float(lat),
            "lng": float(lon),
            "nct": nct,
            "color": color,
            "title": p.get("title") or "",
            "status": p.get("status") or "",
            "phases": p.get("phases") or [],
            "interventions": _intervention_names(p.get("interventions")),
            "sponsor": p.get("sponsor") or "",
            "address": ", ".join([x for x in [p.get("address", {}).get("city"),
                                              p.get("address", {}).get("state"),
                                              p.get("address", {}).get("country")] if x]),
            "url": p.get("url") or "",
            "diseases": p.get("diseases") or []  # [] in single-disease mode, list in run_multi
        })

    # Bare map (no Python-side markers)
    m = folium.Map(location=[20, 0], zoom_start=2)

    # Force Folium to include Leaflet.markercluster assets
    MarkerCluster().add_to(m)

    map_var = m.get_name()
    js_points = json.dumps(points)
    run_date = datetime.datetime.now().strftime("%Y-%m-%d %H:%M")

    html_js = f"""
    <style>
      #nct-panel {{
        position:absolute; top:10px; right:10px; z-index:9999;
        background:white; padding:8px 10px; width:520px; max-height:420px;
        box-shadow:0 2px 8px rgba(0,0,0,0.2); border-radius:8px; font-family:sans-serif; overflow:auto;
      }}
      #nct-panel h4 {{ margin:0 0 6px 0; font-size:14px; }}
      .legend-note {{ font-size:11px; color:#666; margin-bottom:6px; }}
      #nct-presets {{ display:flex; flex-wrap:wrap; gap:8px; margin:6px 0 6px 0; }}
      #nct-presets .preset {{ padding:4px 8px; font-size:12px; }}
      #disease-filters {{ display:flex; flex-wrap:wrap; gap:10px; margin: 6px 0 8px 0; }}
      #disease-filters label {{ font-size:12px; user-select:none; }}
      .bar-row {{ display:flex; align-items:center; gap:8px; margin:6px 0; }}
      .bar-label {{ width:160px; font-size:12px; white-space:nowrap; overflow:hidden; text-overflow:ellipsis; }}
      .bar-track {{ position:relative; flex:1; background:#f0f0f0; height:14px; border-radius:7px; overflow:hidden; }}
      .bar-global {{ position:absolute; left:0; top:0; height:14px; opacity:.35; }}
      .bar-local {{ position:absolute; left:0; top:0; height:14px; }}
      .bar-count {{ width:90px; text-align:right; font-size:12px; color:#333; }}
      /* sticky footer stays visible at panel bottom */
      #nct-footer {{
        position: sticky; bottom: 0; left: 0;
        background: white; padding-top: 4px;
        font-size: 11px; color: #999; margin-top: 6px;
      }}
    </style>

    <div id="nct-panel" class="leaflet-control">
      <h4>Sites per NCT (cluster vs global)</h4>
      <div class="legend-note">Solid = selected diseases in current area • Faint = total worldwide (selected diseases)</div>

      <div id="nct-presets">
        <button class="preset" data-diseases='["Parkinson Disease"]'>Parkinson’s</button>
        <button class="preset" data-diseases='["Multiple System Atrophy"]'>MSA</button>
        <button class="preset" data-diseases='["Lewy Body Disease"]'>LBD</button>
        <button class="preset" data-diseases="ALL">All</button>
      </div>

      <div id="disease-filters"></div>
      <div id="nct-rows"></div>
      <div id="nct-footer">Analysis run: {run_date}</div>
    </div>

    <script>
    (function(){{
      var MAP_NAME = "{map_var}";
      var DATA = {js_points};  // array of points with all needed meta

      function initWhenReady(tries){{
        var map = window[MAP_NAME];
        if (!map || !map.getContainer) {{
          if ((tries||0) < 200) return setTimeout(function(){{initWhenReady((tries||0)+1)}}, 50);
          return;
        }}

        // Panel into map
        var panel = document.getElementById('nct-panel');
        var mapContainer = map.getContainer();
        if (panel.parentNode !== mapContainer) mapContainer.appendChild(panel);

        // Build cluster group & markers (in JS, so we fully control them)
        var clusterGroup = L.markerClusterGroup();
        map.addLayer(clusterGroup);

        // Build marker registry
        var MARKERS = []; // each: {{ marker, nct, lat, lng, diseases, color }}
        DATA.forEach(function(p){{
          var m = L.circleMarker([p.lat, p.lng], {{
            radius: 6, color: p.color, fillColor: p.color, fillOpacity: 0.85
          }});
          var phases = (p.phases || []).join(', ');
          var html = '<b>' + (p.title||'') + '</b><br>'
                   + 'NCT: <a href="' + (p.url||'#') + '" target="_blank">' + p.nct + '</a><br>'
                   + 'Status: ' + (p.status||'') + ' | Phase: ' + (phases||'') + '<br>'
                   + 'Interventions: ' + (p.interventions||'n/a') + '<br>'
                   + 'Sponsor: ' + (p.sponsor||'—') + '<br>'
                   + (p.address||'');
          m.bindPopup(html, {{maxWidth:350}});
          MARKERS.push({{ marker:m, nct:p.nct, lat:p.lat, lng:p.lng, diseases:(p.diseases||[]), color:p.color }});
        }});

        // Add all markers initially
        MARKERS.forEach(function(r){{ clusterGroup.addLayer(r.marker); }});

        // ==== Disease selection UI ====
        // Universe of diseases
        var ALL_DISEASES = [];
        (function(){{
          var seen = new Set();
          DATA.forEach(function(p){{
            (p.diseases||[]).forEach(function(d){{ if(!seen.has(d)){{seen.add(d); ALL_DISEASES.push(d);}} }});
          }});
          ALL_DISEASES.sort();
        }})();
        var sel = new Set(ALL_DISEASES);

        // Checkboxes
        var df = document.getElementById('disease-filters');
        df.innerHTML = '';
        ALL_DISEASES.forEach(function(d){{
          var id = 'dchk_' + d.replace(/\\W+/g,'_');
          var label = document.createElement('label');
          label.innerHTML = '<input type="checkbox" id="'+id+'" '+(sel.has(d)?'checked':'')+'> ' + d;
          df.appendChild(label);
          label.querySelector('input').addEventListener('change', function(e){{
            if (e.target.checked) sel.add(d); else sel.delete(d);
            applyMarkerFilter();
            rerender(currentBounds());
          }});
        }});

        // Presets
        function setSelectionTo(listOrAll){{
          sel.clear();
          if (listOrAll === "ALL") ALL_DISEASES.forEach(function(d){{ sel.add(d); }});
          else (listOrAll||[]).forEach(function(d){{ sel.add(d); }});
          // sync UI
          ALL_DISEASES.forEach(function(d){{
            var id = 'dchk_' + d.replace(/\\W+/g,'_');
            var el = document.getElementById(id);
            if (el) el.checked = sel.has(d);
          }});
          applyMarkerFilter();
          rerender(currentBounds());
        }}
        document.querySelectorAll('#nct-presets .preset').forEach(function(btn){{
          btn.addEventListener('click', function(){{
            var val = btn.getAttribute('data-diseases');
            if (val === 'ALL') return setSelectionTo('ALL');
            try {{ setSelectionTo(JSON.parse(val)); }} catch(e) {{}}
          }});
        }});

        // ==== Filtering + chart logic ====
        function markerIncluded(rec){{
          if (!ALL_DISEASES.length) return true; // single-disease mode
          var ds = rec.diseases || [];
          for (var i=0;i<ds.length;i++) if (sel.has(ds[i])) return true;
          return false;
        }}
        function applyMarkerFilter(){{
          MARKERS.forEach(function(rec){{
            var has = clusterGroup.hasLayer(rec.marker);
            var should = markerIncluded(rec);
            if (should && !has) clusterGroup.addLayer(rec.marker);
            else if (!should && has) clusterGroup.removeLayer(rec.marker);
          }});
        }}

        function withinBounds(lat, lng, bounds) {{
          return bounds.contains(L.latLng(lat, lng));
        }}
        function currentBounds(){{
          return map.getBounds ? map.getBounds() : L.latLngBounds([[-90,-180],[90,180]]);
        }}

        function localCounts(bounds){{
          var counts = {{}};
          MARKERS.forEach(function(r){{
            if (!markerIncluded(r)) return;
            if (!bounds || withinBounds(r.lat, r.lng, bounds)) {{
              counts[r.nct] = (counts[r.nct]||0) + 1;
            }}
          }});
          return counts;
        }}
        function globalCountsSelected(){{
          var counts = {{}};
          MARKERS.forEach(function(r){{
            if (!markerIncluded(r)) return;
            counts[r.nct] = (counts[r.nct]||0) + 1;
          }});
          return counts;
        }}

        function renderBars(local){{
          var cont = document.getElementById('nct-rows');
          cont.innerHTML = '';
          var globalSel = globalCountsSelected();

          var keys = Object.keys(local);
          if (!keys.length) {{
            cont.innerHTML = '<div style="color:#666;font-size:12px;">No sites in selection for chosen diseases.</div>';
            return;
          }}

          keys.sort(function(a,b){{
            if (local[b] !== local[a]) return local[b] - local[a];
            return (globalSel[b]||0) - (globalSel[a]||0);
          }});

          var maxG = 1;
          keys.forEach(function(k){{ maxG = Math.max(maxG, globalSel[k]||0); }});
          keys.forEach(function(nct){{
            var g = globalSel[nct] || 0;
            var l = local[nct] || 0;
            var rec = MARKERS.find(function(r){{return r.nct===nct;}});
            var color = (rec && rec.color) ? rec.color : '#888';
            var gPct = Math.max(4, Math.round(100 * g / (maxG||1)));
            var lPct = Math.min(gPct, Math.round(gPct * (l / (g||1))));

            var row = document.createElement('div'); row.className = 'bar-row';

            var label = document.createElement('div'); label.className = 'bar-label';
            label.textContent = nct;
            var any = DATA.find(function(p){{return p.nct===nct;}});
            label.title = (any && any.interventions) ? any.interventions : 'Intervention: n/a';

            var track = document.createElement('div'); track.className = 'bar-track';
            track.title = label.title;

            var gdiv = document.createElement('div'); gdiv.className='bar-global';
            gdiv.style.width = gPct + '%'; gdiv.style.background = color;
            var ldiv = document.createElement('div'); ldiv.className='bar-local';
            ldiv.style.width = lPct + '%'; ldiv.style.background = color;

            var cnt = document.createElement('div'); cnt.className='bar-count';
            cnt.textContent = l + ' / ' + g;

            track.appendChild(gdiv); track.appendChild(ldiv);
            row.appendChild(label); row.appendChild(track); row.appendChild(cnt);
            cont.appendChild(row);
          }});
        }}

        function rerender(bounds){{ renderBars(localCounts(bounds)); }}

        // Initial sync
        applyMarkerFilter();
        rerender(L.latLngBounds([[-90,-180],[90,180]]));

        // Cluster click → render for that area
        clusterGroup.on('clusterclick', function(e){{
          rerender(e.layer.getBounds());
        }});
      }}
      initWhenReady(0);
    }})();
    </script>
    """

    m.get_root().html.add_child(folium.Element(html_js))
    m.save(outfile)
    return m


In [6]:
# m = make_map_with_cluster_chart(fc, outfile="C:\\Users\\robin\\Programs\\Proteinopathy\\reveal.js\\trials_map.html")
# m  # inline

In [7]:
m = make_map_with_cluster_chart(run_multi(["Lewy Body Disease", "Parkinson Disease", "Multiple System Atrophy"]), outfile="C:\\Users\\robin\\Programs\\Proteinopathy\\reveal.js\\trials_map.html")
m  # inline

In [8]:
convert_and_render(
    geojson_or_path=run_multi(["Lewy Body Disease", "Parkinson Disease", "Multiple System Atrophy"]),  # or pass a dict
    out_json="synuc_trials_graph_data.json",
    out_html="C:\\Users\\robin\\Programs\\Proteinopathy\\reveal.js\\synuc_trials_graph.html",
    title="Synucleinopathies"  # optional
)

(WindowsPath('synuc_trials_graph_data.json'),
 WindowsPath('C:/Users/robin/Programs/Proteinopathy/reveal.js/synuc_trials_graph.html'))