# Train and deploy a Semantic Segmentation model using pytorch

In this lab, you will learn how to train a semantic segmentation model with a model from the [torchvision subpackage](https://pytorch.org/vision/stable/models.html#semantic-segmentation). We will be using the [DeepLabV3 ResNet50 model](https://arxiv.org/abs/1706.05587) and training it on the [SpaceNet dataset](https://spacenet.ai/spacenet-buildings-dataset-v2/).

We will be exploring concepts on how to implement a custom model with the PyTorch framework using script mode and learning how to leverage local mode for debugging and testing. 

Reference:
- Examples of Amazon SageMaker Local mode: https://github.com/aws-samples/amazon-sagemaker-local-mode

## Install libraries and dependencies

The following cells will install the required libraries and dependencies on our kernel to support build of this model. One particular library is the solaris tool which will allow us to convert geojson into polygon and back. As the SpaceNet dataset masks are defined in geojson, we will need this utility to convert this file. Optionally, if your downsteam system are expecting a geojson output, we can leverage the solaris library to achieve this.

Required libraries. To run this notebook, you will need to install the following dependencies:
- rtree
- gdal 3.0.3
- geopandas
- solaris

If you are not using the lab2.yml cloudformation template, uncomment the next 3 cells and run them.

In [None]:
#%conda install -c conda-forge rtree gdal=3.0.3

In [None]:
#!pip install solaris

In [None]:
#!pip install geopandas

In [None]:
import sagemaker
from sagemaker import get_execution_role
import os, time, json, random
from pathlib import Path
import numpy as np
import pandas as pd
import skimage
from skimage import io

sagemaker_session = sagemaker.Session()
bucket = sagemaker_session.default_bucket()
prefix = 'sagemaker/LAB-pytorch-semantic-segmentation'

role = sagemaker.get_execution_role()

In [None]:
import solaris as sol

## Get Data

We will be downloading the spacenet dataset and uncompressing it.

In [None]:
!aws s3 cp s3://spacenet-dataset/spacenet/SN2_buildings/tarballs/SN2_buildings_train_AOI_3_Paris.tar.gz ./dataset/

In [None]:
!tar -xf ./dataset/SN2_buildings_train_AOI_3_Paris.tar.gz -C./dataset/

### Plotting preferences

In [None]:
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

import matplotlib.pyplot as plt
plt.style.use('seaborn-notebook')

## Review data

Let's review the dataset we've just downloaded.

In [None]:
data_dir = './dataset/AOI_3_Paris_Train/'
img_dir = os.path.join(data_dir, 'RGB-PanSharpen')
bldg_dir = os.path.join(data_dir, 'geojson', 'buildings')

# Prefix of all filename - naming convention
midfix = 'AOI_3_Paris_'

In [None]:
def scale_bands(img, lower_pct = 1, upper_pct = 99):
    """
    Rescale the bands of a multichannel image for display
    """
    # Loop through the image bands, rescaling each one
    img_scaled = np.zeros(img.shape, np.uint8)
    
    for i in range(img.shape[2]):
        
        band = img[:, :, i]
        
        # Pick out the lower and upper percentiles
        lower, upper = np.percentile(band, [lower_pct, upper_pct])
        
        # Normalize the band
        band = (band - lower) / (upper - lower) * 255
        
        # Clip the high and low values, and cast to uint8
        img_scaled[:, :, i] = np.clip(band, 0, 255).astype(np.uint8)
        
    return img_scaled

#### What's happening here?
The following cell is loading a sample image and mask from the SpaceNet dataset. We are leveraging the `sol.vector.mask.footprint_mask()` function to convert the geojson file into an array format. We are then using the plotting library to preview our image and mask.

In [None]:
# Show a data sample
sample = 'img100' # chip ID, img? format

# Read in 4-channel image from GeoTIFF.
img_file = 'RGB-PanSharpen_' + midfix + sample + '.tif'
img_path = os.path.join(img_dir, img_file)
img = skimage.io.imread(img_path)
rgb = img

# Read in GeoJSON file and convert polygons to footprint mask.
bldg_file = 'buildings_'+ midfix + sample + '.geojson'
bldg_path = os.path.join(bldg_dir, bldg_file)
mask = sol.vector.mask.footprint_mask(bldg_path, reference_im=img_path)

# Display satellite image and building footprint mask.
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
ax[0].imshow(scale_bands(rgb))
ax[0].set_title('Satellite image')
ax[1].imshow(mask, cmap='Blues')
ax[1].set_title('Building footprint masks')
fig.show()

In the image above, you can see the original satellite image and the building mask generated from the accompanying geojson file. We used the solaris library to conver the geojson into an image so that we can visualise it.

**Note** For the satellite image, we created a function `scale_bands()` to process the image as the original image from dataset is an 11-bit image stored in 16 bit integers. 

## Generate training data

Now that we have some clarity on our dataset, let's process the rest of the images and masks using the `scale_bands()` for the images and `sol.vector.mask.footprint_mask()` function for the mask. We will also split out dataset into training and test and upload it to a designated S3 bucket.

In [None]:
data_dir = './data/'

if not os.path.exists(data_dir):
    os.mkdir(data_dir) 
    
training_dir = os.path.join(data_dir, 'train/')
test_dir = os.path.join(data_dir, 'test/')

if not os.path.exists(training_dir):
    os.mkdir(training_dir) 
    
if not os.path.exists(test_dir):
    os.mkdir(test_dir) 

training_img_dir = os.path.join(training_dir, 'img/')

if not os.path.exists(training_img_dir):
    os.mkdir(training_img_dir)
    
training_mask_dir = os.path.join(training_dir, 'mask/')

if not os.path.exists(training_mask_dir):
    os.mkdir(training_mask_dir) 
    
    
test_img_dir = os.path.join(test_dir, 'img/')

if not os.path.exists(test_img_dir):
    os.mkdir(test_img_dir)
    
test_mask_dir = os.path.join(test_dir, 'mask/')

if not os.path.exists(test_mask_dir):
    os.mkdir(test_mask_dir) 

### Split dataset

In [None]:
# Get list of images
ListImages=os.listdir(img_dir)

In [None]:
# Split dataset
train = ListImages[0:int(0.7 * len(ListImages))]
test = ListImages[int(0.7 * len(ListImages)):]

### Process dataset

Here we will process our images and masks and save it as a tif file.

**Note** Typically this will be done using a [Amazon SageMaker Processing job](https://docs.aws.amazon.com/sagemaker/latest/dg/processing-job.html)

In [None]:
def process_images_and_mask(image_list, image_dir, geojson_dir, output_dir):
    for img_file in image_list:
        img_path = os.path.join(image_dir, img_file)
        img = skimage.io.imread(img_path)
        img = scale_bands(img)
        
        prefix = 'RGB-PanSharpen_' + midfix
        if img_file.startswith(prefix):
            file_suffix = os.path.splitext(img_file[len(prefix):])[0]
            
            if os.path.exists(os.path.join(geojson_dir, bldg_file)):
                # Create training mask
                create_masks(geojson_dir, file_suffix, img_path, output_dir)         
        
                # Save paired image
                output_image_filename = file_suffix + '.tif'
                skimage.io.imsave(os.path.join(output_dir,'img', output_image_filename), img, check_contrast=False)

def create_masks(geojson_dir, file_suffix, img_path, output_dir):
    bldg_file = 'buildings_'+ midfix + file_suffix + '.geojson'
    bldg_path = os.path.join(geojson_dir, bldg_file)
    output_mask_filename = file_suffix + '.tif'
    # Create mask and save
    _ = sol.vector.mask.footprint_mask(bldg_path, out_file=os.path.join(output_dir, 'mask/', output_mask_filename), reference_im=img_path)

In [None]:
# Process our training dataset
process_images_and_mask(train, img_dir, bldg_dir, training_dir)

In [None]:
# Process our test dataset
process_images_and_mask(test, img_dir, bldg_dir, test_dir)

### Upload dataset to S3

In [None]:
inputs = sagemaker_session.upload_data(path='data', bucket=bucket, key_prefix=prefix)

In [None]:
inputs

## Setup SageMaker Experiments

With [Amazon SageMaker Experiments](https://aws.amazon.com/blogs/aws/amazon-sagemaker-experiments-organize-track-and-compare-your-machine-learning-trainings/), we can track multiple iterations of our training job. With Amazon SageMaker Experiments, you can track the hyperparameters, datasets and algorithms used for each trial and easily compare them. 

In this section we will setup create an experiment and in the later section create a trial.

In [None]:
!pip install sagemaker-experiments

In [None]:
import time
from time import strftime

import sagemaker

from smexperiments.experiment import Experiment
from smexperiments.trial import Trial
from smexperiments.trial_component import TrialComponent
from smexperiments.tracker import Tracker

In [None]:
create_date = strftime("%Y-%m-%d-%H-%M-%S")

seg_experiment = Experiment.create(
    experiment_name="spacenet-semantic-segmentation-{}".format(create_date), 
    description="Semantic Segmentation for the spacenet aerial images",
    tags = [{'Key': 'Environment', 'Value': 'demo1'}])

## Training

In this lab, we will be training a model using the [bring your own model with script mode](https://sagemaker-examples.readthedocs.io/en/latest/sagemaker-script-mode/sagemaker-script-mode.html). To achieve this, we will be using the [PyTorch with SageMaker python SDK](https://sagemaker.readthedocs.io/en/stable/frameworks/pytorch/using_pytorch.html).

By using the PyTorch sdk, we will be using a pre-build PyTorch container as our base platform to run our training script. We will provide the container with a custom training script, `script/train.py`. Our custom training script includes a dataloader to load our training and test dataset, our model definition which loads the DeepLabV3 ResNet50 from the torchvision library and supporting function to initialise the neural net and save the model.

In [None]:
from sagemaker.pytorch import PyTorch

### Test training script in local mode

To facilitate debugging your training script, you can train your model using local mode. This can be achieved by setting the `instance_type` variable to `local`. The following code will execute a training job for one epoch.

#### What's happening here?
The following code is launching a local training job (`instance_type='local'`) with a pytorch framework version of 1.8.0 (`framework_version='1.8.0'`). The training job will load our training script **train.py** (`entry_point=train.py`) from our source directory **script** (`source_dir=script`). Within the `hyperparameters` variable, we set the training job to run for 1 epoch, a batch size of 16 and a learning rate of 0.0005.

In [None]:
estimator = PyTorch(entry_point='train.py',
                    source_dir="script",
                    role=role,
                    py_version='py3',
                    framework_version='1.8.0',
                    instance_count=1,
                    instance_type='local',
                    hyperparameters={
                        'epochs': 1,
                        'batch-size': 16,
                        'lr': 0.0005,
                        'log-interval': 10
                    })

In [None]:
estimator.fit({'training': inputs})

### Execute a SageMaker Job

Once you're satisfied with your training script, the next process is to scale your model training by leveraging an Amazon SageMaker training job. To do so we will set the `instance_type` to a specific [Amazon SageMaker Instance types](https://aws.amazon.com/sagemaker/pricing/). In this example, we will be using a `ml.g4dn.2xlarge` instance that has a gpu.

When the `fit()` function is called, an Amazon SageMaker training job will be initialised. If you have access to the console, you will be able visualise your training job execution [https://ap-southeast-2.console.aws.amazon.com/sagemaker/home?region=ap-southeast-2#/jobs](https://ap-southeast-2.console.aws.amazon.com/sagemaker/home?region=ap-southeast-2#/jobs).



In [None]:
seg_trial = Trial.create(trial_name = "spacenet-semantic-segmentation-{}-{}".format(create_date, int(time.time())),
                          experiment_name = seg_experiment.experiment_name,
                          tags = [{'Key': 'Environment', 'Value': 'demo1'}])


estimator = PyTorch(entry_point='train.py',
                    source_dir="script",
                    role=role,
                    py_version='py3',
                    framework_version='1.8.0',
                    instance_count=1,
                    instance_type='ml.g4dn.4xlarge',
                    volume_size=50,
                    hyperparameters={
                        'epochs': 5,
                        'lr': 0.005,
                        'batch-size': 16,
                        'log-interval': 10
                    })

In [None]:
estimator.fit({'training': inputs}, experiment_config = {
        # "ExperimentName"
        "TrialName" : seg_trial.trial_name,
        "TrialComponentDisplayName" : "TrainingJob",
    })

## Deploy model

## Create a real-time endpoint
As we are using a custom model and will be sending images as an input, we will need to override the default mechanism of how Amazon SageMaker inference container loads our model, process the input request and output response. This is achieved by overriding the following respective functions implemented in the `script/inference.py` file:

- model_fn: Override the model loading function to load the deeplabv3_resnet50 model with weights from our training job.
- input_fn: Override the input function convert the incoming image payload into a tensor suitable for prediction/
- output_fn: Override the output function to convert the output prediction into a numpy array.

More information on how to override the inference functions here:
https://docs.aws.amazon.com/sagemaker/latest/dg/adapt-inference-container.html

#### What is happening here?
We are using the `PyTorchModel()` class to define our real-time inference configuration. The trained model is parse through the `model_data` parameter and the source of the inference script is specified through the `source_dir` for the directory and `entry_point` for the script.

The real-time inference endpoint is then deployed using the `deploy()` function where we specify the initial number of instances (`initial_instance_count`) and desired  instance type (`instance_type`).

In [None]:
pytorch_serving_model = PyTorchModel(
    model_data=estimator.model_data,
    role=role,
    framework_version='1.8.0',
    py_version='py3',
    entry_point='inference.py',
    source_dir="script"  
)

In [None]:
hosted_predictor = pytorch_serving_model.deploy(initial_instance_count=1,
        instance_type='ml.m4.4xlarge')

In [None]:
hosted_predictor.endpoint_name

### Run some predictions

Now that our real-time endpoint is up, let us run some predictions. For this lab, as we only train a model is a few epoch, we won't be expecting an accurate prediction.

In [None]:
# Using a test image
# img_path = './data/test/img/img100.tif'
# mask_path = './data/test/mask/img100.tif'

# Using a train image
img_path = './data/train/img/img1643.tif'
mask_path = './data/train/mask/img1643.tif'

with open(img_path, "rb") as f:
    payload = f.read()

sm_runtime = boto3.Session().client("sagemaker-runtime")

response = sm_runtime.invoke_endpoint(
    EndpointName=hosted_predictor.endpoint_name, ContentType="application/x-image", Body=payload
)
print(response)

In [None]:
# Decode the output response
result = json.loads(response["Body"].read().decode())

In [None]:
np.array(result).shape

In [None]:
np.array(result)

In [None]:
with open(img_path, "rb") as image_file, open(mask_path, "rb") as mask_file:
    image = Image.open(image_file).convert("RGB")
    mask = Image.open(mask_path).convert("L")
    
    fig, ax = plt.subplots(1, 3, figsize=(15, 5))
    ax[0].imshow(image)
    ax[0].set_title('Satellite image')
    ax[1].imshow(np.array(result)[0] > 0.5, cmap='Blues')
    ax[1].set_title('Building masks prediction')
    ax[2].imshow(mask, cmap='Blues')
    ax[2].set_title('Building ground truth')
    fig.show()

## Clean up
As the real-time inference endpoint is running 24/7, it is often best practice to always delete the endpoint once we are done with testing. The `delete_endpoint()` function will delete our running endpoint and its associated configuration.

In [None]:
hosted_predictor.delete_endpoint()

## Test local inference

As we are overriding the inference functions with a custom script, it may be useful to run our inference locally to test and debug inference script prior to deploying it in the cloud. The folowing section demonstrates how you can achieve this.

#### What's happening here?
Similar to the above function, we are defining an instance of the PyTorchModel. However we are setting the session to a local session (`sagemaker_session=LocalSession()`).

In the `deploy()` function, we set the `instance_type` to `local`. When we run the deploy function, the sdk will launch a docker container within our notebook environment using a pre-built pytorch container whilst uploading our inference script and loading it with our model weights from s3. If you're interested in learning more, you can launch a terminal in a new tab and run `docker ps` and you will see an inference container running.

In [None]:
from sagemaker.local import LocalSession
from sagemaker.pytorch import PyTorchModel

pytorch_local_serving_model = PyTorchModel(
    model_data=estimator.model_data,
    role=role,
    framework_version='1.8.0',
    py_version='py3',
    sagemaker_session=LocalSession(),
    entry_point='inference.py',
    source_dir="script"
    
)

In [None]:
local_predictor = pytorch_local_serving_model.deploy(initial_instance_count=1,
        instance_type='local')

In [None]:
with open(img_path, "rb") as f:
    payload = f.read()

    
response = local_predictor.predict(payload, initial_args={'ContentType': 'application/x-image'})
print(response)

In [None]:
local_predictor.delete_endpoint()