In [None]:
# ---------------------------------------------------------------
# Copyright (c) 2021, NVIDIA Corporation. All rights reserved.
#
# This work is licensed under the NVIDIA Source Code License
# ---------------------------------------------------------------
import torch
import torch.nn as nn
import torch.nn.functional as F

# from timm.models.layers import DropPath, to_2tuple, trunc_normal_
# from timm.models.registry import register_model
# from timm.models.vision_transformer import _cfg
# from mmseg.models.builder import BACKBONES
# from mmseg.utils import get_root_logger
# from mmcv.runner import load_checkpoint
# import math

def drop_path(input, drop_prob: float = 0.0, training: bool = False, scale_by_keep=True):
    """
    Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
    Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
    however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
    See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
    layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
    argument.
    """
    if drop_prob == 0.0 or not training:
        return input
    keep_prob = 1 - drop_prob
    shape = (input.shape[0],) + (1,) * (input.ndim - 1)  # work with diff dim tensors, not just 2D ConvNets
    random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
    random_tensor.floor_()  # binarize
    output = input.div(keep_prob) * random_tensor
    return output

class DropPath(nn.Module):
    """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""

    def __init__(self, drop_prob = 0.0) -> None:
        super().__init__()
        self.drop_prob = drop_prob

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        return drop_path(hidden_states, self.drop_prob, self.training)

    def extra_repr(self) -> str:
        return "p={}".format(self.drop_prob)

class Mlp(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.dwconv = DWConv(hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

        # self.apply(self._init_weights)

    # def _init_weights(self, m):
    #     if isinstance(m, nn.Linear):
    #         trunc_normal_(m.weight, std=.02)
    #         if isinstance(m, nn.Linear) and m.bias is not None:
    #             nn.init.constant_(m.bias, 0)
    #     elif isinstance(m, nn.LayerNorm):
    #         nn.init.constant_(m.bias, 0)
    #         nn.init.constant_(m.weight, 1.0)
    #     elif isinstance(m, nn.Conv2d):
    #         fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
    #         fan_out //= m.groups
    #         m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
    #         if m.bias is not None:
    #             m.bias.data.zero_()

    def forward(self, x, H, W):
        x = self.fc1(x)
        x = self.dwconv(x, H, W)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x


class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1):
        super().__init__()
        assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."

        self.dim = dim
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5

        self.q = nn.Linear(dim, dim, bias=qkv_bias)
        self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        self.sr_ratio = sr_ratio
        if sr_ratio > 1:
            self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
            self.norm = nn.LayerNorm(dim)

        # self.apply(self._init_weights)

    # def _init_weights(self, m):
    #     if isinstance(m, nn.Linear):
    #         trunc_normal_(m.weight, std=.02)
    #         if isinstance(m, nn.Linear) and m.bias is not None:
    #             nn.init.constant_(m.bias, 0)
    #     elif isinstance(m, nn.LayerNorm):
    #         nn.init.constant_(m.bias, 0)
    #         nn.init.constant_(m.weight, 1.0)
    #     elif isinstance(m, nn.Conv2d):
    #         fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
    #         fan_out //= m.groups
    #         m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
    #         if m.bias is not None:
    #             m.bias.data.zero_()

    def forward(self, x, H, W):
        B, N, C = x.shape
        q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)

        if self.sr_ratio > 1:
            x_ = x.permute(0, 2, 1).reshape(B, C, H, W)
            x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1)
            x_ = self.norm(x_)
            kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        else:
            kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        k, v = kv[0], kv[1]

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)

        return x


class Block(nn.Module):

    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
                 drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = Attention(
            dim,
            num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
            attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio)
        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

        # self.apply(self._init_weights)

    # def _init_weights(self, m):
    #     if isinstance(m, nn.Linear):
    #         trunc_normal_(m.weight, std=.02)
    #         if isinstance(m, nn.Linear) and m.bias is not None:
    #             nn.init.constant_(m.bias, 0)
    #     elif isinstance(m, nn.LayerNorm):
    #         nn.init.constant_(m.bias, 0)
    #         nn.init.constant_(m.weight, 1.0)
    #     elif isinstance(m, nn.Conv2d):
    #         fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
    #         fan_out //= m.groups
    #         m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
    #         if m.bias is not None:
    #             m.bias.data.zero_()

    def forward(self, x, H, W):
        x = x + self.drop_path(self.attn(self.norm1(x), H, W))
        x = x + self.drop_path(self.mlp(self.norm2(x), H, W))

        return x

def to_2tuple(x):
    return (x,x)

class OverlapPatchEmbed(nn.Module):
    """ Image to Patch Embedding
    """

    def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768):
        super().__init__()
        img_size = to_2tuple(img_size)
        patch_size = to_2tuple(patch_size)

        self.img_size = img_size
        self.patch_size = patch_size
        self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1]
        self.num_patches = self.H * self.W
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride,
                              padding=(patch_size[0] // 2, patch_size[1] // 2))
        self.norm = nn.LayerNorm(embed_dim)

    #     self.apply(self._init_weights)

    # def _init_weights(self, m):
    #     if isinstance(m, nn.Linear):
    #         trunc_normal_(m.weight, std=.02)
    #         if isinstance(m, nn.Linear) and m.bias is not None:
    #             nn.init.constant_(m.bias, 0)
    #     elif isinstance(m, nn.LayerNorm):
    #         nn.init.constant_(m.bias, 0)
    #         nn.init.constant_(m.weight, 1.0)
    #     elif isinstance(m, nn.Conv2d):
    #         fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
    #         fan_out //= m.groups
    #         m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
    #         if m.bias is not None:
    #             m.bias.data.zero_()

    def forward(self, x):
        x = self.proj(x)
        _, _, H, W = x.shape
        x = x.flatten(2).transpose(1, 2)
        x = self.norm(x)

        return x, H, W


class MixVisionTransformer(nn.Module):
    def __init__(self, img_size=512, patch_size=5, in_chans=3, num_classes=1000, embed_dims=[32, 64, 160, 256],
                 num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=True, qk_scale=None, drop_rate=0.,
                 attn_drop_rate=0., drop_path_rate=0.1, norm_layer=nn.LayerNorm,
                 depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1]):
        super().__init__()
        self.num_classes = num_classes
        self.depths = depths

        # patch_embed
        self.patch_embed1 = OverlapPatchEmbed(img_size=img_size, patch_size=7, stride=4, in_chans=in_chans,
                                              embed_dim=embed_dims[0])
        self.patch_embed2 = OverlapPatchEmbed(img_size=img_size // 4, patch_size=3, stride=2, in_chans=embed_dims[0],
                                              embed_dim=embed_dims[1])
        self.patch_embed3 = OverlapPatchEmbed(img_size=img_size // 8, patch_size=3, stride=2, in_chans=embed_dims[1],
                                              embed_dim=embed_dims[2])
        self.patch_embed4 = OverlapPatchEmbed(img_size=img_size // 16, patch_size=3, stride=2, in_chans=embed_dims[2],
                                              embed_dim=embed_dims[3])

        # transformer encoder
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]  # stochastic depth decay rule
        cur = 0
        self.block1 = nn.ModuleList([Block(
            dim=embed_dims[0], num_heads=num_heads[0], mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias, qk_scale=qk_scale,
            drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
            sr_ratio=sr_ratios[0])
            for i in range(depths[0])])
        self.norm1 = norm_layer(embed_dims[0])

        cur += depths[0]
        self.block2 = nn.ModuleList([Block(
            dim=embed_dims[1], num_heads=num_heads[1], mlp_ratio=mlp_ratios[1], qkv_bias=qkv_bias, qk_scale=qk_scale,
            drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
            sr_ratio=sr_ratios[1])
            for i in range(depths[1])])
        self.norm2 = norm_layer(embed_dims[1])

        cur += depths[1]
        self.block3 = nn.ModuleList([Block(
            dim=embed_dims[2], num_heads=num_heads[2], mlp_ratio=mlp_ratios[2], qkv_bias=qkv_bias, qk_scale=qk_scale,
            drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
            sr_ratio=sr_ratios[2])
            for i in range(depths[2])])
        self.norm3 = norm_layer(embed_dims[2])

        cur += depths[2]
        self.block4 = nn.ModuleList([Block(
            dim=embed_dims[3], num_heads=num_heads[3], mlp_ratio=mlp_ratios[3], qkv_bias=qkv_bias, qk_scale=qk_scale,
            drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
            sr_ratio=sr_ratios[3])
            for i in range(depths[3])])
        self.norm4 = norm_layer(embed_dims[3])

        # classification head
        # self.head = nn.Linear(embed_dims[3], num_classes) if num_classes > 0 else nn.Identity()

    #     self.apply(self._init_weights)

    # def _init_weights(self, m):
    #     if isinstance(m, nn.Linear):
    #         trunc_normal_(m.weight, std=.02)
    #         if isinstance(m, nn.Linear) and m.bias is not None:
    #             nn.init.constant_(m.bias, 0)
    #     elif isinstance(m, nn.LayerNorm):
    #         nn.init.constant_(m.bias, 0)
    #         nn.init.constant_(m.weight, 1.0)
    #     elif isinstance(m, nn.Conv2d):
    #         fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
    #         fan_out //= m.groups
    #         m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
    #         if m.bias is not None:
    #             m.bias.data.zero_()

    # def init_weights(self, pretrained=None):
    #     if isinstance(pretrained, str):
    #         logger = get_root_logger()
    #         load_checkpoint(self, pretrained, map_location='cpu', strict=False, logger=logger)

    # def reset_drop_path(self, drop_path_rate):
    #     dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))]
    #     cur = 0
    #     for i in range(self.depths[0]):
    #         self.block1[i].drop_path.drop_prob = dpr[cur + i]

    #     cur += self.depths[0]
    #     for i in range(self.depths[1]):
    #         self.block2[i].drop_path.drop_prob = dpr[cur + i]

    #     cur += self.depths[1]
    #     for i in range(self.depths[2]):
    #         self.block3[i].drop_path.drop_prob = dpr[cur + i]

    #     cur += self.depths[2]
    #     for i in range(self.depths[3]):
    #         self.block4[i].drop_path.drop_prob = dpr[cur + i]

    # def freeze_patch_emb(self):
    #     self.patch_embed1.requires_grad = False

    # @torch.jit.ignore
    # def no_weight_decay(self):
    #     return {'pos_embed1', 'pos_embed2', 'pos_embed3', 'pos_embed4', 'cls_token'}  # has pos_embed may be better

    # def get_classifier(self):
    #     return self.head

    # def reset_classifier(self, num_classes, global_pool=''):
    #     self.num_classes = num_classes
    #     self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()

    def forward_features(self, x):
        B = x.shape[0]
        outs = []

        # stage 1
        x, H, W = self.patch_embed1(x)
        for i, blk in enumerate(self.block1):
            x = blk(x, H, W)
        x = self.norm1(x)
        x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
        outs.append(x)

        # stage 2
        x, H, W = self.patch_embed2(x)
        for i, blk in enumerate(self.block2):
            x = blk(x, H, W)
        x = self.norm2(x)
        x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
        outs.append(x)

        # stage 3
        x, H, W = self.patch_embed3(x)
        for i, blk in enumerate(self.block3):
            x = blk(x, H, W)
        x = self.norm3(x)
        x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
        outs.append(x)

        # stage 4
        x, H, W = self.patch_embed4(x)
        for i, blk in enumerate(self.block4):
            x = blk(x, H, W)
        x = self.norm4(x)
        x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
        outs.append(x)

        return outs

    def forward(self, x):
        x = self.forward_features(x)
        # x = self.head(x)

        return x


class DWConv(nn.Module):
    def __init__(self, dim=768):
        super(DWConv, self).__init__()
        self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)

    def forward(self, x, H, W):
        B, N, C = x.shape
        x = x.transpose(1, 2).view(B, C, H, W)
        x = self.dwconv(x)
        x = x.flatten(2).transpose(1, 2)

        return x

# ---------------------------------------------------------------
# Copyright (c) 2021, NVIDIA Corporation. All rights reserved.
#
# This work is licensed under the NVIDIA Source Code License
# ---------------------------------------------------------------


class head_MLP(nn.Module):
    """
    Linear Embedding
    """
    def __init__(self, input_dim=2048, embed_dim=768):
        super().__init__()
        self.proj = nn.Linear(input_dim, embed_dim)

    def forward(self, x):
        x = x.flatten(2).transpose(1, 2)
        x = self.proj(x)
        return x

class ConvModule(nn.Module):
    def __init__(self, embedding_dim=2048):
        super().__init__()
        self.conv = nn.Conv2d(
            in_channels=embedding_dim*4,
            out_channels=embedding_dim,
            kernel_size=1,
            bias=False
        )
        self.bn = nn.BatchNorm2d(embedding_dim)
        self.activation = nn.ReLU()

    def forward(self, x):
        x = self.conv(x)
        x = self.activation(self.bn(x))
        return x


# @HEADS.register_module()
class SegFormerHead(nn.Module):
    """
    SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers
    """
    def __init__(self,in_channels=[32, 64, 160, 256],embed_dim=256,num_class=19):
        super(SegFormerHead, self).__init__()
        # assert len(feature_strides) == len(self.in_channels)
        # assert min(feature_strides) == feature_strides[0]
        # self.feature_strides = feature_strides
        self.num_classes = num_class
        self.in_channels = in_channels
        c1_in_channels, c2_in_channels, c3_in_channels, c4_in_channels = self.in_channels

        # decoder_params = kwargs['decoder_params']
        embedding_dim = embed_dim#decoder_params['embed_dim']

        self.linear_c4 = head_MLP(input_dim=c4_in_channels, embed_dim=embedding_dim)
        self.linear_c3 = head_MLP(input_dim=c3_in_channels, embed_dim=embedding_dim)
        self.linear_c2 = head_MLP(input_dim=c2_in_channels, embed_dim=embedding_dim)
        self.linear_c1 = head_MLP(input_dim=c1_in_channels, embed_dim=embedding_dim)

        self.linear_fuse =ConvModule(
            embedding_dim=embedding_dim
        )
            # norm_cfg=dict(type='SyncBN', requires_grad=True)
        # self.batch_norm = nn.BatchNorm2d(embedding_dim)
        # self.activation = nn.ReLU()

        self.dropout = nn.Dropout(0.1)
        self.linear_pred = nn.Conv2d(embedding_dim, self.num_classes, kernel_size=1)

    def forward(self, x):
        # x = self._transform_inputs(inputs)  # len=4, 1/4,1/8,1/16,1/32
        c1, c2, c3, c4 = x

        ############## MLP decoder on C1-C4 ###########
        n, _, h, w = c4.shape

        _c4 = self.linear_c4(c4).permute(0,2,1).reshape(n, -1, c4.shape[2], c4.shape[3])
        _c4 = nn.functional.interpolate(_c4, size=c1.size()[2:],mode='bilinear',align_corners=False)

        _c3 = self.linear_c3(c3).permute(0,2,1).reshape(n, -1, c3.shape[2], c3.shape[3])
        _c3 = nn.functional.interpolate(_c3, size=c1.size()[2:],mode='bilinear',align_corners=False)

        _c2 = self.linear_c2(c2).permute(0,2,1).reshape(n, -1, c2.shape[2], c2.shape[3])
        _c2 = nn.functional.interpolate(_c2, size=c1.size()[2:],mode='bilinear',align_corners=False)

        _c1 = self.linear_c1(c1).permute(0,2,1).reshape(n, -1, c1.shape[2], c1.shape[3])

        _c = self.linear_fuse(torch.cat([_c4, _c3, _c2, _c1], dim=1))

        x = self.dropout(_c)
        x = self.linear_pred(x)

        return x,_c

class Segformer(nn.Module):
    def __init__(self):
        super().__init__()
        self.backbone = MixVisionTransformer()
        self.decode_head = SegFormerHead()
    def forward(self,x):
        x=self.backbone(x)
        x,fused_layer=self.decode_head(x)
        return x,fused_layer


In [None]:
import cv2
import numpy as np
import time
import sys
import tkinter as tk

# np.set_printoptions(threshold=sys.maxsize)

class fisheye(object):
    def __init__(self) -> None:
        # self.img = None 
        # self.fimg = None
        # self.map = None
        
        para = []
        
        self.pitch = -0.4
        self.yaw=0
        self.trans_x=1
        self.trans_y=1
        self.f0 = 700
        self.zoom = 0.5
        self.position_y = 1
        self.position_x = 1
        self.size_x = 256
        self.size_y = 256
        self.model=1

    def norm2fisheye(self,img,label=False):
        # self.img=img
        map_x,map_y = self.model1(img,size=img.shape,label=label)
        return map_x,map_y


    def model1(self,img,size=(512,512,3),label=False):
        # Orthographic
        w,h,C = size
        self.size_x,self.size_y = w/2,h/2
        # w,h=self.img.shape[1],self.img.shape[0]
        f0 = w/512*self.f0 #f0 gets bigger, distortion gets smaller
        fc = int(self.zoom*w/np.sin(np.arctan(w/(2*f0))))
        # img= self.img
        trans_x = self.trans_x
        trans_y = self.trans_y
        yaw = self.yaw
        pitch = self.pitch #pitch angle of fisheye camera
        # fc = self.fc #fisheye focal length
        rx = int(self.size_x) #image size
        ry = int(self.size_y)  
        ##build the transform map
        u=np.linspace(0,2*rx,2*rx)
        v=np.linspace(0,2*ry,2*ry)
        udst,vdst = np.meshgrid(u,v)
        v,u = vdst-self.position_y*ry+fc*np.sin(pitch)+450*fc*(1-trans_y)/f0 ,\
            udst-self.position_x*rx+fc*np.sin(yaw)+250*w/512*fc*(1-trans_x)/f0 #get proxy
        
        # rotate the fisheye sphere
        r1 = np.sqrt(fc**2-u**2)
        yc = r1*np.sin(np.arcsin(v/r1)-pitch)

        r1 = np.sqrt(fc**2-yc**2)
        filter2 = np.arcsin(u/r1)-yaw
        # filter2[filter2>np.pi/2]=None
        # filter2[filter2<-np.pi/2]=None
        xc = r1*np.sin(filter2) 

        # convert the proxy into raw image
        r = np.sqrt(xc**2+yc**2)
        r0 = f0*np.tan(np.arcsin(r/fc))
        p_theta = np.arctan2(yc,xc)
        x,y = r0*np.cos(p_theta),r0*np.sin(p_theta)

        map_x = x+w/2*trans_x
        map_y = y+h/2*trans_y
        map_y = np.array(map_y,dtype=np.float32)
        map_x = np.array(map_x,dtype=np.float32)
        # self.map_x = map_x
        # self.map_y = map_y
        #transform
        return map_x,map_y
        # if label==True:
        #     return #cv2.remap(img,map_x,map_y,cv2.INTER_NEAREST,borderValue=255)
        # else:return  #cv2.remap(img,map_x,map_y,cv2.INTER_LINEAR)


In [None]:
def calculate_metrics(pr,gt,eps=1e-7):
    IoUs = {}
    acc= {}
    sum_iou = 0
    sum_acc = 0
    for i in range(1,pr.size()[1]):
        if torch.sum(gt[:, i, :, :]) == 0 :
            # remove the unlabeled channel
            pass
        else: 
            intersection = torch.sum(gt[:,i,:,:] * pr[:,i,:,:])
            union = torch.sum(gt[:,i,:,:]) + torch.sum(pr[:,i,:,:]) - intersection + eps
            IoUs[i] = (intersection + eps) / union
            acc[i] =  (intersection + eps) / (torch.sum(gt[:,i,:,:])+eps)
            sum_iou += IoUs[i]
            sum_acc += acc[i]
            # IoUs.append((intersection + eps) / union)
            # acc.append((intersection + eps) / (torch.sum(gt[:,i,:,:])+eps))
    mIoUs = sum_iou/len(IoUs)
    macc = sum_acc/len(acc)
    
    return {"miou":mIoUs,"macc":macc,'per_iou':IoUs,'per_acc':acc}


In [None]:
from torch.utils.data import Dataset
import os
# from PIL import Image
from transformers import SegformerFeatureExtractor
from torch.utils.data import DataLoader
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.callbacks import LearningRateMonitor
import pytorch_lightning as pl
from transformers import (SegformerForSemanticSegmentation,
        SegformerModel,
        SegformerDecodeHead,
        get_constant_schedule_with_warmup,
        get_linear_schedule_with_warmup,
        get_cosine_with_hard_restarts_schedule_with_warmup)
import torch
from torch import nn
import time
import numpy as np
import sys
import cv2
# import PIL
from torchvision.datasets import Cityscapes
# from norm2fisheye import fisheye
import torch.nn.functional as F
# from meaniou import meanIOU
# from metrics import calculate_metrics
# from .model.segformer import Segformer

class mydataset(Cityscapes):
    def __init__(self,root,split,mode,target_type,test_mode=False):
        super().__init__(root,split,mode,target_type)
        # self.feature_extractor = SegformerFeatureExtractor.from_pretrained("nvidia/segformer-b0-finetuned-cityscapes-1024-1024")
        self.test_mode = test_mode
        self.void_classes = [0, 1, 2, 3, 4, 5, 6, 9, 10, 14, 15, 16, 18, 29, 30, -1]
        self.valid_classes = [7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 31, 32, 33]
        self.class_map = dict(zip(self.valid_classes, range(len(self.valid_classes))))

    def encode_segmap(self,mask):
        for _voidc in self.void_classes:
            mask[mask == _voidc] = 255
        for _validc in self.valid_classes:
            mask[mask == _validc] = self.class_map[_validc]
        return mask

    def __getitem__(self,idx):
        image = cv2.imread(self.images[idx])
        target = cv2.imread(self.targets[idx][0])
        image = cv2.resize(image,dsize=(512,512))
        target = cv2.resize(target,dsize=(512,512))
        print(target.shape)
        target = self.encode_segmap(target)[:,:,0]
        sys.exit()
        fi = fisheye()
        if self.test_mode == False:
            fi.f0 = np.random.randint(200,600)
            fi.pitch = np.random.uniform(-0.8,0)
            fi.trans_x = np.random.uniform(0.5,1.5)
        map_x,map_y = fi.norm2fisheye(image)
        # target = fi.norm2fisheye(target,label=True)
        inputs = self.feature_extractor(image,target,return_tensors='pt')
        # ori_inputs =self.feature_extractor(image,return_tensors='pt')
        for k,v in inputs.items():
            inputs[k].squeeze_()
        # ori_inputs['pixel_values'].squeeze_()

        return inputs,map_x,map_y

def convert2fisheye(imgs,grid,label=False):
    if label==False:
        outp = F.grid_sample(imgs.permute(0,3,1,2),grid=grid,mode='bilinear')
    else:
        outp = F.grid_sample(imgs.unsqueeze(dim=1),grid=grid,mode='nearest')
    return outp

def mse_loss(input, target, ignored_index=255, reduction='mean'):
    mask = target == ignored_index
    out = (input1[~mask]-target[~mask])**2
    if reduction == "mean":
        return out.mean()
    elif reduction == "None":
        return out

class SegformerFinetuner(pl.LightningModule):
    
    def __init__(self, learning_rate= 2e-4,train_dataloader=None, val_dataloader=None, test_dataloader=None, metrics_interval=100):
        super(SegformerFinetuner, self).__init__()
        # self.id2label = id2label
        self.metrics_interval = metrics_interval
        self.feature_extractor = SegformerFeatureExtractor.from_pretrained("nvidia/segformer-b0-finetuned-cityscapes-1024-1024")
        self.train_dl = train_dataloader
        self.val_dl = val_dataloader
        self.test_dl = test_dataloader
        self.learning_rate=learning_rate
        self.num_classes = 19 #len(id2label.keys())
        # self.MSEloss = mse_loss#nn.MSELoss()
        self.CEloss = nn.CrossEntropyLoss(ignore_index=255)
        self.t_model = Segformer()
        for param in self.t_model.named_parameters():
            param[1].requires_grad=False
        self.s_model = Segformer()
            # return_dict=False,
        # pretrained_dict = torch.load("/mnt/ssd/home/tianxiaofeng/segformer_train/lightning_logs/version_26/checkpoints/epoch=9-step=2050.ckpt")['state_dict']
        # pretrained_dict = {key.replace('model.',''): value for key,value in pretrained_dict.items()}
        # self.model.load_state_dict(pretrained_dict)

    def training_step(self, batch, batch_nb):
        # images,masks,map_x,map_y = batch[0]['pixel_values'],batch[1]['pixel_values'], batch[1]['labels'],batch[2],batch[3]
        images,masks,map_x,map_y = batch[0]['pixel_values'],batch[0]['labels'],batch[1],batch[2]
        B,w,h,C = images.shape
        map_x = [torch.tensor(i).float().to(images.device) for i in map_x]
        map_y = [torch.tensor(i).float().to(images.device) for i in map_y]
        map_x = torch.stack(map_x,dim=0)
        map_y = torch.stack(map_y,dim=0)
        grid = torch.stack((map_x/((w)/2)-1,map_y/((h)/2)-1),dim=3)#.unsqueeze(0)

        fimages = convert2fisheye(imgs=images.permute(0,2,3,1),grid=grid)
        masks = convert2fisheye(imgs=masks-255,grid=grid,label=True)
        masks = masks.squeeze(dim=1)+255 #padding with 255

        # print(masks.shape)
        # f_inputs = feature_extractor(fimages,return_tensors='pt')
        # ori_inputs = feature_extractor(images,return_tensors='pt')
        
        _,t_hidden = self.t_model(images)
        # s_outputs = self.s_model(pixel_values=fimages,labels=masks)
        s_outputs,s_hidden = self.s_model(fimages)
        
        t_hidden = nn.functional.interpolate(
            t_hidden, 
            size=masks.shape[-2:], 
            mode="bilinear", 
            align_corners=False
        )
        t_hidden = F.grid_sample(t_hidden.permute(0,3,1,2),grid=grid,mode='bilinear')
        
        t_hidden = nn.functional.interpolate(
            t_hidden, 
            size=s_hidden.shape[-2:], 
            mode="bilinear", 
            align_corners=False
        )

        CEloss = self.CEloss(s_outputs,masks)
        loss= 0.7*CEloss + 0.3*mse_loss(s_hidden,t_hidden,ignore_index=0)
        
        
        return({'loss': loss})
    
    def training_epoch_end(self, outputs):
        avg_train_loss = torch.stack([x["loss"] for x in outputs]).mean()
        self.log('loss',avg_train_loss,on_epoch=True,on_step=False)

    def validation_step(self, batch, batch_nb):
        images,masks,map_x,map_y = batch[0]['pixel_values'],batch[0]['labels'],batch[1],batch[2]
        B,w,h,C = images.shape
        map_x = [torch.tensor(i).float().to(images.device) for i in map_x]
        map_y = [torch.tensor(i).float().to(images.device) for i in map_y]
        map_x = torch.stack(map_x,dim=0)
        map_y = torch.stack(map_y,dim=0)
        grid = torch.stack((map_x/((w)/2)-1,map_y/((h)/2)-1),dim=3)#.unsqueeze(0)

        fimages = convert2fisheye(imgs=images.permute(0,2,3,1),grid=grid)
        masks = convert2fisheye(imgs=masks-255,grid=grid,label=True)
        masks = masks.squeeze(dim=1)+255 #padding with 255
        
        # print(masks.shape)
        # f_inputs = feature_extractor(fimages,return_tensors='pt')
        # ori_inputs = feature_extractor(images,return_tensors='pt')
        
        _,t_hidden = self.t_model(images)
        # s_outputs = self.s_model(pixel_values=fimages,labels=masks)
        s_outputs,s_hidden = self.s_model(fimages)
        
        t_hidden = nn.functional.interpolate(
            t_hidden, 
            size=masks.shape[-2:], 
            mode="bilinear", 
            align_corners=False
        )
        t_hidden = F.grid_sample(t_hidden.permute(0,3,1,2),grid=grid,mode='bilinear')
        
        t_hidden = nn.functional.interpolate(
            t_hidden, 
            size=s_hidden.shape[-2:], 
            mode="bilinear", 
            align_corners=False
        )

        CEloss = self.CEloss(s_outputs,masks)
        loss= 0.7*CEloss + 0.3*mse_loss(s_hidden,t_hidden,ignore_index=0)
        
        upsampled_logits = nn.functional.interpolate(
            s_outputs, 
            size=masks.shape[-2:], 
            mode="bilinear", 
            align_corners=False
        )
        
        masks = F.one_hot(masks,num_classes=19).permute(0,3,1,2)
        # print(masks.shape)
        predicted = output.argmax(dim=1)
        predicted = F.one_hot(predicted,num_classes=19).permute(0,3,1,2)
        cm = {}
        cm = calculate_metrics(predicted,masks)
        return({'val_loss': loss,'val_miou':cm['miou'],'mean_acc':cm['macc']})
    
    def validation_epoch_end(self, outputs):
        avg_val_loss = torch.stack([x["val_loss"] for x in outputs]).mean()
        avg_val_miou = torch.stack([x["val_miou"] for x in outputs]).mean()
        avg_val_macc =  torch.stack([x["mean_acc"] for x in outputs]).mean()
        metrics = {"val_loss": avg_val_loss, "val_mean_iou":avg_val_miou,"val_mean_acc":avg_val_macc}
        
        for k,v in metrics.items():
            self.log(k,v)
        return ({'avg_val_loss':avg_val_loss})
    
    def test_step(self, batch, batch_nb):
        
        images, masks = batch['pixel_values'], batch['labels']
        
        outputs = self(images, masks)
        
        loss, logits = outputs[0], outputs[1]
        
        upsampled_logits = nn.functional.interpolate(
            logits, 
            size=masks.shape[-2:], 
            mode="bilinear", 
            align_corners=False
        )
        
        predicted = upsampled_logits.argmax(dim=1)
        
        self.test_mean_iou.add_batch(
            predictions=predicted.detach().cpu().numpy(), 
            references=masks.detach().cpu().numpy()
        )
            
        return({'test_loss': loss})
    
    def test_epoch_end(self, outputs):
        metrics = self.test_mean_iou.compute(
              num_labels=self.num_classes, 
              ignore_index=0, 
              reduce_labels=False,
          )
       
        avg_test_loss = torch.stack([x["test_loss"] for x in outputs]).mean()
        test_mean_iou = metrics["mean_iou"]
        test_mean_accuracy = metrics["mean_accuracy"]

        metrics = {"test_loss": avg_test_loss, "test_mean_iou":test_mean_iou, "test_mean_accuracy":test_mean_accuracy}
        
        for k,v in metrics.items():
            self.log(k,v)
        
        return metrics
    
    # def configure_optimizers(self):
    #     return torch.optim.Adam([p for p in self.parameters() if p.requires_grad], lr=2e-05, eps=1e-08)
    
    def train_dataloader(self):
        return self.train_dl
    
    def val_dataloader(self):
        return self.val_dl
    
    def test_dataloader(self):
        return self.test_dl

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(filter(lambda p:p.requires_grad,self.parameters()),lr=self.learning_rate)
        # scheduler = get_linear_schedule_with_warmup(
        #     optimizer,
        #     num_warmup_steps=0,
        #     num_training_steps=8000
        # )
        scheduler = get_constant_schedule_with_warmup(
            optimizer,
            num_warmup_steps=100,
        )
        scheduler = {"scheduler":scheduler, "interval":"step","frequency":1}
        
        return [optimizer],[scheduler]

if __name__ == "__main__":

    torch.set_float32_matmul_precision('medium')

    feature_extractor = SegformerFeatureExtractor.from_pretrained("nvidia/segformer-b0-finetuned-cityscapes-1024-1024")
    # print(dir(feature_extractor))
    # feature_extractor.reduce_labels = False
    # feature_extractor.size = 128
    train_dataset = mydataset(
    '/mnt/hdd/dataset/cityscapes/extracted/',
                split='train',
                mode='fine',
                target_type='semantic'                
    )

    val_dataset = mydataset(
        '/mnt/hdd/dataset/cityscapes/extracted/',
                    split='val',
                    mode='fine',
                    target_type='semantic' ,
                    test_mode=True               
    )
    test_dataset = mydataset(
        '/mnt/hdd/dataset/cityscapes/extracted/',
                    split='test',
                    mode='fine',
                    target_type='semantic' ,
                    test_mode=True               
    )

    batch_size = 10
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size,shuffle=True,num_workers=5)#pin_memory=True, shuffle=True)
    val_dataloader = DataLoader(val_dataset, batch_size=batch_size,num_workers=5)#pin_memory=True)
    test_dataloader = DataLoader(test_dataset, batch_size=batch_size,num_workers=5)#,pin_memory=True)#,num_workers=3,prefetch_factor=8

    # segformer_finetuner=SegformerFinetuner.load_from_checkpoint(
    #     "/mnt/ssd/home/tianxiaofeng/segformer_train/lightning_logs/version_24/checkpoints/epoch=6-step=1250.ckpt"
    # )
    segformer_finetuner = SegformerFinetuner(
        train_dataloader=train_dataloader, 
        val_dataloader=val_dataloader,
        test_dataloader=test_dataloader
    )

    lr_monitor = LearningRateMonitor()

    early_stop_callback = EarlyStopping(
        monitor="avg_val_loss", 
        min_delta=0.00, 
        patience=7, 
        verbose=False, 
        mode="min",
    )   

    checkpoint_callback = ModelCheckpoint(
                        save_top_k=-1,
                        save_weights_only=True,
                        every_n_train_steps=750,
                        )
                        # save_on_train_epoch_end=True,
                        # save_last=True,
    trainer = pl.Trainer(
        accelerator='gpu',
        devices=[2],
        callbacks=[lr_monitor, checkpoint_callback],
        max_epochs=100,
        check_val_every_n_epoch=10)
        # log_every_n_steps=20,
    #     resume_from_checkpoint="/mnt/ssd/home/tianxiaofeng/segformer_train/lightning_logs/version_24/checkpoints/epoch=6-step=1250.ckpt"
    # )
    # trainer.tune(segformer_finetuner)
    # print(segformer.learning_rate)
    trainer.fit(segformer_finetuner)
            #   ckpt_path="/mnt/ssd/home/tianxiaofeng/distill_fisheye/lightning_logs/version_3/checkpoints/epoch=99-step=7500.ckpt")
# function ConnectButton(){
#     console.log("Connect pushed"); 
#     document.querySelector("#top-toolbar > colab-connect-button").shadowRoot.querySelector("#connect").click() 
# }

# Interval(ConnectButton,60000);
