In [1]:
import warnings
import sys

warnings.simplefilter("ignore", (UserWarning, FutureWarning))
from utils.hparams import HParam
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from tqdm import tqdm
from utils import metrics
from core.res_unet import ResUnet
from core.res_unet_plus import ResUnetPlusPlus
from core.unet import UNetSmall
from utils.logger import MyWriter
import torch
import argparse
import os
import numpy as np
from dataset import *

from __future__ import print_function, division
from typing import Any
from skimage import io
import glob
import matplotlib.pyplot as plt

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
os.environ['TORCH_USE_CUDA_DSA'] = '1'

In [2]:
# train_path =  "../CV_class/training/mask*.png"
# valid_path =  "../CV_class/testing/mask*.png"
log =  "logs"
logging_step = 100
validation_interval = 7000 # Save and valid have same interval
checkpoints = "checkpoints"

batch_size = 2
lr = 0.001
RESNET_PLUS_PLUS = True
IMAGE_SIZE = 1600
CROP_SIZE = 224

epochs = 20
resume = ""
name = "default"

In [3]:
def show_tensor_image(tensordata_input, tensordata_label):
    
    numpy_input = tensordata_input.permute(1, 2, 0).cpu().detach().numpy()
    numpy_label = tensordata_label.permute(1, 2, 0).cpu().detach().numpy()
    
    fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2, figsize=(8, 4),
                                    sharex=True, sharey=True)
    
    ax1.imshow(numpy_input)
    ax2.imshow(numpy_label)
    
    for aa in (ax1, ax2):
        aa.set_axis_off()
    plt.tight_layout()
    plt.show()

In [4]:
checkpoint_dir = "{}/{}".format(checkpoints, name)
os.makedirs(checkpoint_dir, exist_ok=True)
os.makedirs("{}/{}".format(log, name), exist_ok=True)
writer = MyWriter("{}/{}".format(log, name))


# get model
print(RESNET_PLUS_PLUS)
if RESNET_PLUS_PLUS:
    model = ResUnetPlusPlus(3).to(device)
else:
    model = UNetSmall(3).to(device)


True


In [5]:
# optionally resume from a checkpoint
if resume:
    if os.path.isfile(resume):
        print("=> loading checkpoint '{}'".format(resume))
        checkpoint = torch.load(resume)

        start_epoch = checkpoint["epoch"]

        best_loss = checkpoint["best_loss"]
        model.load_state_dict(checkpoint["state_dict"])
        optimizer.load_state_dict(checkpoint["optimizer"])
        print(
            "=> loaded checkpoint '{}' (epoch {})".format(
                resume, checkpoint["epoch"]
            )
        )
    else:
        print("=> no checkpoint found at '{}'".format(resume))

In [6]:
test_tfm = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((CROP_SIZE, CROP_SIZE)),
    transforms.ToTensor(), 
])

# get data
mass_dataset_train = ImageDataset(
    transform=test_tfm
)

mass_dataset_val = ImageDataset(
    False, transform=test_tfm
)



# creating loaders
train_dataloader = DataLoader(
    mass_dataset_train, batch_size=batch_size, shuffle=True, pin_memory=False
)
val_dataloader = DataLoader(
    mass_dataset_val, batch_size=1, shuffle=False, pin_memory=False
)


In [7]:
# set up binary cross entropy and dice loss
criterion = metrics.BCEDiceLoss()

# optimizer
# optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum, nesterov=True)
# optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-5)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

# decay LR
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.1)

# starting params
best_loss = 999
start_epoch = 0

torch.cuda.empty_cache()

step = 0
for epoch in range(start_epoch, epochs):
    
    print("Epoch {}/{}".format(epoch, epochs - 1))
    print("-" * 10)

    # step the learning rate scheduler
    lr_scheduler.step()

    # run training and validation
    # logging accuracy and loss
    train_acc = metrics.MetricTracker()
    train_loss = metrics.MetricTracker()
    # iterate over data

    loader = tqdm(train_dataloader, desc="training")
    for idx, data in enumerate(loader):
        # get the inputs and wrap in Variable
        inputs = data["sat_img"].to(device)
        labels = data["map_img"].to(device)
        # ====test=====
        # show_tensor_image(inputs[0], labels[0])
        
        # zero the parameter gradients
        optimizer.zero_grad()

        # forward
        # prob_map = model(inputs) # last activation was a sigmoid
        # outputs = (prob_map > 0.3).float()
        outputs = model(inputs)
        # outputs = torch.nn.functional.sigmoid(outputs)
        loss = criterion(outputs, labels)
        
        # backward
        loss.backward()
        optimizer.step()

        train_acc.update(metrics.dice_coeff(outputs, labels), outputs.size(0))
        train_loss.update(loss.data.item(), outputs.size(0))

        # tensorboard logging
        if step % logging_step == 0:
            writer.log_training(train_loss.avg, train_acc.avg, step)
            loader.set_description(
                "Training Loss: {:.4f} Acc: {:.4f}".format(
                    train_loss.avg, train_acc.avg
                )
            )
        step += 1

Epoch 0/19
----------


Training Loss: 0.5222 Acc: 0.6799: 100%|████| 1452/1452 [03:43<00:00,  6.50it/s]


Epoch 1/19
----------


Training Loss: 0.2736 Acc: 0.7995: 100%|████| 1452/1452 [03:39<00:00,  6.61it/s]


Epoch 2/19
----------


Training Loss: 0.2183 Acc: 0.8341: 100%|████| 1452/1452 [03:41<00:00,  6.56it/s]


Epoch 3/19
----------


Training Loss: 0.1910 Acc: 0.8512: 100%|████| 1452/1452 [03:38<00:00,  6.66it/s]


Epoch 4/19
----------


Training Loss: 0.1792 Acc: 0.8583: 100%|████| 1452/1452 [03:36<00:00,  6.70it/s]


Epoch 5/19
----------


Training Loss: 0.1621 Acc: 0.8709: 100%|████| 1452/1452 [03:37<00:00,  6.67it/s]


Epoch 6/19
----------


Training Loss: 0.1496 Acc: 0.8787: 100%|████| 1452/1452 [03:38<00:00,  6.64it/s]


Epoch 7/19
----------


Training Loss: 0.1406 Acc: 0.8850: 100%|████| 1452/1452 [03:39<00:00,  6.62it/s]


Epoch 8/19
----------


Training Loss: 0.1345 Acc: 0.8877: 100%|████| 1452/1452 [03:38<00:00,  6.65it/s]


Epoch 9/19
----------


Training Loss: 0.1231 Acc: 0.8956: 100%|████| 1452/1452 [03:38<00:00,  6.64it/s]


Epoch 10/19
----------


Training Loss: 0.1183 Acc: 0.8982: 100%|████| 1452/1452 [03:38<00:00,  6.66it/s]


Epoch 11/19
----------


Training Loss: 0.1153 Acc: 0.9006: 100%|████| 1452/1452 [03:38<00:00,  6.66it/s]


Epoch 12/19
----------


Training Loss: 0.1118 Acc: 0.9024: 100%|████| 1452/1452 [03:37<00:00,  6.69it/s]


Epoch 13/19
----------


Training Loss: 0.1043 Acc: 0.9088: 100%|████| 1452/1452 [03:38<00:00,  6.65it/s]


Epoch 14/19
----------


Training Loss: 0.1041 Acc: 0.9090: 100%|████| 1452/1452 [03:37<00:00,  6.69it/s]


Epoch 15/19
----------


Training Loss: 0.0999 Acc: 0.9134: 100%|████| 1452/1452 [03:40<00:00,  6.60it/s]


Epoch 16/19
----------


Training Loss: 0.0987 Acc: 0.9131: 100%|████| 1452/1452 [03:39<00:00,  6.60it/s]


Epoch 17/19
----------


Training Loss: 0.0929 Acc: 0.9186: 100%|████| 1452/1452 [03:38<00:00,  6.66it/s]


Epoch 18/19
----------


Training Loss: 0.0917 Acc: 0.9193: 100%|████| 1452/1452 [03:35<00:00,  6.73it/s]


Epoch 19/19
----------


Training Loss: 0.0812 Acc: 0.9280: 100%|████| 1452/1452 [03:36<00:00,  6.69it/s]


In [8]:
# Testing

# logging accuracy and loss
valid_acc = metrics.MetricTracker()
valid_loss = metrics.MetricTracker()

# switch to evaluate mode
model.eval()

# Iterate over data.
for idx, data in enumerate(tqdm(val_dataloader, desc="testing")):

    # get the inputs and wrap in Variable
    inputs = data["sat_img"].to(device)
    labels = data["map_img"].to(device)
    
    # forward
    # prob_map = model(inputs) # last activation was a sigmoid
    # outputs = (prob_map > 0.3).float()
    outputs = model(inputs)
    # outputs = torch.nn.functional.sigmoid(outputs)
    # ======= test =======
    # show_tensor_image(outputs[0], labels[0])
    loss = criterion(outputs, labels)

    valid_acc.update(metrics.dice_coeff(outputs, labels), outputs.size(0))
    valid_loss.update(loss.data.item(), outputs.size(0))
    # if idx == 0:
    #     writer.log_images(inputs.to(device), labels.to(device), outputs.to(device), step)
writer.log_validation(valid_loss.avg, valid_acc.avg, step)

print("Testing Loss: {:.4f} Acc: {:.4f}".format(valid_loss.avg, valid_acc.avg))
model.train()
valid_metrics = {"valid_loss": valid_loss.avg, "valid_acc": valid_acc.avg}
save_path = os.path.join(
    checkpoint_dir, "%s_checkpoint_%04d.pt" % (name, step)
)
# store best loss and save a model checkpoint
best_loss = min(valid_metrics["valid_loss"], best_loss)
torch.save(
    {
        "step": step,
        "epoch": epoch,
        "arch": "ResUnet",
        "state_dict": model.state_dict(),
        "best_loss": best_loss,
        "optimizer": optimizer.state_dict(),
    },
    save_path,
)
print("Saved checkpoint to: %s" % save_path)


testing: 100%|████████████████████████████████| 200/200 [00:06<00:00, 31.23it/s]


Testing Loss: 2.9300 Acc: 0.7992
Saved checkpoint to: checkpoints/default/default_checkpoint_29040.pt
