In [None]:
from train import train_per_epoch, get_data_loader, validate_model
from torchvision import transforms
import torchvision
from PIL import Image
import segmentation_models_pytorch as smp
from utils.loss import CrossEntropyLoss2d
import cv2
import numpy as np
import torch

In [None]:
transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize((512, 1024),interpolation=Image.NEAREST),
])

In [None]:
net = smp.Unet(encoder_name="mobilenet_v2",classes= 19, encoder_weights='imagenet')

In [None]:
torch.backends.cudnn.benchmark = True
np.random.seed(50)
torch.manual_seed(50)

if torch.cuda.is_available():
    torch.cuda.manual_seed(50)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
num_epochs = 50
max_acc = 0
patience = 10
not_improved_count = 0
batch_size = 4

In [None]:

train_loader = get_data_loader(datapth='datasets/cityscapes',annpath='datasets/cityscapes/train.txt',ims_per_gpu=batch_size,trans_func=transform,mode='train')
val_loader = get_data_loader(datapth='datasets/cityscapes',annpath='datasets/cityscapes/val.txt',ims_per_gpu=batch_size,trans_func=transform,mode='val')
# val_loader = DataLoader(val_dataset, batch_size=batch_size,
#                             shuffle=True, num_workers=6)

In [None]:
criterion = CrossEntropyLoss2d()
optimizer = torch.optim.Adam(net.parameters(),5e-4,(0.9, 0.999), eps=1e-08, weight_decay=1e-4)

for epoch in range(num_epochs):
    train_iou = train_per_epoch(net, criterion, optimizer, train_loader, device)
    val_iou= validate_model(net, criterion, val_loader, device)

    print('Epoch: {}'.format(epoch))
    print('Train_iou: {:.4f}'.format(train_iou))
    print('Valid_iou: {:.4f}'.format(val_iou))

    if val_acc > max_acc:
        torch.save(net.state_dict(), '/content/drive/MyDrive/checkpoints/UnetMobilenetv2_epoch_' + str(epoch) + '_acc_{0:.4f}'.format(val_acc)+'.pt')
        max_acc = val_acc
        not_improved_count = 0
    else:
        not_improved_count+=1
    
    if not_improved_count >=patience:
        break


In [None]:
# def get_filepaths(directory):
#     """
#     This function will generate the file names in a directory 
#     tree by walking the tree either top-down or bottom-up. For each 
#     directory in the tree rooted at directory top (including top itself), 
#     it yields a 3-tuple (dirpath, dirnames, filenames).
#     """
#     file_paths = []  # List which will store all of the full filepaths.

#     # Walk the tree.
#     for root, directories, files in os.walk(directory):
#         for filename in files:
#             # Join the two strings in order to form the full filepath.
#             filepath = os.path.join(root, filename).replace('\\','/')
#             file_paths.append(filepath)  # Add it to the list.

#     return file_paths  # Self-explanatory.

# # Run the above function and store its results in a variable.   
# full_file_paths = get_filepaths("datasets/cityscapes/gtFine")

In [None]:
# full_file_paths0 = get_filepaths("datasets/cityscapes/leftImg8bit")
# full_file_paths1 = get_filepaths("datasets/cityscapes/gtFine")