# (Do not use) Try to fix class/order imbalances

I edited the "01" notebook to better balance the class/order combinations.

I won't be able to fix all imbalances but I may be able to improve them for well represented plant orders.

1. For each order get the counts of each class: Flowering, Fruiting, Neither.
1. Try to get to at least 500 samples per class within each order.

In [1]:
import json
import math
import re
import sqlite3
import time
from pathlib import Path
from types import SimpleNamespace

import pandas as pd
import requests
from tqdm.notebook import tqdm

In [2]:
INAT_DIR = Path("..") / "data" / "inat"

In [3]:
args = SimpleNamespace(
    base_url="https://api.inaturalist.org/v1/observations?",
    taxa=INAT_DIR / "taxa.csv.gz",  # Loction of iNat taxon info
    db=INAT_DIR / "inat.sqlite",
    obs_dir=INAT_DIR / "obs",  # Observation JSON data
    image_dir=INAT_DIR / "images",  # Save images here
    plantae=47126,  # The taxon ID for the plant kingdom
    per_page=100,  # How many iNaturalist records to get per API request
    pages=5,  # Limit to this number of a pages
)

## Get target plant orders

In [4]:
# !aws s3 cp s3://inaturalist-open-data/taxa.csv.gz $args.taxa --no-sign-request

In [5]:
sql = """
    select taxon_id, name as order_
    from   taxa
    where  rank = 'order'
    and    name in (select distinct order_ from obs)
"""

with sqlite3.connect(args.db) as cxn:
    df = pd.read_sql(sql, cxn)

orders = df.set_index("order_")["taxon_id"].to_dict()
len(orders)

64

## Make API requests for observation data to iNaturalist

### The URL for the API request to iNaturalist.

In [6]:
def get_url(taxon_id, term, per_page, page=1):
    url = "&".join(
        [
            args.base_url,
            "photos=true",
            "quality_grade=research",
            f"taxon_id={taxon_id}",
            "term_id=12",
            f"term_value_id={term}",
            f"per_page={per_page}",
            f"page={page}",
            "order=desc",
            "order_by=created_at",
            "photo_license=cc0,cc-by,cc-by-nc,cc-by-sa",
        ]
    )
    return url

### Make the API request.

In [7]:
def call_api(url, results):
    try:
        response = requests.get(url)
        results += response.json()["results"]
    except KeyError:
        print(response.json())
        raise  # ???????????????????????????
    return response.json()

### Loop through each plant order and make requests to the API

In [None]:
pheno = {
    "flowering": 13,
    "fruiting": 14,
    "neither": 21,
}

terms = [
    {"term": "flowering", "values": [13], "without": 14},
    {"term": "fruiting", "values": [14], "without": 13},
    {"term": "neither", "values": [21], "without": 0},
    {"term": "both", "values": [13, 14], "without": 0},
]

for order, taxon_id in tqdm(orders.items(), position=0, leave=None):

    for term in terms:
        sql = """select count(*) from obs where order = ?"""
        clauses = [f" and {pheno[v]} = 1" for v in [13, 14, 21] if v in term["values"]]
        clauses += [f" and {pheno[v]} = 0" for v in [13, 14, 21] if v == term["without"]]

        for v in [13, 14, 21]:
            if v in term["values"]:
                clauses.append(f" and {pheno[value]} = 1")

        count = cxn.execute(sql, (order,))
        count = count.fetchone()[0]
        if count >= args.pages * args.per_page:
            continue

        url = get_url(taxon_id, term["values"], term["without"])

        results = []
        response = call_api(url, results)

        if len(results) == 0:
            continue

        last = math.ceil(response["total_results"] / args.per_page) + 1
        last = min(last, args.pages + 1)

        for page in tqdm(range(2, last), position=1, leave=None):
            url = get_url(taxon_id, term["values"], term["without"], page=page)
            call_api(url, results)

        path = args.obs_dir / f"pheno_{term['term']}_order_{order}.json"
        with open(path, "w") as out_file:
            json.dump(results, out_file)

In [9]:
with sqlite3.connect(args.db) as cxn:
    for order, taxon_id in tqdm(orders.items(), position=0, leave=None):
        for term, term_id in {"flowering": 13, "fruiting": 14, "neither": 21}.items():
            sql = f"""
                select count(*) from obs
                where  {term} = 1
                and    order_ = ?
            """
            count = cxn.execute(sql, (order,))
            count = count.fetchone()[0]
            if count >= args.pages * args.per_page:
                continue

            url = get_url(taxon_id, term_id, args.per_page)

            results = []
            response = call_api(url, results)

            if len(results) == 0:
                continue

            last = math.ceil(response["total_results"] / args.per_page) + 1
            last = min(last, args.pages + 1)

            for page in tqdm(range(2, last), position=1, leave=None):
                url = get_url(taxon_id, term_id, args.per_page, page=page)
                call_api(url, results)

            path = args.obs_dir / f"pheno_{term}_order_{order}.json"
            with open(path, "w") as out_file:
                json.dump(results, out_file)

  0%|          | 0/64 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

0it [00:00, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

0it [00:00, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

0it [00:00, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

0it [00:00, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

0it [00:00, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

## Download images

### Try the image download a few times.

In [12]:
def download_image(url, path):
    if path.exists():
        return
    for attempt in range(3):
        try:
            image = requests.get(url).content
            with open(path, "wb") as out_file:
                out_file.write(image)
            return
        except (TimeoutError, ConnectionError, ConnectionResetError):
            time.sleep(20)
    else:
        raise TimeoutError

### Loop through each set of downloaded JSON observations and get their images.

In [15]:
size = "medium."  # medium large original

json_paths = sorted(args.obs_dir.glob("pheno_*_order_*.json"))

for json_path in tqdm(json_paths, position=0, leave=None):
    with open(json_path) as in_file:
        data = json.load(in_file)
    for result in tqdm(data, position=1, leave=None):
        for photo in result["photos"]:
            url = photo["url"].replace("square.", size)
            match = re.search(r"/(\d+)/[a-z]+\.([a-z]+)$", url, flags=re.I)
            if not match:
                continue
            image_path = args.image_dir / f"{match[1]}_{size}{match[2]}"
            download_image(url, image_path)

  0%|          | 0/82 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/487 [00:00<?, ?it/s]

  0%|          | 0/180 [00:00<?, ?it/s]

  0%|          | 0/194 [00:00<?, ?it/s]

  0%|          | 0/154 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/435 [00:00<?, ?it/s]

  0%|          | 0/194 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/149 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/287 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/108 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/104 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/201 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/109 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/106 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/280 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/249 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

ConnectionError: ('Connection aborted.', ConnectionResetError(104, 'Connection reset by peer'))