In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from glob import glob
import json
from pathlib import Path

from multicamera_labelling_and_training.labelling_set_to_training_set import *
import numpy as np
import yaml

## jarvis --> COCO
See COCO spec here: https://cocodataset.org/#format-data
 

In [3]:
# where you want to save your training set
output_directory = Path("/n/groups/datta/6cam_keypoint_networks/training_data/JP_CW_scale_annos/COCO_format")

# what to name the dataset
trainingset_name = "JP_CW_scale_annos"

dataset_locs = glob("/n/groups/datta/6cam_keypoint_networks/training_data/JP_CW_scale_annos/JARVIS_format/*")
dataset_locs = [{"location": loc, "use_in_validation": True} for loc in dataset_locs]

In [4]:
# this needs to match with the order used in mmpose, and also in the jarvis info.yaml file.
# currently the mmpose order comes from the airflow pipeline, eg .../tim_240731/skeletons/weinreb15.py

# this is the gimbal-compatible order for these keypoints
keypoints_order = [
    "spine_low",
    "tail_base",
    "tail_tip",
    "spine_mid",
    "spine_high",
    "left_ear",
    "right_ear",
    "forehead",
    "nose_tip",
    "left_fore_paw",
    "right_fore_paw",
    "left_hind_paw_back",
    "left_hind_paw_front",
    "right_hind_paw_back",
    "right_hind_paw_front",
]

skeleton = [
    ["tail_tip", "tail_base"],
    ["tail_base", "spine_low"],
    ["spine_low", "spine_mid"],
    ["spine_mid", "spine_high"],
    ["spine_high", "left_ear"],
    ["spine_high", "right_ear"],
    ["spine_high", "forehead"],
    ["forehead", "nose_tip"],
    ["left_hind_paw_back", "left_hind_paw_front"],
    ["spine_low", "left_hind_paw_back"],
    ["right_hind_paw_back", "right_hind_paw_front"],
    ["spine_low", "right_hind_paw_back"],
    ["spine_high", "left_fore_paw"],
    ["spine_high", "right_fore_paw"],
]

In [5]:
keypoints_to_ignore = []
dataset_dict, n_keypoints = initialize_jarvis_trainingset(
    output_directory, 
    keypoints_to_ignore, 
    keypoints_order, 
    dataset_locs
)
dataset_dict

{'info': {'contributor': '',
  'date_created': '2024-11-01',
  'description': '',
  'url': '',
  'versoin': '1.0',
  'year': 2024},
 'keypoint_names': ['spine_low',
  'tail_base',
  'tail_tip',
  'spine_mid',
  'spine_high',
  'left_ear',
  'right_ear',
  'forehead',
  'nose_tip',
  'left_fore_paw',
  'right_fore_paw',
  'left_hind_paw_back',
  'left_hind_paw_front',
  'right_hind_paw_back',
  'right_hind_paw_front'],
 'skeleton': [{'keypointA': 'tail_tip',
   'keypointB': 'tail_base',
   'length': 0,
   'name': 'Joint 1'},
  {'keypointA': 'tail_base',
   'keypointB': 'spine_low',
   'length': 0,
   'name': 'Joint 2'},
  {'keypointA': 'spine_low',
   'keypointB': 'spine_mid',
   'length': 0,
   'name': 'Joint 3'},
  {'keypointA': 'spine_mid',
   'keypointB': 'spine_high',
   'length': 0,
   'name': 'Joint 4'},
  {'keypointA': 'spine_high',
   'keypointB': 'left_ear',
   'length': 0,
   'name': 'Joint 5'},
  {'keypointA': 'spine_high',
   'keypointB': 'right_ear',
   'length': 0,
   'na

In [6]:
training_dict = copy.deepcopy(dataset_dict)
val_dict = copy.deepcopy(dataset_dict)

In [7]:
image_id = 0
annotation_id = 0
padding = 60
validation_fraction = 0.1
completed_anno_files = []  # so you can restart in case of errors but not repeat files already added

In [8]:
do_break = False
for ds in dataset_locs:

    # Load the dataset yaml (prob unecessary, should match across all data in a given training set)
    dataset_loc = Path(ds["location"])
    dataset_yaml = list(dataset_loc.glob("*.yaml"))[0]
    use_in_validation = ds["use_in_validation"]
    with open(dataset_yaml, "r") as file:
        dataset_info = yaml.safe_load(file)

    # Go through each JARVIS annotation file
    annotations_files = list(dataset_yaml.parent.glob("**/annotations.csv"))
    for annotations_file in annotations_files:

        # Try to load the jarvis annotations
        print(annotations_file)
        try:
            annotations = load_jarvis_annotations(annotations_file)
            # print("Loaded annotations")
        except Exception as e:
            print(f"Could not load annotations file: {e}")
            continue

        # Skip if a given keypoint is missing (sometimes happens if it is never annotated in that camera in that session)
        if len(annotations.columns.get_level_values("bodypart").unique()) != len(keypoints_order):
            print(f"Skipping {annotations_file} because it has the wrong number of keypoints")
            continue
        
        # Skip if this file has already been processed here
        if annotations_file in completed_anno_files:
            print(f"Skipping {annotations_file} because it has already been processed")
            continue

        # Create the image and annotation dictionaries
        for idx, row in annotations.iterrows():

            # Get metadata about the image for COCO
            image_loc = annotations_file.parent / row.name
            width, height = get_image_size(image_loc.as_posix())
            file_name = image_loc.name
            
            # Loop over each entity (each mouse) in the image
            # TODO: update this so that we can have multiple entity types (e.g. mouse vs pup)
            unique_entities = np.unique(row.index.get_level_values("entities"))
            for entity in unique_entities:
                entity_mask = row.index.get_level_values("entities") == entity

                # Get keypoints in format [x, y, state]
                keypoints = [
                    get_row_keypoints(row[entity_mask], keypoint)
                    for keypoint in dataset_dict["keypoint_names"]
                ]
                keypoints_list = list(np.concatenate(keypoints))

                # Create a bounding box
                xvals = np.stack(keypoints)[:, 0]
                yvals = np.stack(keypoints)[:, 1]
                xvals = xvals[xvals != 0]
                yvals = yvals[yvals != 0]
                if len(xvals) == 0:
                    continue
                xmin, xmax, ymin, ymax = (
                    np.min(xvals),
                    np.max(xvals),
                    np.min(yvals),
                    np.max(yvals),
                )
                xmin = max([0, xmin - padding])
                ymin = max([0, ymin - padding])
                xmax = min([width, xmax + padding])
                ymax = min([height, ymax + padding])

                area = (xmax - xmin) * (ymax - ymin)
                bbox = [
                    float(xmin),
                    float(ymin),
                    float(xmax - xmin),
                    float(ymax - ymin),
                ]
                
                # Save image metadata
                image_dict = {
                    "coco_url": "",
                    "date_captured": "",
                    "file_name": file_name,
                    "file_path": image_loc.as_posix(),
                    "flickr_url": "",
                    "height": height,
                    "id": image_id,
                    "license": 1,
                    "width": width,
                }
                dataset_dict["images"].append(image_dict)
                
                # Save the annotation data for COCO
                annotations_dict = {
                    "area": area,
                    "bbox": bbox,
                    "category_id": 0,
                    "id": annotation_id,
                    "image_id": image_id,
                    "iscrowd": 0,
                    "keypoints": keypoints_list,
                    "num_keypoints": n_keypoints,
                    "segmentation": [],
                }
                annotation_id += 1
                dataset_dict["annotations"].append(annotations_dict)
            image_id += 1
        completed_anno_files.append(annotations_file)
        # DEBUGGING:
        ann_ids = [a['id'] for a in dataset_dict["images"]]
    #     if len(set(ann_ids)) != len(ann_ids):
    #         print(f"PROBLEM W NON UQ ANN IDS: {annotations_file}")
    #         do_break = True
    #         break
    # if do_break:
    #     break

/n/groups/datta/6cam_keypoint_networks/training_data/JP_CW_scale_annos/JARVIS_format/20230903_J01601/bottom/annotations.csv
/n/groups/datta/6cam_keypoint_networks/training_data/JP_CW_scale_annos/JARVIS_format/20230903_J01601/side1/annotations.csv
/n/groups/datta/6cam_keypoint_networks/training_data/JP_CW_scale_annos/JARVIS_format/20230903_J01601/side2/annotations.csv
/n/groups/datta/6cam_keypoint_networks/training_data/JP_CW_scale_annos/JARVIS_format/20230903_J01601/side3/annotations.csv
/n/groups/datta/6cam_keypoint_networks/training_data/JP_CW_scale_annos/JARVIS_format/20230903_J01601/side4/annotations.csv
/n/groups/datta/6cam_keypoint_networks/training_data/JP_CW_scale_annos/JARVIS_format/20230903_J01601/top/annotations.csv
/n/groups/datta/6cam_keypoint_networks/training_data/JP_CW_scale_annos/JARVIS_format/20230903_J01602/bottom/annotations.csv
/n/groups/datta/6cam_keypoint_networks/training_data/JP_CW_scale_annos/JARVIS_format/20230903_J01602/top/annotations.csv
/n/groups/datta/6c

In [9]:
assert len(set(ann_ids)) == len(ann_ids)  # if this is false, mmpose will complain

In [10]:
# Split the dataset dict into train and validation data
train_dict = copy.deepcopy(dataset_dict)
val_dict = copy.deepcopy(dataset_dict)

all_img_ids = np.array([img["id"] for img in dataset_dict["images"]])
np.random.shuffle(all_img_ids)
val_img_ids = all_img_ids[:int(len(all_img_ids) * validation_fraction)]
train_img_ids = all_img_ids[int(len(all_img_ids) * validation_fraction):]

train_dict["images"] = [img for img in dataset_dict["images"] if img["id"] in train_img_ids]
train_dict["annotations"] = [anno for anno in dataset_dict["annotations"] if anno["image_id"] in train_img_ids]

val_dict["images"] = [img for img in dataset_dict["images"] if img["id"] in val_img_ids]
val_dict["annotations"] = [anno for anno in dataset_dict["annotations"] if anno["image_id"] in val_img_ids]


In [11]:
val_json = output_directory / "annotations" / "instances_val.json"
train_json = output_directory / "annotations" / "instances_train.json"

In [12]:
# save the yaml files
with open(val_json, "w") as file:
    json.dump(val_dict, file, cls=Int32Encoder)

with open(train_json, "w") as file:
    json.dump(train_dict, file, cls=Int32Encoder)

In [13]:
# check its openable

with open(val_json, "r") as f:
    train_dict = json.load(f)

with open(train_json, "r") as f:
    val_dict = json.load(f)

In [14]:
# # symlink the images from their paths to the output directories
# train_dir = output_directory / "train"
# val_dir = output_directory / "val"
# train_dir.mkdir(exist_ok=True, parents=True)
# val_dir.mkdir(exist_ok=True, parents=True)

# # Just symlink all the images to both directories, encountering some issues
# for img in dataset_dict["images"]:
#     img_loc = Path(img["file_path"])
#     new_loc = train_dir / img_loc.name
#     if not new_loc.exists():
#         new_loc.symlink_to(img_loc)

#     new_loc = val_dir / img_loc.name
#     if not new_loc.exists():
#         new_loc.symlink_to(img_loc)



# for img in train_dict["images"]:
#     img_loc = Path(img["file_path"])
#     new_loc = train_dir / img_loc.name
#     new_loc.symlink_to(img_loc, exist_ok=True)
    
# for img in val_dict["images"]:
#     img_loc = Path(img["file_path"])
#     new_loc = val_dir / img_loc.name
#     new_loc.symlink_to(img_loc)