In [1]:
import os
from fastai.vision.all import *
import pandas as pd
from sklearn.model_selection import StratifiedGroupKFold
import wandb
import warnings
from config import WANDB_PROJECT, ENTITY, RAW_DATA_AT
warnings.filterwarnings('ignore')

In [2]:
run = wandb.init(project=WANDB_PROJECT, entity=ENTITY, job_type="data_split")

[34m[1mwandb[0m: Currently logged in as: [33mdarek[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [3]:
raw_data_at = run.use_artifact(f'{RAW_DATA_AT}:latest')
path = Path(raw_data_at.download())

[34m[1mwandb[0m: Downloading large artifact bdd_sample_1k:latest, 856.80MB. 4003 files... 
[34m[1mwandb[0m:   4003 of 4003 files downloaded.  
Done. 0:0:1.0


In [4]:
path.ls()

(#5) [Path('artifacts/bdd_sample_1k:v1/media'),Path('artifacts/bdd_sample_1k:v1/eda_table.table.json'),Path('artifacts/bdd_sample_1k:v1/labels'),Path('artifacts/bdd_sample_1k:v1/LICENSE.txt'),Path('artifacts/bdd_sample_1k:v1/images')]

In [5]:
fnames = os.listdir(path/'images')
groups = [s.split('-')[0] for s in fnames]

In [6]:
orig_eda_table = raw_data_at.get("eda_table")

[34m[1mwandb[0m: Downloading large artifact bdd_sample_1k:latest, 856.80MB. 4003 files... 
[34m[1mwandb[0m:   4003 of 4003 files downloaded.  
Done. 0:0:0.3


In [7]:
y = orig_eda_table.get_column('train')

In [8]:
df = pd.DataFrame()
df['File_Name'] = fnames
df['fold'] = -1

In [9]:
cv = StratifiedGroupKFold(n_splits=10)
for i, (train_idxs, test_idxs) in enumerate(cv.split(fnames, y, groups)):
    df.loc[test_idxs, ['fold']] = i

In [10]:
df['Stage'] = 'train'
df.loc[df.fold == 0, ['Stage']] = 'test'
df.loc[df.fold == 1, ['Stage']] = 'valid'
del df['fold']
df.Stage.value_counts()

train    800
valid    100
test     100
Name: Stage, dtype: int64

In [11]:
df.to_csv('data_split.csv', index=False)

In [12]:
processed_data_at = wandb.Artifact("bdd_sample_1k_split", type="split_data")

In [13]:
processed_data_at.add_file('data_split.csv')
processed_data_at.add_dir(path)

[34m[1mwandb[0m: Adding directory to artifact (./artifacts/bdd_sample_1k:v1)... Done. 0.6s


In [14]:
data_split_table = wandb.Table(dataframe=df[['File_Name', 'Stage']])

In [15]:
join_table = wandb.JoinedTable(orig_eda_table, data_split_table, "File_Name")

In [16]:
processed_data_at.add(join_table, "eda_table_data_split")

<ManifestEntry digest: Vt6JPzWbyCwbq+h1eHGn7Q==>

In [17]:
run.log_artifact(processed_data_at)
run.finish()