# CRES Deep Learning: 
## --- *Image Segmentation with UNET* --- 

-----
**Overview**

* Below is an example of how to load a simulated dataset and train an image segmentation model on it.

-----
**Instructions**

* Start by uploading a .zip file with the .spec file dataset. This may take a second. 
    0. Start by following the directions on the [README](https://github.com/Helium6CRES/he6-cres-deep-learning) to make a training dataset. Compress the root directory that contains both the spec files and labels.
    1. Click on files on right menu bar. 
    2. Click upload. 
    3. Upload a .zip containing data files. 
    4. Need to have instructions on the readme for how to create these files. 
* Then follow the cells below to visualize the dataset and train the lightning module. 

-----
**Tips**

* If you restart the runtime you don't lose all your imported data but if you restart and delete then you do. 
* If you change the runtime type (GPU to CPU for example) you lose all imported data. 

-----
**Resources**

* [Pytorch docs](https://pytorch.org/docs/stable/index.html)
* [Torchvision docs](https://pytorch.org/vision/stable/index.html)
* [Useful for uploading to colab](https://medium.com/@vishakha1203/easiest-way-to-upload-large-datasets-to-google-colab-1f89231844dc)
* [Discussion of how to optimize num workers in Dataloader](https://discuss.pytorch.org/t/guidelines-for-assigning-num-workers-to-dataloader/813)

-----
**Project Links**: 
* [he6-cres-deep-learning github page](https://github.com/Helium6CRES/he6-cres-deep-learning)

## 1.&nbsp;Imports

In [1]:
## Put the below into a requirements.txt

%%capture
! pip install torch == 1.12.1
! pip install torchvision == 0.13.1
! pip install pytorch-lightning == 1.6.4
! pip install pytorch-lightning-bolts
! pip install torchmetrics == 0.9.1
! pip install matplotlib == 3.1.3
! pip install numpy == 1.21.6
! pip install ipywidgets==7.7.0
! pip install re==2.2.1


UsageError: Line magic function `%%capture` not found.


In [2]:
# Deep learning imports.
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split, TensorDataset

import torchvision
from torchvision.utils import draw_bounding_boxes, draw_segmentation_masks, make_grid
from torchvision.ops import masks_to_boxes
import torchvision.transforms.functional as TF
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor

import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger

import torchmetrics

# Standard imports. 
from typing import List, Union
import gc
import matplotlib.pyplot as plt
import numpy as np 
import zipfile
from pathlib import Path
import re
import sys

# Necessary for creating our images.
from skimage.draw import line_aa

# Interactive widgets for data viz.
import ipywidgets as widgets
from ipywidgets import interact, interact_manual

**import deep learning package from Helium6CRES organization github page**

* May need to git clone if you have not already. 

In [3]:
%load_ext autoreload

In [9]:
%autoreload 2
sys.path.append('~/Documents/dsir-1031/project/he6-cres-deep-learning/')
from he6_cres_deep_learning.deep_learning import ds
from he6_cres_deep_learning.deep_learning import util 
from he6_cres_deep_learning.deep_learning import model 

ModuleNotFoundError: No module named 'he6_cres_deep_learning'

## 3.&nbsp;Visualize Dataset

In [12]:
dataset_path = "/media/drew/T7 Shield/cres_deep_learning/training_data/config/simple_ds"
cres_dm = ds.CRES_DM(root_dir = dataset_path, max_pool = 16, file_max = 12, batch_size= 4)

In [13]:
%matplotlib qt
style = {'description_width': 'initial'}

@interact
def vizualize_label_targets(display_num_imgs= widgets.IntSlider(style= style,value=4,min=0,max=8,step=1, description = "display_num_imgs"),
                            mask_alpha= widgets.FloatSlider(style= style,value=.4,min=0,max=1,step=.01, description = "mask alpha"),
                            display_size = widgets.IntSlider(style= style, value=15,min=5,max=50,step=1),
                            show_labels = widgets.Checkbox(style= style,value=True,description='target masks'),   
                            ): 


    dataiter = iter(cres_dm.train_dataloader())

    imgs, labels = dataiter.next()

    imgs = 255 - imgs.repeat(1, 3, 1, 1)
    imgs = imgs[:display_num_imgs]
    labels = labels[:display_num_imgs]
    masks = util.labels_to_masks(labels)

    result_images = [imgs[i] for i in range(display_num_imgs)]
    
    class_map={
            0: {
                "name": "background",
                "target_color": (255, 255, 255),
            },
            1: {"name": "band 0", "target_color": (255, 0, 0)},
            2: {"name": "band 1", "target_color": (0, 255, 0)},
            3: {"name": "band 2", "target_color": (0, 0, 255)},
        }

    if show_labels: 
        result_images = util.display_masks_unet(imgs, masks, cres_dm.class_map, alpha = mask_alpha)

    grid = make_grid(result_images)
    util.show(grid, figsize = (display_size, display_size))

interactive(children=(IntSlider(value=4, description='display_num_imgs', max=8, style=SliderStyle(description_…

virtual void QEventDispatcherUNIX::registerSocketNotifier(QSocketNotifier*): Multiple socket notifiers for same socket 11 and type Read


## 4.&nbsp;Train the Lightning Module

* Too many spec/label files can overpower the ram. Start with `file_max` = 10, `max_pool` = 16 below. 
* If you have more classes you will need to change the `weights` tensor to have more values and the `num_classes` to match `max(labels)`. 

In [36]:
dataset_path = "/media/drew/T7 Shield/cres_deep_learning/training_data/config/simple_ds"
cres_dm = ds.CRES_DM(root_dir = dataset_path, max_pool = 16, file_max = 10, batch_size = 4)

# Define weights for loss function. 
weights = torch.tensor([1,10]).float()

# Create callback for ModelCheckpoints. 
checkpoint_callback = ModelCheckpoint(filename='{epoch:02d}', save_top_k = 50, monitor = "Loss/val_loss", every_n_epochs = 1)

# Define Logger. 
logger = TensorBoardLogger("tb_logs", name="cres_image_segmentation", log_graph = True)

# Create Instance of Lightning Module. 
img_seg_lm = model.LightningImageSegmentation(in_channels=1, 
                                        num_classes=2, 
                                        first_feature_num = 4, 
                                        num_layers = 2, 
                                        skip_connect = True, 
                                        kernel_size = 3, 
                                        bias = False, 
                                        weight_loss = weights)

# -----------Set device.------------------
device = "gpu" if torch.cuda.is_available() else "cpu"

# Create an instance of a Trainer.
trainer = pl.Trainer(logger = logger, callbacks = [checkpoint_callback], accelerator = device, max_epochs = 15, log_every_n_steps = 2)

# Fit. 
trainer.fit(img_seg_lm, cres_dm)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name      | Type             | Params | In sizes         | Out sizes       
-------------------------------------------------------------------------------------
0 | train_acc | Accuracy         | 0      | ?                | ?               
1 | train_f1  | F1Score          | 0      | ?                | ?               
2 | train_iou | JaccardIndex     | 0      | ?                | ?               
3 | val_acc   | Accuracy         | 0      | ?                | ?               
4 | val_f1    | F1Score          | 0      | ?                | ?               
5 | loss      | CrossEntropyLoss | 0      | ?                | ?               
6 | model     | UNET             | 30.5 K | [4, 1, 256, 256] | [4, 2, 256, 256]
-------------------------------------------------------------------------------------
30.5 K    Trainable params
0         Non

Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

**tensorboard**

In [32]:
%load_ext tensorboard
%tensorboard --logdir tb_logs/

The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


Reusing TensorBoard on port 6006 (pid 6444), started 0:46:30 ago. (Use '!kill 6444' to kill it.)

## 5.&nbsp; Visualize Predictions


**get test set**

In [37]:
test_dataiter = iter(cres_dm.test_dataloader())

**execute following cell again to see more test data**

In [38]:
imgs, labels = test_dataiter.next()
masks = util.labels_to_masks(labels)

**visualize predictions on test set**

In [39]:
%matplotlib inline
version = 4

@interact
def vizualize_labels_preds( epoch = widgets.IntSlider(value=9,min=0,max=4,step=1),
                            show_labels = widgets.Checkbox(value=False,description='display_labels'),
                            show_preds = widgets.Checkbox(value=False,description='display_preds'),
                            display_size = widgets.IntSlider(value=10,min=2,max=50,step=1), 
                            mask_threshold = widgets.FloatSlider(value=.5,min=0,max=1,step=.0001,description='mask_thresh'),
                        ): 

  
    PATH = '/home/drew/He6CRES/he6-cres-deep-learning/demo/tb_logs/cres_image_segmentation/version_{}\
    /checkpoints/epoch={:02d}.ckpt'.format(version, epoch)
  
    loaded_lm = model.LightningImageSegmentation.load_from_checkpoint(PATH)

    with torch.no_grad():
        logits = loaded_lm(imgs)

    probs = logits.softmax(dim = 1)

    imgs_display = 255 - imgs.repeat(1, 3, 1, 1)
    result_image = [imgs_display[i] for i in range(len(imgs_display))]

    if  show_labels: 
        result_image = util.display_masks_unet(result_image, masks, cres_dm.class_map)
    if  show_preds: 
        preds = (probs > mask_threshold)
        result_image = util.display_masks_unet(result_image, preds, cres_dm.class_map)
    grid = make_grid(result_image)
    util.show(grid, figsize = (display_size,display_size), extent = [0, 35, 0, 1200])

interactive(children=(IntSlider(value=4, description='epoch', max=4), Checkbox(value=False, description='displ…