In [None]:
!pip install -U fvcore
!pip install -U git+https://github.com/qubvel/segmentation_models.pytorch
!pip install pytorch-lightning
!pip install albumentations
!pip install torchvision       

In [None]:
import gdown

# https://drive.google.com/file/d/1Vog0SCV90K3Z-3IRw-aVUqTKGSNaHEhx/view?usp=sharing

url='https://drive.google.com/uc?id=1Vog0SCV90K3Z-3IRw-aVUqTKGSNaHEhx'
output_file_train='resnet_unet.ckpt'

gdown.download(url, output_file_train, quiet=False)

# https://drive.google.com/file/d/1_1g7zTV0IPce8AN3SxcMaNL2j40otoLN/view?usp=sharing
url='https://drive.google.com/uc?id=1_1g7zTV0IPce8AN3SxcMaNL2j40otoLN'
output_file_train='mobilenetv2_nontune.ckpt'

gdown.download(url, output_file_train, quiet=False)

#https://drive.google.com/file/d/1LeicojNwoQM6-eLGdu1pX4EkKXB2D7RX/view?usp=sharing
url='https://drive.google.com/uc?id=1LeicojNwoQM6-eLGdu1pX4EkKXB2D7RX'
output_file_train='mix_transfomer_unet.ckpt'

gdown.download(url, output_file_train, quiet=False)

#https://drive.google.com/file/d/1irlDBPZWBqLTjrh-IgmQ0pgrwPAMRF76/view?usp=sharing
url='https://drive.google.com/uc?id=1irlDBPZWBqLTjrh-IgmQ0pgrwPAMRF76'
output_file_train='mix_transfomer_unet_b1.ckpt'

gdown.download(url, output_file_train, quiet=False)

In [3]:
import pytorch_lightning as pl
import segmentation_models_pytorch as smp

class Segmentation_custom(pl.LightningModule):

    def __init__(self, arch, encoder_name, in_channels, out_classes, **kwargs):
        super().__init__()
        self.model = smp.create_model(
            arch, encoder_name=encoder_name, in_channels=in_channels, classes=out_classes, **kwargs
        )

        # for image segmentation dice loss could be the best first choice
        # self.loss_fn = smp.losses.DiceLoss(smp.losses.BINARY_MODE, from_logits=True)
        self.loss_fn = smp.losses.DiceLoss(smp.losses.BINARY_MODE)

    def forward(self, image):
        # normalize image here
        mask = self.model(image)
        return mask

    def shared_step(self, batch, stage):
        
        image = batch["image"]

        # Shape of the image should be (batch_size, num_channels, height, width)
        # if you work with grayscale images, expand channels dim to have [batch_size, 1, height, width]
        assert image.ndim == 4

        # Check that image dimensions are divisible by 32, 
        # encoder and decoder connected by `skip connections` and usually encoder have 5 stages of 
        # downsampling by factor 2 (2 ^ 5 = 32); e.g. if we have image with shape 65x65 we will have 
        # following shapes of features in encoder and decoder: 84, 42, 21, 10, 5 -> 5, 10, 20, 40, 80
        # and we will get an error trying to concat these features
        h, w = image.shape[2:]
        assert h % 32 == 0 and w % 32 == 0

        mask = batch["mask"]

        # Shape of the mask should be [batch_size, num_classes, height, width]
        # for binary segmentation num_classes = 1
        assert mask.ndim == 4
        

        # Check that mask values in between 0 and 1, NOT 0 and 255 for binary segmentation
        assert mask.max() <= 1.0 and mask.min() >= 0

        logits_mask = self.forward(image)
        
        # Predicted mask contains logits, and loss_fn param `from_logits` is set to True
        loss = self.loss_fn(logits_mask, mask)

        # Lets compute metrics for some threshold
        # first convert mask values to probabilities, then 
        # apply thresholding
        prob_mask = logits_mask.sigmoid()
        pred_mask = (prob_mask > 0.5).float()

        # We will compute IoU metric by two ways
        #   1. dataset-wise
        #   2. image-wise
        # but for now we just compute true positive, false positive, false negative and
        # true negative 'pixels' for each image and class
        # these values will be aggregated in the end of an epoch
        tp, fp, fn, tn = smp.metrics.get_stats(pred_mask.long(), mask.long(), mode="binary")

        return {
            "loss": loss,
            "tp": tp,
            "fp": fp,
            "fn": fn,
            "tn": tn,
        }

    def shared_epoch_end(self, outputs, stage):
        # aggregate step metics
        tp = torch.cat([x["tp"] for x in outputs])
        fp = torch.cat([x["fp"] for x in outputs])
        fn = torch.cat([x["fn"] for x in outputs])
        tn = torch.cat([x["tn"] for x in outputs])

        # per image IoU means that we first calculate IoU score for each image 
        # and then compute mean over these scores
        per_image_iou = smp.metrics.iou_score(tp, fp, fn, tn, reduction="micro-imagewise")
        
        # dataset IoU means that we aggregate intersection and union over whole dataset
        # and then compute IoU score. The difference between dataset_iou and per_image_iou scores
        # in this particular case will not be much, however for dataset 
        # with "empty" images (images without target class) a large gap could be observed. 
        # Empty images influence a lot on per_image_iou and much less on dataset_iou.
        dataset_iou = smp.metrics.iou_score(tp, fp, fn, tn, reduction="micro")

        metrics = {
            f"{stage}_per_image_iou": per_image_iou,
            f"{stage}_dataset_iou": dataset_iou,
        }
        
        self.log_dict(metrics, prog_bar=True)

    def training_step(self, batch, batch_idx):
        return self.shared_step(batch, "train")            

    def training_epoch_end(self, outputs):
        return self.shared_epoch_end(outputs, "train")

    def validation_step(self, batch, batch_idx):
        return self.shared_step(batch, "valid")

    def validation_epoch_end(self, outputs):
        return self.shared_epoch_end(outputs, "valid")

    def test_step(self, batch, batch_idx):
        return self.shared_step(batch, "test")  

    def test_epoch_end(self, outputs):
        return self.shared_epoch_end(outputs, "test")

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.001)

In [None]:
model_mitb0 = Segmentation_custom.load_from_checkpoint('/content/mix_transfomer_unet.ckpt',arch="Unet", encoder_name="mit_b0", in_channels=3, out_classes=1)
model_mitb1 = Segmentation_custom.load_from_checkpoint('/content/mix_transfomer_unet_b1.ckpt',arch="Unet", encoder_name="mit_b1", in_channels=3, out_classes=1)
model_resnet50 = Segmentation_custom.load_from_checkpoint('/content/resnet_unet.ckpt',arch="Unet", encoder_name="resnet50", in_channels=3, out_classes=1)
model_mobilenetv2 = Segmentation_custom.load_from_checkpoint('/content/mobilenetv2_nontune.ckpt',arch="Unet", encoder_name="mobilenet_v2", in_channels=3, out_classes=1)

model_mitb0.cuda()
model_mitb1.cuda()
model_resnet50.cuda()
model_mobilenetv2.cuda()

In [5]:
from fvcore.nn import FlopCountAnalysis, flop_count_table, flop_count_str,ActivationCountAnalysis
import torch

input = torch.randn(1,3, 224, 224)
input=input.type(torch.cuda.FloatTensor)

input.cuda()

flops_model_mitb0= FlopCountAnalysis(model_mitb0, input)
flops_model_mitb1=FlopCountAnalysis(model_mitb1, input)
flops_resnet50=FlopCountAnalysis(model_resnet50, input)
flops_mobilenetv2=FlopCountAnalysis(model_mobilenetv2, input)


In [6]:
print(flop_count_table(flops_model_mitb0))

| module                       | #parameters or shape   | #flops     |
|:-----------------------------|:-----------------------|:-----------|
| model                        | 5.549M                 | 2.29G      |
|  encoder                     |  3.319M                |  0.458G    |
|   encoder.patch_embed1       |   4.8K                 |   15.254M  |
|    encoder.patch_embed1.proj |    4.736K              |    14.752M |
|    encoder.patch_embed1.norm |    64                  |    0.502M  |
|   encoder.patch_embed2       |   18.624K              |   14.702M  |
|    encoder.patch_embed2.proj |    18.496K             |    14.451M |
|    encoder.patch_embed2.norm |    0.128K              |    0.251M  |
|   encoder.patch_embed3       |   92.64K               |   18.22M   |
|    encoder.patch_embed3.proj |    92.32K              |    18.063M |
|    encoder.patch_embed3.norm |    0.32K               |    0.157M  |
|   encoder.patch_embed4       |   0.369M               |   18.126M  |
|    e

In [7]:
print(flop_count_table(flops_model_mitb1))

| module                       | #parameters or shape   | #flops     |
|:-----------------------------|:-----------------------|:-----------|
| model                        | 16.432M                | 3.823G     |
|  encoder                     |  13.151M               |  1.687G    |
|   encoder.patch_embed1       |   9.6K                 |   30.507M  |
|    encoder.patch_embed1.proj |    9.472K              |    29.503M |
|    encoder.patch_embed1.norm |    0.128K              |    1.004M  |
|   encoder.patch_embed2       |   74.112K              |   58.305M  |
|    encoder.patch_embed2.proj |    73.856K             |    57.803M |
|    encoder.patch_embed2.norm |    0.256K              |    0.502M  |
|   encoder.patch_embed3       |   0.37M                |   72.567M  |
|    encoder.patch_embed3.proj |    0.369M              |    72.253M |
|    encoder.patch_embed3.norm |    0.64K               |    0.314M  |
|   encoder.patch_embed4       |   1.476M               |   72.379M  |
|    e

In [8]:
print(flop_count_table(flops_resnet50))

| module                       | #parameters or shape   | #flops    |
|:-----------------------------|:-----------------------|:----------|
| model                        | 32.521M                | 8.215G    |
|  encoder                     |  23.508M               |  4.143G   |
|   encoder.conv1              |   9.408K               |   0.118G  |
|    encoder.conv1.weight      |    (64, 3, 7, 7)       |           |
|   encoder.bn1                |   0.128K               |   4.014M  |
|    encoder.bn1.weight        |    (64,)               |           |
|    encoder.bn1.bias          |    (64,)               |           |
|   encoder.layer1             |   0.216M               |   0.69G   |
|    encoder.layer1.0          |    75.008K             |    0.241G |
|    encoder.layer1.1          |    70.4K               |    0.224G |
|    encoder.layer1.2          |    70.4K               |    0.224G |
|   encoder.layer2             |   1.22M                |   1.043G  |
|    encoder.layer2.

In [9]:
print(flop_count_table(flops_mobilenetv2))

| module                        | #parameters or shape   | #flops     |
|:------------------------------|:-----------------------|:-----------|
| model                         | 6.629M                 | 2.613G     |
|  encoder.features             |  2.224M                |  0.333G    |
|   encoder.features.0          |   0.928K               |   12.845M  |
|    encoder.features.0.0       |    0.864K              |    10.838M |
|    encoder.features.0.1       |    64                  |    2.007M  |
|   encoder.features.1.conv     |   0.896K               |   13.046M  |
|    encoder.features.1.conv.0  |    0.352K              |    5.62M   |
|    encoder.features.1.conv.1  |    0.512K              |    6.423M  |
|    encoder.features.1.conv.2  |    32                  |    1.004M  |
|   encoder.features.2.conv     |   5.136K               |   37.105M  |
|    encoder.features.2.conv.0  |    1.728K              |    25.289M |
|    encoder.features.2.conv.1  |    1.056K              |    4.