# Supervisely Tutorial #5

# Neural Network: automating training workflow using API

Description

Description

description

https://i.imgur.com/Qp8RgiI.png

![alt text](https://i.imgur.com/Qp8RgiI.png "Logo Title Text 1")

TODO: add training charts and few inference results 

# Imports

In [1]:
%matplotlib inline
import supervisely_lib as sly
import os
import matplotlib.pyplot as plt
import json

# Initialize API object

In [2]:
# get server address and user api_token from environment variables
address = os.environ['SERVER_ADDRESS']
token = os.environ['API_TOKEN']

# create api object
api = sly.Api(address, token)

print("Server address: ", address)
print("Your API token: ", token)

Server address:  http://192.168.1.69:5555
Your API token:  OfaV5z24gEQ7ikv2DiVdYu1CXZhMavU7POtJw2iDtQtvGUux31DUyWTXW6mZ0wd3IRuXTNtMFS9pCggewQWRcqSTUi4EJXzly8kH7MJL1hm3uZeM2MCn5HaoEYwXejKT


# Define context (team/workspace)

In [3]:
team = api.team.get_list()[0]

workspace_name = "tutorial_05"
if api.workspace.exists(team.id, workspace_name):
    workspace = api.workspace.get_info_by_name(team.id, workspace_name)
else:
    workspace = api.workspace.create(team.id, workspace_name)

print("Team: id={}, name={}".format(team.id, team.name))
print("Workspace: id={}, name={}".format(workspace.id, workspace.name))

Team: id=9, name=max
Workspace: id=84, name=tutorial_05


# Check if "UNet VGG weights" model exists. Clone model and corresponding plugin from explore if needed

In [4]:
model_name = "unet_vgg"
if not api.model.exists(workspace.id, model_name):
    task_id = api.model.clone_from_explore('Supervisely/Model Zoo/UNet (VGG weights)', workspace.id, model_name)
    api.task.wait(task_id, api.task.Status.FINISHED)
model = api.model.get_info_by_name(workspace.id, model_name)
print("Model: id = {}, name = {!r}".format(model.id, model.name))

Model: id = 272, name = 'unet_vgg'


# Check if projects "lemons_annotated" and "lemons_test" exist. Clone projects if needed

In [5]:
def _clone_project(project_name, explore_path):
    if not api.project.exists(workspace.id, project_name):
        task_id = api.project.clone_from_explore(explore_path, workspace.id, project_name)
        api.task.wait(task_id, api.task.Status.FINISHED)
    project = api.project.get_info_by_name(workspace.id, project_name)
    return project

In [6]:
project_annotated_name = "lemons_annotated"
project_test_name = "lemons_test"

In [7]:
project_annotated = _clone_project(project_annotated_name, 'Supervisely/Demo/lemons_annotated')
project_test = _clone_project(project_test_name, 'Supervisely/Demo/lemons_test')

print("Annotated project: id = {}, name = {!r}".format(project_annotated.id, project_annotated.name))
print("Test project: id = {}, name = {!r}".format(project_test.id, project_test.name))

Annotated project: id = 951, name = 'lemons_annotated'
Test project: id = 952, name = 'lemons_test'


# Check if agent exists and is running

In [8]:
# please type your agent name here
agent_name = "agent_01"

agent = api.agent.get_info_by_name(team.id, agent_name)
if agent is None:
    raise RuntimeError("Agent {!r} not found".format(agent_name))
if agent.status is api.agent.Status.WAITING:
    raise RuntimeError("Agent {!r} is not running".format(agent_name))

# Step1: Run DTL plugin to prepare training dataset 

In [14]:
project_train_name = "lemons_train"

In [15]:
# read graph template and define input/output projects
path_dtl_graph = './dtl_segmentation_graph.json'
with open(path_dtl_graph, 'r') as file:
    dtl_graph_str = file.read()
dtl_graph_str = dtl_graph_str.replace('%SRC_PROJECT_NAME%', project_annotated_name)
dtl_graph_str = dtl_graph_str.replace('%DST_PROJECT_NAME%', project_train_name)
dtl_graph = json.loads(dtl_graph_str)

In [16]:
task_id = None
if not api.project.exists(workspace.id, project_train_name):
    task_id = api.task.run_dtl(workspace.id, dtl_graph, agent.id)
    print('DTL task (id={}) is started'.format(task_id))

In [17]:
if task_id is not None:
    api.task.wait(task_id, api.task.Status.FINISHED)

In [19]:
project_train = api.project.get_info_by_name(workspace.id, project_train_name)
print("Training dataset {!r} contains {} images".format(project_train.name, api.project.get_images_count(project_train.id)))

Training dataset 'lemons_train' contains 72 images


# Step2: Run NN training

In [21]:
trained_model_name = "nn_lemon_kiwi"

training_config = {
  "lr": 0.001,
  "epochs": 10,
  "val_every": 0.5,
  "batch_size": {
    "val": 6,
    "train": 12
  },
  "input_size": {
    "width": 256,
    "height": 256
  },
  "gpu_devices": [
    0
  ],
  "data_workers": {
    "val": 0,
    "train": 3
  },
  "dataset_tags": {
    "val": "val",
    "train": "train"
  },
  "special_classes": {
    "neutral": "neutral",
    "background": "bg"
  },
  "weights_init_type": "transfer_learning"
}

In [22]:
task_id = api.task.run_train(agent.id, project_train.id, model.id, trained_model_name, training_config)
print('Train task (id={}) is started'.format(task_id))

Train task (id=1563) is started


In [23]:
api.task.wait(task_id, api.task.Status.FINISHED)
print('Train task (id={}) is finished'.format(task_id))

In [24]:
trained_model = api.model.get_info_by_name(workspace.id, trained_model_name)
if trained_model is None:
    raise RuntimeError("Model {!r} not found".format(trained_model_name))

print("trained model: id = {}, name = {!r}".format(trained_model.id, trained_model.name))

trained model: id = 273, name = 'nn_lemon_kiwi'


# Step3: NN inference on "lemons_test" project

In [25]:
project_inf_name = "lemons_test_inf"

inference_config = {
    "mode": {
      "name": "full_image",
      "model_classes": {
        "add_suffix": "_unet",
        "save_classes": "__all__"
      }
    },
    "model": {
      "gpu_device": 0
    }
  }

In [26]:
task_id = api.task.run_inference(agent.id, project_test.id, trained_model.id, project_inf_name, inference_config)
print('Inference task (id={}) is started'.format(task_id))

Inference btask (id=1564) is started


In [28]:
api.task.wait(task_id, api.task.Status.FINISHED)
print('Inference task (id={}) is finished'.format(task_id))

Inference task (id=1564) is finished


# Workflow Finished!