In [None]:
import os, warnings
import wandb

import pandas as pd
from fastai.vision.all import *
from sklearn.model_selection import StratifiedGroupKFold

import params
warnings.filterwarnings('ignore')

In [None]:
run = wandb.init(project=params.WANDB_PROJECT, entity=params.ENTITY, job_type="data_split")

In [None]:
raw_data_at = run.use_artifact(f'{params.RAW_DATA_AT}:latest')
path = Path(raw_data_at.download())

[34m[1mwandb[0m: Downloading large artifact bdd_sample_1k:latest, 823.99MB. 4003 files... Done. 0:0:0.2


In [None]:
path.ls()

(#5) [Path('artifacts/bdd_sample_1k:v0/images'),Path('artifacts/bdd_sample_1k:v0/labels'),Path('artifacts/bdd_sample_1k:v0/LICENSE.txt'),Path('artifacts/bdd_sample_1k:v0/eda_table.table.json'),Path('artifacts/bdd_sample_1k:v0/media')]

In [None]:
fnames = os.listdir(path/'images')
groups = [s.split('-')[0] for s in fnames]

In [None]:
orig_eda_table = raw_data_at.get("eda_table")

[34m[1mwandb[0m: Downloading large artifact bdd_sample_1k:latest, 823.99MB. 4003 files... Done. 0:0:0.1


In [None]:
y = orig_eda_table.get_column('train')

In [None]:
df = pd.DataFrame()
df['File_Name'] = fnames
df['fold'] = -1

In [None]:
cv = StratifiedGroupKFold(n_splits=10)
for i, (train_idxs, test_idxs) in enumerate(cv.split(fnames, y, groups)):
    df.loc[test_idxs, ['fold']] = i

In [None]:
df['Stage'] = 'train'
df.loc[df.fold == 0, ['Stage']] = 'test'
df.loc[df.fold == 1, ['Stage']] = 'valid'
del df['fold']
df.Stage.value_counts()

train    800
test     100
valid    100
Name: Stage, dtype: int64

In [None]:
df.to_csv('data_split.csv', index=False)

In [None]:
processed_data_at = wandb.Artifact(params.PROCESSED_DATA_AT, type="split_data")

In [None]:
processed_data_at.add_file('data_split.csv')
processed_data_at.add_dir(path)

[34m[1mwandb[0m: Adding directory to artifact (./artifacts/bdd_sample_1k:v0)... Done. 0.8s


In [None]:
data_split_table = wandb.Table(dataframe=df[['File_Name', 'Stage']])

In [None]:
join_table = wandb.JoinedTable(orig_eda_table, data_split_table, "File_Name")

In [None]:
processed_data_at.add(join_table, "eda_table_data_split")

<ManifestEntry digest: rjcJpd0b5kx/sxKReMRIqQ==>

In [None]:
run.log_artifact(processed_data_at)
run.finish()