<a href="https://colab.research.google.com/github/veda-sunkara/StreetToCloud/blob/master/FCNN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# Download anything that may need to be downloaded
!pip install rasterio



In [2]:
import os
import random
import numpy as np
from tqdm import tqdm_notebook as tqdm

import torch
import torch.nn as nn
# import torch.utils.data
import torch.optim as optim
from torchvision import transforms
import torchvision.transforms.functional as F

from models import get_model
from datasets import get_dataset, get_cs_points_path


In [3]:
# Hyperparameters
MODEL_NAME = 'fcn'  # [unet, refiner, fcn]
DATASET_NAME = 'sen1'  # [sen1, sen2]
CS_POINTS_CLUSTERING = None  # [None, low, high]
CS_POINTS_NOISE = None  # [None, low, high]

RUNNAME = "1e3_flood_0"

# Training parameters
LR = 1e-3
BATCH_SIZE = 4
EPOCHS = 1000

BASE_DIR = '/home/purri/research/water_dots/Sen1_dataset/'  # Need to handle Google Collab directory

In [4]:
# set random seeds
random.seed(0)
np.random.seed(0)
torch.manual_seed(0)

<torch._C.Generator at 0x7f587c0a9a90>

Get CUDA information

In [5]:
device = torch.device('cuda')

Get datasets and create DataLoader objects

In [6]:
# get csv paths
train_data_csv_path = os.path.join(BASE_DIR, 'flood_train_data.csv')
valid_data_csv_path = os.path.join(BASE_DIR, 'flood_valid_data.csv')
test_data_csv_path = os.path.join(BASE_DIR, 'flood_test_data.csv')

crowd_points_path = get_cs_points_path(BASE_DIR, CS_POINTS_CLUSTERING, CS_POINTS_NOISE)

# get dataset objects
train_dataset = get_dataset(DATASET_NAME,
                            BASE_DIR,
                            train_data_csv_path, 
                            crowd_points_path=crowd_points_path,
                            is_train=True)
valid_dataset = get_dataset(DATASET_NAME,
                            BASE_DIR,
                            valid_data_csv_path, 
                            crowd_points_path=crowd_points_path,
                            is_train=False)
test_dataset = get_dataset(DATASET_NAME,
                           BASE_DIR,
                           test_data_csv_path, 
                           crowd_points_path=crowd_points_path,
                           is_train=False)

# create dataloaders
train_dataloader = torch.utils.data.DataLoader(train_dataset,
                              batch_size=BATCH_SIZE,
                              shuffle=True,
                              num_workers=0,
                              pin_memory=True,
                              drop_last=False,
                             )
valid_dataloader = torch.utils.data.DataLoader(valid_dataset,
                              batch_size=BATCH_SIZE,
                              shuffle=False,
                              num_workers=0,
                              pin_memory=True,
                              drop_last=False,
                             )
test_dataloader = torch.utils.data.DataLoader(test_dataset,
                             batch_size=BATCH_SIZE,
                             shuffle=False,
                             num_workers=0,
                             pin_memory=True,
                             drop_last=False,
                            )

Load model

In [7]:
net = get_model(MODEL_NAME, DATASET_NAME, crowd_points_path)
net = net.to(device)

Get optimizer, loss scheduler, and loss objective

In [8]:
optimizer = optim.AdamW(net.parameters(),lr=LR)
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, 260, T_mult=2, eta_min=0, last_epoch=-1)
loss_func = nn.CrossEntropyLoss(weight=torch.tensor([1,8], dtype=torch.float).to(device), ignore_index=255, size_average=True)

Metrics

In [9]:
def computeIOU(output, target):
  output = torch.argmax(output, dim=1).flatten() 
  target = target.flatten()
  no_ignore = target.ne(255).cuda()
  output = output.masked_select(no_ignore)
  target = target.masked_select(no_ignore)
  intersection = torch.sum(output * target)
  union = torch.sum(target) + torch.sum(output) - intersection
  iou = (intersection + .0000001) / (union + .0000001)
  if iou != iou:
    print("failed, replacing with 0")
    iou = torch.tensor(0).float()
  return iou
  

def computeAccuracy(output, target):
  output = torch.argmax(output, dim=1).flatten() 
  target = target.flatten()
  no_ignore = target.ne(255).cuda()
  output = output.masked_select(no_ignore)
  target = target.masked_select(no_ignore)
  correct = torch.sum(output.eq(target))
  return correct.float() / len(target)

Create training and evaluate functions

In [10]:
def train_model(model, device, train_loader, loss_func, optimizer, epoch):
    model.train()
    tbar = tqdm(train_loader)
    total_loss = 0
    for batch_idx, batch in enumerate(tbar):
        optimizer.zero_grad()

        data, target = batch['img'], batch['target']
        data, target = data.to(device), target.to(device)

        if crowd_points_path:
            cs_img = batch['cs_img'].to(device)
            output = model([data, cs_img])
        else:
            output = model(data)

        if MODEL_NAME == 'fcn':
            loss = loss_func(output['out'], target)
        else:
            loss = loss_func(output, target)

        loss.backward()
        optimizer.step()

        total_loss += loss.item()

        tbar.set_description('Train loss: {0:2.4f}'.format(loss.item()))
    tbar.set_description('Train loss: {0:2.4f}'.format(total_loss / len(train_loader.dataset)))

def test_model(model, device, test_loader, loss_func):
    model.eval()
    test_loss = 0
    metrics = {'iou': 0,
               'acc': 0,
              }
    with torch.no_grad():
        tbar = tqdm(test_loader)
        for batch in tbar:
            data, target = batch['img'], batch['target']
            data, target = data.to(device), target.to(device)

            if crowd_points_path:
                cs_img = batch['cs_img'].to(device)
                output = model([data, cs_img])
            else:
                output = model(data)
            
            if MODEL_NAME == 'fcn':
                test_loss += loss_func(output['out'], target).item()
                metrics['iou'] += computeIOU(output['out'], target)
                metrics['acc'] += computeAccuracy(output['out'], target)
            else:
                test_loss += loss_func(output, target).item()
                metrics['iou'] += computeIOU(output, target)
                metrics['acc'] += computeAccuracy(output, target)
 
    test_loss /= len(test_loader.dataset)
    
    # compute average metric
    for metric_name, value in metrics.items():
        metrics[metric_name] = value / len(test_loader.dataset)
    
    tbar.set_description('[{0:4d}/{1:4d}]|Acc: {2:3.2f}% |mIoU: {3:3.2f}% |'.format(epoch, EPOCHS, metrics['acc']*100, metrics['iou']*100))

    return metrics

In [11]:
def save_model(model, eval_result, best_result, model_save_path):
    total_result = 0
    for metric, value in eval_result.items():
        total_result += value
    
    if total_result > best_result:
        torch.save(model.state_dict(), model_save_path)
        best_result = max(total_result, best_result)
    return best_result
    
def get_save_path(base_dir, runname):
    save_dir = os.path.join(base_dir, 'checkpoints', runname)
    if os.path.isdir(save_dir) is False:
        os.makedirs(save_dir)
    save_path = os.path.join(save_dir, 'best_model.pth.tar')
    return save_path

Main training loop

In [24]:
best_result = 0
model_save_path = get_save_path(BASE_DIR, RUNNAME)
for epoch in range(1, EPOCHS+1):
    train_model(net, device, train_dataloader, loss_func, optimizer, epoch)
    valid_result = test_model(net, device, valid_dataloader, loss_func)
    best_result = save_model(net, valid_result, best_result, model_save_path)
    scheduler.step()


HBox(children=(FloatProgress(value=0.0, max=63.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=22.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=63.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=22.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=63.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=22.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=63.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=22.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=63.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=22.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=63.0), HTML(value='')))

Evaluate model

In [16]:
eval_result = test_model(net, device, test_dataloader, loss_func)

# TODO: present final results in pretty manner

TypeError: Caught TypeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/home/purri/anaconda3/lib/python3.8/site-packages/PIL/Image.py", line 2749, in fromarray
    mode, rawmode = _fromarray_typemap[typekey]
KeyError: ((1, 1, 512), '<f4')

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/purri/anaconda3/lib/python3.8/site-packages/torch/utils/data/_utils/worker.py", line 185, in _worker_loop
    data = fetcher.fetch(index)
  File "/home/purri/anaconda3/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 44, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/purri/anaconda3/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 44, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/purri/research/water_dots/StreetToCloud/sen1floods11_dataset.py", line 73, in __getitem__
    img = load_image(img_path)
  File "/home/purri/research/water_dots/StreetToCloud/sen1floods11_dataset.py", line 13, in load_image
    img = Image.fromarray(img)
  File "/home/purri/anaconda3/lib/python3.8/site-packages/PIL/Image.py", line 2751, in fromarray
    raise TypeError("Cannot handle this data type: %s, %s" % typekey) from e
TypeError: Cannot handle this data type: (1, 1, 512), <f4
