# Train-validation tagging

This notebook shows how to split a training dataset into train and validation folds using tags

**Input**:
- Source project
- Train-validation split ratio

**Output**:
- New project with images randomly tagged by `train` or `val`, based on split ration

## Configuration

Edit the following settings for your own case

In [1]:
import supervisely_lib as sly
from tqdm import tqdm
import random
import os

In [2]:
team_name = "jupyter_tutorials"
workspace_name = "cookbook"
project_name = "tutorial_project"

dst_project_name = "tutorial_project_tagged"

validation_fraction = 0.4

tag_meta_train = sly.TagMeta('train', sly.TagValueType.NONE)
tag_meta_val = sly.TagMeta('val', sly.TagValueType.NONE)

# Obtain server address and your api_token from environment variables
# Edit those values if you run this notebook on your own PC
address = os.environ['SERVER_ADDRESS']
token = os.environ['API_TOKEN']

In [3]:
# Initialize API object
api = sly.Api(address, token)

## Verify input values

Test that context (team / workspace / project) exists

In [4]:
# Get IDs of team, workspace and project by names

team = api.team.get_info_by_name(team_name)
if team is None:
    raise RuntimeError("Team {!r} not found".format(team_name))

workspace = api.workspace.get_info_by_name(team.id, workspace_name)
if workspace is None:
    raise RuntimeError("Workspace {!r} not found".format(workspace_name))
    
project = api.project.get_info_by_name(workspace.id, project_name)
if project is None:
    raise RuntimeError("Project {!r} not found".format(project_name))
    
print("Team: id={}, name={}".format(team.id, team.name))
print("Workspace: id={}, name={}".format(workspace.id, workspace.name))
print("Project: id={}, name={}".format(project.id, project.name))

Team: id=30, name=jupyter_tutorials
Workspace: id=76, name=cookbook
Project: id=898, name=tutorial_project


## Get Source ProjectMeta

In [5]:
meta_json = api.project.get_meta(project.id)
meta = sly.ProjectMeta.from_json(meta_json)
print("Source ProjectMeta: \n", meta)

Source ProjectMeta: 
 ProjectMeta:
Object Classes
+--------+-----------+----------------+
|  Name  |   Shape   |     Color      |
+--------+-----------+----------------+
|  bike  | Rectangle | [246, 255, 0]  |
|  car   |  Polygon  | [190, 85, 206] |
|  dog   |  Polygon  |  [253, 0, 0]   |
| person |   Bitmap  |  [0, 255, 18]  |
+--------+-----------+----------------+
Image Tags
+-------------+--------------+-----------------------+
|     Name    |  Value type  |    Possible values    |
+-------------+--------------+-----------------------+
| cars_number |  any_number  |          None         |
|     like    |     none     |          None         |
|   situated  | oneof_string | ['inside', 'outside'] |
+-------------+--------------+-----------------------+
Object Tags
+---------------+--------------+-----------------------+
|      Name     |  Value type  |    Possible values    |
+---------------+--------------+-----------------------+
|   car_color   |  any_string  |          None     

## Construct Destination ProjectMeta

In [6]:
dst_meta = meta.add_img_tag_metas([tag_meta_train, tag_meta_val])
print("Destination ProjectMeta:\n", dst_meta)

Destination ProjectMeta:
 ProjectMeta:
Object Classes
+--------+-----------+----------------+
|  Name  |   Shape   |     Color      |
+--------+-----------+----------------+
|  bike  | Rectangle | [246, 255, 0]  |
|  car   |  Polygon  | [190, 85, 206] |
|  dog   |  Polygon  |  [253, 0, 0]   |
| person |   Bitmap  |  [0, 255, 18]  |
+--------+-----------+----------------+
Image Tags
+-------------+--------------+-----------------------+
|     Name    |  Value type  |    Possible values    |
+-------------+--------------+-----------------------+
| cars_number |  any_number  |          None         |
|     like    |     none     |          None         |
|   situated  | oneof_string | ['inside', 'outside'] |
|    train    |     none     |          None         |
|     val     |     none     |          None         |
+-------------+--------------+-----------------------+
Object Tags
+---------------+--------------+-----------------------+
|      Name     |  Value type  |    Possible values

## Create Destination project

In [7]:
# check if destination project already exists. If yes - generate new free name
if api.project.exists(workspace.id, dst_project_name):
    dst_project_name = api.project.get_free_name(workspace.id, dst_project_name)
print("Destination project name: ", dst_project_name)

Destination project name:  tutorial_project_tagged


In [8]:
dst_project = api.project.create(workspace.id, dst_project_name)
api.project.update_meta(dst_project.id, dst_meta.to_json())
print("Destination project has been created: id={}, name={!r}".format(dst_project.id, dst_project.name))

Destination project has been created: id=1328, name='tutorial_project_tagged'


## Iterate over all images, tag them and add to destination project

In [9]:
for dataset in api.dataset.get_list(project.id):
    print('Dataset: {}'.format(dataset.name), flush=True)
    dst_dataset = api.dataset.create(dst_project.id, dataset.name)
    
    images = api.image.get_list(dataset.id)
    with tqdm(total=len(images), desc="Process annotations") as progress_bar:
        for batch in sly.batched(images):
            image_ids = [image_info.id for image_info in batch]
            image_names = [image_info.name for image_info in batch]
            
            ann_infos = api.annotation.download_batch(dataset.id, image_ids)

            anns_to_upload = []
            for ann_info in ann_infos:
                ann = sly.Annotation.from_json(ann_info.annotation, meta)

                tag = sly.Tag(tag_meta_val) if random.random() <= validation_fraction else sly.Tag(tag_meta_train)
                ann = ann.add_tag(tag)
                anns_to_upload.append(ann)
            
            dst_image_infos = api.image.upload_ids(dst_dataset.id, image_names, image_ids)
            dst_image_ids = [image_info.id for image_info in dst_image_infos]
            api.annotation.upload_anns(dst_image_ids, anns_to_upload)
            progress_bar.update(len(batch))

Dataset: dataset_01


Process annotations: 100%|██████████| 3/3 [00:00<00:00, 13.86it/s]

Dataset: dataset_02



Process annotations: 100%|██████████| 2/2 [00:00<00:00, 11.67it/s]


In [10]:
print("Project {!r} has been sucessfully uploaded".format(dst_project.name))
print("Number of images: ", api.project.get_images_count(dst_project.id))

Project 'tutorial_project_tagged' has been sucessfully uploaded
Number of images:  5
