In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# !unzip '/content/drive/MyDrive/mapplirary_vista_4labels.zip' -d .

In [None]:
# !git clone https://github.com/vietawake/RoadSeg
# !mv -v RoadSeg/* .

In [6]:
from models.vietnet import VietNet, CrossEntropyLoss2d
import torch
import numpy as np
from torch.utils.data import DataLoader
from torchvision.models import mobilenet_v2
from tqdm import tqdm
from torchvision import transforms
from load_dataset import ImageDataset
from train import train_one_epoch, validate_model
from PIL import Image
import os

In [7]:
def generate_txtdata(path, file_name, label_name):
  images_name = os.listdir(path)
  for i in tqdm(range(len(images_name))):
    name = images_name[i][:-4]
    image = 'images/'+ name +'.jpg'
    label = label_name + name + '.png'
    sample_data = ', '.join([image, label])
    images_name[i] = sample_data
  return np.savetxt(file_name,images_name, delimiter='\n', fmt="%s")

In [None]:
generate_txtdata(path='validation/images', file_name='val_list.txt',label_name='val_labels/')
generate_txtdata(path='training/images', file_name='train_list.txt',label_name='train_labels/')

In [2]:
net = VietNet(num_classes= 4)

In [26]:
torch.backends.cudnn.benchmark = True
np.random.seed(50)
torch.manual_seed(50)

if torch.cuda.is_available():
    torch.cuda.manual_seed(50)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [27]:
num_epochs = 50
max_acc = 0
patience = 10
not_improved_count = 0
batch_size = 4

In [None]:
transform = transforms.Compose([
    torchvision.transforms.Resize((384, 640),interpolation=Image.NEAREST),
])


train_dataset = ImageDataset(root_dir='training/',
                              txt_files='data/train_list.txt', 
                              img_size=(384, 640), 
                              transform=transform)

val_dataset = ImageDataset(root_dir='validation/',
                            txt_files='data/val_list.txt', 
                            img_size=(384, 640), 
                            transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size,
                              shuffle=True, num_workers=6)
val_loader = DataLoader(val_dataset, batch_size=batch_size,
                            shuffle=True, num_workers=6)

In [None]:
criterion = CrossEntropyLoss2d()
optimizer = torch.optim.Adam(net.parameters(),5e-4,(0.9, 0.999), eps=1e-08, weight_decay=1e-4)

for epoch in range(num_epochs):
    train_loss, acc, train_jsc = train_one_epoch(net, criterion, optimizer, train_loader, device)
    val_loss , val_acc, val_jsc = validate_model(net, criterion, val_loader, device)

    print('Epoch: {}'.format(epoch))
    print('Training acc: {:.4f}\tTrain_jsc: {:.4f}\tTraining Loss: {:.4f}'.format(acc,train_jsc,train_loss))
    print('Valid acc: {:.4f}\tValid_jsc: {:.4f}\tValid Loss: {:.4f}'.format(val_acc,val_jsc, val_loss))

    if val_acc > max_acc:
        
        torch.save(net.state_dict(), '/content/drive/MyDrive/checkpoints/RoadSeg_epoch_' + str(epoch) + '_acc_{0:.4f}'.format(val_acc)+'.pt')
        max_acc = val_acc
        not_improved_count = 0
    else:
        not_improved_count+=1
    
    if not_improved_count >=patience:
        break





In [3]:
from torchsummary import summary

In [4]:
summary(net, (3, 384, 640))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 32, 192, 320]             864
       BatchNorm2d-2         [-1, 32, 192, 320]              64
             ReLU6-3         [-1, 32, 192, 320]               0
            Conv2d-4         [-1, 32, 192, 320]             288
       BatchNorm2d-5         [-1, 32, 192, 320]              64
             ReLU6-6         [-1, 32, 192, 320]               0
            Conv2d-7         [-1, 16, 192, 320]             512
       BatchNorm2d-8         [-1, 16, 192, 320]              32
  InvertedResidual-9         [-1, 16, 192, 320]               0
           Conv2d-10         [-1, 96, 192, 320]           1,536
      BatchNorm2d-11         [-1, 96, 192, 320]             192
            ReLU6-12         [-1, 96, 192, 320]               0
           Conv2d-13          [-1, 96, 96, 160]             864
      BatchNorm2d-14          [-1, 96, 