In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# !unzip '/content/drive/MyDrive/mapplirary_vista_4labels_v2.zip' -d .

In [None]:
# !git clone https://github.com/vietawake/VietNet
# !mv -v VietNet/* .

In [None]:
from models.vietnet import VietNet, CrossEntropyLoss2d
import torch
import numpy as np
from torch.utils.data import DataLoader
from torchvision.models import mobilenet_v2
from tqdm import tqdm
from utils.criterion import CriterionOhemDSN
from torchvision import transforms
from load_dataset import ImageDataset
from train import train_one_epoch, validate_model
from PIL import Image
import torchvision
import os

In [None]:
net = VietNet(num_classes= 5)

In [None]:
torch.backends.cudnn.benchmark = True
np.random.seed(50)
torch.manual_seed(50)

if torch.cuda.is_available():
    torch.cuda.manual_seed(50)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
num_epochs = 50
max_acc = 0
patience = 10
not_improved_count = 0
batch_size = 4

In [None]:
transform = transforms.Compose([
    torchvision.transforms.Resize((384, 640),interpolation=Image.NEAREST),
])


train_dataset = ImageDataset(txt_files='data/train_list.txt', 
                              img_size=(384, 640), 
                              transform=transform)

val_dataset = ImageDataset(txt_files='data/val_list.txt', 
                            img_size=(384, 640), 
                            transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size,
                              shuffle=True, num_workers=6)
val_loader = DataLoader(val_dataset, batch_size=batch_size,
                            shuffle=True, num_workers=6)

In [None]:
criterion = CrossEntropyLoss2d()
optimizer = torch.optim.Adam(net.parameters(),5e-4,(0.9, 0.999), eps=1e-08, weight_decay=1e-4)

for epoch in range(num_epochs):
    train_loss, train_acc, train_iou, train_dice = train_one_epoch(net, criterion, optimizer, train_loader, device)
    val_loss , val_acc, val_iou, val_dice = validate_model(net, criterion, val_loader, device)

    print('Epoch: {}'.format(epoch))
    print('Train_acc: {:.4f}\tTrain_iou: {:.4f}\tTrain_Dice: {:.4f}\tTrain_Loss: {:.4f}'.format(train_acc,train_iou,train_dice,train_loss))
    print('Valid_acc: {:.4f}\tValid_iou: {:.4f}\tTrain_Dice: {:.4f}\tValid_Loss: {:.4f}'.format(val_acc,val_iou, val_dice, val_loss))

    if val_acc > max_acc:
        torch.save(net.state_dict(), '/content/drive/MyDrive/checkpoints/RoadSeg_epoch_' + str(epoch) + '_acc_{0:.4f}'.format(val_acc)+'.pt')
        max_acc = val_acc
        not_improved_count = 0
    else:
        not_improved_count+=1
    
    if not_improved_count >=patience:
        break
