In [1]:
import sys
sys.path.append('..')

In [2]:
import os
import numpy as np
import torch
import torch.nn as nn
import albumentations as A
from src.dataset import SegDataset
from src.engine import fit, train_one_epoch, evaluate
from src.callbacks import TrainMetricRecorder
from src.losses import BinaryFocalLoss, FocalLoss, DiceLoss
from unet import UNet

In [3]:
#create datasets and dataloader

train_img_dir = '../data/road_segmentation_ideal/training/input'
train_mask_dir = '../data/road_segmentation_ideal/training/output'
test_img_dir = '../data/road_segmentation_ideal/testing/input'
test_mask_dir = '../data/road_segmentation_ideal/testing/output'
train_ids = sorted([filename.split('.')[0] for filename in os.listdir(train_mask_dir) if filename.endswith('.png')])
test_ids = sorted([filename.split('.')[0] for filename in os.listdir(test_mask_dir) if filename.endswith('.png')])

print(f'Train samples: {len(train_ids)} Test samples: {len(test_ids)}')

Train samples: 804 Test samples: 13


In [4]:
img_size=500
test_size = 1500
train_transforms = A.Compose([ A.RandomCrop(height=img_size, width=img_size, p=1.0, 
                               A.HorizontalFlip(p=0.5),
                               A.VerticalFlip(p=0.5), 
                               A.RandomRotate90(p=0.5),
                               A.Transpose(p=0.5)
                           ])

test_transforms = A.Resize(height=test_size, width=test_size, p=1)

train_dataset = SegDataset(train_img_dir, train_mask_dir, train_ids, train_transforms)
test_dataset = SegDataset(test_img_dir, test_mask_dir, test_ids, test_transforms)

assert len(train_dataset) == len(train_ids)
assert len(test_dataset) == len(test_ids)

In [5]:
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=8, shuffle=True, pin_memory=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False, pin_memory=True)

In [6]:
imgs, masks = next(iter(train_loader))
imgs.shape, masks.shape

(torch.Size([8, 3, 500, 500]), torch.Size([8, 1, 500, 500]))

In [7]:
model = UNet(n_channels=3, n_classes=1)
model.load_state_dict(torch.load('../model_weights/pretrained/unet_carvana_scale1_epoch5.pth'))

<All keys matched successfully>

In [8]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device {device}')

Using device cuda


In [9]:
#move model to right device
model.to(device)
#define optimizer
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.Adam(params, lr=0.001)
#criterion
criterion = DiceLoss()

#metric
recorder = TrainMetricRecorder(['iou', 'accuracy', 'precision', 'recall'])

In [10]:
fit(model, train_loader, test_loader, optimizer, criterion, 10, device, recorder)

100%|██████████| 101/101 [02:02<00:00,  1.22s/it]
100%|██████████| 13/13 [00:03<00:00,  3.57it/s]
  0%|          | 0/101 [00:00<?, ?it/s]


Epoch 1: train_loss: 0.5937 val_loss: 0.5169 train_iou: 0.2637 val_iou: 0.3258 train_accuracy: 0.9173 val_accuracy: 0.9308 train_precision: 0.3571 val_precision: 0.5186 train_recall: 0.5275 val_recall: 0.4938 



100%|██████████| 101/101 [02:02<00:00,  1.21s/it]
100%|██████████| 13/13 [00:03<00:00,  3.61it/s]
  0%|          | 0/101 [00:00<?, ?it/s]


Epoch 2: train_loss: 0.5056 val_loss: 0.5620 train_iou: 0.3332 val_iou: 0.2905 train_accuracy: 0.9438 val_accuracy: 0.9393 train_precision: 0.4672 val_precision: 0.6339 train_recall: 0.5494 val_recall: 0.3699 



100%|██████████| 101/101 [02:02<00:00,  1.21s/it]
100%|██████████| 13/13 [00:03<00:00,  3.65it/s]
  0%|          | 0/101 [00:00<?, ?it/s]


Epoch 3: train_loss: 0.4513 val_loss: 0.4648 train_iou: 0.3826 val_iou: 0.3708 train_accuracy: 0.9516 val_accuracy: 0.9137 train_precision: 0.5325 val_precision: 0.4242 train_recall: 0.5821 val_recall: 0.7498 



100%|██████████| 101/101 [02:01<00:00,  1.21s/it]
100%|██████████| 13/13 [00:03<00:00,  3.61it/s]
  0%|          | 0/101 [00:00<?, ?it/s]


Epoch 4: train_loss: 0.4378 val_loss: 0.3839 train_iou: 0.3955 val_iou: 0.4509 train_accuracy: 0.9537 val_accuracy: 0.9481 train_precision: 0.5521 val_precision: 0.6407 train_recall: 0.5914 val_recall: 0.6076 



100%|██████████| 101/101 [02:02<00:00,  1.21s/it]
100%|██████████| 13/13 [00:03<00:00,  3.60it/s]
  0%|          | 0/101 [00:00<?, ?it/s]


Epoch 5: train_loss: 0.3999 val_loss: 0.3713 train_iou: 0.4328 val_iou: 0.4637 train_accuracy: 0.9590 val_accuracy: 0.9492 train_precision: 0.6016 val_precision: 0.6472 train_recall: 0.6169 val_recall: 0.6291 



100%|██████████| 101/101 [02:02<00:00,  1.21s/it]
100%|██████████| 13/13 [00:03<00:00,  3.63it/s]
  0%|          | 0/101 [00:00<?, ?it/s]


Epoch 6: train_loss: 0.3762 val_loss: 0.3627 train_iou: 0.4573 val_iou: 0.4720 train_accuracy: 0.9606 val_accuracy: 0.9504 train_precision: 0.6275 val_precision: 0.6545 train_recall: 0.6367 val_recall: 0.6280 



100%|██████████| 101/101 [02:02<00:00,  1.21s/it]
100%|██████████| 13/13 [00:03<00:00,  3.60it/s]
  0%|          | 0/101 [00:00<?, ?it/s]


Epoch 7: train_loss: 0.3449 val_loss: 0.3369 train_iou: 0.4910 val_iou: 0.4996 train_accuracy: 0.9638 val_accuracy: 0.9534 train_precision: 0.6573 val_precision: 0.6689 train_recall: 0.6671 val_recall: 0.6638 



100%|██████████| 101/101 [02:02<00:00,  1.21s/it]
100%|██████████| 13/13 [00:03<00:00,  3.56it/s]
  0%|          | 0/101 [00:00<?, ?it/s]


Epoch 8: train_loss: 0.3273 val_loss: 0.4369 train_iou: 0.5094 val_iou: 0.3983 train_accuracy: 0.9662 val_accuracy: 0.9511 train_precision: 0.6829 val_precision: 0.7424 train_recall: 0.6713 val_recall: 0.4691 



100%|██████████| 101/101 [02:02<00:00,  1.21s/it]
100%|██████████| 13/13 [00:03<00:00,  3.61it/s]
  0%|          | 0/101 [00:00<?, ?it/s]


Epoch 9: train_loss: 0.3267 val_loss: 0.3990 train_iou: 0.5106 val_iou: 0.4381 train_accuracy: 0.9666 val_accuracy: 0.9557 train_precision: 0.6817 val_precision: 0.8017 train_recall: 0.6767 val_recall: 0.4989 



100%|██████████| 101/101 [02:01<00:00,  1.21s/it]
100%|██████████| 13/13 [00:03<00:00,  3.63it/s]


Epoch 10: train_loss: 0.3166 val_loss: 0.3558 train_iou: 0.5223 val_iou: 0.4792 train_accuracy: 0.9682 val_accuracy: 0.9550 train_precision: 0.6946 val_precision: 0.7284 train_recall: 0.6824 val_recall: 0.5873 




