In [None]:
!pip3 install segmentation_models_pytorch

In [1]:
import os
import numpy as np
import cv2
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from torch.utils.data import Dataset as BaseDataset
import torch.nn as nn
import albumentations as albu
import torch
import segmentation_models_pytorch as smp

In [2]:
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
DEVICE

'cuda'

In [3]:
DATA_DIR = './'
x_train_dir = '/kaggle/input/all-dataset-files/all_dataset_files/all_dataset_files/all_dataset_imgs'
y_train_dir = '/kaggle/input/all-dataset-files/all_dataset_files/all_dataset_files/all_dataset_masks'

x_valid_dir = '/kaggle/input/all-dataset-files/sample_valid_imgs'
y_valid_dir = '/kaggle/input/all-dataset-files/sample_valid_masks'

In [4]:
len(os.listdir(x_train_dir)), len(os.listdir(y_train_dir)), len(os.listdir(x_valid_dir)), len(os.listdir(y_valid_dir))

(1633, 1633, 3, 3)

In [5]:
# helper function for data visualization
def visualize(**images):
    """PLot images in one row."""
    n = len(images)
    plt.figure(figsize=(16, 5))
    for i, (name, image) in enumerate(images.items()):
        plt.subplot(1, n, i + 1)
        plt.xticks([])
        plt.yticks([])
        plt.title(' '.join(name.split('_')).title())
        plt.imshow(image)
    plt.show()

In [6]:
class HubMapDataset(BaseDataset):
    """Read images, apply augmentation and preprocessing transformations.
    
    Args:
        images_dir (str): path to images folder
        masks_dir (str): path to segmentation masks folder
        class_values (list): values of classes to extract from segmentation mask
        augmentation (albumentations.Compose): data transfromation pipeline 
            (e.g. flip, scale, etc.)
        preprocessing (albumentations.Compose): data preprocessing 
            (e.g. noralization, shape manipulation, etc.)
    
    """
    
    CLASSES = ['unlabelled', 'blood_vessel']
    
    def __init__(
            self, 
            images_dir, 
            masks_dir, 
            classes=None, 
            augmentation=None, 
            preprocessing=None,
    ):
        self.ids = os.listdir(images_dir)
        self.images_fps = [os.path.join(images_dir, image_id) for image_id in self.ids]
        self.masks_fps = [os.path.join(masks_dir, image_id) for image_id in self.ids]
        
        # convert str names to class values on masks
        self.class_values = [self.CLASSES.index(cls.lower()) for cls in classes]
        
        self.augmentation = augmentation
        self.preprocessing = preprocessing
    
    def __getitem__(self, i):
        
        # read data
        image = cv2.imread(self.images_fps[i])
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        mask = cv2.imread(self.masks_fps[i], 0)
        
        # extract certain classes from mask (e.g. cars)
        masks = [(mask == v) for v in self.class_values]
        mask = np.stack(masks, axis=-1).astype('float')
        
        # apply augmentations
        if self.augmentation:
            sample = self.augmentation(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']
        
        # apply preprocessing
        if self.preprocessing:
            sample = self.preprocessing(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']
            
        return self.masks_fps[i], image, mask
        
    def __len__(self):
        return len(self.ids)

In [7]:
def get_training_augmentation():
  train_transform = [
    albu.ShiftScaleRotate(shift_limit=0, scale_limit=0, rotate_limit=90),
    albu.ShiftScaleRotate(shift_limit=0.2, scale_limit=0, rotate_limit=0),
    albu.ShiftScaleRotate(shift_limit=0, scale_limit=0.2, rotate_limit=0),
    albu.Flip(),
    albu.RandomBrightnessContrast(),
    albu.RandomResizedCrop(height=512, width=512, scale=(0.8, 1.0), p=1)
  ]
  return albu.Compose(train_transform)

def to_tensor(x, **kwargs):
    return x.transpose(2, 0, 1).astype('float32')

def get_preprocessing(preprocessing_fn):
    """Construct preprocessing transform
    
    Args:
        preprocessing_fn (callbale): data normalization function 
            (can be specific for each pretrained neural network)
    Return:
        transform: albumentations.Compose
    
    """
    
    _transform = [
        albu.Lambda(image=preprocessing_fn),
        albu.Lambda(image=to_tensor, mask=to_tensor),
    ]
    return albu.Compose(_transform)

In [8]:
CLASSES = ['unlabelled', 'blood_vessel']
ENCODER = 'efficientnet-b7'
ENCODER_WEIGHTS = 'imagenet'

In [9]:
model = smp.Unet(
    encoder_name=ENCODER,        
    encoder_weights=ENCODER_WEIGHTS,     
    in_channels=3,                  
    classes=len(CLASSES)
)
model = model.to(DEVICE)
model = nn.DataParallel(model)

In [10]:
preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)

In [11]:
train_dataset = HubMapDataset(
    x_train_dir, 
    y_train_dir, 
    augmentation=get_training_augmentation(),
    preprocessing=get_preprocessing(preprocessing_fn),
    classes=CLASSES,
)

valid_dataset = HubMapDataset(
    x_valid_dir, 
    y_valid_dir, 
    preprocessing=get_preprocessing(preprocessing_fn),
    classes=CLASSES,
)

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=2)
valid_loader = DataLoader(valid_dataset, batch_size=8, shuffle=False, num_workers=2)

In [12]:
_, image, mask = train_dataset[0]
print(image.shape, mask.shape)

(3, 512, 512) (2, 512, 512)


In [13]:
from torchmetrics import Metric
class IoUScore(Metric):
    def __init__(self, threshold=0.5, dist_sync_on_step=False):
        super().__init__(dist_sync_on_step=dist_sync_on_step)
        self.threshold = threshold
        self.add_state("intersection_back", default=torch.tensor(0), dist_reduce_fx="sum")
        self.add_state("union_back", default=torch.tensor(0), dist_reduce_fx="sum")
        self.add_state("intersection_fore", default=torch.tensor(0), dist_reduce_fx="sum")
        self.add_state("union_fore", default=torch.tensor(0), dist_reduce_fx="sum")
        self.add_state("num_images", default=torch.tensor(0), dist_reduce_fx="sum")

    def update(self, preds: torch.Tensor, target: torch.Tensor):
        preds = (preds > self.threshold).int()
        intersection_back = torch.logical_and(preds[:,0,:,:], target[:,0,:,:]).sum()
        union_back = torch.logical_or(preds[:,0,:,:], target[:,0,:,:]).sum()
        intersection_fore = torch.logical_and(preds[:,1,:,:], target[:,1,:,:]).sum()
        union_fore = torch.logical_or(preds[:,1,:,:], target[:,1,:,:]).sum()
        num_images = preds.shape[0]

        self.intersection_back += intersection_back
        self.union_back += union_back
        self.intersection_fore += intersection_fore
        self.union_fore += union_fore
        self.num_images += num_images

    def compute(self):
        print(f'num images is: {self.num_images}')
        iou_back = (self.intersection_back.float() / self.union_back.float())
        iou_fore = (self.intersection_fore.float() / self.union_fore.float())
        self.intersection_back = 0
        self.union_back = 0
        self.intersection_fore = 0
        self.union_fore = 0
        self.num_images = 0
        return iou_back,iou_fore

caused by: ['/opt/conda/lib/python3.10/site-packages/tensorflow_io/python/ops/libtensorflow_io_plugins.so: undefined symbol: _ZN3tsl6StatusC1EN10tensorflow5error4CodeESt17basic_string_viewIcSt11char_traitsIcEENS_14SourceLocationE']
caused by: ['/opt/conda/lib/python3.10/site-packages/tensorflow_io/python/ops/libtensorflow_io.so: undefined symbol: _ZTVN10tensorflow13GcsFileSystemE']


In [14]:
import torchmetrics
loss = smp.losses.DiceLoss(mode='binary', from_logits=True)
metrics = [
    IoUScore(threshold=0.5).to(DEVICE),
]

optimizer = torch.optim.Adam([ 
    dict(params=model.parameters(), lr=0.0001),
])

In [15]:
from tqdm import tqdm
# Training loop
def train_epoch(model, loss_fn, metrics, optimizer, device, dataloader):
    model.train()
    num_batches = len(dataloader)
    total_loss = 0
    print(f'Processing a total of {num_batches} batches in training')
    # Iterate over the training dataset
    for batch_idx, (f, inputs, targets) in tqdm(enumerate(dataloader)):        
        inputs = inputs.to(device)
        targets = targets.to(device)

        optimizer.zero_grad()

        # Forward pass
        outputs = model(inputs)
#         print(f'input and output shapes: {inputs.shape}, {outputs.shape}, {targets.shape}')
#         print(f'Outputs min: {torch.min(outputs)}, Outputs max: {torch.max(outputs)}')
        loss = loss_fn(outputs, targets)

        # Backward pass
        loss.backward()
        optimizer.step()

        # Compute metrics
        for metric in metrics:
            metric.update(torch.softmax(outputs, dim=1), targets)
        total_loss += loss

    # Get the metric values
    metric_values = [float(total_loss)/num_batches] + [metric.compute() for metric in metrics]
    return metric_values

# Validation loop
def valid_epoch(model, loss_fn, metrics, device, dataloader):
    model.eval()
    num_batches = len(dataloader)
    total_loss = 0
    print(f'Processing a total of {num_batches} batches in validation')
    # Disable gradient calculation
    with torch.no_grad():
        # Iterate over the validation dataset
        for batch_idx, (f, inputs, targets) in tqdm(enumerate(dataloader)):
            inputs = inputs.to(device)
            targets = targets.to(device)

            # Forward pass
            outputs = model(inputs)
            loss = loss_fn(outputs, targets)

            # Compute metrics
            for metric in metrics:
                metric.update(torch.softmax(outputs, dim=1), targets)
            total_loss += loss

    # Get the metric values
    metric_values = [float(total_loss)/num_batches] + [metric.compute() for metric in metrics]
    return metric_values

In [None]:
import time
max_iou = 0
num_epochs = 50
if os.path.exists('/kaggle/working/model_stats.txt'):
  os.remove('/kaggle/working/model_stats.txt')
fp = open('/kaggle/working/model_stats.txt', 'a')
for epoch in range(num_epochs):
    # Training
    start_time = time.time()
    train_metrics = train_epoch(model, loss, metrics, optimizer, DEVICE, train_loader)
    print(f'=========Finished Training Epoch {epoch} in {float(time.time()-start_time)/60}==========')
    # Validation
    start_time = time.time()
    valid_metrics = valid_epoch(model, loss, metrics, DEVICE, valid_loader)
    print(f'=========Finished Validation Epoch {epoch} {float(time.time()-start_time)/60}in =========')
    
    save_interval = 10
    if (epoch+1) % 10 == 0:
        torch.save(model, f'/kaggle/working/model_{epoch}.pth')
    
    cur_validation_iou = 0.5*valid_metrics[1][0] + 0.5*valid_metrics[1][1]
    if cur_validation_iou > max_iou:
      print(f'Saving model with IoU: {cur_validation_iou}...')
      torch.save(model, '/kaggle/working/best_model.pth')
      with open('/kaggle/working/best_model.txt', 'w') as f:
        f.write(f"Epoch {epoch}: Train Loss={train_metrics[0]}, Validation Loss={valid_metrics[0]}, Train IoU Back={train_metrics[1][0]}, Train IoU Fore={train_metrics[1][1]}, Validation IoU Back={valid_metrics[1][0]}, Validation IoU Fore={valid_metrics[1][1]}")
      max_iou = cur_validation_iou
    # Print or log the metrics for each epoch
    print(f"Epoch {epoch}: Train Loss={train_metrics[0]}, Validation Loss={valid_metrics[0]}, Train IoU Back={train_metrics[1][0]}, Train IoU Fore={train_metrics[1][1]}, Validation IoU Back={valid_metrics[1][0]}, Validation IoU Fore={valid_metrics[1][1]}")
    fp.write(f"Epoch {epoch}: Train Loss={train_metrics[0]}, Validation Loss={valid_metrics[0]}, Train IoU Back={train_metrics[1][0]}, Train IoU Fore={train_metrics[1][1]}, Validation IoU Back={valid_metrics[1][0]}, Validation IoU Fore={valid_metrics[1][1]}\n")
    fp.flush()
fp.close()

Processing a total of 205 batches in training


205it [04:10,  1.22s/it]


num images is: 1633
Processing a total of 1 batches in validation


1it [00:00,  1.54it/s]


num images is: 3
Saving model with IoU: 0.49571308493614197...
Epoch 0: Train Loss=0.4098246504620808, Validation Loss=0.2467450499534607, Train IoU Back=0.7560134530067444, Train IoU Fore=0.04336860030889511, Validation IoU Back=0.9910023808479309, Validation IoU Fore=0.00042378867510706186
Processing a total of 205 batches in training


205it [04:04,  1.19s/it]


num images is: 1633
Processing a total of 1 batches in validation


1it [00:00,  1.29it/s]

num images is: 3
Epoch 1: Train Loss=0.2143523239507908, Validation Loss=0.12928783893585205, Train IoU Back=0.9538811445236206, Train IoU Fore=0.0122360959649086, Validation IoU Back=0.9908599853515625, Validation IoU Fore=0.00027816410874947906
Processing a total of 205 batches in training



32it [00:38,  1.18s/it]

In [None]:
# Inference
DEVICE = 'cpu'
best_model = torch.load('./models/best_model.pth', map_location=torch.device('cpu'))
best_model = best_model.to(DEVICE)

In [None]:
CLASSES = ['blood_vessel']
test_dataset = HubMapDataset(
    x_test_dir, 
    y_test_dir, 
    augmentation=get_validation_augmentation(), 
    preprocessing=get_preprocessing(preprocessing_fn),
    classes=CLASSES,
)

test_loader = DataLoader(test_dataset)

In [None]:
valid_metrics = valid_epoch(best_model, loss, metrics, DEVICE, test_loader)
print(f'Test Loss: {valid_metrics[0]}, Test IoU: {valid_metrics[1]}')

In [None]:
test_dataset_without_aug = HubMapDataset(
    x_test_dir, y_test_dir, 
    classes=CLASSES,
)
train_dataset_without_aug = HubMapDataset(
    x_train_dir, y_train_dir, 
    classes=CLASSES,
)

In [None]:
target_dataset = test_dataset
target_dataset_without_aug = test_dataset_without_aug
for i in range(20):
    n = np.random.choice(len(target_dataset))
    
    image_vis = target_dataset_without_aug[n][0].astype('uint8')
    image, gt_mask = target_dataset[n]
    image_trans = image.transpose(1,2,0)
    print(f'image tran shape: {image_trans.shape}')
    print(image.shape, gt_mask.shape)
    
    gt_mask = gt_mask.squeeze()
    
    x_tensor = torch.from_numpy(image).to(DEVICE).unsqueeze(0)
    print(x_tensor.shape)
    pr_mask = torch.sigmoid(best_model.predict(x_tensor))
    print(pr_mask.shape)
    pr_mask = (pr_mask.squeeze().cpu().numpy().round())
    print(pr_mask.shape)
    
    visualize(
        image=image_trans, 
        ground_truth_mask=gt_mask, 
        predicted_mask=pr_mask
    )

In [None]:
visualize(
        image=image_trans, 
        ground_truth_mask=gt_mask, 
        predicted_mask=pr_mask
    )

In [None]:
print(pr_mask.shape)
print(pr_mask)
plt.imshow(pr_mask)
# im2, contours, hierarchy = cv2.findContours(pr_mask.astype('uint8'),cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE)