In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# !pip install shapely
# !git clone https://github.com/vietawake/ERFModel
# !mv -v ERFModel/* .

In [None]:
from models.erfnet_road import ERFNet
import torch
import numpy as np
from torch.utils.data import DataLoader
import torchvision
from tqdm import tqdm
from torchvision import transforms
from load_dataset import ImageDataset
from train import train_one_epoch, CrossEntropyLoss2d, validate_model
from PIL import Image
import os

In [None]:
!unzip '/content/drive/MyDrive/mapplirary_vista_4labels.zip' -d .

In [None]:
def generate_txtdata(path, file_name, label_name):
  images_name = os.listdir(path)
  for i in tqdm(range(len(images_name))):
    name = images_name[i][:-4]
    image = 'images/'+ name +'.jpg'
    label = label_name + name + '.png'
    sample_data = ', '.join([image, label])
    images_name[i] = sample_data
  return np.savetxt(file_name,images_name, delimiter='\n', fmt="%s")

In [None]:
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
np.random.seed(50)
torch.manual_seed(50)

if torch.cuda.is_available():
    torch.cuda.manual_seed(50)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
net = ERFNet(num_classes=4)

In [None]:
generate_txtdata(path='validation/images', file_name='val_list.txt',label_name='val_labels/')
generate_txtdata(path='training/images', file_name='train_list.txt',label_name='train_labels/')

In [None]:
pretrained_model = torch.load('./pretrained_models/weights_erfnet_road.pth',  map_location=device)
new_mw = {}
for k,w in pretrained_model.items():
    new_mw[k[7:]] = w

In [None]:
net.state_dict().update(new_mw)

In [None]:
num_epochs = 50
max_acc = 0
patience = 10
not_improved_count = 0
batch_size = 4

In [None]:
transform = transforms.Compose([
    torchvision.transforms.Resize((360, 640),interpolation=Image.NEAREST),
    # torchvision.transforms.RandomHorizontalFlip(),
    # torchvision.transforms.ToTensor()
])


train_dataset = ImageDataset(root_dir='training/',
                              txt_files='data/train_list.txt', 
                              img_size=(360, 640), 
                              transform=transform)

val_dataset = ImageDataset(root_dir='validation/',
                            txt_files='data/val_list.txt', 
                            img_size=(360, 640), 
                            transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size,
                              shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=batch_size,
                            shuffle=True, num_workers=4)

In [None]:
criterion = CrossEntropyLoss2d()
optimizer = torch.optim.Adam(net.parameters(),5e-4,(0.9, 0.999), eps=1e-08, weight_decay=1e-4)

for epoch in range(num_epochs):
    train_loss, acc, train_jsc = train_one_epoch(net, criterion, optimizer, train_loader, device)
    val_loss , val_acc, val_jsc = validate_model(net, criterion, val_loader, device)

    print('Epoch: {}'.format(epoch))
    print('Training acc: {:.4f}\tTrain_jsc: {:.4f}\tTraining Loss: {:.4f}'.format(acc,train_jsc,train_loss))
    print('Valid acc: {:.4f}\tValid_jsc: {:.4f}\tValid Loss: {:.4f}'.format(val_acc,val_jsc, val_loss))

    if val_acc > max_acc:
        
        torch.save(net.state_dict(), '/content/drive/MyDrive/checkpoints/ERF_epoch_' + str(epoch) + '_acc_{0:.4f}'.format(val_acc)+'.pt')
        max_acc = val_acc
        not_improved_count = 0
    else:
        not_improved_count+=1
    
    if not_improved_count >=patience:
        break



