# Split final CCT into separate CCT's for each train/val/test split

In [1]:
import json
import os

import pandas as pd
from tqdm import tqdm

In [2]:
experiment = 'animl'
output_path = os.path.join(os.environ['HOME'], f'animl-ml/classification/data/processed/{experiment}')

data_file = os.path.join(os.environ['HOME'], f'animl-ml/classification/data/interim/{experiment}/classification_cct.json')
with open(data_file, 'r') as f:
    js = json.load(f)

splits_file = os.path.join(os.environ['HOME'], f'animl-ml/classification/data/interim/{experiment}/splits.json')
with open(splits_file, 'r') as f:
    splits = json.load(f)

In [3]:
imgs = {
    'train': [],
    'val': [],
    'test': []
}
annos = {
    'train': [],
    'val': [],
    'test': []
}

location_to_split = {}
for split, loc_pairs in splits.items():
   for [dataset, location] in loc_pairs:
      location_to_split[location] = split

image_to_split = {}
# iterate through all images
for img in js['images']:
    # map img['location'] to split and append to new list
    if img['location'] in location_to_split:
        assigned_split = location_to_split[img['location']]
        imgs[assigned_split].append(img)
        image_to_split[img['id']] = assigned_split
    else:
        print(f'Couldnt find location {img["location"]} in any splits. All of the samples from this location may have not been present in enough locations to pass the --min-locs filter during create_classifications_dataset.py')

# iterate through all annotations
for anno in js['annotations']:
    if anno['image_id'] in image_to_split:
        assigned_split = image_to_split[anno['image_id']]
        annos[assigned_split].append(anno)
    else:
        print(f'Couldnt find image for image_id {anno["image_id"]}')

# save files
for split in ['train', 'val', 'test']:
    print(f'CCT built, saving {split} as {split}_cct.json')
    new_cct = {
        'images': imgs[split],
        'annotations': annos[split],
        'categories': js['categories'],
        'info': js['info']
    }
    out_file = os.path.join(output_path, f'{split}_cct.json')
    with open(out_file, 'w') as f:
      json.dump(new_cct, f)

CCT built, saving train as train_cct.json
CCT built, saving val as val_cct.json
CCT built, saving test as test_cct.json
