# Sample images

Get 100 randomly sampled images from the following 4 classes
1. flowering
1. not flowering
1. fruiting
1. not fruiting

In [1]:
import sys

sys.path.append('..')

In [15]:
import shutil
import sqlite3
from pathlib import Path

import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.utils.class_weight import compute_class_weight

from herbarium.pylib import db
from herbarium.pylib.herbarium_dataset import HerbariumDataset

In [3]:
DATA_DIR = Path('..') / 'data'

TEMP_DIR = DATA_DIR / 'temp'
IMAGE_DIR = DATA_DIR / 'images'

DB = DATA_DIR / 'angiosperms.sqlite'

## Build training, validation, and test splits

The classes are woefully unbalanced, so I'm going to weight the losses per class. I'm also going to make sure that the test and validation splits have a representative amount of all classes. Also note that any image may belong to multiple classes.

I'm saving the splits so that I don't wind up training on my test data.

### Get count of image classes

In [4]:
sql = """
    select count(*) as n,
           sum(flowering) as flowering,
           sum(not_flowering) as not_flowering, 
           sum(fruiting) as fruiting,
           sum(not_fruiting) as not_fruiting
      from images
      join angiosperms using (coreid)
"""

In [5]:
counts = db.rows_as_dicts(DB, sql)[0]

total = counts['n']

counts = [(k, v) for k, v in counts.items() if k != 'n']
counts = sorted(counts, key=lambda t: t[1])

print(f'{total=}')
counts

total=15123


[('not_flowering', 1370.0),
 ('not_fruiting', 1434.0),
 ('fruiting', 4543.0),
 ('flowering', 11254.0)]

In [9]:
split_run = 'first_split'
# db.create_split_table(DB)

### Get records from each class

In [10]:
# used = set()

# for cls, count in counts:
#     sql = f"""select coreid
#         from images join angiosperms using (coreid)
#         where {cls} = 1"""
#     rows = db.rows_as_dicts(DB, sql)

#     coreids = {row['coreid'] for row in rows} - used
#     used |= coreids

#     coreids = list(coreids)

#     train_ids, test_ids = train_test_split(coreids)
#     train_ids, val_ids = train_test_split(train_ids)

#     train_recs = [{
#         'split_run': split_run,
#         "split": "train",
#         "coreid": i
#     } for i in train_ids]
#     db.insert_splits(DB, train_recs)

#     val_recs = [{
#         'split_run': split_run,
#         "split": "val",
#         "coreid": i
#     } for i in val_ids]
#     db.insert_splits(DB, val_recs)

#     test_recs = [{
#         'split_run': split_run,
#         "split": "test",
#         "coreid": i
#     } for i in test_ids]
#     db.insert_splits(DB, test_recs)

## Compute weights

In [19]:
classes = 'flowering not_flowering fruiting not_fruiting'.split()

sql = f"""
    select count(*) as n,
           sum(flowering) as flowering,
           sum(not_flowering) as not_flowering, 
           sum(fruiting) as fruiting,
           sum(not_fruiting) as not_fruiting
      from splits
      join angiosperms using (coreid)
     where split_run = ?
       and split = ?
"""
train_counts = db.rows_as_dicts(DB, sql, params=[split_run, 'train'])[0]

total = train_counts['n']

for cls in classes:
    pos = train_counts[cls]
    pos_weight = (total - pos) / pos
    print(pos_weight)

0.3096855743403381
9.764935064935065
2.265957446808511
9.083941605839415


## Pick 100 records from each class for QC

In [4]:
def get_image_class(cls):
    sql = f"""
        select *
          from angiosperms
          join images using (coreid)
         where {cls} = 1
      order by random()
         limit 100
    """
    with sqlite3.connect(DB) as cxn:
        df = pd.read_sql(sql, cxn)
    path = TEMP_DIR / f'{cls}.csv'
    df.to_csv(path, index=False)
    dir_ = TEMP_DIR / f'{cls}'
    dir_.mkdir(parents=True, exist_ok=True)
    for idx, row in df.iterrows():
        src = Path('..') / row['path']
        dst = dir_ / Path(row['path']).name
        shutil.copy(src, dst)

In [5]:
# get_image_class('flowering')
# get_image_class('not_flowering')
# get_image_class('fruiting')
# get_image_class('not_fruiting')