# UNET training - Pytorch 2.1
This notebook shows how to fine-tune a pretrained UNET PyTorch model with AWS Trainium (trn1 instances) using NeuronSDK.\
The model implementation is provided by milesial/Pytorch-UNet. 



The example has 2 stages:
1. First compile the model using the utility `neuron_parallel_compile` to compile the model to run on the AWS Trainium device.
1. Run the fine-tuning script to train the model based on image segmentaion task. The training job will use 32 workers with data parallel to speed up the training.

It has been tested and run on trn1.32xlarge instance using 256 x 256 input image for binary segmentation with batch size 4.

**Reference:** 
milesial, U-Net: Semantic segmentation with PyTorch, GitHub repository
https://github.com/milesial/Pytorch-UNet

In [1]:
!pip freeze | grep -E 'neuron|torch|pill|glob|sci|timm|transformers|tensorboard'

aws-neuronx-runtime-discovery==2.9
libneuronxla==0.5.971
neuronx-cc==2.13.72.0+78a426937
neuronx-distributed==0.7.0
pillow==10.3.0
scikit-learn==1.3.2
scipy==1.10.1
tensorboard==2.14.0
tensorboard-data-server==0.7.2
tensorboard-plugin-neuronx==2.6.7.0
timm==0.9.16
torch==1.13.1
torch-neuronx==1.13.1.1.14.0
torch-xla==1.13.1+torchneurone
torchvision==0.14.1
transformers==4.40.1


## 1) Install dependencies

In [None]:
#Install Neuron Compiler and Neuron/XLA packages
%pip install -U "timm" "tensorboard" torchvision==0.16.*
%pip install -U "Pillow" "glob2" "scikit-learn" 
# use --force-reinstall if you're facing some issues while loading the modules
# now restart the kernel again

## 2) Download Carvana dataset
This example uses Carvana dataset which requires users to manually download the dataset before training.\
 https://www.kaggle.com/competitions/carvana-image-masking-challenge/data 

1. Download train.zip and train_masks.zip 
2. Unzip
3. Create a carvana directory
4. Directory structure\
carvana/train/\
carvana/train_masks/

dataset_path = \<Path to Carvana directory\>

## 3) Set the parameters

In [1]:
# num_workers = 32
num_workers = 2
dataloader_num_workers = 2
image_dim = 256
# num_epochs = 20
num_epochs = 2

In [2]:
learning_rate = 2e-4
batch_size = 4
env_var_options = "NEURON_RT_ASYNC_EXEC_MAX_INFLIGHT_REQUESTS=3  " + \
    "NEURON_CC_FLAGS=\'--cache_dir=./compiler_cache --model-type=cnn-training\'"
dataset_path = "./carvana/"

## 4) Compile the model with neuron_parallel_compile

In [3]:
COMPILE_CMD = f"""{env_var_options} neuron_parallel_compile torchrun --nproc_per_node={num_workers} \
   train.py \
    --num_workers {dataloader_num_workers} \
    --image_dim {image_dim} \
    --num_epochs 2 \
    --batch_size {batch_size} \
    --drop_last \
    --data_dir {dataset_path} \
    --lr {learning_rate}"""
print(COMPILE_CMD)

NEURON_RT_ASYNC_EXEC_MAX_INFLIGHT_REQUESTS=3  NEURON_CC_FLAGS='--cache_dir=./compiler_cache --model-type=cnn-training' neuron_parallel_compile torchrun --nproc_per_node=2    train.py     --num_workers 2     --image_dim 256     --num_epochs 2     --batch_size 4     --drop_last     --data_dir ./carvana/     --lr 0.0002


In [5]:
# NEURON_RT_ASYNC_EXEC_MAX_INFLIGHT_REQUESTS=3  NEURON_CC_FLAGS='--cache_dir=./compiler_cache --model-type=cnn-training' neuron_parallel_compile torchrun --nproc_per_node=32 train.py --num_workers 2 --image_dim 256 --num_epochs 2 --batch_size 4 --drop_last --data_dir ./carvana/ --lr 0.0002

In [6]:
%%time
import subprocess
print("Compile model")
COMPILE_CMD = f"""{env_var_options} neuron_parallel_compile torchrun --nproc_per_node={num_workers} \
   train.py \
    --num_workers {dataloader_num_workers} \
    --image_dim {image_dim} \
    --num_epochs 2 \
    --batch_size {batch_size} \
    --drop_last \
    --data_dir {dataset_path} \
    --lr {learning_rate}"""

print(f'Running command: \n{COMPILE_CMD}')
if subprocess.check_call(COMPILE_CMD,shell=True):
   print("There was an error with the compilation command")
else:
   print("Compilation Success!!!")

Compile model
Running command: 
NEURON_RT_ASYNC_EXEC_MAX_INFLIGHT_REQUESTS=3  NEURON_CC_FLAGS='--cache_dir=./compiler_cache --model-type=cnn-training' neuron_parallel_compile torchrun --nproc_per_node=2    train.py     --num_workers 2     --image_dim 256     --num_epochs 2     --batch_size 4     --drop_last     --data_dir ./carvana/     --lr 0.0002
2024-05-03 20:42:14.000627:  1253951  INFO ||NEURON_PARALLEL_COMPILE||: Removing existing workdir /tmp/ubuntu/parallel_compile_workdir
2024-05-03 20:42:14.000627:  1253951  INFO ||NEURON_PARALLEL_COMPILE||: Running trial run (add option to terminate trial run early; also ignore trial run's generated outputs, i.e. loss, checkpoints)


*****************************************
Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
*****************************************
==> Preparing data..
==> Preparing data..
train_dataset : 4579, test_dataset : 509
Image shape : torch.Size([3, 256, 256]), mask shape : torch.Size([1, 256, 256])
Epoch 1 train begin 2024-05-03 20:42:29.472695
2024-05-03 20:42:29.000699:  1254571  INFO ||NEURON_CACHE||: Compile cache path: ./compiler_cache
2024-05-03 20:42:29.000700:  1254571  INFO ||NEURON_CC_WRAPPER||: Extracting graphs (/home/ubuntu/torch_neuronx_exploration/from_samples/unet/compiler_cache/neuronxcc-2.13.72.0+78a426937/MODULE_1839587966621543786+ade7b014/model.hlo.pb) for ahead-of-time parallel compilation. No compilation was done.
2024-05-03 20:42:30.000259:  1254679  INFO ||NEURON_CACHE||: Compile cache path: ./compiler_cach

## 5) Compile and Fine-tune the model

In [None]:
%%time
import subprocess
print("Compile model")
COMPILE_CMD = f"""{env_var_options} torchrun --nproc_per_node={num_workers} \
    train.py \
    --num_workers {dataloader_num_workers} \
    --image_dim {image_dim} \
    --num_epochs {num_epochs} \
    --batch_size {batch_size} \
    --do_eval \
    --drop_last \
    --data_dir {dataset_path} \
    --lr {learning_rate}"""

print(f'Running command: \n{COMPILE_CMD}')
if subprocess.check_call(COMPILE_CMD,shell=True):
   print("There was an error with the fine-tune command")
else:
   print("Fine-tune Successful!!!")