In [None]:
from glob import glob
from src.data.segmentation_dataset import SegmentationData
from src.networks.segmentation_nn import SegmentationNN
from torch.utils.data import DataLoader
import torch
from src.util import visualizer
from src.util import accuracy

In [None]:
train_img = glob(r'data\Cityspaces\images\train\*\*.png') 
train_label = glob(r'data\Cityspaces\gtFine\train\*\*_gtFine_labelIds.png')
val_img = glob(r'data\Cityspaces\images\val\*\*.png') 
val_label = glob(r'data\Cityspaces\gtFine/val\*\*_gtFine_labelIds.png')

In [None]:
color_pixel={
            0:0,
            1:0,
            2:0,
            3:0,
            4:0,
            5:0,
            6:0,
            7:1,
            8:8,
            9:1,
            10:1,
            11:2,
            12:2,
            13:2,
            14:2,
            15:2,
            16:2,
            17:3,
            18:3,
            19:3,
            20:3,
            21:4,
            22:4,
            23:5,
            24:6,
            25:6,
            26:7,
            27:7,
            28:7,
            29:7,
            30:7,
            31:7,
            32:7,
            33:7,
            34:7
}

color_map={
            0:(255,255,255),
            1:(128, 64,128),
            2:( 70, 70, 70),
            3:(153,153,153),
            4:(107,142, 35),
            5:( 70,130,180),
            6:(220, 20, 60),
            7:(  0,  0,142),
            8:(244, 35,232),
}

hparams = {'batch_size':3,
           'num_workers':8,
           'lr':0.0001,
           'device':'cuda',
           'crop_size':500,
           'num_classes':9,
           'color_coder':color_pixel,
           'crop_num':4,

           }

traindata = SegmentationData(hparams,train_img,train_label,False)
valdata = SegmentationData(hparams,val_img,val_label,False)
testdata = SegmentationData(hparams,val_img,val_label,True)

train_dataloader = DataLoader(traindata, batch_size=hparams['batch_size'], shuffle=False, num_workers=hparams['num_workers'])
val_dataloader = DataLoader(valdata, batch_size=hparams['batch_size'], shuffle=False, num_workers=hparams['num_workers'],drop_last=True)

model = SegmentationNN(hparams=hparams).to("cuda")
model=torch.load(r'models\model_segmentare')

In [None]:
import pytorch_lightning as pl
from pytorch_lightning.callbacks.early_stopping import EarlyStopping

early_stop_callback = EarlyStopping(monitor="val_loss", patience=3  , verbose=True ,min_delta=0.001)
trainer = pl.Trainer(
    max_epochs=100,
    min_epochs=5,
    accelerator="auto",
    callbacks = [early_stop_callback]
)

trainer.fit(model, train_dataloader, val_dataloader)
torch.save(model, r'models\model_segmentare')

In [None]:
%load_ext tensorboard                                            
%tensorboard --logdir lightning_logs --port 6006

In [None]:
model=torch.load(r'models\model_segmentare')
visualizer(model,10, testdata,color_map)

In [None]:
classes={0:'void',
         1:'road',
         2:'building',
         3:'traffic sign',
         4:'vegetation',
         5:'sky',
         6:'person',
         7:'vehicle',
         8:'sidewalk',     
        }
iou_mean,iou=accuracy(model,testdata,hparams['num_classes'])
print('Iou mean : ',round(iou_mean,2),'%')
for i in range(hparams['num_classes']):
  print(classes[i],':',round(iou[i],2),'%')