<a href="https://colab.research.google.com/github/zacharylazzara/tent-detection/blob/main/Tent_Detector_2023.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Description
This notebook detects tents in satellite images using UNet.


TODO:
* Try using Dice loss https://www.jeremyjordan.me/semantic-segmentation/#:~:text=The%20most%20commonly%20used%20loss,one%2Dhot%20encoded%20target%20vector.

#Settings

In [36]:
#@title Initial Imports
import os
from google.colab import drive
from google.colab import files
from google.colab import output
from google.colab import runtime

In [37]:
#@title Models

DOWNLOAD_MODEL            = False #@param {type:"boolean"}
UPLOAD_MODEL              = False #@param {type:"boolean"}
FORCE_UPLOAD              = False #@param {type:"boolean"}
SEGMENTOR_MODEL_FILENAME  = 'model.pth' #@param {type:"string"}
DOWNLOAD_MODEL_URL        = 'https://github.com/zacharylazzara/tent-detection/blob/685f80327f493cf816fceed7a2094654f3107bc2/unet.pth' #@param {type:"string"}

uploaded_files = None
if FORCE_UPLOAD and (UPLOAD_MODEL or DOWNLOAD_MODEL):
  os.remove(SEGMENTOR_MODEL_FILENAME)
if DOWNLOAD_MODEL:
  !wget $DOWNLOAD_MODEL_URL
  if not (os.path.exists(SEGMENTOR_MODEL_FILENAME)):
    raise Exception(f'File "{SEGMENTOR_MODEL_FILENAME}" not found!')
else:
  if UPLOAD_MODEL and not (os.path.exists(SEGMENTOR_MODEL_FILENAME)):
    uploaded_files = files.upload()
    for filename in uploaded_files.keys():
      print(f'Uploaded file "{filename}"')
      if filename != SEGMENTOR_MODEL_FILENAME:
        raise Exception('Filename must match SEGMENTOR_MODEL_FILENAME!')
  if UPLOAD_MODEL and uploaded_files == {}:
    raise Exception('No files uploaded!')

In [38]:
#@title Output
SAVE_TO_GOOGLE_DRIVE      = True #@param {type:"boolean"}
DOWNLOAD_OUTPUT           = False #@param {type:"boolean"}
PLAY_SOUND_ON_COMPLETE    = True #@param {type:"boolean"}
KILL_RUNTIME_ON_COMPLETE  = True #@param {type:"boolean"}
MINUTES_TO_KILL_RUNTIME   = 5 #@param {type:"slider", min:0, max:30, step:1}
MINUTES_TO_KILL_RUNTIME   = MINUTES_TO_KILL_RUNTIME*60


In [39]:
#@title Config
TRAIN_MODEL   = True #@param {type:"boolean"}
N_EPOCHS      = 200 #@param {type:"number"}
BATCH_SIZE    = 8 #@param {type:"number"}
INIT_LR       = 0.0001 #@param {type:"number"}
IMAGE_HEIGHT  = 512 #@param {type:"number"}
IMAGE_WIDTH   = 512 #@param {type:"number"}
TEST_SPLIT    = 0.2 #@param {type:"number"}
RANDOM_STATE  = 42 #@param {type:"number"}
# OUTPUT_FORMAT = 'png' #@param ["png", "jpg"] {allow-input: true}
# INPUT_FORMAT  = 'jpg' #@param ["png", "jpg"] {allow-input: true}


OUTPUT_FORMATS = {'model':'pth', 'spreadsheet':'csv', 'image':'png'}
INPUT_FORMATS = {'model':'pth', 'spreadsheet':'csv', 'image':'jpg'}

#Definitions

##Initialization

In [40]:
#@title Imports
!pip install -q tqdm-thread
!pip install -q torchmetrics
import time
import math
import csv
import cv2
import torch
import threading
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import torchvision.transforms.functional as TF
from PIL import Image
from PIL import ImageDraw
from PIL import ImageFont
from statistics import mean
from tqdm.auto import tqdm
from tqdm_thread import tqdm_thread
from pathlib import Path
from sklearn.model_selection import train_test_split
from torch.nn import Module
from torch.nn import Conv2d
from torch.nn import Linear
from torch.nn import MaxPool2d
from torch.nn import ReLU
from torch.nn import LogSoftmax
from torch.nn import Sequential
from torch.nn import ModuleList
from torch.nn import ConvTranspose2d
from torch.nn import Flatten
from torch.nn import functional
from torch.nn import BatchNorm2d
from torch.nn import Softplus
from torch.nn.modules.loss import BCEWithLogitsLoss
from torchmetrics.classification import BinaryJaccardIndex
from torchmetrics.classification import BinaryF1Score

from torch import flatten
from torch import cat
from torch import randn
from torchvision import transforms
from torchvision.transforms import CenterCrop
from torchvision.utils import save_image
from torch.utils.data import Dataset
from torch.utils.data import Subset
from torch.utils.data import DataLoader
from torch.optim import Adam
from torch.optim import SGD

In [41]:
#@title Device
DEVICE = None
if torch.cuda.is_available():
    DEVICE = 'cuda'
elif torch.backends.mps.is_available():
    DEVICE = 'mps'
else:
    DEVICE = 'cpu'
PIN_MEMORY = True if DEVICE != 'cpu' else False
print(f'Using device: {DEVICE}')

Using device: cuda


In [42]:
#@title Paths
SRC_PATH            = os.environ['SRC_PATH']            = f'sarpol-zahab-tents'
OUTPUT_PATH         = os.environ['OUTPUT_PATH']         = f'output'

DATA_PATH           = os.environ['DATA_PATH']           = f'{SRC_PATH}/data'
IMAGES_PATH         = os.environ['IMAGES_PATH']         = f'{DATA_PATH}/images'
MASKS_PATH          = os.environ['MASKS_PATH']          = f'{DATA_PATH}/labels'
LABELS_PATH         = os.environ['LABELS_PATH']         = f'{DATA_PATH}/sarpol_counts.csv'

G_DRIVE_MOUNT_POINT = os.environ['G_DRIVE_MOUNT_POINT'] = f'g_drive'
G_DRIVE_STORAGE     = os.environ['G_DRIVE_STORAGE']     = f'{G_DRIVE_MOUNT_POINT}/MyDrive'

In [43]:
#@title Mount
if SAVE_TO_GOOGLE_DRIVE:
  !mkdir -p $G_DRIVE_MOUNT_POINT
  drive.mount(G_DRIVE_MOUNT_POINT)

Drive already mounted at g_drive; to attempt to forcibly remount, call drive.mount("g_drive", force_remount=True).


In [44]:
#@title Environment
%%bash

if [ -d 'sample_data' ]; then
  rm -r sample_data
fi

if [ ! -d $SRC_PATH ]; then
  git clone https://github.com/tofighi/sarpol-zahab-tents.git
fi

if [ ! -d $OUTPUT_PATH ]; then
  mkdir -p $OUTPUT_PATH
fi

##Models

In [45]:
#@title UNet
# Adapted from: https://pyimagesearch.com/2021/11/08/u-net-training-image-segmentation-models-in-pytorch/
class Block(Module):
  def __init__(self, in_channels, out_channels):
    super(Block, self).__init__()
    self.double_conv2d = Sequential(
        Conv2d(in_channels, out_channels, 3, 1, 1, bias=False),
        BatchNorm2d(out_channels),
        ReLU(inplace=True),
        Conv2d(out_channels, out_channels, 3, 1, 1, bias=False),
        BatchNorm2d(out_channels),
        ReLU(inplace=True)
    )

  def forward(self, x):
    return self.double_conv2d(x)

class Encoder(Module):
  def __init__(self, channels=(3, 16, 32, 64)):
    super(Encoder, self).__init__()
    self.encoder_blocks = ModuleList([Block(channels[i], channels[i+1]) for i in range(len(channels)-1)])
    self.pool = MaxPool2d(2)

  def forward(self, x):
    block_outputs = []
    for block in self.encoder_blocks:
      x = block(x)
      block_outputs.append(x)
      x = self.pool(x)
    return block_outputs

class Decoder(Module):
  def __init__(self, channels=(64, 32, 16)):
    super(Decoder, self).__init__()
    self.up_convs = ModuleList([ConvTranspose2d(channels[i], channels[i+1], 2, 2) for i in range(len(channels)-1)])
    self.decoder_blocks = ModuleList([Block(channels[i], channels[i+1]) for i in range(len(channels)-1)])
  
  def crop(self, encoder_features, x):
    (_, _, H, W) = x.shape
    return CenterCrop([H, W])(encoder_features)
  
  def forward(self, x, encoder_features):
    for i in range(len(self.up_convs)):
      x = self.up_convs[i](x)
      encoder_feature = self.crop(encoder_features[i], x)
      x = cat([x, encoder_feature], dim=1)
      x = self.decoder_blocks[i](x)
    return x

class UNet(Module):
  def __str__(self) -> str:
    return 'UNet'

  def __init__(self, encoder_channels=(3, 16, 32, 64), decoder_channels=(64, 32, 16), classes=1, retain_dim=True, output_size=(512, 512)):
    super(UNet, self).__init__()
    self.encoder = Encoder(encoder_channels)
    self.decoder = Decoder(decoder_channels)
    self.head = Conv2d(decoder_channels[-1], classes, 1)
    self.retain_dim = retain_dim
    self.output_size = output_size

  def forward(self, x):
    encoder_features = self.encoder(x)
    decoder_features = self.decoder(encoder_features[::-1][0], encoder_features[::-1][1:])
    map = self.head(decoder_features)
    if self.retain_dim:
      map = functional.interpolate(map, self.output_size)
    return map

In [46]:
#@title Dataset
# Adapted from: https://pyimagesearch.com/2021/11/08/u-net-training-image-segmentation-models-in-pytorch/
class SegmentationDataset(Dataset):
  def __init__(self, dataframe, transformations = None):
    self.dataframe = dataframe
    self.transformations = transformations

  def __len__(self):
    return len(self.dataframe.index)

  def __getitem__(self, index):
    image = cv2.cvtColor(cv2.imread(self.dataframe.iloc[index]['image_paths']), cv2.COLOR_BGR2RGB)
    mask = cv2.threshold(cv2.imread(self.dataframe.iloc[index]['mask_paths'], cv2.IMREAD_GRAYSCALE), 150, 255, cv2.THRESH_BINARY)[1]
    if self.transformations is not None:
      image = self.transformations(image)
      mask = self.transformations(mask)
    return (image, mask, self.dataframe.iloc[index]['labels'], self.dataframe.index[index])

In [47]:
#@title Model
class Model:
  def __init__(self, model, loss_fn, opt_fn, output_dir=OUTPUT_PATH, metric_fns=None):
    self.model = model
    self.loss_fn = loss_fn
    self.metric_fns = metric_fns
    self.output_dir = output_dir

##Functions

In [48]:
#@title Load
def load_data(x_images_path=IMAGES_PATH, y_masks_path=MASKS_PATH, csv_path=LABELS_PATH):
  """
  Returns a dataframe with the feature and target paths, along with the image
  name and number of tents (i.e., labels). The dataset must be square otherwise
  tiling will fail.
  """
  with open(csv_path) as csv_file:
    rows = [row for row in csv.reader(csv_file)]
    assert len(rows) % 2 == 0 # We need our data to be square in order to tile it later
  return pd.DataFrame({
    'names'        : [row[0].split('.')[0] for row in rows],
    'image_paths'  : [str(next(Path(x_images_path).glob(row[0]))) for row in rows],
    'mask_paths'   : [str(next(Path(y_masks_path).glob(row[0]))) for row in rows],
    'labels'       : [int(row[1]) for row in rows]
  }).set_index('names').astype({'labels': 'int'})

In [49]:
#@title Train
# Adapted from: https://pyimagesearch.com/2021/11/08/u-net-training-image-segmentation-models-in-pytorch/
def train(model, t_loader, v_loader, loss_fn, opt_fn, metric_fns=None, epochs=N_EPOCHS):
  """Trains the model and outputs its metrics in a dataframe (loss and accuracy)."""
  
  history = pd.DataFrame({'t': {str(key):value for (key, value) in zip(['losses', *metric_fns], [[] for _ in [*metric_fns, '']])},
                          'v': {str(key):value for (key, value) in zip(['losses', *metric_fns], [[] for _ in [*metric_fns, '']])}})

  if metric_fns != None:
    for metric_fn in metric_fns:
      metric_fn.to(DEVICE)

  progress_bar = tqdm(range(epochs))
  for e in progress_bar:
    model.train()
    losses = []
    for (i, (x, y, _, _)) in enumerate(t_loader):
      (x, y) = (x.to(DEVICE), y.to(DEVICE))

      pred = model(x)
      loss = loss_fn(pred, y)
      losses.append(loss.item())

      if loss.requires_grad:
        opt_fn.zero_grad()
        loss.backward()
      opt_fn.step()

      if metric_fns != None:
        for metric_fn in metric_fns:
          metric_fn.update(pred, y)
        
    history['t']['losses'].append(mean(losses))
    if metric_fns != None:
      for metric_fn in metric_fns:
        history['t'][f'{metric_fn}'].append(metric_fn.compute().cpu().detach().numpy())
        metric_fn.reset()
    
    with torch.no_grad():
      model.eval()
      
      losses = []
      for (x, y, _, _) in v_loader:
        (x, y) = (x.to(DEVICE), y.to(DEVICE))

        pred = model(x)
        loss = loss_fn(pred, y)
        losses.append(loss.item())

        if metric_fns != None:
          for metric in metric_fns:
            metric.update(pred, y)
      
      history['v']['losses'].append(mean(losses))
      if metric_fns != None:
        for metric_fn in metric_fns:
          history['v'][f'{metric_fn}'].append(metric_fn.compute().cpu().detach().numpy())
          metric_fn.reset()

    progress_bar.set_description(f'Epoch({e+1}/{N_EPOCHS}) Training {model}, Train Loss: {history["t"]["losses"][-1]:.4f}, Test Loss: {history["v"]["losses"][-1]:.4f}')
  return history

In [50]:
#@title Predict
def predict(model, loader, output_dir, loss_fn=None, metric_fns=None, x_images_path=IMAGES_PATH, o_formats=OUTPUT_FORMATS, i_formats=INPUT_FORMATS):
  """
  Runs the model without training and saves the predictions to the output
  directory. Returns a dataframe containing the image paths, mask paths (i.e,
  prediction paths), and the tent count as determined by counting contours in the
  prediction.
  """
  history = pd.DataFrame({'l':{str(key):value for (key, value) in zip(['losses', *metric_fns], [[] for _ in [*metric_fns, '']])}})

  preds = []
  with torch.no_grad():
    model.eval()
    progress_bar = tqdm(loader)
    progress_bar.set_description(f'Evaluating {model}')


    losses = []
    for (x, y, _, name) in progress_bar:
      (x, y) = (x.to(DEVICE), y.to(DEVICE))
      p = model(x)

      if loss_fn != None:
        loss = loss_fn(p, y)
        losses.append(loss.item())
      if metric_fns != None:
        for metric_fn in metric_fns:
          metric_fn.update(p, y)

      for batch, img in enumerate(p.cpu().detach()):
        out_path = f'{output_dir}/{name[batch]}.{o_formats["image"]}'
        save_image(img, out_path)
        preds.append({'names':name[batch], 'image_paths':str(next(Path(x_images_path).glob(f'{name[batch]}.{i_formats["image"]}'))), 'mask_paths':out_path, 'labels':count_contours(img)})
        progress_bar.set_description(f'Saved prediction for {name[batch]} to {out_path}')

    if losses != []:
      history['l']['losses'].append(mean(losses))
    if metric_fns != None:
      for metric_fn in metric_fns:
        history['l'][f'{metric_fn}'].append(metric_fn.compute().cpu().detach().numpy())
        metric_fn.reset()
  return (pd.DataFrame(preds).set_index('names').fillna(np.nan), history)

In [51]:
#@title Region
def region_box(name, c, w, h):
  """Draws the region box with the tent count and name."""
  shape = [(0, 0), (w-1, h-1)]
  
  img = Image.new('RGBA', (w, h))
  draw = ImageDraw.Draw(img)
  font = ImageFont.truetype('LiberationMono-Regular.ttf', 50)
  draw.text((10, 10), f'{c}', font=font, fill=(255, 0, 0))

  _, _, tw, th = draw.textbbox((0, 0), name, font=font)

  tw = w-tw
  th = h-th

  draw.text((tw-20, th-10), name, font=font, fill=(255, 0, 0))

  rec = ImageDraw.Draw(img)  
  rec.rectangle(shape, fill = None, outline ='red')

  return img

def region(loader, output_dir, overview_path, o_formats=OUTPUT_FORMATS):
  """
  Generates an outline around each image and puts the number of tents
  in the upper left corner of the image, with the image name in the bottom right.
  Mostly redundant since we're using Seaborn for the heatmap anyway, but may
  be useful for troubleshooting.
  """
  region_paths = []
  progress_bar = tqdm(loader)
  progress_bar.set_description(f'Creating region overlays...')
  region_overview = None
  for (_, y, c, name) in progress_bar:
    regions = []
    for batch, img in enumerate(y.cpu().detach()):
      out_path = f'{output_dir}/{name[batch]}.{o_formats["image"]}'
      region = region_box(name[batch], c[batch], img.shape[1], img.shape[2])
      region.save(out_path)
      region_paths.append({'names':name[batch], 'region_paths':out_path})
      progress_bar.set_description(f'Saved region overlay for {name[batch]} to {out_path}')

      transform = transforms.Compose([transforms.PILToTensor()])
      region = transform(region)
      regions.append(region)

    region_overview = tiler(region_overview, regions).double()
  save_image(region_overview, overview_path)
  return pd.DataFrame(region_paths).set_index('names').fillna(np.nan)

In [52]:
#@title Tile
def tiler(tiles, tile_row):
  """Takes a row of tiles. The total number of tiles must be divisible by 2."""
  if tiles == None:
    tiles = torch.cat(tuple(tile_row), 2)
  else:
    tiles = torch.cat((tiles, torch.cat(tuple(tile_row), 2)), 1)
  return tiles
 
def tile(loader, output_path_x, output_path_y):
  """Tiles using a loader. Total number of tiles must be divisible by 2."""
  output_x = output_y = None
  progress_bar = tqdm(loader)
  progress_bar.set_description('Tiling')
  for (x, y, _, _) in progress_bar:
    if output_path_x != None:
      output_x = tiler(output_x, x)
    if output_path_y != None:
      output_y = tiler(output_y, y)

  if output_x != None:
    print(f'Saving {output_path_x}...')
    save_image(output_x, output_path_x)
  if output_y != None:
    print(f'Saving {output_path_y}...')
    save_image(output_y, output_path_y)
  print('Done.')

In [53]:
#@title Count Contours
# Adapted from https://stackoverflow.com/questions/48154642/how-to-count-number-of-dots-in-an-image-using-python-and-opencv
def contours(y):
  img = y.numpy().T.astype(np.uint8).copy()

  kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (2, 2))
  closing = cv2.morphologyEx(img, cv2.MORPH_CLOSE, kernel)
  cnts = cv2.findContours(closing, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)[-2]

  max_area = 20
  xcnts = []
  for cnt in cnts:
    if cv2.contourArea(cnt) < max_area:
      xcnts.append(cnt)

  return xcnts

def count_contours(y):
  return int(len(contours(y)))

def localize_contours(y):
  """
  Note that not all detected contours will be localizable, as some are opened
  (i.e., look like a C) and thus result in a division by zero error (thus we 
  use isContourConvex() to prevent this, skipping the opened contours).
  """
  moments = [cv2.moments(cnt, True) for cnt in contours(y) if cv2.isContourConvex(cnt)]
  coordinates = [(m['m10']/m['m00'], m['m01']/m['m00']) for m in moments]
  return coordinates

def count(loader, x_images_path=IMAGES_PATH, i_formats=INPUT_FORMATS):
  """Unused."""
  preds = []
  progress_bar = tqdm(loader)
  progress_bar.set_description('Counting Contours...')
  for (_, y, _, name) in progress_bar:
    y = y.to(DEVICE)
    for batch, img in enumerate(y.cpu().detach()):
      preds.append({'names':name[batch], 'image_paths':str(next(x_images_path.glob(f'{name[batch]}{i_formats["image"]}'))), 'mask_paths':None, 'labels':count_contours(img)})
  return pd.DataFrame(preds).set_index('names').fillna(np.nan)

In [54]:
#@title Overlay
def overlay(background_path, foreground_path, output_path, bg_opacity = 1, fg_opacity = 1):
  with Image.open(background_path).convert('RGBA') as background:
    background = np.array(background)
  with Image.open(foreground_path).convert('RGBA') as foreground:
    foreground = np.array(foreground)

  for channel in range(1, 2):
    foreground[foreground[:,:,channel] > 0, channel] = 0

  overlay = cv2.addWeighted(background, bg_opacity, foreground, fg_opacity, 0)
  output = Image.fromarray(overlay)
  output.save(output_path)

  return output_path

In [55]:
#@title Heatmap
def heatmap(image_path, dataframe, output_path, title):
  data = [x for x in np.array_split(dataframe['labels'].replace(0, np.nan).tolist(), int(math.sqrt(dataframe.shape[0])))]

  sns.set(font_scale=1)
  fig, ax = plt.subplots(figsize=(15, 15))
  ax.set_title(title)
  
  ax.get_xaxis().set_visible(False)
  ax.get_yaxis().set_visible(False)

  ax.tick_params(left=False, bottom=False)
  sns.heatmap(data, annot=True, square=True, fmt='.5g', alpha=0.3, zorder=2, cbar_kws={'shrink': 0.7}, ax=ax)

  with Image.open(image_path).convert("RGB") as image:
    ax.imshow(image, aspect=ax.get_aspect(), extent=ax.get_xlim()+ax.get_ylim(), zorder=1)
  plt.savefig(output_path, bbox_inches='tight')
  plt.close()

In [56]:
#@title Make Dirs Helper
def make_dirs(output_dir):
  dirs = {'output'  :f'{output_dir}',
          'tiles'   :f'{output_dir}/tiles',
          'overlay' :f'{output_dir}/tiles/overlay'}
  for dir in dirs.values():
    if not os.path.exists(dir):
      os.makedirs(dir)
  return dirs

In [57]:
#@title Truths
def process_truths(output_dir, data, loader, o_formats=OUTPUT_FORMATS):
  """Processes ground truth tiles to create ground truth overviews, overlays and heatmaps."""
  y_dirs = make_dirs(output_dir)
  y_data = data#.combine_first(region(loader, y_dirs["region"], f'{y_dirs["output"]}/y_regions_overview.{o_formats["image"]}')).sort_index()
  
  # Tiling Feature and Ground Truth Overviews
  tile(loader, f'{y_dirs["output"]}/x_overview.{o_formats["image"]}', f'{y_dirs["output"]}/y_overview.{o_formats["image"]}')

  print('Saving overlays...')
  # Ground Truth Overlays
  np.vectorize(overlay)(y_data['image_paths'], 
                        y_data['mask_paths'], 
                        np.vectorize((lambda n: f'{y_dirs["overlay"]}/{n}.{o_formats["image"]}'))(y_data.index), 
                        0.7)
  overlay(f'{y_dirs["output"]}/x_overview.{o_formats["image"]}', 
          f'{y_dirs["output"]}/y_overview.{o_formats["image"]}', 
          f'{y_dirs["output"]}/y_overlay.{o_formats["image"]}', 
          0.7)
  
  # # Ground Truth Region Overlays
  # np.vectorize(overlay)(np.vectorize((lambda n: f'{y_dirs["overlay"]}/{n}.{o_formats["image"]}'))(y_data.index),
  #                       y_data['region_paths'], 
  #                       np.vectorize((lambda n: f'{y_dirs["region_overlay"]}/{n}.{o_formats["image"]}'))(y_data.index))
  # overlay(f'{y_dirs["output"]}/y_overlay.{o_formats["image"]}', 
  #         f'{y_dirs["output"]}/y_regions_overview.{o_formats["image"]}', 
  #         f'{y_dirs["output"]}/y_regions_overlay.{o_formats["image"]}')
  print('Done.')

  print('Saving heatmaps...')
  heatmap(f'{y_dirs["output"]}/y_overlay.{o_formats["image"]}', data, f'{y_dirs["output"]}/y_heatmap.{o_formats["image"]}', 'Actual Tents')
  print('Done.')

In [58]:
#@title Trainer
def trainer(output_dir, model, t_loader, v_loader, loss_fn, opt_fn, metric_fns=None, o_formats=OUTPUT_FORMATS):
  if not os.path.exists(output_dir):
    os.makedirs(output_dir)

  results = train(model, t_loader, v_loader, loss_fn, opt_fn, metric_fns)
  torch.save(model, f'{output_dir}/model.{o_formats["model"]}')

  #Model Loss
  plt.plot(results['t']['losses'], label='training')
  plt.plot(results['v']['losses'], label='validation')
  plt.xlabel('Epoch')
  plt.ylabel('Loss')
  plt.ylim([0, 1])
  plt.legend(loc='lower right')
  loss_fn_name = str(loss_fn).replace('()', '')
  plt.title(f'{model} Loss ({loss_fn_name})')
  plt.savefig(f'{output_dir}/model_loss_{loss_fn_name}.{o_formats["image"]}')
  plt.close()

  if metric_fns != None:
    for metric_fn in metric_fns:
      #Model Metrics
      plt.plot(results['t'][f'{metric_fn}'], label='training')
      plt.plot(results['v'][f'{metric_fn}'], label='validation')
      metric_fn_name = str(metric_fn).replace('()', '')
      plt.xlabel('Epoch')
      plt.ylabel('Score')
      plt.ylim([0, 1])
      plt.legend(loc='lower right')
      plt.title(f'{model} Metric ({metric_fn_name})')
      plt.savefig(f'{output_dir}/model_metric_{metric_fn_name}.{o_formats["image"]}')
      plt.close()

  print(f'Saved model and metrics graph to: {output_dir}')

In [59]:
#@title Predictor
def predictor(x_overview_dir, output_dir, model, loader, loss_fn=None, metric_fns=None, o_formats=OUTPUT_FORMATS):
  p_dirs = make_dirs(output_dir)
  predictions, results = predict(model, loader, p_dirs["tiles"], loss_fn, metric_fns) # Prediction dataframe

  bar_data = {}
  if loss_fn != None:
    bar_data[str(loss_fn).replace('()', '')] = results['l']['losses'][0]
  if metric_fns != None:
    for metric_fn in metric_fns:
      bar_data[str(metric_fn).replace('()', '')] = results['l'][f'{metric_fn}'][0]
  if bar_data != {}:

    # TODO: no artists error
    plt.bar(list(bar_data.keys()), list(bar_data.values()))
    plt.ylim([0, 1])
    plt.legend(loc='lower right')
    plt.title(f'{model} Prediction Performance')
    plt.savefig(f'{output_dir}/model_p_performance.{o_formats["image"]}')
    plt.close()

  transformations = transforms.Compose([transforms.ToPILImage(), transforms.Resize((IMAGE_HEIGHT, IMAGE_WIDTH)), transforms.ToTensor()])
  dataset = SegmentationDataset(predictions, transformations)
  row_size = int(math.sqrt(len(dataset))) # Batch of 16 since each row is 16 images long
  model_loader = DataLoader(dataset, shuffle=False, batch_size=row_size, pin_memory=PIN_MEMORY, num_workers=os.cpu_count(), persistent_workers=True)
  predictions.to_csv(f'{p_dirs["output"]}/labels.{o_formats["spreadsheet"]}')

  p_data = predictions#.combine_first(region(model_loader, p_dirs["region"], f'{p_dirs["output"]}/p_regions_overview.{o_formats["image"]}')).sort_index()
  
  # Tiling Prediction Overview
  tile(model_loader, None, f'{p_dirs["output"]}/p_overview.{o_formats["image"]}')
  
  print('Saving overlays...')
  # Prediction Overlays
  np.vectorize(overlay)(p_data['image_paths'], 
                        p_data['mask_paths'], 
                        np.vectorize((lambda n: f'{p_dirs["overlay"]}/{n}.{o_formats["image"]}'))(p_data.index), 
                        0.7)
  overlay(f'{x_overview_dir}/x_overview.{o_formats["image"]}', 
          f'{p_dirs["output"]}/p_overview.{o_formats["image"]}', 
          f'{p_dirs["output"]}/p_overlay.{o_formats["image"]}', 
          0.7)

  # # Prediction Region Overlays
  # np.vectorize(overlay)(np.vectorize((lambda n: f'{p_dirs["overlay"]}/{n}.{o_formats["image"]}'))(p_data.index), 
  #                       p_data['region_paths'], 
  #                       np.vectorize((lambda n: f'{p_dirs["region"]}/{n}.{o_formats["image"]}'))(p_data.index))
  # overlay(f'{p_dirs["output"]}/p_overlay.{o_formats["image"]}', 
  #         f'{p_dirs["output"]}/p_regions_overview.{o_formats["image"]}', 
  #         f'{p_dirs["output"]}/p_regions_overlay.{o_formats["image"]}')
  print('Done.')

  print('Saving heatmaps...')
  heatmap(f'{p_dirs["output"]}/p_overlay.{o_formats["image"]}', predictions, f'{p_dirs["output"]}/p_heatmap.{o_formats["image"]}', 'Detected Tents')
  print('Done.')

In [60]:
#@title Eval Models
def evaluator(models, t_loader, v_loader, training=TRAIN_MODEL):
  """Trains models and generates predictions from those models"""
  if training:
    for model in models:
      trainer(model.output_dir, model.model, t_loader, v_loader, model.loss_fn, Adam(model.model.parameters(), lr=INIT_LR), metric_fns=model.metric_fns)

  transformations = transforms.Compose([transforms.ToPILImage(), transforms.Resize((IMAGE_HEIGHT, IMAGE_WIDTH)), transforms.ToTensor()])
  data = load_data() # Ground truth dataframe
  dataset = SegmentationDataset(data, transformations)
  row_size = int(math.sqrt(len(dataset))) # Batch of 16 since each row is 16 images long
  loader = DataLoader(dataset, shuffle=False, batch_size=row_size, pin_memory=PIN_MEMORY, num_workers=os.cpu_count(), persistent_workers=True)

  # Outputs
  process_truths(f'{OUTPUT_PATH}/truths', data, loader)
  for model in models:
    predictor(f'{OUTPUT_PATH}/truths', f'{model.output_dir}/preds', model.model, loader, model.loss_fn, model.metric_fns)

#Main

##Dataset

In [61]:
#@title Training and Validation Loaders
transformations = transforms.Compose([transforms.ToPILImage(), transforms.Resize((IMAGE_HEIGHT, IMAGE_WIDTH)), transforms.ToTensor()])

# Split dataset into training and validation subsets
training_data, validation_data = train_test_split(load_data(), test_size=TEST_SPLIT, random_state=RANDOM_STATE)
training_dataset = SegmentationDataset(training_data, transformations)
validation_dataset = SegmentationDataset(validation_data, transformations)

t_loader = DataLoader(training_dataset, shuffle=True, batch_size=BATCH_SIZE, pin_memory=PIN_MEMORY, num_workers=os.cpu_count(), persistent_workers=True)
v_loader = DataLoader(validation_dataset, shuffle=True, batch_size=BATCH_SIZE, pin_memory=PIN_MEMORY, num_workers=os.cpu_count(), persistent_workers=True)

In [62]:
#@title Split Ratio
t_count = len(training_data)
v_count = len(validation_data)

t_ratio = 1-v_count/t_count
v_ratio = v_count/t_count

print(f'Training to Validation Ratio\n')
print(f'Training ({t_count}): \t{t_ratio*100:>10.2f}%')
print(f'Validation ({v_count}): \t{v_ratio*100:>10.2f}%')
print(f'Total ({t_count + v_count}): \t\t{t_ratio*100 + v_ratio*100:>10.2f}%') 

assert t_ratio + v_ratio == 1 # Sanity Check

Training to Validation Ratio

Training (204): 	     74.51%
Validation (52): 	     25.49%
Total (256): 		    100.00%


In [63]:
#@title Evaluate Models
class Model:
  def __init__(self, model, loss_fn, output_dir=OUTPUT_PATH, metric_fns=None, opt_fn=None):
    self.model = model
    self.loss_fn = loss_fn
    self.opt_fn = Adam(model.parameters(), lr=INIT_LR) if opt_fn == None else opt_fn
    self.metric_fns = metric_fns
    self.output_dir = output_dir
    
models = []
if UPLOAD_MODEL or DOWNLOAD_MODEL:
  models.append(Model(torch.load(SEGMENTOR_MODEL_FILENAME, map_location=DEVICE).to(DEVICE),
                      BCEWithLogitsLoss(),
                      f'{OUTPUT_PATH}/uploaded_model'),
                      [BinaryF1Score()])
else:
  models.append(Model(UNet().to(DEVICE),
                      BCEWithLogitsLoss(),
                      f'{OUTPUT_PATH}/unet',
                      [BinaryF1Score(), BinaryJaccardIndex()]))
evaluator(models, t_loader, v_loader)

  0%|          | 0/200 [00:00<?, ?it/s]

Saved model and metrics graph to: output/unet


  0%|          | 0/16 [00:00<?, ?it/s]

Saving output/truths/x_overview.png...
Saving output/truths/y_overview.png...
Done.
Saving overlays...
Done.
Saving heatmaps...
Done.


  0%|          | 0/16 [00:00<?, ?it/s]



  0%|          | 0/16 [00:00<?, ?it/s]

Saving output/unet/preds/p_overview.png...
Done.
Saving overlays...
Done.
Saving heatmaps...
Done.


In [None]:
#@title Complete

# Since there's seemingly no way to reasonably wait for files.download, we have
# to wait a specified period of time before disconnecting the session.

def kill_runtime(seconds):
  with tqdm_thread(desc='Terminating session...', total=seconds, step_sec=0.5):
    time.sleep(seconds)
  print('Session terminated automatically.')
  time.sleep(1)
  runtime.unassign()

def end_runtime(seconds):
  if seconds > 0:
    print(f'Terminating Session in {seconds//60} minutes...')
    threading.Thread(target=kill_runtime, args=[seconds]).start()
  else:
    print('Session terminated automatically.')
    time.sleep(1)
    runtime.unassign()

print(f'Zipping output and finishing up.')
!7z a -tzip tent_detector_output.zip $OUTPUT_PATH

if SAVE_TO_GOOGLE_DRIVE:
  print('\nSaving to Google Drive\n')
  !rsync -arh --progress tent_detector_output.zip $G_DRIVE_STORAGE
if DOWNLOAD_OUTPUT:
  print('\nDownloading to Local Storage\n')
  files.download('tent_detector_output.zip')
if PLAY_SOUND_ON_COMPLETE:
  output.eval_js('new Audio("https://upload.wikimedia.org/wikipedia/commons/0/05/Beep-09.ogg").play()')
print('Done.\n\n')

# Killing the runtime on complete is only useful if we save the data somewhere first
if KILL_RUNTIME_ON_COMPLETE and (SAVE_TO_GOOGLE_DRIVE or DOWNLOAD_OUTPUT):
  end_runtime(MINUTES_TO_KILL_RUNTIME if DOWNLOAD_OUTPUT else 0)

Zipping output and finishing up.

7-Zip [64] 16.02 : Copyright (c) 1999-2016 Igor Pavlov : 2016-05-21
p7zip Version 16.02 (locale=en_US.UTF-8,Utf16=on,HugeFiles=on,64 bits,2 CPUs Intel(R) Xeon(R) CPU @ 2.30GHz (306F0),ASM,AES-NI)

Scanning the drive:
  0M Scan           8 folders, 781 files, 503540817 bytes (481 MiB)

Creating archive: tent_detector_output.zip

Items to compress: 789

  0%      1% 18 + output/truths/tiles/overlay/sarpol_018.png                                                      2% 34 + output/truths/tiles/overlay/sarpol_034.png                                                      3% 49 + output/truths/tiles/overlay/sarpol_049.png                                                    