# Blocks

## pvt-v2 backbone

In [None]:
from functools import partial

from timm.models.layers import DropPath, to_2tuple, trunc_normal_
from timm.models.registry import register_model
from timm.models.vision_transformer import _cfg
import math

class Mlp(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0., linear=False):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.dwconv = DWConv(hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)
        self.linear = linear
        if self.linear:
            self.relu = nn.ReLU(inplace=True)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv2d):
            fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
            fan_out //= m.groups
            m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
            if m.bias is not None:
                m.bias.data.zero_()

    def forward(self, x, H, W):
        x = self.fc1(x)
        if self.linear:
            x = self.relu(x)
        x = self.dwconv(x, H, W)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x


class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1, linear=False):
        super().__init__()
        assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."

        self.dim = dim
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5

        self.q = nn.Linear(dim, dim, bias=qkv_bias)
        self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        self.linear = linear
        self.sr_ratio = sr_ratio
        if not linear:
            if sr_ratio > 1:
                self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
                self.norm = nn.LayerNorm(dim)
        else:
            self.pool = nn.AdaptiveAvgPool2d(7)
            self.sr = nn.Conv2d(dim, dim, kernel_size=1, stride=1)
            self.norm = nn.LayerNorm(dim)
            self.act = nn.GELU()
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv2d):
            fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
            fan_out //= m.groups
            m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
            if m.bias is not None:
                m.bias.data.zero_()

    def forward(self, x, H, W):
        B, N, C = x.shape
        q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)

        if not self.linear:
            if self.sr_ratio > 1:
                x_ = x.permute(0, 2, 1).reshape(B, C, H, W)
                x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1)
                x_ = self.norm(x_)
                kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
            else:
                kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        else:
            x_ = x.permute(0, 2, 1).reshape(B, C, H, W)
            x_ = self.sr(self.pool(x_)).reshape(B, C, -1).permute(0, 2, 1)
            x_ = self.norm(x_)
            x_ = self.act(x_)
            kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        k, v = kv[0], kv[1]

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)

        return x

class Block(nn.Module):

    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
                 drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1, linear=False):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = Attention(
            dim,
            num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
            attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio, linear=linear)
        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, linear=linear)

        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv2d):
            fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
            fan_out //= m.groups
            m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
            if m.bias is not None:
                m.bias.data.zero_()

    def forward(self, x, H, W):
        x = x + self.drop_path(self.attn(self.norm1(x), H, W))
        x = x + self.drop_path(self.mlp(self.norm2(x), H, W))

        return x

class OverlapPatchEmbed(nn.Module):
    """ Image to Patch Embedding
    """

    def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768):
        super().__init__()

        img_size = to_2tuple(img_size)
        patch_size = to_2tuple(patch_size)

        assert max(patch_size) > stride, "Set larger patch_size than stride"

        self.img_size = img_size
        self.patch_size = patch_size
        self.H, self.W = img_size[0] // stride, img_size[1] // stride
        self.num_patches = self.H * self.W
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride,
                              padding=(patch_size[0] // 2, patch_size[1] // 2))
        self.norm = nn.LayerNorm(embed_dim)

        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv2d):
            fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
            fan_out //= m.groups
            m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
            if m.bias is not None:
                m.bias.data.zero_()

    def forward(self, x):
        x = self.proj(x)
        _, _, H, W = x.shape
        x = x.flatten(2).transpose(1, 2)
        x = self.norm(x)

        return x, H, W


class PyramidVisionTransformerV2(nn.Module):
    def __init__(self, img_size=256, patch_size=16, in_chans=3, num_classes=1000, embed_dims=[64, 128, 256, 512],
                 num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0.,
                 attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm,
                 depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], num_stages=4, linear=False):
        super().__init__()
        self.num_classes = num_classes
        self.depths = depths
        self.num_stages = num_stages

        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]  # stochastic depth decay rule
        cur = 0

        for i in range(num_stages):
            patch_embed = OverlapPatchEmbed(img_size=img_size if i == 0 else img_size // (2 ** (i + 1)),
                                            patch_size=7 if i == 0 else 3,
                                            stride=4 if i == 0 else 2,
                                            in_chans=in_chans if i == 0 else embed_dims[i - 1],
                                            embed_dim=embed_dims[i])

            block = nn.ModuleList([Block(
                dim=embed_dims[i], num_heads=num_heads[i], mlp_ratio=mlp_ratios[i], qkv_bias=qkv_bias, qk_scale=qk_scale,
                drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + j], norm_layer=norm_layer,
                sr_ratio=sr_ratios[i], linear=linear)
                for j in range(depths[i])])
            norm = norm_layer(embed_dims[i])
            cur += depths[i]

            setattr(self, f"patch_embed{i + 1}", patch_embed)
            setattr(self, f"block{i + 1}", block)
            setattr(self, f"norm{i + 1}", norm)

        # classification head
        self.head = nn.Linear(embed_dims[3], num_classes) if num_classes > 0 else nn.Identity()

        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv2d):
            fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
            fan_out //= m.groups
            m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
            if m.bias is not None:
                m.bias.data.zero_()

    def freeze_patch_emb(self):
        self.patch_embed1.requires_grad = False

    @torch.jit.ignore
    def no_weight_decay(self):
        return {'pos_embed1', 'pos_embed2', 'pos_embed3', 'pos_embed4', 'cls_token'}  # has pos_embed may be better

    def get_classifier(self):
        return self.head

    def reset_classifier(self, num_classes, global_pool=''):
        self.num_classes = num_classes
        self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()

    def forward_features(self, x):
        B = x.shape[0]

        for i in range(self.num_stages):
            patch_embed = getattr(self, f"patch_embed{i + 1}")
            block = getattr(self, f"block{i + 1}")
            norm = getattr(self, f"norm{i + 1}")
            x, H, W = patch_embed(x)
            for blk in block:
                x = blk(x, H, W)
            x = norm(x)
            if i != self.num_stages - 1:
                x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()

        return x.mean(dim=1)

    def forward(self, x):
        x = self.forward_features(x)
        x = self.head(x)

        return x


class DWConv(nn.Module):
    def __init__(self, dim=768):
        super(DWConv, self).__init__()
        self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)

    def forward(self, x, H, W):
        B, N, C = x.shape
        x = x.transpose(1, 2).view(B, C, H, W)
        x = self.dwconv(x)
        x = x.flatten(2).transpose(1, 2)

        return x


def _conv_filter(state_dict, patch_size=16):
    """ convert patch embedding weight from manual patchify + linear proj to conv"""
    out_dict = {}
    for k, v in state_dict.items():
        if 'patch_embed.proj.weight' in k:
            v = v.reshape((v.shape[0], 3, patch_size, patch_size))
        out_dict[k] = v

    return out_dict

@register_model
def pvt_v2_b0(pretrained=False, **kwargs):
    model = PyramidVisionTransformerV2(
        patch_size=4, embed_dims=[32, 64, 160, 256], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1],
        **kwargs)
    model.default_cfg = _cfg()

    return model


@register_model
def pvt_v2_b1(pretrained=False, **kwargs):
    model = PyramidVisionTransformerV2(
        patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1],
        **kwargs)
    model.default_cfg = _cfg()

    return model


@register_model
def pvt_v2_b2(pretrained=False, **kwargs):
    model = PyramidVisionTransformerV2(
        patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], **kwargs)
    model.default_cfg = _cfg()

    return model


@register_model
def pvt_v2_b3(pretrained=False, **kwargs):
    model = PyramidVisionTransformerV2(
        patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1],
        **kwargs)
    model.default_cfg = _cfg()

    return model


@register_model
def pvt_v2_b4(pretrained=False, **kwargs):
    model = PyramidVisionTransformerV2(
        patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 8, 27, 3], sr_ratios=[8, 4, 2, 1],
        **kwargs)
    model.default_cfg = _cfg()

    return model


@register_model
def pvt_v2_b5(pretrained=False, **kwargs):
    model = PyramidVisionTransformerV2(
        patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 6, 40, 3], sr_ratios=[8, 4, 2, 1],
        **kwargs)
    model.default_cfg = _cfg()

    return model


@register_model
def pvt_v2_b2_li(pretrained=False, **kwargs):
    model = PyramidVisionTransformerV2(
        patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], linear=True, **kwargs)
    model.default_cfg = _cfg()

    return model


## cbam, paspp

In [None]:
class BasicConv(nn.Module):
    def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True, bn=True, bias=False):
        super(BasicConv, self).__init__()
        self.out_channels = out_planes
        self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)
        self.bn = nn.BatchNorm2d(out_planes,eps=1e-5, momentum=0.01, affine=True) if bn else None
        self.relu = nn.ReLU() if relu else None

    def forward(self, x):
        x = self.conv(x)
        if self.bn is not None:
            x = self.bn(x)
        if self.relu is not None:
            x = self.relu(x)
        return x

class Flatten(nn.Module):
    def forward(self, x):
        return x.view(x.size(0), -1)

class ChannelGate(nn.Module):
    def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max']):
        super(ChannelGate, self).__init__()
        self.gate_channels = gate_channels
        self.mlp = nn.Sequential(
            Flatten(),
            nn.Linear(gate_channels, gate_channels // reduction_ratio),
            nn.ReLU(),
            nn.Linear(gate_channels // reduction_ratio, gate_channels)
            )
        self.pool_types = pool_types
    def forward(self, x):
        channel_att_sum = None
        for pool_type in self.pool_types:
            if pool_type=='avg':
                avg_pool = F.avg_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
                channel_att_raw = self.mlp( avg_pool )
            elif pool_type=='max':
                max_pool = F.max_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
                channel_att_raw = self.mlp( max_pool )
            elif pool_type=='lp':
                lp_pool = F.lp_pool2d( x, 2, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
                channel_att_raw = self.mlp( lp_pool )
            elif pool_type=='lse':
                # LSE pool only
                lse_pool = logsumexp_2d(x)
                channel_att_raw = self.mlp( lse_pool )

            if channel_att_sum is None:
                channel_att_sum = channel_att_raw
            else:
                channel_att_sum = channel_att_sum + channel_att_raw

        scale = F.sigmoid( channel_att_sum ).unsqueeze(2).unsqueeze(3).expand_as(x)
        return x * scale

def logsumexp_2d(tensor):
    tensor_flatten = tensor.view(tensor.size(0), tensor.size(1), -1)
    s, _ = torch.max(tensor_flatten, dim=2, keepdim=True)
    outputs = s + (tensor_flatten - s).exp().sum(dim=2, keepdim=True).log()
    return outputs

class ChannelPool(nn.Module):
    def forward(self, x):
        return torch.cat( (torch.max(x,1)[0].unsqueeze(1), torch.mean(x,1).unsqueeze(1)), dim=1 )

class SpatialGate(nn.Module):
    def __init__(self):
        super(SpatialGate, self).__init__()
        kernel_size = 7
        self.compress = ChannelPool()
        self.spatial = BasicConv(2, 1, kernel_size, stride=1, padding=(kernel_size-1) // 2, relu=False)
    def forward(self, x):
        x_compress = self.compress(x)
        x_out = self.spatial(x_compress)
        scale = F.sigmoid(x_out) # broadcasting
        return x * scale

class CBAM(nn.Module):
    def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max'], no_spatial=False):
        super(CBAM, self).__init__()
        self.ChannelGate = ChannelGate(gate_channels, reduction_ratio, pool_types)
        self.no_spatial=no_spatial
        if not no_spatial:
            self.SpatialGate = SpatialGate()
    def forward(self, x):
        x_out = self.ChannelGate(x)
        if not self.no_spatial:
            x_out = self.SpatialGate(x_out)
        return x_out

In [None]:
class PASPP(nn.Module):
    def __init__(self, inplanes, outplanes, output_stride=4, BatchNorm=nn.BatchNorm2d):
        super().__init__()
        if output_stride == 4:
            dilations = [1, 6, 12, 18]
        elif output_stride == 8:
            dilations = [1, 4, 6, 10]
        elif output_stride == 2:
            dilations = [1, 12, 24, 36]
        elif output_stride == 16:
            dilations = [1, 2, 3, 4]
        elif output_stride == 1:
            dilations = [1, 16, 32, 48]
        else:
            raise NotImplementedError
        self._norm_layer = BatchNorm
        self.silu = nn.SiLU(inplace=True)
        self.conv1 = self._make_layer(inplanes, inplanes // 4)
        self.conv2 = self._make_layer(inplanes, inplanes // 4)
        self.conv3 = self._make_layer(inplanes, inplanes // 4)
        self.conv4 = self._make_layer(inplanes, inplanes // 4)
        self.atrous_conv1 = nn.Conv2d(inplanes // 4, inplanes // 4, kernel_size=3, dilation=dilations[0], padding=dilations[0])
        self.atrous_conv2 = nn.Conv2d(inplanes // 4, inplanes // 4, kernel_size=3, dilation=dilations[1], padding=dilations[1])
        self.atrous_conv3 = nn.Conv2d(inplanes // 4, inplanes // 4, kernel_size=3, dilation=dilations[2], padding=dilations[2])
        self.atrous_conv4 = nn.Conv2d(inplanes // 4, inplanes // 4, kernel_size=3, dilation=dilations[3], padding=dilations[3])
        self.conv5 = self._make_layer(inplanes // 2, inplanes // 2)
        self.conv6 = self._make_layer(inplanes // 2, inplanes // 2)
        self.convout = self._make_layer(inplanes, inplanes)

    def _make_layer(self, inplanes, outplanes):
        layer = []
        layer.append(nn.Conv2d(inplanes, outplanes, kernel_size = 1))
        layer.append(self._norm_layer(outplanes))
        layer.append(self.silu)
        return nn.Sequential(*layer)

    def forward(self, X):
        x1 = self.conv1(X)
        x2 = self.conv2(X)
        x3 = self.conv3(X)
        x4 = self.conv4(X)

        x12 = torch.add(x1, x2)
        x34 = torch.add(x3, x4)

        x1 = torch.add(self.atrous_conv1(x1),x12)
        x2 = torch.add(self.atrous_conv2(x2),x12)
        x3 = torch.add(self.atrous_conv3(x3),x34)
        x4 = torch.add(self.atrous_conv4(x4),x34)

        x12 = torch.cat([x1, x2], dim = 1)
        x34 = torch.cat([x3, x4], dim = 1)

        x12 = self.conv5(x12)
        x34 = self.conv5(x34)
        x = torch.cat([x12, x34], dim=1)
        x = self.convout(x)
        return x

## meta former block

In [None]:
class Mlp2(nn.Module):
    """ MLP as used in MetaFormer models, eg Transformer, MLP-Mixer, PoolFormer, MetaFormer baslines and related networks.
    Mostly copied from timm.
    """
    def __init__(self, dim, mlp_ratio=4, out_features=None, act_layer=nn.GELU, drop=0., bias=False, **kwargs):
        super().__init__()
        in_features = dim
        out_features = out_features or in_features
        hidden_features = int(mlp_ratio * in_features)
        drop_probs = to_2tuple(drop)

        self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
        self.act = act_layer()
        self.drop1 = nn.Dropout(drop_probs[0])
        self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
        self.drop2 = nn.Dropout(drop_probs[1])

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop1(x)
        x = self.fc2(x)
        x = self.drop2(x)
        return x

class DepthwiseBlock(nn.Module):
  def __init__(self,dim):
    super().__init__()

    self.norm = nn.BatchNorm2d(dim)
    self.act = nn.GELU()
    self.dwconv1 = nn.Conv2d(dim,dim,kernel_size = 3, stride = 1, padding = 1, groups = dim, bias = False)
    self.dwconv2 = nn.Conv2d(dim,dim,kernel_size = 5, stride = 1, padding = 2, groups = dim, bias = False)
    self.dwconv3 = nn.Conv2d(dim,dim,kernel_size = 7, stride = 1, padding = 3, groups = dim, bias = False)

  def forward(self,x):

    x = self.norm(x)
    x1 = self.dwconv1(x)
    x2 = self.dwconv2(x)
    x3 = self.dwconv3(x)
    x = x1 + x2 + x3
    x = self.act(x)

    return x

class ConvMLPMixer(nn.Module):
  def __init__(self,dim,mlp=Mlp2,norm_layer=nn.LayerNorm,
                 drop=0., drop_path=0.1,):
    super().__init__()

    self.dw = DepthwiseBlock(dim)
    self.norm = norm_layer(dim)
    self.mlp = mlp(dim=dim,drop=drop)
    self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()

  def forward(self,x):

    x = x + self.dw(x)
    x = x.permute(0, 2, 3, 1) # (B, C, H, W) -> (B, H, W, C)
    x = x + self.drop_path(self.mlp(self.norm(x)))
    x = x.permute(0, 3, 1, 2) #  (B, H, W, C) -> (B, C, H, W)

    return x



## Convmixer

In [None]:
class activation_block(nn.Module):
  def __init__(self, dim):
    super(activation_block, self).__init__()
    self.gelu = nn.GELU()
    self.batchnorm = nn.BatchNorm2d(dim)

  def forward(self, x):
    x = self.gelu(x)
    x = self.batchnorm(x)
    return x

class DepthwiseConv2d(nn.Module):
  def __init__(self, dim):
    super(DepthwiseConv2d, self).__init__()
    self.depthwise = nn.Conv2d(dim, dim, kernel_size=3, padding=1, groups=dim)

  def forward(self, x):
    out = self.depthwise(x)
    return out

class ConvMixer(nn.Module):
  def __init__(self, dim, kernels_size=1):
    super(ConvMixer, self).__init__()
    self.depthwise = DepthwiseConv2d(dim)
    self.pointwise = nn.Conv2d(dim,dim, kernel_size=1)
    self.activation = activation_block(dim)

  def forward(self, x):
    #Depthwise convolution
    x0 = x
    x = self.depthwise(x)
    x = self.activation(x)
    x = x + x0 #Residual

    #Pointwise convolution
    x = self.pointwise(x)
    x = self.activation(x)
    return x

## MLP mixer

In [None]:
!pip install einops

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting einops
  Downloading einops-0.6.1-py3-none-any.whl (42 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m42.2/42.2 kB[0m [31m5.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: einops
Successfully installed einops-0.6.1


In [None]:
from einops.layers.torch import Rearrange

class SpatialMLP(nn.Module):

    def __init__(self, N, mlp_ratio=4, patch_size=(4, 4), **kwargs):             # N = H * W
        super().__init__()
        self.p1 = patch_size[0]
        self.p2 = patch_size[1]
        N1 = N // (patch_size[0] * patch_size[1])
        self.fc1 = nn.Linear(N1, N1 * mlp_ratio, bias=False)
        self.fc2 = nn.Linear(N1 * mlp_ratio, N1, bias=False)
        self.gelu = nn.GELU()
        #self.rearrange1 = Rearrange('b c (h p1) (w p2) -> b (p1 p2 c) (h w)', p1 = p1, p2 = p2)
        #self.rearrange2 = Rearrange('b (p1 p2 c) (h w) -> b c (h p1) (w p2)', p1 = p1, p2 = p2)


    def forward(self, x):
        B, H, W, C = x.shape
        x = Rearrange('b (h p1) (w p2) c -> b (p1 p2 c) (h w)', p1= self.p1, p2 = self.p2)(x)    # (B, p1*p2*C, H*W/(p1*p2))
        x = self.fc1(x)                     # (B, p1*p2*C, H * W//(p1*p2) * 4)
        x = self.gelu(x)
        x = self.fc2(x)
        x = self.gelu(x)                    # (B, p1*p2*C, H * W/(p1*p2))
        x = Rearrange('b (p1 p2 c) (h w) -> b (h p1) (w p2) c', p1= self.p1, p2 = self.p2, h=H//self.p1, w=W//self.p2)(x)
        return x

In [None]:
class MLPMixer(nn.Module):
    """
    Implementation of one MetaFormer block.
    Input: [B, C, H, W]
    """
    def __init__(self, dim, N = 8*8,
                 token_mixer=nn.Identity, mlp=Mlp2,
                 norm_layer=nn.LayerNorm,
                 drop=0., drop_path=0.1,
                 num_heads = 8, sr_ratio=1
                 ):

        super().__init__()

        self.norm1 = norm_layer(dim)
        self.token_mixer = SpatialMLP(N=N)
        self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()

        self.norm2 = norm_layer(dim)
        self.mlp = mlp(dim=dim, drop=drop)
        self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()

    def forward(self, x):

        x = x.permute(0, 2, 3, 1) # (B, C, H, W) -> (B, H, W, C)
        x = x + self.drop_path1(self.token_mixer(self.norm1(x)))
        x = x + self.drop_path2(self.mlp(self.norm2(x)))
        x = x.permute(0, 3, 1, 2) #  (B, H, W, C) -> (B, C, H, W)

        return x

## Attention Map

In [None]:
class MapReduce(nn.Module):
    """
    Reduce feature maps into a single edge map
    """
    def __init__(self, channels):
        super(MapReduce, self).__init__()
        self.conv = nn.Conv2d(channels, 1, kernel_size=1, padding=0)
        nn.init.constant_(self.conv.bias, 0)
        # self.bn = nn.BatchNorm2d(1)
        # self.relu  = nn.ReLU(inplace=True)

    def forward(self, x):
        return self.conv(x)


class Attention_img(nn.Module):
  """Apply attention mechanism to output images"""

  def __init__(self):
    super().__init__()
    self.bn = nn.BatchNorm2d(1)
    self.relu  = nn.ReLU(inplace=True)
    self.sigmoid = nn.Sigmoid()
    self.conv_out = nn.Conv2d(2,1,kernel_size=1, bias=True)
  def forward(self, x1, x2):
    # x = torch.cat((x1,x2),dim=1)
    # x = self.conv_out(x)
    x = self.bn(x1 + x2)
    x = self.relu(x)
    x = self.sigmoid(x)
    return x1*x+x2

# Model

In [None]:
class PVTFormerNet(nn.Module):

  def __init__(self,backbone,attention = True,convmix=False,convMLP=True,MLPmix=False):
    super().__init__()
    self.backbone = torch.nn.Sequential(*list(backbone.children()))[:-1]
    for i in [1, 4, 7, 10]:
        self.backbone[i] = torch.nn.Sequential(*list(self.backbone[i].children()))

    self.convmix = convmix
    self.convMLP = convMLP
    self.MLPmix = MLPmix
    self.attention = attention
    self.paspp = PASPP(512,512)

    if self.convmix:
      self.connect1 = ConvMixer(dim = 512)
      self.connect2 = ConvMixer(dim = 320)
      self.connect3 = ConvMixer(dim = 128)
      self.connect4 = ConvMixer(dim = 64)

    if self.MLPmix:
      self.connect1 = MLPMixer(dim = 512,N=64)
      self.connect2 = MLPMixer(dim = 320,N=16*16)
      self.connect3 = MLPMixer(dim = 128,N=32*32)
      self.connect4 = MLPMixer(dim = 64,N=64*64)

    if self.convMLP:
      self.connect1 = ConvMLPMixer(dim=512)
      self.connect2 = ConvMLPMixer(dim=320)
      self.connect3 = ConvMLPMixer(dim=128)
      self.connect4 = ConvMLPMixer(dim=64)

    self.CAM1 = ChannelGate(512)
    self.CAM2 = ChannelGate(320)
    self.CAM3 = ChannelGate(128)
    self.CAM4 = ChannelGate(64)

    self.SA = SpatialGate()

    self.up1 = up_conv(512,320)
    self.up2 = up_conv(320,128)
    self.up3 = up_conv(128,64)

    self.upconv1 = conv_block(1024,512)
    self.upconv2 = conv_block(640,320)
    self.upconv3 = conv_block(256,128)
    self.upconv4 = conv_block(128,64)

    self.mapreduce1 = MapReduce(512)
    self.mapreduce2 = MapReduce(320)
    self.mapreduce3 = MapReduce(128)
    self.mapreduce4 = MapReduce(64)

    self.attn_img = Attention_img()

    self.conv = nn.Conv2d(4,1,kernel_size=1)
    self.sigmoid = nn.Sigmoid()

  def get_pyramid(self,x):
      pyramid = []
      B = x.shape[0]
      for i, module in enumerate(self.backbone):
          if i in [0, 3, 6, 9]:
              x, H, W = module(x)
          elif i in [1, 4, 7,10]:
              for sub_module in module:
                  x = sub_module(x, H, W)
          else:
              x = module(x)
              x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
              pyramid.append(x)
      return pyramid

  def attentionmap(self,x1,x2,x3,x4):
    if self.attention:
      a3 = self.attn_img(x3, x4)
      a2 = self.attn_img(x2, a3)
      a1 = self.attn_img(x1, a2)
      x = self.sigmoid(a1)
      output = [x, self.sigmoid(a2), self.sigmoid(a3), self.sigmoid(x4),  self.sigmoid(x3), self.sigmoid(x2), self.sigmoid(x1)]
    else:
      x = torch.cat([x1,x2,x3,x4],dim=1)
      x = self.conv(x)
      output = self.sigmoid(x)
    return output


  def forward(self,x):

    H, W = x.size()[2:]

    pyramid=self.get_pyramid(x)


    pyramid3=self.paspp(pyramid[3])           #(512,8,8)
    pyramid[3]=self.connect1(pyramid[3])   #(512,8,8)
    pyramid[3]=self.CAM1(pyramid[3])          #(512,8,8)
    pyramid[3]=torch.cat((pyramid[3],pyramid3),dim=1) #(1024,8,8)
    pyramid[3]=self.upconv1(pyramid[3])       #(512,8,8)
    pyramid[3]=self.SA(pyramid[3])            #(512,8,8)

    x1 = self.mapreduce1(pyramid[3])          #(1,8,8)
    x1 = F.interpolate(x1, (H, W), mode="bilinear", align_corners=False) #(1,256,256)

    pyramid[2]=self.connect2(pyramid[2])   #(320,16,16)
    pyramid[2]=self.CAM2(pyramid[2])
    pyramid[3]=self.up1(pyramid[3])           #(320,16,16)
    pyramid[2]=torch.cat((pyramid[3],pyramid[2]),dim=1) #(640,16,16)
    pyramid[2]=self.upconv2(pyramid[2])       #(320,16,16)
    pyramid[2]=self.SA(pyramid[2])

    x2 = self.mapreduce2(pyramid[2])          #(1,16,16)
    x2 = F.interpolate(x2, (H, W), mode="bilinear", align_corners=False) #(1,256,256)

    pyramid[1]=self.connect3(pyramid[1])   #(128,32,32)
    pyramid[1]=self.CAM3(pyramid[1])
    pyramid[2]=self.up2(pyramid[2])           #(128,32,32)
    pyramid[1]=torch.cat((pyramid[2],pyramid[1]),dim=1) #(256,32,32)
    pyramid[1]=self.upconv3(pyramid[1])       #(128,32,32)
    pyramid[1]=self.SA(pyramid[1])

    x3 = self.mapreduce3(pyramid[1])          #(1,32,32)
    x3 = F.interpolate(x3, (H, W), mode="bilinear", align_corners=False) #(1,256,256)

    pyramid[0]=self.connect4(pyramid[0])   #(64,64,64)
    pyramid[0]=self.CAM4(pyramid[0])
    pyramid[1]=self.up3(pyramid[1])           #(64,64,64)
    pyramid[0]=torch.cat((pyramid[1],pyramid[0]),dim=1) #(128,64,64)
    pyramid[0]=self.upconv4(pyramid[0])       #(64,64,64)
    pyramid[0]=self.SA(pyramid[0])

    x4 = self.mapreduce4(pyramid[0])          #(1,64,64)
    x4 = F.interpolate(x4, (H, W), mode="bilinear", align_corners=False) #(1,256,256)

    output = self.attentionmap(x1,x2,x3,x4)

    return output

backbone = pvt_v2_b3()
backbone.load_state_dict(torch.load(PRETRAINED_PATH))

model = PVTFormerNet(backbone=backbone,attention=True,convmix=False,convMLP=True,MLPmix=False)

In [None]:
x = torch.rand(1,3,256,256)
y = model(x)[0]
y.shape

  return F.conv2d(input, weight, bias, self.stride,


torch.Size([1, 1, 256, 256])

# Metrics

In [None]:
def iou_score_fn(output, target):
    smooth = 1e-5

    # if torch.is_tensor(output):
    #    output = torch.sigmoid(output).data.cpu().numpy()
    # if torch.is_tensor(target):
    #    target = target.data.cpu().numpy()

    output = output.data.cpu().numpy()
    target = target.data.cpu().numpy()

    output_ = output > 0.5
    target_ = target > 0.5
    intersection = (output_ & target_).sum()
    union = (output_ | target_).sum()

    return (intersection + smooth) / (union + smooth)

def dice_coef_fn(output, target):
    smooth = 1

    output = output.view(-1)
    target = target.view(-1)
    intersection = (output * target).sum()

    return (2. * intersection + smooth) / (output.sum() + target.sum() + smooth)

class DiceLoss(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(DiceLoss, self).__init__()

    def forward(self, inputs, targets, smooth=1):
        inputs = inputs.view(-1)
        targets = targets.view(-1)

        intersection = (inputs * targets).sum()
        dice = (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth)

        return 1 - dice

dice_loss = DiceLoss()
def calc_loss(pred, target, bce_weight=0.5):
    bce = F.binary_cross_entropy_with_logits(pred, target)
    dice =  dice_loss(pred, target)
    loss = bce * bce_weight + dice * (1 - bce_weight)

    return loss

# Train

In [None]:
!pip install pytorch-lightning &> /dev/null
import pytorch_lightning as pl

class Segmentor(pl.LightningModule):
    def __init__(self, model = model):
        super().__init__()
        self.model = model
    def forward(self, x):
        return self.model(x)

    def get_metrics(self):
        # don't show the version number
        items = super().get_metrics()
        items.pop("v_num", None)
        return items

    def _step(self, batch):
        image, y_true = batch
        y_pred = self.model(image)[0]
        loss = calc_loss(y_true.float(), y_pred)
        dice_score = dice_coef_fn(y_pred, y_true)
        iou_score = iou_score_fn(y_pred, y_true)
        return loss, dice_score, iou_score

    def training_step(self, batch, batch_idx):
        loss, dice_score, iou_score = self._step(batch)
        metrics = {"loss": loss, "train_dice": dice_score, "train_iou": iou_score}
        self.log_dict(metrics, on_step=False, on_epoch=True, prog_bar = True)
        return loss

    def validation_step(self, batch, batch_idx):
        loss, dice_score, iou_score = self._step(batch)
        metrics = {"val_loss": loss, "val_dice": dice_score, "val_iou": iou_score}
        self.log_dict(metrics, prog_bar = True)
        return metrics

    def test_step(self, batch, batch_idx):
        loss, dice_score, iou_score = self._step(batch)
        metrics = {"loss":loss, "test_dice": dice_score, "test_iou": iou_score}
        self.log_dict(metrics, prog_bar = True)
        return metrics


    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-5)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="max",
                                                         factor = 0.5, patience=10, verbose =True)
        lr_schedulers = {"scheduler": scheduler, "monitor": "val_dice"}
        return [optimizer], lr_schedulers


In [None]:
import csv
class HistoryLogger(pl.callbacks.Callback):
    def __init__(self, dir = "history_rivf.csv"):
        self.dir = dir
    def on_validation_epoch_end(self, trainer, pl_module):
        metrics = trainer.callback_metrics
        if "loss_epoch" in metrics.keys():
            logs = {"epoch": trainer.current_epoch}
            keys = ["loss_epoch", "train_dice_epoch", "val_loss","val_dice"
                    ]
            for key in keys:
                logs[key] = metrics[key].item()
            header = list(logs.keys())
            isFile = os.path.isfile(self.dir)
            with open(self.dir, 'a', newline='') as csvfile:
                writer = csv.DictWriter(csvfile, fieldnames=header)
                if not isFile:
                    writer.writeheader()
                writer.writerow(logs)
        else:
            pass

In [None]:
os.makedirs('/content/drive/MyDrive/polyp/clinic', exist_ok = True)
check_point = pl.callbacks.model_checkpoint.ModelCheckpoint("/content/drive/MyDrive/polyp/clinic", filename="ckpt{val_dice:0.4f}",
                                                            monitor="val_dice", mode = "max", save_top_k =1,
                                                            verbose=True, save_weights_only=True,
                                                            auto_insert_metric_name=False,)
progress_bar = pl.callbacks.TQDMProgressBar()
logger = HistoryLogger()
PARAMS = {"benchmark": True, "enable_progress_bar" : True,"logger":False,
        #   "callbacks" : [progress_bar],
        #    "overfit_batches" :1,
          "callbacks" : [check_point, progress_bar, logger],
          "log_every_n_steps" :1, "num_sanity_val_steps":0, "max_epochs":150,
          "precision":16,
          }

trainer = pl.Trainer(**PARAMS)
segmentor = Segmentor(model=model)
trainer.fit(segmentor, trainloader, testloader4)

  rank_zero_warn(
INFO:pytorch_lightning.utilities.rank_zero:Using 16bit Automatic Mixed Precision (AMP)
INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name  | Type         | Params
---------------------------------------
0 | model | PVTFormerNet | 60.4 M
---------------------------------------
60.4 M    Trainable params
0         Non-trainable params
60.4 M    Total params
241.793   Total estimated model params size (MB)


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 0, global step 91: 'val_dice' reached 0.90036 (best 0.90036), saving model to '/content/drive/MyDrive/polyp/clinic/ckpt0.9004.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 1, global step 182: 'val_dice' reached 0.90208 (best 0.90208), saving model to '/content/drive/MyDrive/polyp/clinic/ckpt0.9021.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 2, global step 273: 'val_dice' reached 0.90830 (best 0.90830), saving model to '/content/drive/MyDrive/polyp/clinic/ckpt0.9083.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 3, global step 364: 'val_dice' was not in top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 4, global step 455: 'val_dice' reached 0.90834 (best 0.90834), saving model to '/content/drive/MyDrive/polyp/clinic/ckpt0.9083.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 5, global step 546: 'val_dice' reached 0.91288 (best 0.91288), saving model to '/content/drive/MyDrive/polyp/clinic/ckpt0.9129.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 6, global step 637: 'val_dice' was not in top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 7, global step 728: 'val_dice' was not in top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 8, global step 819: 'val_dice' reached 0.91327 (best 0.91327), saving model to '/content/drive/MyDrive/polyp/clinic/ckpt0.9133.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 9, global step 910: 'val_dice' reached 0.92270 (best 0.92270), saving model to '/content/drive/MyDrive/polyp/clinic/ckpt0.9227.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 10, global step 1001: 'val_dice' was not in top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 11, global step 1092: 'val_dice' was not in top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 12, global step 1183: 'val_dice' was not in top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 13, global step 1274: 'val_dice' was not in top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 14, global step 1365: 'val_dice' was not in top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 15, global step 1456: 'val_dice' was not in top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 16, global step 1547: 'val_dice' was not in top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 17, global step 1638: 'val_dice' was not in top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 18, global step 1729: 'val_dice' reached 0.92658 (best 0.92658), saving model to '/content/drive/MyDrive/polyp/clinic/ckpt0.9266.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 19, global step 1820: 'val_dice' was not in top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 20, global step 1911: 'val_dice' was not in top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 21, global step 2002: 'val_dice' was not in top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 22, global step 2093: 'val_dice' was not in top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 23, global step 2184: 'val_dice' was not in top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 24, global step 2275: 'val_dice' was not in top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 25, global step 2366: 'val_dice' was not in top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 26, global step 2457: 'val_dice' was not in top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 27, global step 2548: 'val_dice' was not in top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 28, global step 2639: 'val_dice' was not in top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 29, global step 2730: 'val_dice' was not in top 1


Epoch 00030: reducing learning rate of group 0 to 5.0000e-06.


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 30, global step 2821: 'val_dice' was not in top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 31, global step 2912: 'val_dice' was not in top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 32, global step 3003: 'val_dice' was not in top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 33, global step 3094: 'val_dice' was not in top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 34, global step 3185: 'val_dice' was not in top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 35, global step 3276: 'val_dice' was not in top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 36, global step 3367: 'val_dice' was not in top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 37, global step 3458: 'val_dice' was not in top 1
INFO:pytorch_lightning.utilities.rank_zero:Epoch 38, global step 3549: 'val_dice' was not in top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 39, global step 3640: 'val_dice' was not in top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 40, global step 3731: 'val_dice' was not in top 1


Epoch 00041: reducing learning rate of group 0 to 2.5000e-06.


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 41, global step 3822: 'val_dice' was not in top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 42, global step 3913: 'val_dice' was not in top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 43, global step 4004: 'val_dice' was not in top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 44, global step 4095: 'val_dice' was not in top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 45, global step 4186: 'val_dice' was not in top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 46, global step 4277: 'val_dice' was not in top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 47, global step 4368: 'val_dice' was not in top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 48, global step 4459: 'val_dice' was not in top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 49, global step 4550: 'val_dice' was not in top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 50, global step 4641: 'val_dice' was not in top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 51, global step 4732: 'val_dice' was not in top 1


Epoch 00052: reducing learning rate of group 0 to 1.2500e-06.


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 52, global step 4823: 'val_dice' was not in top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 53, global step 4914: 'val_dice' was not in top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 54, global step 5005: 'val_dice' was not in top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 55, global step 5096: 'val_dice' was not in top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 56, global step 5187: 'val_dice' was not in top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 57, global step 5278: 'val_dice' was not in top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 58, global step 5369: 'val_dice' was not in top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 59, global step 5460: 'val_dice' was not in top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 60, global step 5551: 'val_dice' was not in top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 61, global step 5642: 'val_dice' was not in top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 62, global step 5733: 'val_dice' was not in top 1


Epoch 00063: reducing learning rate of group 0 to 6.2500e-07.


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 63, global step 5824: 'val_dice' was not in top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 64, global step 5915: 'val_dice' was not in top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 65, global step 6006: 'val_dice' was not in top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 66, global step 6097: 'val_dice' was not in top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 67, global step 6188: 'val_dice' was not in top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 68, global step 6279: 'val_dice' was not in top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 69, global step 6370: 'val_dice' was not in top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 70, global step 6461: 'val_dice' was not in top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 71, global step 6552: 'val_dice' was not in top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 72, global step 6643: 'val_dice' was not in top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 73, global step 6734: 'val_dice' was not in top 1


Epoch 00074: reducing learning rate of group 0 to 3.1250e-07.


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 74, global step 6825: 'val_dice' was not in top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 75, global step 6916: 'val_dice' was not in top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 76, global step 7007: 'val_dice' was not in top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 77, global step 7098: 'val_dice' was not in top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 78, global step 7189: 'val_dice' was not in top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 79, global step 7280: 'val_dice' was not in top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 80, global step 7371: 'val_dice' was not in top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 81, global step 7462: 'val_dice' was not in top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 82, global step 7553: 'val_dice' was not in top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 83, global step 7644: 'val_dice' was not in top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 84, global step 7735: 'val_dice' was not in top 1


Epoch 00085: reducing learning rate of group 0 to 1.5625e-07.


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 85, global step 7826: 'val_dice' was not in top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 86, global step 7917: 'val_dice' was not in top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 87, global step 8008: 'val_dice' was not in top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 88, global step 8099: 'val_dice' was not in top 1
