## Config CoLab Environment

Make sure this script is opened from your own google drive folder, rather than directly opened from Github. This is crucial for saving your code for later use during the workshop. If you have not done it,  uncomment and run the cell below to clone the RVSS 2022 Github to your Google Drive, quit this Colab session and reopen it from your own google drive https://drive.google.com

In [None]:
# import os
# from google.colab import drive

# drive.mount('/content/drive')

# %cd / content/drive/MyDrive/
# if not os.path.exists('RVSS2022'):
#   !git clone https://github.com/Tobias-Fischer/RVSS2022
# else:
#   %cd / content/drive/MyDrive/RVSS2022
#   !git pull
# %cd / content/drive/MyDrive/RVSS2022/Visual_Learning/segmentation

In [None]:
import os
from google.colab import drive
drive.mount('/content/drive')

%cd /content/drive/MyDrive/RVSS2022/Visual_Learning/segmentation

## Train The Networks
Make sure the Runtime is selected as GPU accelerated. 

__Runtime can be changed from the menu:__
_Runtime --> change runtime type --> Hardware accelerator --> GPU_

In [None]:
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader

from segmenter import Segmenter
import utils.cmd_printer as cmd_printer
from trainer import Trainer
from args import args
from models import Res18Baseline
from models import Res18Skip


## Train The Face Segmentation

The hyper-parameters are defaulted as:


| Hyper-parameter | Describtion | Default Value |
| :-: | :-: | :-: |
|n_classes| the number of classes (background excluded)| - |
|lr| learning rate | 1e-3 |
|epochs| The number of training epochs | 10 |
|batch_size| Batch size | 16 |
|weight_decay| - | 1e-4 |
|scheduler_step| Learning rate decays every X epochs | 5 |
|scheduler_gamma| Learning rate decays to X*current_lr | 0.5 |
|weights_dir| Path for saving trained weights | - |
|model| Networ architectures, <br> choose between res18_baseline or res18_skip|res18_baseline|
|dataset_dir| Path to the training dataset|-|

In [None]:
from imdb import FaceIMDB

# args
args.dataset_dir = 'datasets/faces'
args.weights_dir = 'weights/face_baseline'
args.n_classes = 2
args.batch_size = 32
# print args
cmd_printer.divider(text="Hyper-parameters", line_max=60)
for arg in vars(args):
    print(f"   {arg}: {getattr(args, arg)}")
cmd_printer.divider(line_max=60)


train_loader = DataLoader(dataset=FaceIMDB(args.dataset_dir, mode='train'),
                          batch_size=args.batch_size, shuffle=True,
                          num_workers=4, drop_last=True)

eval_loader = DataLoader(dataset=FaceIMDB(args.dataset_dir, mode='eval'),
                         batch_size=args.batch_size, shuffle=False,
                         num_workers=4, drop_last=False)
   
if args.model == 'res18_baseline':
    model = Res18Baseline(args)
elif args.model == 'res18_skip':
    model = Res18Skip(args)
trainer = Trainer(args)
trainer.fit(model, train_loader, eval_loader)

## Visualise Face Segmentation
### 'ckpt' represents checkpoint.

You can set `ckpt = ''` to test if your network architecture is implemented correctly.

Later on, you can set `ckpt = <path_to_weight_file>` to inspect the network outputs.

In [None]:
# ckpt = ''
ckpt = 'weights/face_baseline/model.best.pth'
model = 'res18_baseline'
# ckpt = 'res18_skip_weights.pth'
# model = 'res18_skip'
segmenter = Segmenter(ckpt, use_gpu=False, model=model)

In [None]:
test_dir = "test_images/faces"
pred_dir = os.path.join(test_dir, model+'_output')
os.makedirs(pred_dir, exist_ok=True)
all_test_images = [file for file in os.listdir(
    test_dir) if file.endswith('.jpg')]
for image_name in all_test_images:
    np_img = np.array(Image.open(os.path.join(test_dir, image_name)))
    pred, colour_map = segmenter.segment_single_image(
        np_img, resize_to=(256, 256), labels=['hair', 'face', 'bg'])
    title = ["Input", "Prediction"]
    pics = [np_img, colour_map]
    fig, axs = plt.subplots(1, 2, figsize=(15, 10))
    axs[0].imshow(pics[0], interpolation='nearest')
    axs[0].set_title(title[0])
    axs[1].imshow(pics[1], interpolation='nearest')
    axs[1].set_title(title[1])
    axs[0].axis('off')
    axs[1].axis('off')
    path = os.path.join(pred_dir, image_name)
    plt.savefig(os.path.join(pred_dir, image_name[:-4]+'.jpg'))

## Train The Fruit Segmentation Model

| Hyper-parameter | Describtion | Default Value |
| :-: | :-: | :-: |
|n_classes| the number of classes (background excluded)| - |
|lr| learning rate | 1e-3 |
|epochs| The number of training epochs | 10 |
|batch_size| Batch size | 16 |
|weight_decay| - | 1e-4 |
|scheduler_step| Learning rate decays every X epochs | 5 |
|scheduler_gamma| Learning rate decays to X*current_lr | 0.5 |
|weights_dir| Path for saving trained weights | - |
|model| Networ architectures, <br> choose between res18_baseline or res18_skip|res18_baseline|
|dataset_dir| Path to the training dataset|-|

In [None]:
# from imdb import FruitIMDB

# # args
# args.dataset_dir = 'datasets/fruit'
# args.weights_dir = 'weights/fruit_baseline'
# args.n_classes = 4
# args.batch_size = 64
# # print args
# cmd_printer.divider(text="Hyper-parameters", line_max=60)
# for arg in vars(args):
#     print(f"   {arg}: {getattr(args, arg)}")
# cmd_printer.divider(line_max=60)


# train_loader = DataLoader(dataset=FruitIMDB(args.dataset_dir, mode='train'),
#                           batch_size=args.batch_size, shuffle=True,
#                           num_workers=4, drop_last=True)

# eval_loader = DataLoader(dataset=FruitIMDB(args.dataset_dir, mode='eval'),
#                          batch_size=args.batch_size, shuffle=False,
#                          num_workers=4, drop_last=False)

# if args.model == 'res18_baseline':
#     model = Res18Baseline(args)
# elif args.model == 'res18_skip':
#     model = Res18Skip(args)
# trainer = Trainer(args)
# trainer.fit(model, train_loader, eval_loader)