In [1]:
import os
import sys
import cv2
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt

import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms as pth_transforms


try:
    from einops import rearrange
    import segmentation_models_pytorch as smp
    from timm.models.layers import drop_path, trunc_normal_
    
except:
    !pip install timm
    !pip install einops
    !pip install segmentation-models-pytorch
    
    from einops import rearrange
    import segmentation_models_pytorch as smp
    from timm.models.layers import drop_path, trunc_normal_

Collecting timm
  Downloading timm-0.6.7-py3-none-any.whl (509 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m510.0/510.0 kB[0m [31m1.6 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: timm
Successfully installed timm-0.6.7
[0mCollecting einops
  Downloading einops-0.4.1-py3-none-any.whl (28 kB)
Installing collected packages: einops
Successfully installed einops-0.4.1
[0mCollecting segmentation-models-pytorch
  Downloading segmentation_models_pytorch-0.3.0-py3-none-any.whl (97 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m97.9/97.9 kB[0m [31m700.7 kB/s[0m eta [36m0:00:00[0m
[?25hCollecting efficientnet-pytorch==0.7.1
  Downloading efficientnet_pytorch-0.7.1.tar.gz (21 kB)
  Preparing metadata (setup.py) ... [?25l- done
[?25hCollecting timm==0.4.12
  Downloading timm-0.4.12-py3-none-any.whl (376 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m377.0/377.0 kB[0m [31m2.0 MB/s[0m eta [

In [2]:
class overlap_patch_embed(nn.Module):
    def __init__(self, patch_size, stride, in_chans, embed_dim):
        super().__init__()
        self.patch_size = patch_size
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride,
                              padding=(patch_size // 2, patch_size // 2))
        self.norm = nn.LayerNorm(embed_dim)

    def forward(self, x):
        x = self.proj(x)
        _, _, h, w = x.shape
        x = rearrange(x, 'b c h w -> b (h w) c')
        x = self.norm(x)
        return x, h, w
    
    

class mix_feedforward(nn.Module):
    def __init__(self, in_features, out_features, hidden_features, dropout_p = 0.0):
        super().__init__()
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.fc2 = nn.Linear(hidden_features, out_features)
        
        # Depth-wise separable convolution
        self.conv = nn.Conv2d(hidden_features, hidden_features, (3, 3), padding=(1, 1),
                              bias=True, groups=hidden_features)
        self.dropout_p = dropout_p
        
    def forward(self, x, h, w):
        x = self.fc1(x)
        x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
        x = self.conv(x)
        x = rearrange(x, 'b c h w -> b (h w) c')
        x = F.gelu(x)
        x = F.dropout(x, p=self.dropout_p, training=self.training)
        x = self.fc2(x)
        x = F.dropout(x, p=self.dropout_p, training=self.training)
        return x
    
    

class efficient_self_attention(nn.Module):
    def __init__(self, attn_dim, num_heads, dropout_p, sr_ratio):
        super().__init__()
        assert attn_dim % num_heads == 0, f'expected attn_dim {attn_dim} to be a multiple of num_heads {num_heads}'
        self.attn_dim = attn_dim
        self.num_heads = num_heads
        self.dropout_p = dropout_p
        self.sr_ratio = sr_ratio
        if sr_ratio > 1:
            self.sr = nn.Conv2d(attn_dim, attn_dim, kernel_size=sr_ratio, stride=sr_ratio)
            self.norm = nn.LayerNorm(attn_dim)

        # Multi-head Self-Attention using dot product
        # Query - Key Dot product is scaled by root of head_dim
        self.q = nn.Linear(attn_dim, attn_dim, bias=True)
        self.kv = nn.Linear(attn_dim, attn_dim * 2, bias=True)
        self.scale = (attn_dim // num_heads) ** -0.5

        # Projecting concatenated outputs from 
        # multiple heads to single `attn_dim` size
        self.proj = nn.Linear(attn_dim, attn_dim)


    def forward(self, x, h, w):
        q = self.q(x)
        q = rearrange(q, ('b hw (m c) -> b m hw c'), m=self.num_heads)

        if self.sr_ratio > 1:
            x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
            x = self.sr(x)
            x = rearrange(x, 'b c h w -> b (h w) c')
            x = self.norm(x)

        x = self.kv(x)
        x = rearrange(x, 'b d (a m c) -> a b m d c', a=2, m=self.num_heads)
        k, v = x[0], x[1] # x.unbind(0)

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)

        x = attn @ v
        x = rearrange(x, 'b m hw c -> b hw (m c)')
        x = self.proj(x)
        x = F.dropout(x, p=self.dropout_p, training=self.training)
        
        attn_output = {'key' : k, 'query' : q, 'value' : v, 'attn' : attn}
        return x, attn_output
    
    

class transformer_block(nn.Module):
    def __init__(self, dim, num_heads, dropout_p, drop_path_p, sr_ratio):
        super().__init__()
        # One transformer block is defined as :
        # Norm -> self-attention -> Norm -> FeedForward
        # skip-connections are added after attention and FF layers
        self.attn = efficient_self_attention(attn_dim=dim, num_heads=num_heads, 
                    dropout_p=dropout_p, sr_ratio=sr_ratio)
        self.ffn = mix_feedforward( dim, dim, hidden_features=dim * 4, dropout_p=dropout_p)                    

        self.drop_path_p = drop_path_p
        self.norm1 = nn.LayerNorm(dim, eps=1e-6)
        self.norm2 = nn.LayerNorm(dim, eps=1e-6)
        

    def forward(self, x, h, w):
        # Norm -> self-attention
        skip = x
        x = self.norm1(x)
        x, attn_output = self.attn(x, h, w)
        x = drop_path(x, drop_prob=self.drop_path_p, training=self.training)
        x = x + skip

        # Norm -> FeedForward
        skip = x
        x = self.norm2(x)
        x = self.ffn(x, h, w)
        x = drop_path(x, drop_prob=self.drop_path_p, training=self.training)
        x = x + skip
        return x, attn_output
    
    
    
class mix_transformer_stage(nn.Module):
    def __init__(self, patch_embed, blocks, norm):
        super().__init__()
        self.patch_embed = patch_embed
        self.blocks = nn.ModuleList(blocks)
        self.norm = norm

    def forward(self, x):
        # patch embedding and store required data
        stage_output  = {}
        stage_output['patch_embed_input'] = x
        x, h, w = self.patch_embed(x)
        stage_output['patch_embed_h'] = h
        stage_output['patch_embed_w'] = w
        stage_output['patch_embed_output'] = x
        
        for block in self.blocks:
            x, attn_output = block(x, h, w)
                        
        x = self.norm(x)
        x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
        
        # store last attention block data 
        # in stages' output data
        for k,v in attn_output.items():
            stage_output[k] = v
        del attn_output
        return x, stage_output
    
    
    
class mix_transformer(nn.Module):
    def __init__(self, in_chans, embed_dims, num_heads, depths, 
                sr_ratios, dropout_p, drop_path_p):
        super().__init__()
        self.stages = nn.ModuleList()
        for stage_i in range(len(depths)):
            # Each Stage consists of following blocks :
            # Overlap patch embedding -> mix_transformer_block -> norm
            blocks = []
            for i in range(depths[stage_i]):
                blocks.append(transformer_block(dim = embed_dims[stage_i],
                        num_heads= num_heads[stage_i], dropout_p=dropout_p,
                        drop_path_p = drop_path_p * (sum(depths[:stage_i])+i) / (sum(depths)-1),
                        sr_ratio = sr_ratios[stage_i] ))

            if(stage_i == 0):
                patch_size = 7
                stride = 4
                in_chans = in_chans
            else:
                patch_size = 3
                stride = 2
                in_chans = embed_dims[stage_i -1]
            
            patch_embed = overlap_patch_embed(patch_size, stride=stride, in_chans=in_chans, 
                            embed_dim= embed_dims[stage_i])
            norm = nn.LayerNorm(embed_dims[stage_i], eps=1e-6)
            self.stages.append(mix_transformer_stage(patch_embed, blocks, norm))
            

    def forward(self, x):
        outputs = []
        for i,stage in enumerate(self.stages):
            x, _ = stage(x)
            outputs.append(x)
        return outputs
        
    
    def get_attn_outputs(self, x):
        stage_outputs = []
        for i,stage in enumerate(self.stages):
            x, stage_data = stage(x)
            stage_outputs.append(stage_data)
        return stage_outputs


class segformer_head(nn.Module):
    def __init__(self, in_channels, num_classes, embed_dim, dropout_p=0.1):
        super().__init__()
        self.in_channels = in_channels
        self.num_classes = num_classes
        self.embed_dim = embed_dim
        self.dropout_p = dropout_p

        # 1x1 conv to fuse multi-scale output from encoder
        self.layers = nn.ModuleList([nn.Conv2d(chans, embed_dim, (1, 1))
                                     for chans in reversed(in_channels)])
        self.linear_fuse = nn.Conv2d(embed_dim * len(self.layers), embed_dim, (1, 1), bias=False)
        self.bn = nn.BatchNorm2d(embed_dim, eps=1e-5)

        # 1x1 conv to get num_class channel predictions
        self.linear_pred = nn.Conv2d(self.embed_dim, num_classes, kernel_size=(1, 1))
        self.init_weights()

    def init_weights(self):
        nn.init.kaiming_normal_(self.linear_fuse.weight, mode='fan_out', nonlinearity='relu')
        nn.init.constant_(self.bn.weight, 1)
        nn.init.constant_(self.bn.bias, 0)

    def forward(self, x):
        feature_size = x[0].shape[2:]
        
        # project each encoder stage output to H/4, W/4
        x = [layer(xi) for layer, xi in zip(self.layers, reversed(x))]
        x = [F.interpolate(xi, size=feature_size, mode='bilinear', align_corners=False)
             for xi in x[:-1]] + [x[-1]]
        
        # concatenate project output and use 1x1
        # convs to get num_class channel output
        x = self.linear_fuse(torch.cat(x, dim=1))
        x = self.bn(x)
        x = F.relu(x, inplace=True)
        x = F.dropout(x, p=self.dropout_p, training=self.training)
        x = self.linear_pred(x)
        return x
    
    
    
class segformer_mit_b3(nn.Module):    
    def __init__(self, in_channels, num_classes):
        super().__init__()
        # Encoder block    
        self.backbone = mix_transformer(in_chans=in_channels, embed_dims=(64, 128, 320, 512), 
                                    num_heads=(1, 2, 5, 8), depths=(3, 4, 18, 3),
                                    sr_ratios=(8, 4, 2, 1), dropout_p=0.0, drop_path_p=0.1)
        # decoder block
        self.decoder_head = segformer_head(in_channels=(64, 128, 320, 512), 
                                    num_classes=num_classes, embed_dim=256)
        
        # init weights
        self.apply(self._init_weights)
        
        
    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv2d):
            nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
            
            
    def forward(self, x):
        image_hw = x.shape[2:]
        x = self.backbone(x)
        x = self.decoder_head(x)
        x = F.interpolate(x, size=image_hw, mode='bilinear', align_corners=False)
        return x
    
    
    def get_attention_outputs(self, x):
        return self.backbone.get_attn_outputs(x)
    
    def get_last_selfattention(self, x):
        outputs = self.get_attention_outputs(x)
        return outputs[-1].get('attn', None)

In [3]:
def preprocess_image(image_path, tf, patch_size):
    # read image -> convert to RGB -> torch Tensor
    rgb_img = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
    img = tf(rgb_img)
    _, image_height, image_width = img.shape
    
    # make the image divisible by the patch size
    w, h = image_width - image_width % patch_size, image_height - image_height % patch_size
    img = img[:, :h, :w].unsqueeze(0)
    
    w_featmap = img.shape[-1] // patch_size
    h_featmap = img.shape[-2] // patch_size
    return rgb_img, img, w_featmap, h_featmap

In [4]:
def calculate_attentions(img, w_featmap, h_featmap, patch_size, mode = 'bilinear'):
    attentions = model.get_last_selfattention(img.to(device))
    nh = attentions.shape[1]
    
    # we keep only the output patch attention
    # reshape to image size
    attentions = attentions[0, :, :, 0].reshape(nh, h_featmap, w_featmap)
    attentions = F.interpolate(attentions.unsqueeze(0), scale_factor=patch_size, mode=mode)[0].detach().cpu().numpy()
    return attentions


def get_attention_masks(image_path, model, transform, patch_size, mode = 'bilinear'):
    rgb_img, img, w_featmap, h_featmap = preprocess_image(image_path, transform, patch_size)
    attentions = calculate_attentions(img, w_featmap, h_featmap, patch_size, mode = mode)
    return rgb_img, attentions

In [5]:
def calculate_segformer_stage_attentions(img, num_stages, mode = 'bilinear'):
    stages_data = model.get_attention_outputs(img.to(device))
    stage_attn_output = []

    for i, data in enumerate(stages_data[0:num_stages]):
        stage_attn = data['attn']
        stage_nh = stage_attn.shape[1]

        # we keep only the output patch attention
        # reshape to image size
        stage_attn = stage_attn[0, :, :, 0]
        stage_h, stage_w = int(targetHeight / stage_scale[i]), int(targetWidth / stage_scale[i])
        stage_attn = stage_attn.reshape(stage_nh, stage_h, stage_w)
        stage_attn = F.interpolate(stage_attn.unsqueeze(0), size=(targetHeight, targetWidth), mode=mode)[0].detach().cpu().numpy()
        stage_attn_output.append(stage_attn)
    
    stage_attn_output = np.concatenate(stage_attn_output, axis=0)    
    return stage_attn_output


def get_stage_attention_masks(image_path, model, transform, patch_size, num_stages, mode = 'bilinear'):
    rgb_img, img, w_featmap, h_featmap = preprocess_image(image_path, transform, patch_size)
    attentions = calculate_segformer_stage_attentions(img, num_stages = num_stages, mode = mode)
    return rgb_img, attentions

In [6]:
targetWidth = 1024
targetHeight = 512

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 

NUM_CLASSES = 19
MODEL_NAME = f'segformer_mit_b3_last_stage'

In [7]:
model = segformer_mit_b3(in_channels=3, num_classes=NUM_CLASSES).to(device)
model.eval();
checkpoint = torch.load('../input/image-segmentation/segformer_mit_b3_cs_pretrain_19CLS_512_1024_CE_loss.pt')
model.load_state_dict(checkpoint)

<All keys matched successfully>

In [8]:
output_dir = '.'
patch_size = 32
stage_scale = [4, 8, 16, 32]
stage_heads = [1, 2, 5, 8]
titles = []
for stage_index, stage_nh in enumerate(stage_heads):
    titles.extend([f"STAGE_{stage_index+1}_HEAD_{x+1}" for x in range(stage_nh)])

transform = pth_transforms.Compose([
    pth_transforms.ToTensor(),
    pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])

In [9]:
input_dir = '../input/cityscapes-512x1024/demoVideo/stuttgart_00'
image_list = sorted(os.listdir(input_dir))
images_path = [os.path.join(input_dir, x) for x in image_list]

In [10]:
font = {'family' : 'normal', 'weight' : 'bold', 'size'   : 4}
plt.rc('font', **font)
plt.rcParams['text.color'] = 'white'

In [11]:
%matplotlib agg
fig, axes = plt.subplots(3,3, figsize=(15.5,8))
axes = axes.flatten()
fig.tight_layout()

for image_path in tqdm(images_path):
    image_name = image_path.split(os.sep)[-1].split('.')[0]
    
    rgb_img, attentions = get_attention_masks(image_path, model, transform, patch_size, mode = 'bilinear')
#     rgb_img, attentions = get_stage_attention_masks(image_path, model, transform, patch_size, num_stages=3, mode = 'bilinear')    
        
#     for i in range(len(axes)):
#         axes[i].clear()
#         axes[i].imshow(rgb_img)
#         axes[i].imshow(attentions[i], cmap='inferno', alpha=0.5)
#         axes[i].axis('off')
#         axes[i].set_title(titles[i], x= 0.22, y=0.9, va="top")
    
    
###########################################    
    for i in range(len(axes)):
        axes[i].clear()
        if (i < 4):
            axes[i].imshow(rgb_img)
            axes[i].imshow(attentions[i], cmap='inferno', alpha=0.5)
            axes[i].set_title(titles[i+8], x= 0.20, y=0.9, va="top")
            
        elif(i==4):
            axes[i].imshow(np.zeros_like(rgb_img))
        else:
            axes[i].imshow(rgb_img)
            axes[i].imshow(attentions[i-1], cmap='inferno', alpha=0.5)
            axes[i].set_title(titles[i-1+8], x= 0.20, y=0.9, va="top")

        axes[i].axis('off')

###########################################

    fig.subplots_adjust(wspace=0, hspace=0)
    fig.savefig(f'{image_name}_last_stage.png')

100%|██████████| 599/599 [17:40<00:00,  1.77s/it]


In [12]:
def convert_images_to_video(images_dir, output_video_path, fps : int = 20):
    
    input_images = [os.path.join(images_dir, *[x]) for x in sorted(os.listdir(images_dir)) if x.endswith('png')]
    
    if(len(input_images) > 0):
        sample_image = cv2.imread(input_images[0])
        height, width, _ = sample_image.shape
        
        # handles for input output videos
        output_handle = cv2.VideoWriter(output_video_path, cv2.VideoWriter_fourcc(*'DIVX'), fps, (width, height))

        # create progress bar
        num_frames = int(len(input_images))
        pbar = tqdm(total = num_frames, position=0, leave=True)

        for i in tqdm(range(num_frames), position=0, leave=True):
            frame = cv2.imread(input_images[i])
            output_handle.write(frame)
            pbar.update(1)

        # release the output video handler
        output_handle.release()
                
    else:
        pass

In [13]:
def createDir(dirPath):
    if(not os.path.isdir(dirPath)):
        os.mkdir(dirPath)

In [14]:
video_output_dir = os.path.join(output_dir, *['videos'])
createDir(video_output_dir)
output_video_path = os.path.join(video_output_dir, *[f"{MODEL_NAME}_last_stage_demoVideo.mp4"])
print(output_video_path)

./videos/segformer_mit_b3_last_stage_last_stage_demoVideo.mp4


In [15]:
convert_images_to_video('./', output_video_path)

OpenCV: FFMPEG: tag 0x58564944/'DIVX' is not supported with codec id 12 and format 'mp4 / MP4 (MPEG-4 Part 14)'
OpenCV: FFMPEG: fallback to use tag 0x7634706d/'mp4v'
100%|██████████| 599/599 [00:34<00:00, 17.54it/s]
100%|██████████| 599/599 [00:34<00:00, 17.53it/s]
