In [29]:
from pathlib import Path

import pandas as pd
from sklearn.model_selection import GroupShuffleSplit, GroupKFold

See which files are present in the targets and list them

In [62]:
# Load and inspect the targets.
df_params = pd.read_csv("../data/params.csv")
df_persons = pd.read_csv("../data/sample_to_person.csv")
df = df_params.merge(df_persons, on="sample_id")

df["filename"] = [str(index) + ".tif" for index in df["sample_id"]]
filenames = list(str(fn.name) for fn in Path("../data/stacks").glob("*.tif"))  # Make sure the image is there!
df = df[df["filename"].isin(filenames)]
print(f"there are {len(df)} eligible samples")

there are 51 eligible samples


First, shuffle the dataset and split the dataset in train and test, while making sure no person is in both the training and testing set.

In [34]:
# Split the full dataset in train and test.
gss = GroupShuffleSplit(1, test_size=int(0.05 * len(df)), random_state=42)
for split in gss.split(df["sample_id"], groups=df["person_id"]):
    train, test = split
    train_df = df[df["sample_id"].isin(train)]
    test_df = df[df["sample_id"].isin(test)]
    train_df.to_csv("../data/splits/train.csv", index=False)
    test_df.to_csv("../data/splits/test.csv", index=False)

Make cross validation splits, again making sure no person is in both the training and validation sets.

In [35]:
train_path = "../data/splits/train.csv"
# Split the train dataset into train and validation for cross-validation.
df = pd.read_csv(train_path)
gkf = GroupKFold(3)
for fold, split in enumerate(gkf.split(df["sample_id"], groups=df["person_id"])):
    train, val = split
    train_df = df[df["sample_id"].isin(train)]
    val_df = df[df["sample_id"].isin(val)]
    train_df.to_csv(f"../data/splits/fold-{fold}-split-train.csv", index=False)
    val_df.to_csv(f"../data/splits/fold-{fold}-split-val.csv", index=False)

In [6]:
# TODO: as in 10-sj-preprocess.ipynb, make plots of the distribution of the samples into folds.

In [63]:
df.columns = map(str.lower, df.columns)
df[map(str.lower, ["A", "k", "xc"])].to_numpy()

array([[ 3.30971613, 28.96359874,  1.20185066],
       [ 5.88107624, 24.94981488,  1.17793236],
       [ 3.66373522, 21.9629815 ,  1.26584654],
       [ 2.42745287, 17.7658721 ,  1.28946096],
       [ 3.14678712, 46.67170878,  1.40218732],
       [ 4.10505452, 29.31514811,  1.32206248],
       [ 3.48449005, 24.49080838,  1.2175642 ],
       [ 3.30196127, 23.38083588,  1.23564071],
       [ 2.95022377, 21.6432568 ,  1.22797869],
       [ 3.32131185, 20.25453183,  1.24279591],
       [ 1.5499594 , 19.87204168,  1.30100379],
       [ 2.56055668, 30.44839372,  1.31593115],
       [ 4.73983948, 21.27736641,  1.23834535],
       [ 3.5970366 , 20.82117566,  1.2267107 ],
       [ 1.7854561 , 21.7919852 ,  1.23368762],
       [ 3.79041797, 22.27355876,  1.27490985],
       [ 0.77196126, 31.14547501,  1.26364692],
       [ 1.22452766, 30.05749807,  1.1664169 ],
       [ 0.86433824, 24.69739825,  1.1741678 ],
       [ 1.95450038, 27.00940134,  1.18021803],
       [ 1.3003575 , 28.84285736,  1.197

array([[ 3.30971613, 28.96359874,  1.20185066],
       [ 5.88107624, 24.94981488,  1.17793236],
       [ 4.10505452, 29.31514811,  1.32206248],
       [ 3.48449005, 24.49080838,  1.2175642 ],
       [ 3.30196127, 23.38083588,  1.23564071],
       [ 2.95022377, 21.6432568 ,  1.22797869],
       [ 3.32131185, 20.25453183,  1.24279591],
       [ 1.5499594 , 19.87204168,  1.30100379],
       [ 2.56055668, 30.44839372,  1.31593115],
       [ 4.73983948, 21.27736641,  1.23834535],
       [ 3.5970366 , 20.82117566,  1.2267107 ],
       [ 1.7854561 , 21.7919852 ,  1.23368762],
       [ 3.79041797, 22.27355876,  1.27490985],
       [ 0.77196126, 31.14547501,  1.26364692],
       [ 3.81469879, 22.90397478,  1.21774181],
       [ 3.48598793, 39.7891409 ,  1.17091117],
       [10.05513089, 16.92341643,  1.36711246],
       [ 6.35045975, 16.47766992,  1.26356157],
       [ 4.38591213, 20.85055381,  1.19793373],
       [ 7.16358037, 19.62378149,  1.26913023],
       [ 3.98169997, 21.01658934,  1.199

In [66]:
gss = GroupShuffleSplit(1, test_size=int(0.05 * len(df)), random_state=42)
for split in gss.split(df["sample_id"], groups=df["person_id"]):
    super_train, test = split
    super_train_df = df[df["sample_id"].isin(super_train)]
    test_df = df[df["sample_id"].isin(test)]

gss = GroupShuffleSplit(1, test_size=int(0.1 * len(super_train_df)), random_state=42)
for split in gss.split(super_train_df["sample_id"], groups=super_train_df["person_id"]):
    train, val = split
    train_df = df[df["sample_id"].isin(train)]
    val_df = df[df["sample_id"].isin(val)]

In [72]:
test_df[["a", "k", "xc"]]

Unnamed: 0,a,k,xc
2,3.663735,21.962982,1.265847
3,2.427453,17.765872,1.289461
4,3.146787,46.671709,1.402187
17,1.224528,30.057498,1.166417
18,0.864338,24.697398,1.174168
19,1.9545,27.009401,1.180218
20,1.300357,28.842857,1.197153
21,0.414474,32.777148,1.161752
22,0.47335,29.272698,1.179033
23,0.553763,81.788814,1.264695
