In [None]:
# USEFUL WHEN RUNNING ON CLUSTER
#import sys
#!pip install torch torchvision torchtext pytorch_lightning tensorboard matplotlib tqdm datetime time 

In [None]:
import torch
from torch.utils.data import DataLoader
from torch.nn import CrossEntropyLoss
from torch.nn import BCELoss
from torch.nn import MSELoss

from torch.optim import Adam
from torch.optim import SGD
from tqdm import tqdm
import matplotlib.pyplot as plt
import time
import datetime

from src.model import UNet
from src.dataloader import LandCoverData

%load_ext autoreload
%autoreload 2

In [None]:
path="../"

train_dataset = LandCoverData(path, transforms=None, split="train")
val_dataset = LandCoverData(path, transforms=None, split="val")

In [None]:
BATCH_SIZE = 8

# num_workers 8 default but 2 on colab
train_dl = DataLoader(train_dataset, BATCH_SIZE, True, num_workers=2)
val_dl = DataLoader(val_dataset, BATCH_SIZE, False, num_workers=2)

In [None]:
INPUT_IMAGE_HEIGHT=200
INPUT_IMAGE_WIDTH=200

#DEVICE = "cuda"
DEVICE = "cpu"
INIT_LR = 0.001
INIT_MOMENTUM = 0.9
NUM_EPOCHS = 10

In [None]:
# initialize our UNet model
unet = UNet(nbClasses=8).to(DEVICE)
# initialize loss function and optimizer
lossFunc = CrossEntropyLoss()
opt = SGD(unet.parameters(), lr=INIT_LR, momentum=INIT_MOMENTUM)
# calculate steps per epoch for training and test set
trainSteps = len(train_dataset) // BATCH_SIZE
testSteps = len(val_dataset) // BATCH_SIZE
# initialize a dictionary to store training history
H = {"train_loss": [], "test_loss": []}

In [None]:
# loop over epochs
print("[INFO] training the network...")
startTime = time.time()
#for e in tqdm(range(NUM_EPOCHS)):
for e in tqdm(range(NUM_EPOCHS)):
    # set the model in training mode
    unet.train()
    # initialize the total training and validation loss
    totalTrainLoss = 0
    totalTestLoss = 0
    # loop over the training set
    for (i, (x, y)) in enumerate(train_dl):
        # send the input to the device
        (x, y) = (x.to(DEVICE), y.to(DEVICE))
        # perform a forward pass and calculate the training loss
        pred = unet(x)

        y = y.to(torch.long)
        y = y.squeeze()

        loss = lossFunc(pred, y)
        
        # first, zero out any previously accumulated gradients, then
        # perform backpropagation, and then update model parameters
        opt.zero_grad()
        loss.backward()
        opt.step()
        # add the loss to the total training loss so far
        totalTrainLoss += loss
    # switch off autograd
    with torch.no_grad():
        # set the model in evaluation mode
        unet.eval()
        # loop over the validation set
        for (x, y) in val_dl:
            # send the input to the device
            (x, y) = (x.to(DEVICE), y.to(DEVICE))
            # make the predictions and calculate the validation loss
            pred = unet(x)
            #pred=pred.to(torch.float32)
            y = y.to(torch.long)
            y = y.squeeze()
            totalTestLoss += lossFunc(pred, y)

    # calculate the average training and validation loss
    avgTrainLoss = totalTrainLoss / trainSteps
    avgTestLoss = totalTestLoss / testSteps
    # update our training history
    H["train_loss"].append(avgTrainLoss.cpu().detach().numpy())
    H["test_loss"].append(avgTestLoss.cpu().detach().numpy())
    # print the model training and validation information
    print("[INFO] EPOCH: {}/{}".format(e + 1, NUM_EPOCHS))
    print("Train loss: {:.6f}, Test loss: {:.4f}".format(
      avgTrainLoss, avgTestLoss))
# display the total time needed to perform the training
endTime = time.time()
print("[INFO] total time taken to train the model: {:.2f}s".format(endTime - startTime))

In [None]:
date = datetime.datetime.now()
date_ymd = date.date()
date_hm = f"{date.hour}:{date.minute}"

In [None]:
plt.style.use("ggplot")
plt.figure()
plt.plot(H["train_loss"], label="train_loss")
plt.plot(H["test_loss"], label="test_loss")
plt.title("Training Loss on Dataset")
plt.xlabel("Epoch #")
plt.ylabel("Loss")
plt.legend(loc="lower left")
plt.savefig(f"train_val_loss_{date_ymd}_{date_hm}.png, bbox_inches='tight')

In [None]:
#torch.save(unet.state_dict(), 'model.pth')
torch.save(unet, f"unet_model_{date_ymd}_{date_hm}.pt")