In [1]:
# USEFUL WHEN RUNNING ON CLUSTER
import sys
!pip install torch torchvision torchtext pytorch_lightning tensorboard matplotlib tqdm datetime time 

Collecting torchvision
  Using cached torchvision-0.14.1-cp37-cp37m-manylinux1_x86_64.whl (24.2 MB)
Collecting torchtext
  Using cached torchtext-0.14.1-cp37-cp37m-manylinux1_x86_64.whl (2.0 MB)
Collecting pytorch_lightning
  Using cached pytorch_lightning-1.8.6-py3-none-any.whl (800 kB)
Collecting tensorboard
  Using cached tensorboard-2.11.0-py3-none-any.whl (6.0 MB)
Collecting datetime
  Using cached DateTime-4.9-py2.py3-none-any.whl (52 kB)
[31mERROR: Could not find a version that satisfies the requirement time (from versions: none)[0m[31m
[0m[31mERROR: No matching distribution found for time[0m[31m
[0m

In [1]:
!nvidia-smi

Fri Dec 23 16:16:17 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 455.23.05    Driver Version: 455.23.05    CUDA Version: 11.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla V100-PCIE...  On   | 00000000:86:00.0 Off |                  Off |
| N/A   33C    P0    23W / 250W |      0MiB / 32510MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+---------------------------------------------------------------------------

In [1]:
import torch
from torch.utils.data import DataLoader
from torch.nn import CrossEntropyLoss
from torch.nn import BCELoss
from torch.nn import MSELoss

from torch.optim import Adam
from torch.optim import SGD
from torch.optim.lr_scheduler import CyclicLR
from tqdm import tqdm
import matplotlib.pyplot as plt
import time
import datetime

from src.model import UNet
from src.dataloader import LandCoverData

%load_ext autoreload
%autoreload 2

In [2]:
#path="../"
#path="/scratch/izar/damiani/"
path="/scratch/izar/nkaltenr/"

train_dataset = LandCoverData(path, split="train", ignore_last_number=11, use_augmented=True)
val_dataset = LandCoverData(path, split="val")

In [3]:
BATCH_SIZE = 32

# num_workers 8 default but 2 on colab
train_dl = DataLoader(train_dataset, BATCH_SIZE, True, num_workers=1)
val_dl = DataLoader(val_dataset, BATCH_SIZE, False, num_workers=1)

In [4]:
INPUT_IMAGE_HEIGHT=200
INPUT_IMAGE_WIDTH=200

DEVICE = "cuda"
#DEVICE = "cpu"
INIT_LR = 0.001
INIT_MOMENTUM = 0.9
INIT_WEIGHT_DECAY = 0.01
NUM_EPOCHS = 400

In [5]:
# from https://www.kaggle.com/rishabhiitbhu/unet-with-resnet34-encoder-pytorch

def dice_loss(input, target):
    input = torch.sigmoid(input)
    smooth = 1.0
    iflat = input.view(-1)
    tflat = target.view(-1)
    intersection = (iflat * tflat).sum()
    return ((2.0 * intersection + smooth) / (iflat.sum() + tflat.sum() + smooth))

In [6]:
# from https://www.kaggle.com/rishabhiitbhu/unet-with-resnet34-encoder-pytorch

class FocalLoss(torch.nn.Module):
    def __init__(self, gamma):
        super().__init__()
        self.gamma = gamma

    def forward(self, input, target):
        if not (target.size() == input.size()):
            raise ValueError("Target size ({}) must be the same as input size ({})"
                             .format(target.size(), input.size()))
        max_val = (-input).clamp(min=0)
        loss = input - input * target + max_val + \
            ((-max_val).exp() + (-input - max_val).exp()).log()
        invprobs = F.logsigmoid(-input * (target * 2.0 - 1.0))
        loss = (invprobs * self.gamma).exp() * loss
        return loss.mean()

In [7]:
# from https://www.kaggle.com/rishabhiitbhu/unet-with-resnet34-encoder-pytorch

class MixedLoss(torch.nn.Module):
    def __init__(self, alpha, gamma):
        super().__init__()
        self.alpha = alpha
        self.focal = FocalLoss(gamma)

    def forward(self, input, target):
        loss = self.alpha*self.focal(input, target) - torch.log(dice_loss(input, target))
        return loss.mean()

In [5]:
focal_loss = torch.hub.load(
	'adeelh/pytorch-multi-class-focal-loss',
	model='focal_loss',
	alpha=None,
	gamma=2,
	reduction='mean',
	device=DEVICE,
	dtype=torch.float32,
	force_reload=False
)

Using cache found in /home/nkaltenr/.cache/torch/hub/adeelh_pytorch-multi-class-focal-loss_master


In [9]:
# From https://github.com/Mr-TalhaIlyas/Loss-Functions-Package-Tensorflow-Keras-PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
#PyTorch
#Dice Loss
ALPHA = 0.5
BETA = 0.5
#Try also
#ALPHA = 0.3
#BETA = 0.7

GAMMA = 2

class FocalTverskyLoss(torch.nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(FocalTverskyLoss, self).__init__()

    def forward(self, inputs, targets, smooth=1, alpha=ALPHA, beta=BETA, gamma=GAMMA):
        
        #comment out if your model contains a sigmoid or equivalent activation layer
        inputs = torch.sigmoid(inputs)       
        
        #flatten label and prediction tensors
        inputs = inputs.view(-1)
        targets = targets.view(-1)
        
        #True Positives, False Positives & False Negatives
        #TP = (inputs * targets).sum()    
        #FP = ((1-targets) * inputs).sum()
        #FN = (targets * (1-inputs)).sum()
        TP = torch.sum(targets * inputs, dim=(0, 2, 3))
        FN = torch.sum(targets * (1-inputs), dim=(0, 2, 3))
        FP = torch.sum((1-targets) * inputs, dim=(0, 2, 3))
        
        Tversky = (TP + smooth) / (TP + alpha*FP + beta*FN + smooth)  
        FocalTversky = (1 - Tversky)**gamma
                       
        return FocalTversky

In [5]:
# from https://stackoverflow.com/questions/65125670/implementing-multiclass-dice-loss-function
import torch
import torch.nn as nn
import torch.nn.functional as F

def dice_coef_8cat(y_true, y_pred, smooth=1e-7):
    '''
    Dice coefficient for 8 categories.
    Pass to model as metric during compile statement
    '''
    y_true_f = torch.flatten(F.one_hot(y_true.to(torch.int64), num_classes=8)[...,:])
    y_pred_f = torch.flatten(y_pred[...,:])
    intersect = torch.sum(y_true_f * y_pred_f, axis=-1)
    denom = torch.sum(y_true_f + y_pred_f, axis=-1)
    return torch.mean((2. * intersect / (denom + smooth)))

def dice_coef_8cat_loss(y_true, y_pred):
    '''
    Dice loss to minimize. Pass to model as loss during compile statement
    '''
    return 1 - dice_coef_8cat(y_true, y_pred)

In [6]:
#weight_ce = [1., 3., 3.1, 14., 7., 3.3, 0.9, 1.]
#class_weights_ce = torch.FloatTensor(weight_ce).to(DEVICE)
#ce = CrossEntropyLoss(weight=class_weights_ce)
ce = CrossEntropyLoss()

#INIT_ALPHA = 10.0
#INIT_GAMMA = 2.0
#ml = MixedLoss(INIT_ALPHA, INIT_GAMMA)
#ftl = FocalTverskyLoss()


def UnetLoss(preds, targets):
    #print(f"preds : {preds.shape}")
    #print(f"targets : {targets.shape}")
    #ce_loss = ce(preds, targets)
    #d_loss = dice_coef_8cat_loss(targets, preds)
    
    #mix_loss = ce_loss + d_loss
    #ml_loss = ml
    fl = focal_loss(preds, targets)
    #ftl_loss = ftl(preds, targets)
    acc = (torch.max(preds, 1)[1] == targets).float().mean()
    return fl, acc

In [7]:
# initialize our UNet model
unet = UNet(nbClasses=8).to(DEVICE)
# initialize loss function and optimizer
lossFunc = UnetLoss
opt = SGD(unet.parameters(), lr=INIT_LR, momentum=INIT_MOMENTUM)
#opt = Adam(unet.parameters(), lr= INIT_LR)
#opt = Adam(unet.parameters(), lr=INIT_LR, weight_decay=INIT_WEIGHT_DECAY)
# calculate steps per epoch for training and test set
trainSteps = len(train_dataset) // BATCH_SIZE
testSteps = len(val_dataset) // BATCH_SIZE
# initialize a dictionary to store training history
H = {"train_loss": [], "test_loss": [], "train_acc": [], "test_acc": []}
bestValLoss = float('inf')
bestValAcc = -1

In [8]:
H = torch.load('unet_model_2023-01-03_0:54_fl_2_loss_dict.pth', map_location=torch.device(DEVICE))

{'train_loss': [array(0.8882564, dtype=float32), array(0.6806669, dtype=float32), array(0.60211587, dtype=float32), array(0.5623767, dtype=float32), array(0.5372199, dtype=float32), array(0.5127149, dtype=float32), array(0.50119543, dtype=float32), array(0.4862936, dtype=float32), array(0.4765946, dtype=float32), array(0.46477878, dtype=float32), array(0.45284355, dtype=float32), array(0.4476596, dtype=float32), array(0.44159225, dtype=float32), array(0.43448934, dtype=float32), array(0.42616245, dtype=float32), array(0.41681606, dtype=float32), array(0.41505843, dtype=float32), array(0.4076343, dtype=float32), array(0.40447962, dtype=float32), array(0.3977208, dtype=float32), array(0.39123365, dtype=float32), array(0.38581866, dtype=float32), array(0.38168678, dtype=float32), array(0.37570047, dtype=float32), array(0.3713576, dtype=float32), array(0.3637783, dtype=float32), array(0.35988957, dtype=float32), array(0.35516194, dtype=float32), array(0.35121557, dtype=float32), array(0.34

In [16]:
unet.load_state_dict(torch.load('unet_model_epoch_200_fl_2_loss.pth', map_location=torch.device(DEVICE)))

<All keys matched successfully>

In [17]:
# loop over epochs
NUM_EPOCHS = 100
print("[INFO] training the network...")
startTime = time.time()
#for e in tqdm(range(NUM_EPOCHS)):
for e in tqdm(range(NUM_EPOCHS)):
    # set the model in training mode
    unet.train()
    # initialize the total training and validation loss
    totalTrainLoss = 0
    totalTestLoss = 0
    totalTrainAcc = 0
    totalTestAcc = 0
    # loop over the training set
    for (i, (x, y)) in enumerate(train_dl):
        # send the input to the device
        (x, y) = (x.to(DEVICE), y.to(DEVICE))
        # perform a forward pass and calculate the training loss
        pred = unet(x)
        y = y.to(torch.long)
        y = y.squeeze()
        
        loss, acc = lossFunc(pred, y)
        
        # first, zero out any previously accumulated gradients, then
        # perform backpropagation, and then update model parameters
        opt.zero_grad()
        loss.backward()
        opt.step()
        # add the loss to the total training loss so far
        totalTrainLoss += loss
        totalTrainAcc += acc
    # switch off autograd
    with torch.no_grad():
        # set the model in evaluation mode
        unet.eval()
        # loop over the validation set
        for (x, y) in val_dl:
          # send the input to the device
          (x, y) = (x.to(DEVICE), y.to(DEVICE))
          # make the predictions and calculate the validation loss
          pred = unet(x)
          #pred=pred.to(torch.float32)
          y = y.to(torch.long)
          y = y.squeeze()
          loss, acc = lossFunc(pred, y)
          totalTestLoss += loss
          totalTestAcc += acc
    # calculate the average training and validation loss
    avgTrainLoss = totalTrainLoss / trainSteps
    avgTestLoss = totalTestLoss / testSteps
    avgTrainAcc = totalTrainAcc / trainSteps
    avgTestAcc = totalTestAcc / testSteps
    # update our training history
    H["train_loss"].append(avgTrainLoss.cpu().detach().numpy())
    H["test_loss"].append(avgTestLoss.cpu().detach().numpy())
    H["train_acc"].append(avgTrainAcc.cpu().detach().numpy())
    H["test_acc"].append(avgTestAcc.cpu().detach().numpy())
    # print the model training and validation information
    print("[INFO] EPOCH: {}/{}".format(e + 1, NUM_EPOCHS))
    print("Train loss: {:.6f}, Validation loss: {:.4f}".format(
      avgTrainLoss, avgTestLoss))
    print("Train acc: {:.6f}, Validation acc: {:.4f}".format(
      avgTrainAcc, avgTestAcc))
    # Save the best model (the one that has the lowest loss for validation)
    if (bestValLoss == -1) or (bestValLoss > avgTestLoss):
        bestValLoss = avgTestLoss
        print("best loss => saving")
        torch.save(unet.state_dict(), 'best_model_fl_2_loss.pth')
    if (bestValAcc < avgTestAcc):
        bestValAcc = avgTestAcc
        print("best acc => saving")
        torch.save(unet.state_dict(), 'best_model_acc_fl_2_loss.pth')
    if ((e+1)%50 == 0):
        epoch_name = 200 + e+1
        print("SAVING")
        torch.save(unet.state_dict(), f"unet_model_epoch_{epoch_name}_fl_2_loss.pth")
        
# display the total time needed to perform the training
endTime = time.time()
print("[INFO] total time taken to train the model: {:.2f}s".format(
	endTime - startTime))

[INFO] training the network...


  0%|                                                   | 0/100 [00:26<?, ?it/s]


KeyboardInterrupt: 

In [None]:
date = datetime.datetime.now()
date_ymd = date.date()
date_hm = f"{date.hour}:{date.minute}"

In [None]:
plt.style.use("ggplot")
plt.figure()
plt.plot(H["train_loss"], label="train_loss")
plt.plot(H["test_loss"], label="val_loss")
plt.title("Training Loss on Dataset")
plt.xlabel("Epoch #")
plt.ylabel("Loss")
plt.legend(loc="lower left")
plt.savefig(f"train_val_loss_{date_ymd}_{date_hm}_fl_2_loss.png", bbox_inches='tight')

In [None]:
plt.style.use("ggplot")
plt.figure()
plt.plot(H["train_acc"], label="train_acc")
plt.plot(H["test_acc"], label="val_acc")
plt.title("Training Accuracy on Dataset")
plt.xlabel("Epoch #")
plt.ylabel("Accuracy")
plt.legend(loc="lower left")
plt.savefig(f"train_val_acc_{date_ymd}_{date_hm}_fl_2_loss.png", bbox_inches='tight')

In [None]:
torch.save(unet.state_dict(), f"unet_model_{date_ymd}_{date_hm}_fl_2_loss.pth")
torch.save(unet, f"unet_model_{date_ymd}_{date_hm}_fl_2_loss.pt")
torch.save(H, f"unet_model_{date_ymd}_{date_hm}_fl_2_loss_dict.pth")