In [None]:
# USEFUL WHEN RUNNING ON CLUSTER
import sys
!pip install torch torchvision torchsummary torchtext pytorch_lightning tensorboard matplotlib tqdm datetime time 

In [None]:
!nvidia-smi

In [None]:
import torch
from torchvision.models.segmentation.deeplabv3 import DeepLabHead
from torchvision.models.segmentation import deeplabv3_resnet101
from torch.utils.data import DataLoader
from torch.nn import CrossEntropyLoss
from torch.nn import BCELoss
from torch.nn import MSELoss
from torch.nn import functional as F

from torch.optim import Adam
from torch.optim import SGD
from torch.optim import RMSprop
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm import tqdm
import matplotlib.pyplot as plt
import time
import datetime

from src.model import UNet
from src.dataloader import LandCoverData, transformsNorm, transformsNormAugmentedColoJitter
import src.loss as lossPY

%load_ext autoreload
%autoreload 2

# 1 DataLoader

In [None]:
#path="../"
path="/scratch/izar/damiani/"
#path="/scratch/izar/nkaltenr/"

# use_augmented should be True,
# if you set use_restricted to True.
# Otherwise wrong normalization.

transformsTrain=transformsNorm(use_augmented=True,
                               use_restricted=False,
                               flag_plot=False)

train_dataset = LandCoverData(path,
                              transforms=transformsTrain,
                              split="train",
                              ignore_last_number=11,
                              use_augmented=True,
                              restrict_classes=False)

val_dataset = LandCoverData(path, 
                            transforms=transformsTrain,
                            split="val")

In [None]:
BATCH_SIZE = 32

# num_workers 8 default but 2 on colab
train_dl = DataLoader(train_dataset, BATCH_SIZE, True, drop_last=True)
val_dl = DataLoader(val_dataset, BATCH_SIZE, False, drop_last=True)

In [None]:
INPUT_IMAGE_HEIGHT=200
INPUT_IMAGE_WIDTH=200

DEVICE = "cuda"
#DEVICE = "cpu"

In [None]:
focal_loss = torch.hub.load(
	'adeelh/pytorch-multi-class-focal-loss',
	model='focal_loss',
	alpha=None,
	gamma=2,
	reduction='mean',
	device=DEVICE,
	dtype=torch.float32,
	force_reload=False
)

# 2 Loss function

In [None]:
# Cross Entropy Loss
ce = CrossEntropyLoss()

# Weighted Cross Entropy Loss
wieght_freq = [0.38135882, 0.97431312, 1.02707798, 3.17324418, 1.74480126, 1.11790711, 0.51357981, 0.52475398]
class_weights_ce = torch.FloatTensor(wieght_freq).to(DEVICE)
cew = CrossEntropyLoss(weight=class_weights_ce)

# Focal Loss
focal_loss = torch.hub.load(
	'adeelh/pytorch-multi-class-focal-loss',
	model='focal_loss',
	alpha=None,
	gamma=2,
	reduction='mean',
	device=DEVICE,
	dtype=torch.float32,
	force_reload=False
)

# Intersection Over Union Loss
iou=lossPY.mIoULoss(n_classes=8).to(DEVICE)


def UnetLoss(preds, targets):
    #print(f"preds : {preds.shape}")
    #print(f"targets : {targets.shape}")
    ce_loss = ce(preds, targets)
    #cew_loss = cew(preds, targets)
    #iou_loss = iou(preds, targets)
    #loss = focal_loss(preds, targets)

    acc = (torch.max(preds, 1)[1] == targets).float().mean()
    return ce_loss, acc

# 3 Model Architecture

In [None]:
# initialize our UNet model
unet = UNet(nbClasses=8).to(DEVICE)

# Use Pretrained DeepLabV3 model:
# You also need to modify train loop (follows the instructions in the cell)
"""
unet = deeplabv3_resnet101(pretrained=True, progress=True)

flag_train_only_last_layer=False

if flag_train_only_last_layer:
    for param in unet.parameters():
        param.requires_grad=False
unet.classifier = DeepLabHead(2048, 8)
unet=unet.to(DEVICE)
"""

# 4 HyperParameters

In [None]:
# initialize loss function
lossFunc = UnetLoss

INIT_LR = 1e-3
INIT_MOMENTUM = 0.9

# Choose Optimizer
opt = SGD(unet.parameters(), lr=INIT_LR, momentum=INIT_MOMENTUM)
#opt = Adam(unet.parameters(), lr=INIT_LR, weight_decay=1e-6)
#opt = RMSprop(unet.parameters(), lr=INIT_LR, momentum=INIT_MOMENTUM, weight_decay=1e-6)

# Scheduler
flag_scheduler=True
scheduler = ReduceLROnPlateau(opt, 'max', patience=5)

# Calculate steps per epoch for training and validation set
trainSteps = len(train_dataset) // BATCH_SIZE
valSteps = len(val_dataset) // BATCH_SIZE

# Dictionary to store training history
H = {"train_loss": [], "val_loss": [], "train_acc": [], "val_acc": []}
bestValLoss = float('inf')
bestValAcc = -1

In [None]:
# Load prev H to run in 10 hours batch
#H = torch.load('enter_filename_H.pth', map_location=torch.device(DEVICE))
#unet.load_state_dict(torch.load('enter_filename_model.pth', map_location=torch.device(DEVICE)))

In [None]:
name_for_save = "..."

# 5 Training Loop

In [None]:
# loop over epochs
NUM_EPOCHS = 100
print("[INFO] training the network...")
startTime = time.time()
for e in tqdm(range(NUM_EPOCHS)):
    # set the model in training mode
    unet.train()
    # initialize the total training and validation loss
    totalTrainLoss = 0
    totalValLoss = 0
    totalTrainAcc = 0
    totalValAcc = 0
    # loop over the training set
    for (i, (x, y)) in enumerate(train_dl):
        # send the input to the device
        (x, y) = (x.to(DEVICE), y.to(DEVICE))
        # perform a forward pass and calculate the training loss
        pred = unet(x)
        y = y.to(torch.long)
        y = y.squeeze()
        
        # If you are using Pretrained DeepLabV3 model:
        #loss, acc = lossFunc(pred['out'], y)
        # Otherwise
        loss, acc = lossFunc(pred, y)
        
        
        opt.zero_grad()
        loss.backward()
        opt.step()
        # add the loss to the total training loss so far
        totalTrainLoss += loss
        totalTrainAcc += acc
    # switch off autograd
    with torch.no_grad():
        unet.eval()
        # loop over the validation set
        for (x, y) in val_dl:
            # send the input to the device
            (x, y) = (x.to(DEVICE), y.to(DEVICE))
            # make the predictions and calculate the validation loss
            pred = unet(x)
            y = y.to(torch.long)
            y = y.squeeze()
            
            # If you are using Pretrained DeepLabV3 model:
            #loss, acc = lossFunc(pred['out'], y)
            # Otherwise
            loss, acc = lossFunc(pred, y)
            totalValLoss += loss
            totalValAcc += acc
            
    # calculate the average training and validation loss
    avgTrainLoss = totalTrainLoss / trainSteps
    avgValLoss = totalValLoss / valSteps
    avgTrainAcc = totalTrainAcc / trainSteps
    avgValAcc = totalValAcc / valSteps
    
    if flag_scheduler:
        scheduler.step(avgValAcc)
    
    print(f" learning_rate={opt.param_groups[0]['lr']}")
    
    # update our training history
    H["train_loss"].append(avgTrainLoss.cpu().detach().numpy())
    H["val_loss"].append(avgValLoss.cpu().detach().numpy())
    H["train_acc"].append(avgTrainAcc.cpu().detach().numpy())
    H["val_acc"].append(avgValAcc.cpu().detach().numpy())
    # print the model training and validation information
    print("[INFO] EPOCH: {}/{}".format(e + 1, NUM_EPOCHS))
    print("       train loss: {:.4f}, val loss: {:.4f}".format(
      avgTrainLoss, avgValLoss))
    print("       train acc: {:.4f}%, val acc: {:.4f}%".format(
      avgTrainAcc, avgValAcc))
    # Save the best model (the one that has the lowest loss for validation)
    if (bestValLoss == -1) or (bestValLoss > avgValLoss):
        bestValLoss = avgValLoss
        print("best loss => saving")
        torch.save(unet.state_dict(), f'best_model_{name_for_save}_loss.pth')
    if (bestValAcc < avgValAcc):
        bestValAcc = avgValAcc
        print("best acc => saving")
        torch.save(unet.state_dict(), f'best_model_acc_{name_for_save}_loss.pth')
    if ((e+1)%50 == 0):
        epoch_name = e+1
        print("SAVING")
        torch.save(unet.state_dict(), f"unet_model_epoch_{epoch_name}_{name_for_save}_loss.pth")
        torch.save(H, f"unet_model_epoch_{epoch_name}_{name_for_save}_H.pth")
        
# display the total time needed to perform the training
endTime = time.time()
print("[INFO] total time taken to train the model: {:.2f}s".format(endTime - startTime))

# 6 Save and Results on train/val set

In [None]:
date = datetime.datetime.now()
date_ymd = date.date()
date_hm = f"{date.hour}:{date.minute}"

In [None]:
plt.style.use("ggplot")
plt.figure()
plt.plot(H["train_loss"], label="train_loss")
plt.plot(H["val_loss"], label="val_loss")
plt.title("Training Loss on Dataset")
plt.xlabel("Epoch #")
plt.ylabel("Loss")
plt.legend(loc="lower left")
plt.savefig(f"train_val_loss_{date_ymd}_{date_hm}_{name_for_save}_loss.png", bbox_inches='tight')

In [None]:
plt.style.use("ggplot")
plt.figure()
plt.plot(H["train_acc"], label="train_acc")
plt.plot(H["val_acc"], label="val_acc")
plt.title("Training Accuracy on Dataset")
plt.xlabel("Epoch #")
plt.ylabel("Accuracy")
plt.legend(loc="lower left")
plt.savefig(f"train_val_acc_{date_ymd}_{date_hm}_{name_for_save}_loss.png", bbox_inches='tight')

In [None]:
torch.save(unet.state_dict(), f"unet_model_{date_ymd}_{date_hm}_{name_for_save}_loss.pth")
torch.save(unet, f"unet_model_{date_ymd}_{date_hm}_{name_for_save}_loss.pt")
torch.save(H, f"unet_model_{date_ymd}_{date_hm}_{name_for_save}_loss_dict.pth")