In [None]:
import os, re, time
from collections import deque
from typing import Dict, List, Optional, Tuple

import requests
from tqdm.auto import tqdm

COMMONS_API = "https://commons.wikimedia.org/w/api.php"

BEAR_SPECIES_CATEGORIES: Dict[str, str] = {
    "polar": "Category:Ursus maritimus",
    "brown": "Category:Ursus arctos",
    "american_black": "Category:Ursus americanus",
    "asiatic_black": "Category:Ursus thibetanus",
    "sun": "Category:Helarctos malayanus",
    "sloth": "Category:Melursus ursinus",
    "spectacled": "Category:Tremarctos ornatus",
    "giant_panda": "Category:Ailuropoda melanoleuca",
    "teddy_bear": "Category:Teddy_bears",

}

IMAGE_EXTS = (".jpg", ".jpeg", ".png", ".webp")


def _session_with_retries(total_retries: int = 5) -> requests.Session:
    s = requests.Session()
    adapter = requests.adapters.HTTPAdapter(max_retries=total_retries)
    s.mount("https://", adapter)
    s.mount("http://", adapter)
    s.headers.update({"User-Agent": "bear-thumb-downloader/1.1 (educational)"})
    return s


def _commons_query(session: requests.Session, params: dict) -> dict:
    r = session.get(COMMONS_API, params={**params, "format": "json"}, timeout=30)
    r.raise_for_status()
    return r.json()


def _safe_filename(title: str) -> str:
    name = title.split(":", 1)[-1]
    name = re.sub(r"[^\w\-. ]+", "_", name).strip().replace(" ", "_")
    return name


def _category_members(
    session: requests.Session,
    category: str,
    cmtype: str,
    limit: int = 200,
    cmcontinue: Optional[str] = None,
) -> Tuple[List[str], Optional[str]]:
    params = {
        "action": "query",
        "list": "categorymembers",
        "cmtitle": category,
        "cmtype": cmtype,      # "file" or "subcat"
        "cmlimit": limit,
    }
    if cmcontinue:
        params["cmcontinue"] = cmcontinue

    data = _commons_query(session, params)
    members = data.get("query", {}).get("categorymembers", [])
    titles = [m.get("title", "") for m in members if m.get("title")]
    nxt = data.get("continue", {}).get("cmcontinue")
    return titles, nxt


def collect_files_recursive(
    session: requests.Session,
    root_category: str,
    max_files: int = 200,
    max_categories: int = 200,
) -> List[str]:
    """
    Breadth-first walk subcategories and collect file titles.
    Stops when max_files is reached.
    """
    seen_cats = set([root_category])
    q = deque([root_category])
    file_titles: List[str] = []

    while q and len(file_titles) < max_files and len(seen_cats) <= max_categories:
        cat = q.popleft()

        # 1) Collect files in this category (may be 0 for top-level taxonomy cats)
        cmcontinue = None
        while len(file_titles) < max_files:
            files, cmcontinue = _category_members(session, cat, cmtype="file", cmcontinue=cmcontinue)
            for t in files:
                tl = t.lower()
                if any(tl.endswith(ext) for ext in IMAGE_EXTS):
                    file_titles.append(t)
                    if len(file_titles) >= max_files:
                        break
            if not cmcontinue:
                break
            time.sleep(0.15)

        # 2) Enqueue subcategories
        cmcontinue = None
        while len(seen_cats) <= max_categories:
            subs, cmcontinue = _category_members(session, cat, cmtype="subcat", cmcontinue=cmcontinue)
            for subcat in subs:
                if subcat not in seen_cats:
                    seen_cats.add(subcat)
                    q.append(subcat)
            if not cmcontinue:
                break
            time.sleep(0.15)

    return file_titles[:max_files]

def get_thumbnail_urls(session, titles, thumb_width=224):
    out = []
    batch_size = 50

    for i in range(0, len(titles), batch_size):
        batch = titles[i:i+batch_size]

        params = {
            "action": "query",
            "prop": "imageinfo",
            "titles": "|".join(batch),
            "iiprop": "url",          # <-- KEY CHANGE
            "iiurlwidth": thumb_width # <-- request thumbnail width
        }

        data = _commons_query(session, params)
        pages = data.get("query", {}).get("pages", {})

        for page in pages.values():
            title = page.get("title")
            ii = page.get("imageinfo", [])
            if not title or not ii:
                continue

            info = ii[0]
            # When iiurlwidth is set, MediaWiki commonly returns thumburl/thumbwidth
            url = info.get("thumburl") or info.get("url")
            if url:
                out.append((title, url))

        time.sleep(0.15)

    return out



def download_urls(
    session: requests.Session,
    pairs: List[Tuple[str, str]],
    out_dir: str,
    max_images: int = 30,
) -> List[str]:
    os.makedirs(out_dir, exist_ok=True)
    saved: List[str] = []

    for title, url in tqdm(pairs[:max_images], desc=f"Downloading → {out_dir}"):
        try:
            resp = session.get(url, timeout=30)
            resp.raise_for_status()

            path = os.path.join(out_dir, _safe_filename(title))
            with open(path, "wb") as f:
                f.write(resp.content)

            saved.append(path)
        except Exception as e:
            print(f"Skip: {title} ({e})")

        time.sleep(0.1)

    return saved


def download_bears(
    species_key: str,
    out_root: str = "bear_images_small",
    max_images: int = 30,
    thumb_width: int = 224,
) -> List[str]:
    if species_key not in BEAR_SPECIES_CATEGORIES:
        raise ValueError(f"Unknown species_key={species_key}. Choose one of: {sorted(BEAR_SPECIES_CATEGORIES)}")

    session = _session_with_retries()
    root_cat = BEAR_SPECIES_CATEGORIES[species_key]

    # Collect file titles recursively because top-level taxonomy categories can be mostly subcats
    titles = collect_files_recursive(session, root_cat, max_files=max_images * 5)

    if not titles:
        print(f"No files found under {root_cat}. Try increasing max_categories or using a more specific subcategory.")
        return []

    pairs = get_thumbnail_urls(session, titles, thumb_width=thumb_width)
    out_dir = os.path.join(out_root, species_key)

    return download_urls(session, pairs, out_dir, max_images=max_images)



In [6]:
# Example:
paths = download_bears("brown", max_images=20, thumb_width=224)
print("Saved:", len(paths))


Downloading → bear_images_small/brown: 100%|██████████| 20/20 [02:06<00:00,  6.30s/it]

Saved: 20





In [4]:

# Example:
paths = download_bears("polar", max_images=20, thumb_width=224)
print("Saved:", len(paths))


Downloading → bear_images_small/polar: 100%|██████████| 20/20 [02:04<00:00,  6.22s/it]

Saved: 20



