In [None]:
from train import train_per_epoch, get_data_loader, validate_model
from torchvision import transforms
import torchvision
from PIL import Image
import segmentation_models_pytorch as smp
from utils.loss import CrossEntropyLoss2d
import cv2
import numpy as np
import torch
import utils.augment as T

In [None]:
train_transform = T.Compose([
        T.RandomResize(scale_range=(0.25, 2.0)),
        T.RandomCrop([512, 1024], pad_if_needed=True, lbl_fill=255), 
        T.RandomHorizontalFlip()
])
val_transform = T.Compose([
        T.Resize((512,1024), interpolation=Image.NEAREST)
])

In [None]:
net = smp.Unet(encoder_name="mobilenet_v2",classes= 19, encoder_weights='imagenet')

In [None]:
torch.backends.cudnn.benchmark = True
np.random.seed(50)
torch.manual_seed(50)

if torch.cuda.is_available():
    torch.cuda.manual_seed(50)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
num_epochs = 50
max_acc = 0
patience = 10
not_improved_count = 0
batch_size = 4

In [None]:

train_loader = get_data_loader(datapth='data/cityscapes',annpath='data/cityscapes/train.txt',trans_func=train_transform,batch_size=batch_size,mode='train')
val_loader = get_data_loader(datapth='data/cityscapes',annpath='data/cityscapes/val.txt',trans_func=val_transform,batch_size=batch_size,mode='val')
# val_loader = DataLoader(val_dataset, batch_size=batch_size,
#                             shuffle=True, num_workers=6)

In [None]:
criterion = torch.nn.CrossEntropyLoss(ignore_index=255)
optimizer = torch.optim.Adam(net.parameters(),5e-4,(0.9, 0.999), eps=1e-08, weight_decay=1e-4)

for epoch in range(num_epochs):
     train_per_epoch(net, criterion, optimizer, train_loader, device)
    _,_,val_iou,_ = validate_model(net, criterion, val_loader, device)

    print('Epoch: {}'.format(epoch))
    print('Train_iou: {:.4f}'.format(train_iou))
    print('Valid_iou: {:.4f}'.format(val_iou))

    if val_acc > max_acc:
        torch.save(net.state_dict(), '/content/drive/MyDrive/checkpoints/UnetMobilenetv2_epoch_' + str(epoch) + '_acc_{0:.4f}'.format(val_acc)+'.pt')
        max_acc = val_acc
        not_improved_count = 0
    else:
        not_improved_count+=1
    
    if not_improved_count >=patience:
        break
