# Split image dataset

Split the images into 3 datasets:
1. A training dataset (~60%)
1. A validation dataset (~20%). This is used to score the training progress.
1. A test dataset (~20%). This is a holdout dataset used to do a final score of the data.

Within each dataset there are 4 classes of possibly overlapping data.
1. flowering
1. not flowering
1. fruiting
1. not fruiting

We are currently only using one trait i.e. we are only using one trait at a time.
- flowering and not_flowering
- fruiting and not_fruiting

In [1]:
import sys

sys.path.append('..')

In [2]:
import shutil
import sqlite3
from pathlib import Path

import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.utils.class_weight import compute_class_weight
from tqdm import tqdm

from herbarium import db
from herbarium.pylib.datasets.herbarium_dataset import HerbariumDataset

In [3]:
SPLIT_RUN = 'fruits_all_orders'

In [4]:
# CLASSES = 'flowering not_flowering fruiting not_fruiting'.split()
# CLASSES = 'flowering not_flowering'.split()
CLASSES = 'fruiting not_fruiting'.split()

In [5]:
DATA_DIR = Path('..') / 'data'

TEMP_DIR = DATA_DIR / 'temp'
IMAGE_DIR = DATA_DIR / 'images'

DB = DATA_DIR / 'angiosperms.sqlite'

## Build training, validation, and test splits

The classes are woefully unbalanced, so I'm going to weight the losses per class. I'm also going to make sure that the test and validation splits have a representative amount of all classes. Also note that any image may belong to multiple classes.

I'm saving the splits so that I don't wind up training on my test data.

### Get count of image classes

In [6]:
sql = """
    select order_, count(*) as n,
           sum(flowering) as flowering,
           sum(not_flowering) as not_flowering, 
           sum(fruiting) as fruiting,
           sum(not_fruiting) as not_fruiting
      from images
      join angiosperms using (coreid)
  group by order_
  order by n desc
"""

In [7]:
counts = db.rows_as_dicts(DB, sql)

order_df = pd.DataFrame(counts)
order_df.head()

Unnamed: 0,order_,n,flowering,not_flowering,fruiting,not_fruiting
0,asterales,2721,2469.0,102.0,482.0,179.0
1,caryophyllales,1946,1544.0,80.0,591.0,84.0
2,poales,1708,870.0,537.0,737.0,337.0
3,lamiales,1091,951.0,21.0,210.0,45.0
4,fabales,968,797.0,3.0,377.0,20.0


In [8]:
# db.create_splits_table(DB)

### Limit data to two orders

In [9]:
# order_df = order_df.loc[order_df['order_'].isin(['asterales', 'fabales']), :]

### Get records for each order/trait combination

In [10]:
used = set()

for _, row in tqdm(order_df.iterrows()):
    order = row.order_

    for cls in CLASSES:

        sql = f"""
           select coreid
             from images join angiosperms using (coreid)
            where order_ = ? and {cls} = 1
         order by random()"""

        rows = db.rows_as_dicts(DB, sql, [order])

        coreids = {row['coreid'] for row in rows} - used
        used |= coreids

        recs = [{'split_set': SPLIT_RUN, 'coreid': i} for i in coreids]

        count = len(coreids)

        test_split = round(count * 0.2)
        val_split = round(count * 0.4)

        for i in range(count):
            if i <= test_split:
                split = 'test'
            elif i <= val_split:
                split = 'val'
            else:
                split = 'train'

            recs[i]['split'] = split

        db.insert_splits(DB, recs)

52it [00:38,  1.34it/s]
