# Sample images

Get 100 randomly sampled images from the following 4 classes
1. flowering
1. not flowering
1. fruiting
1. not fruiting

In [1]:
import sys

sys.path.append('..')

In [12]:
import shutil
import sqlite3
from pathlib import Path

import pandas as pd

from herbarium.pylib import db

In [3]:
DATA_DIR = Path('..') / 'data'

TEMP_DIR = DATA_DIR / 'temp'
IMAGE_DIR = DATA_DIR / 'images'

DB = DATA_DIR / 'angiosperms.sqlite'

In [4]:
def get_image_class(cls):
    sql = f"""
        select *
          from angiosperms
          join images using (coreid)
         where {cls} = 1
      order by random()
         limit 100
    """
    with sqlite3.connect(DB) as cxn:
        df = pd.read_sql(sql, cxn)
    path = TEMP_DIR / f'{cls}.csv'
    df.to_csv(path, index=False)
    dir_ = TEMP_DIR / f'{cls}'
    dir_.mkdir(parents=True, exist_ok=True)
    for idx, row in df.iterrows():
        src = Path('..') / row['path']
        dst = dir_ / Path(row['path']).name
        shutil.copy(src, dst)

In [5]:
get_image_class('flowering')
get_image_class('not_flowering')
get_image_class('fruiting')
get_image_class('not_fruiting')

## Build training, validation, and test datasets

The classes are woefully unbalanced and many images have multiple labels, so I'm going to weight the losses per class. I'm also going to make sure that the test and validation datasets have a representative amount of all classes. Also note that any image may belong to multiple classes.

I'm saving the splits so that I don't wind up training on my test data.

### Get count of image classes

In [16]:
sql = """
    select count(*) as n,
           sum(flowering) as flowering,
           sum(not_flowering) as not_flowering, 
           sum(fruiting) as fruiting,
           sum(not_fruiting) as not_fruiting
      from images
      join angiosperms using (coreid)
"""

In [25]:
counts = db.rows_as_dicts(DB, sql)[0]

total = counts['n']

counts = [(k, v) for k, v in counts.items() if k != 'n']
counts = sorted(counts, key=lambda t: t[1])

print(f'{total=}')
counts

total=15123


[('not_flowering', 1370.0),
 ('not_fruiting', 1434.0),
 ('fruiting', 4543.0),
 ('flowering', 11254.0)]

### Get records from each class

In [None]:
used = set()

for cls, count in counts:
    sql = f"""select coreid
        from images join angiosperms using (coreid)
        where {cls} = 1"""
    coreids = db.rows_as_dicts(DB, sql)
    coreids = [i for i in coreids where i not in used]