In [1]:
!pip install -U git+https://github.com/qubvel/segmentation_models.pytorch
!pip install pytorch-lightning
!pip install albumentations
!pip install torchvision       

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting git+https://github.com/qubvel/segmentation_models.pytorch
  Cloning https://github.com/qubvel/segmentation_models.pytorch to /tmp/pip-req-build-xv479c4z
  Running command git clone -q https://github.com/qubvel/segmentation_models.pytorch /tmp/pip-req-build-xv479c4z
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
    Preparing wheel metadata ... [?25l[?25hdone
Collecting pretrainedmodels==0.7.4
  Downloading pretrainedmodels-0.7.4.tar.gz (58 kB)
[K     |████████████████████████████████| 58 kB 4.0 MB/s 
[?25hCollecting efficientnet-pytorch==0.7.1
  Downloading efficientnet_pytorch-0.7.1.tar.gz (21 kB)
Collecting timm==0.4.12
  Downloading timm-0.4.12-py3-none-any.whl (376 kB)
[K     |████████████████████████████████| 376 kB 15.8 MB/s 
Collecting munch
  Downloading munch-2.5.0-py2.py3-none-any.whl (10 kB)
Buildi

In [2]:
import gdown

# https://drive.google.com/file/d/1Vog0SCV90K3Z-3IRw-aVUqTKGSNaHEhx/view?usp=sharing

url='https://drive.google.com/uc?id=1Vog0SCV90K3Z-3IRw-aVUqTKGSNaHEhx'
output_file_train='resnet_unet.ckpt'

gdown.download(url, output_file_train, quiet=False)

# https://drive.google.com/file/d/1_1g7zTV0IPce8AN3SxcMaNL2j40otoLN/view?usp=sharing
url='https://drive.google.com/uc?id=1_1g7zTV0IPce8AN3SxcMaNL2j40otoLN'
output_file_train='mobilenetv2_nontune.ckpt'

gdown.download(url, output_file_train, quiet=False)

#https://drive.google.com/file/d/1LeicojNwoQM6-eLGdu1pX4EkKXB2D7RX/view?usp=sharing
url='https://drive.google.com/uc?id=1LeicojNwoQM6-eLGdu1pX4EkKXB2D7RX'
output_file_train='mix_transfomer_unet.ckpt'

gdown.download(url, output_file_train, quiet=False)

#https://drive.google.com/file/d/1irlDBPZWBqLTjrh-IgmQ0pgrwPAMRF76/view?usp=sharing
url='https://drive.google.com/uc?id=1irlDBPZWBqLTjrh-IgmQ0pgrwPAMRF76'
output_file_train='mix_transfomer_unet_b1.ckpt'

gdown.download(url, output_file_train, quiet=False)

Downloading...
From: https://drive.google.com/uc?id=1Vog0SCV90K3Z-3IRw-aVUqTKGSNaHEhx
To: /content/resnet_unet.ckpt
100%|██████████| 391M/391M [00:06<00:00, 63.2MB/s]
Downloading...
From: https://drive.google.com/uc?id=1_1g7zTV0IPce8AN3SxcMaNL2j40otoLN
To: /content/mobilenetv2_nontune.ckpt
100%|██████████| 79.9M/79.9M [00:02<00:00, 37.7MB/s]
Downloading...
From: https://drive.google.com/uc?id=1LeicojNwoQM6-eLGdu1pX4EkKXB2D7RX
To: /content/mix_transfomer_unet.ckpt
100%|██████████| 66.8M/66.8M [00:01<00:00, 54.1MB/s]
Downloading...
From: https://drive.google.com/uc?id=1irlDBPZWBqLTjrh-IgmQ0pgrwPAMRF76
To: /content/mix_transfomer_unet_b1.ckpt
100%|██████████| 197M/197M [00:02<00:00, 84.0MB/s]


'mix_transfomer_unet_b1.ckpt'

In [3]:
import pytorch_lightning as pl
import segmentation_models_pytorch as smp

class Segmentation_custom(pl.LightningModule):

    def __init__(self, arch, encoder_name, in_channels, out_classes, **kwargs):
        super().__init__()
        self.model = smp.create_model(
            arch, encoder_name=encoder_name, in_channels=in_channels, classes=out_classes, **kwargs
        )

        # for image segmentation dice loss could be the best first choice
        # self.loss_fn = smp.losses.DiceLoss(smp.losses.BINARY_MODE, from_logits=True)
        self.loss_fn = smp.losses.DiceLoss(smp.losses.BINARY_MODE)

    def forward(self, image):
        # normalize image here
        mask = self.model(image)
        return mask

    def shared_step(self, batch, stage):
        
        image = batch["image"]

        # Shape of the image should be (batch_size, num_channels, height, width)
        # if you work with grayscale images, expand channels dim to have [batch_size, 1, height, width]
        assert image.ndim == 4

        # Check that image dimensions are divisible by 32, 
        # encoder and decoder connected by `skip connections` and usually encoder have 5 stages of 
        # downsampling by factor 2 (2 ^ 5 = 32); e.g. if we have image with shape 65x65 we will have 
        # following shapes of features in encoder and decoder: 84, 42, 21, 10, 5 -> 5, 10, 20, 40, 80
        # and we will get an error trying to concat these features
        h, w = image.shape[2:]
        assert h % 32 == 0 and w % 32 == 0

        mask = batch["mask"]

        # Shape of the mask should be [batch_size, num_classes, height, width]
        # for binary segmentation num_classes = 1
        assert mask.ndim == 4
        

        # Check that mask values in between 0 and 1, NOT 0 and 255 for binary segmentation
        assert mask.max() <= 1.0 and mask.min() >= 0

        logits_mask = self.forward(image)
        
        # Predicted mask contains logits, and loss_fn param `from_logits` is set to True
        loss = self.loss_fn(logits_mask, mask)

        # Lets compute metrics for some threshold
        # first convert mask values to probabilities, then 
        # apply thresholding
        prob_mask = logits_mask.sigmoid()
        pred_mask = (prob_mask > 0.5).float()

        # We will compute IoU metric by two ways
        #   1. dataset-wise
        #   2. image-wise
        # but for now we just compute true positive, false positive, false negative and
        # true negative 'pixels' for each image and class
        # these values will be aggregated in the end of an epoch
        tp, fp, fn, tn = smp.metrics.get_stats(pred_mask.long(), mask.long(), mode="binary")

        return {
            "loss": loss,
            "tp": tp,
            "fp": fp,
            "fn": fn,
            "tn": tn,
        }

    def shared_epoch_end(self, outputs, stage):
        # aggregate step metics
        tp = torch.cat([x["tp"] for x in outputs])
        fp = torch.cat([x["fp"] for x in outputs])
        fn = torch.cat([x["fn"] for x in outputs])
        tn = torch.cat([x["tn"] for x in outputs])

        # per image IoU means that we first calculate IoU score for each image 
        # and then compute mean over these scores
        per_image_iou = smp.metrics.iou_score(tp, fp, fn, tn, reduction="micro-imagewise")
        
        # dataset IoU means that we aggregate intersection and union over whole dataset
        # and then compute IoU score. The difference between dataset_iou and per_image_iou scores
        # in this particular case will not be much, however for dataset 
        # with "empty" images (images without target class) a large gap could be observed. 
        # Empty images influence a lot on per_image_iou and much less on dataset_iou.
        dataset_iou = smp.metrics.iou_score(tp, fp, fn, tn, reduction="micro")

        metrics = {
            f"{stage}_per_image_iou": per_image_iou,
            f"{stage}_dataset_iou": dataset_iou,
        }
        
        self.log_dict(metrics, prog_bar=True)

    def training_step(self, batch, batch_idx):
        return self.shared_step(batch, "train")            

    def training_epoch_end(self, outputs):
        return self.shared_epoch_end(outputs, "train")

    def validation_step(self, batch, batch_idx):
        return self.shared_step(batch, "valid")

    def validation_epoch_end(self, outputs):
        return self.shared_epoch_end(outputs, "valid")

    def test_step(self, batch, batch_idx):
        return self.shared_step(batch, "test")  

    def test_epoch_end(self, outputs):
        return self.shared_epoch_end(outputs, "test")

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.001)

In [4]:
model_mitb0 = Segmentation_custom.load_from_checkpoint('/content/mix_transfomer_unet.ckpt',arch="Unet", encoder_name="mit_b0", in_channels=3, out_classes=1)
model_mitb1 = Segmentation_custom.load_from_checkpoint('/content/mix_transfomer_unet_b1.ckpt',arch="Unet", encoder_name="mit_b1", in_channels=3, out_classes=1)
model_resnet50 = Segmentation_custom.load_from_checkpoint('/content/resnet_unet.ckpt',arch="Unet", encoder_name="resnet50", in_channels=3, out_classes=1)
model_mobilenetv2 = Segmentation_custom.load_from_checkpoint('/content/mobilenetv2_nontune.ckpt',arch="Unet", encoder_name="mobilenet_v2", in_channels=3, out_classes=1)



Downloading: "https://github.com/qubvel/segmentation_models.pytorch/releases/download/v0.0.2/mit_b0.pth" to /root/.cache/torch/hub/checkpoints/mit_b0.pth


  0%|          | 0.00/13.7M [00:00<?, ?B/s]

Downloading: "https://github.com/qubvel/segmentation_models.pytorch/releases/download/v0.0.2/mit_b1.pth" to /root/.cache/torch/hub/checkpoints/mit_b1.pth


  0%|          | 0.00/52.2M [00:00<?, ?B/s]

Downloading: "https://download.pytorch.org/models/resnet50-19c8e357.pth" to /root/.cache/torch/hub/checkpoints/resnet50-19c8e357.pth


  0%|          | 0.00/97.8M [00:00<?, ?B/s]

Downloading: "https://download.pytorch.org/models/mobilenet_v2-b0353104.pth" to /root/.cache/torch/hub/checkpoints/mobilenet_v2-b0353104.pth


  0%|          | 0.00/13.6M [00:00<?, ?B/s]

In [16]:
import numpy as np
import torch
def time_measure(model):
      # INIT LOGGERS
    dummy_input = torch.randn(1, 3,224,224, dtype=torch.float)

    dummy_input=dummy_input.type(torch.cuda.FloatTensor)
    dummy_input=dummy_input.to('cuda')

    model.cuda()
    model.eval()

    starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
    repetitions = 100
    timings=np.zeros((repetitions,1))
    #GPU-WARM-UP
    for _ in range(10):
        _ = model(dummy_input)
    # MEASURE PERFORMANCE
    with torch.no_grad():
        for rep in range(repetitions):
            starter.record()
            _ = model(dummy_input)
            ender.record()
            # WAIT FOR GPU SYNC
            torch.cuda.synchronize()
            curr_time = starter.elapsed_time(ender)
            timings[rep] = curr_time
            
    mean_syn = np.sum(timings) / repetitions
    std_syn = np.std(timings)

    return mean_syn

In [19]:
iterasi=5 #untuk memperoleh average time measure nya, karena tiap iterasi tidak stabil

time_mobilenet=0
time_resnet50=0
time_mitb0=0
time_mitb1=0

for _ in range(0,iterasi):
  time_mobilenet+=time_measure(model_mobilenetv2)
  time_resnet50+=time_measure(model_resnet50)
  time_mitb0+=time_measure(model_mitb0)
  time_mitb1+=time_measure(model_mitb1)


print('Time measure MobileNetV2 : ',time_mobilenet/iterasi)
print('Time measure Resnet50 : ',time_resnet50/iterasi)
print('Time measure MixTransformer B0 : ',time_mitb0/iterasi)
print('Time measure MixTransformer B1 : ',time_mitb1/iterasi)


Time measure MobileNetV2 :  7.789138957977295
Time measure Resnet50 :  11.004524295806885
Time measure MixTransformer B0 :  9.63421464920044
Time measure MixTransformer B1 :  9.649583179473876
