In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import janitor
import numpy as np
import pandas as pd

from pybbbc import BBBC021

In [None]:
bbbc021 = BBBC021(moa=[moa for moa in BBBC021.MOA if moa != "null"])

In [None]:
bbbc021.moa_df[["compound", "moa"]].query('moa != "null"').drop_duplicates().groupby(
    "moa"
).count()

# Splitting strategy

* Training and validation sets get at least one compound a piece
* Training set gets the compound with the most images, followed by validation, then test set
* The test set will not have a compound for Eg5 inhibitor or cholesterol-lowering MoAs
* MoAs with 4 compounds will have two sent to training set
* DMSO will be split as closely to the desired ratio as possible

In [None]:
plates = bbbc021.image_df["plate"].unique()
plates

In [None]:
num_plates = len(plates)
num_plates

In [None]:
cdf = np.linspace(0, 1, 55)
cdf

In [None]:
dmso_train_frac = 0.5
dmso_val_frac = 0.35

In [None]:
train_stop_idx = np.flatnonzero(cdf <= dmso_train_frac)[-1] + 1
val_stop_idx = np.flatnonzero(cdf <= dmso_train_frac + dmso_val_frac)[-1] + 1

In [None]:
train_plates = plates[:train_stop_idx]
val_plates = plates[train_stop_idx:val_stop_idx]
test_plates = plates[val_stop_idx:]

train_plates, val_plates, test_plates

In [None]:
train_dmso_idcs = bbbc021.image_df.query(
    'plate in @train_plates and compound == "DMSO"'
)['image_idx'].values
val_dmso_idcs = bbbc021.image_df.query(
    'plate in @val_plates and compound == "DMSO"'
)['image_idx'].values
test_dmso_idcs = bbbc021.image_df.query(
    'plate in @test_plates and compound == "DMSO"'
)['image_idx'].values

In [None]:
cmpd_im_count_df = (
    bbbc021.image_df.query('compound != "DMSO"')
    .groupby(["compound", "moa"])["site"]
    .count()
    .to_frame("num_images")
    .query("num_images > 0")
    .reset_index()
    .sort_values(["moa", "num_images"], ascending=[True, False])
)

train_compounds = []
val_compounds = []
test_compounds = []

for moa, cur_moa_df in cmpd_im_count_df.groupby("moa"):
    if len(cur_moa_df) == 0:
        continue

    train_compounds.append(cur_moa_df.iloc[0]["compound"])
    val_compounds.append(cur_moa_df.iloc[1]["compound"])

    try:
        test_compounds.append(cur_moa_df.iloc[2]["compound"])
        train_compounds.append(cur_moa_df.iloc[3]["compound"])
    except IndexError:
        pass


def fetch_compound_idcs(compounds) -> np.ndarray:
    return bbbc021.image_df.query("compound in @compounds")["image_idx"].values


train_compound_idcs = fetch_compound_idcs(train_compounds)
val_compound_idcs = fetch_compound_idcs(val_compounds)
test_compound_idcs = fetch_compound_idcs(test_compounds)

In [None]:
train_idcs = np.concatenate((train_compound_idcs, train_dmso_idcs))
train_idcs.sort()

val_idcs = np.concatenate((val_compound_idcs, val_dmso_idcs))
val_idcs.sort()

test_idcs = np.concatenate((test_compound_idcs, test_dmso_idcs))
test_idcs.sort()



In [None]:
len(train_idcs), len(val_idcs), len(test_idcs)