In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input/isisc-2018/isic2018'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
!pip install thop

In [None]:
pip install scikit-learn==1.2.2


In [None]:
# For CUDA 12.1 (common in Kaggle as of 2024)
!pip install mmcv-full -f https://download.openmmlab.com/mmcv/dist/cu121/torch2.1.0/index.html

In [None]:

#dataset imported 
import os

import cv2
import numpy as np
import torch
import torch.utils.data


class Dataset(torch.utils.data.Dataset):
    def __init__(self, img_ids, img_dir, mask_dir, img_ext, mask_ext, num_classes, transform=None,train = True):

        self.img_ids = img_ids
        self.img_dir = img_dir
        self.mask_dir = mask_dir
        self.img_ext = img_ext
        self.mask_ext = mask_ext
        self.num_classes = num_classes
        self.transform = transform
        
   
    def __len__(self):
        return len(self.img_ids)

    def __getitem__(self, idx):
        img_id = self.img_ids[idx]
      
        img = cv2.imread(os.path.join(self.img_dir, img_id + self.img_ext))
        img  = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        mask = []
      
        mask.append(cv2.imread(os.path.join(self.mask_dir, img_id + self.mask_ext), cv2.IMREAD_GRAYSCALE)[..., None])
        mask = np.dstack(mask)

        if self.transform is not None:
            augmented = self.transform(image=img, mask=mask)
            img = augmented['image']
            mask = augmented['mask']
        

        img_normalized = img.astype('float32')
        img = ((img_normalized - np.min(img_normalized)) / (np.max(img_normalized)-np.min(img_normalized))) 
        img = img.transpose(2, 0, 1)
       
        mask = mask.astype('float32') / 255.
        mask = mask.transpose(2, 0, 1)
        
      
        
        return img, mask, {'img_id': img_id}

In [None]:
def count_params(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


In [None]:
#metrics
import numpy as np
import torch
import torch.nn.functional as F
from sklearn.metrics import confusion_matrix

from sklearn.metrics import confusion_matrix

def iou_score1(output, target):
    smooth = 1e-5

    if torch.is_tensor(output):
        output = (output).data.cpu().numpy()
    if torch.is_tensor(target):
        target = target.data.cpu().numpy()
    output_ = output >= 0.5
    target_ = target >= 0.5
    intersection = (output_ & target_).sum()
    union = (output_ | target_).sum()
    iou = (intersection + smooth) / (union + smooth)
    dice = (2* iou) / (iou+1)
    return iou, dice

def iou_score(output, target):
    smooth = 1e-5

    if torch.is_tensor(output):
        output = torch.sigmoid(output).data.cpu().numpy()
    if torch.is_tensor(target):
        target = target.data.cpu().numpy()
    output_ = output >= 0.5
    target_ = target >= 0.5
    intersection = (output_ & target_).sum()
    union = (output_ | target_).sum()
    iou = (intersection + smooth) / (union + smooth)
    dice = (2* iou) / (iou+1)
    return iou, dice




def dice_coef(output, target):
    smooth = 1e-5

    output = torch.sigmoid(output).view(-1).data.cpu().numpy()
    target = target.view(-1).data.cpu().numpy()
    intersection = (output * target).sum()

    return (2. * intersection + smooth) / \
        (output.sum() + target.sum() + smooth)


In [None]:
#AverageMeter Class 
class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [None]:
#Argparser 
import argparse
import torch.nn as nn



def str2bool(v):
    if v.lower() in ['true', 1]:
        return True
    elif v.lower() in ['false', 0]:
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')

In [None]:
#Architecture UCM
import torch


import torchvision

import torch.nn as nn

import math


from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.utils import save_image
import torch.nn.functional as F
import os
import matplotlib.pyplot as plt
#from utils import *
import timm
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
import types

from abc import ABCMeta, abstractmethod
from mmcv.cnn import ConvModule
import pdb

__all__ = ['UCM_Net']





class LayerNorm(nn.Module):
    r""" From ConvNeXt (https://arxiv.org/pdf/2201.03545.pdf)
    """
    def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(normalized_shape))
        self.bias = nn.Parameter(torch.zeros(normalized_shape))
        self.eps = eps
        self.data_format = data_format
        if self.data_format not in ["channels_last", "channels_first"]:
            raise NotImplementedError 
        self.normalized_shape = (normalized_shape, )
    
    def forward(self, x):
        if self.data_format == "channels_last":
            return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
        elif self.data_format == "channels_first":
            u = x.mean(1, keepdim=True)
            s = (x - u).pow(2).mean(1, keepdim=True)
            x = (x - u) / torch.sqrt(s + self.eps)
            x = self.weight[:, None, None] * x + self.bias[:, None, None]
            return x


def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, bias=False)


    

class UCMBlock(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
                 drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1, shift_size=5):
        super().__init__()
        
        # Original UCMBlock initializations
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        
        # Merged shiftmlp components
        self.dim = dim
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.fc1 = nn.Linear(dim, mlp_hidden_dim)
        self.dwconv = DWConv(mlp_hidden_dim)  # Assuming DWConv definition is available
        self.dwconv1 = DWConv(mlp_hidden_dim)  # Assuming DWConv definition is available
        self.act = act_layer()
        self.act1 = nn.GELU()  # Assuming Activation is a placeholder for an actual activation like GELU
        self.fc2 = nn.Linear(mlp_hidden_dim, dim)
        self.drop = nn.Dropout(drop)
        
        # Weight initialization for merged components
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv2d):
            fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
            fan_out //= m.groups
            m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
            if m.bias is not None:
                m.bias.data.zero_()

    def forward(self, x, H, W):
        # Norm and DropPath from original UCMBlock
        x = self.norm2(x)
        
        # Begin merged shiftmlp forward logic
        B, N, C = x.shape
        x1 = x.clone().detach()
        
        x = self.fc1(x)
        x = self.dwconv(x, H, W)
        
        xn = x.transpose(1, 2).view(B, C, H, W).contiguous()
        xn = self.act1(xn)
        
        x = self.drop(xn)
        x_s = x.reshape(B, C, H * W).contiguous()
        x = x_s.transpose(1, 2)
        
        x = self.drop(x)
        x = self.fc2(x)
        x = self.dwconv1(x, H, W)
        x = self.drop(x)
        
        x += x1
        
        # Apply DropPath
        x = x + self.drop_path(x)
        
        return x


class DWConv(nn.Module):
    def __init__(self, dim=768):
        super(DWConv, self).__init__()
        self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)
       # self.norm = nn.LayerNorm(dim+1)
        self.weight = nn.Parameter(torch.ones(dim))
        self.bias = nn.Parameter(torch.zeros(dim))

       

    def forward(self, x, H, W):
        B, N, C = x.shape
        x = x.transpose(1, 2).view(B, C, H, W)
       
        x = F.layer_norm(x, [H, W])
        x = self.dwconv(x)
        x = x.flatten(2).transpose(1, 2)

        return x

class OverlapPatchEmbed(nn.Module):
    """ Image to Patch Embedding
    """

    def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768):
        super().__init__()
        img_size = to_2tuple(img_size)
        patch_size = to_2tuple(patch_size)

        self.img_size = img_size
        self.patch_size = patch_size
        self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1]
        self.num_patches = self.H * self.W
        #self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride,
                             # padding=(patch_size[0] // 2, patch_size[1] // 2))
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=1, stride=stride,)
        self.norm = nn.LayerNorm(embed_dim)

        #self.apply(self._init_weights)
    
    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv2d):
            fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
            fan_out //= m.groups
            m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
            if m.bias is not None:
                m.bias.data.zero_()

    def forward(self, x):
        x = self.proj(x)
        _, _, H, W = x.shape
        x = x.flatten(2).transpose(1, 2)
        x = self.norm(x)

        return x, H, W



class UCM_Net(nn.Module):
    

    ## Conv 3 + MLP 2 + shifted MLP w less parameters
    
    def __init__(self,  num_classes, input_channels=3, deep_supervision=False,img_size=256, patch_size=16, in_chans=3,  embed_dims=[ 8,16,24,32,48,64,3],
                 num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0.,
                 attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm,
                 depths=[1, 1, 1], sr_ratios=[8, 4, 2, 1], **kwargs):
        super().__init__()
        


        self.encoder1 = nn.Conv2d(embed_dims[-1], embed_dims[0], 3, stride=1, padding=1)  
     #   self.encoder2 = nn.Conv2d(6, 16, 3, stride=1, padding=1)  
      #  self.encoder3 = nn.Conv2d(16, 24, 3, stride=1, padding=1)


        #self.ebn1 = nn.BatchNorm2d(embed_dims[0])
        self.ebn1 = nn.GroupNorm(4,embed_dims[0])
     #   self.ebn2 = nn.BatchNorm2d(12)
       # self.ebn3 = nn.BatchNorm2d(18)
        
       # self.norm0 = norm_layer(embed_dims[0])
        self.norm1 = norm_layer(embed_dims[1])
        self.norm2 = norm_layer(embed_dims[2])
        self.norm3 = norm_layer(embed_dims[3])
        self.norm4 = norm_layer(embed_dims[4])
        self.norm5 = norm_layer(embed_dims[5])

        self.dnorm2 = norm_layer(embed_dims[4])
        self.dnorm3 = norm_layer(embed_dims[3])
        self.dnorm4 = norm_layer(embed_dims[2])
        self.dnorm5 = norm_layer(embed_dims[1])
        self.dnorm6 = norm_layer(embed_dims[0])
      #  self.dnorm7 = norm_layer(embed_dims[-1])

        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]

        self.block_0_1 = nn.ModuleList([UCMBlock(
            dim=embed_dims[1], num_heads=num_heads[0], mlp_ratio=1, qkv_bias=qkv_bias, qk_scale=qk_scale,
            drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[0], norm_layer=norm_layer,
            sr_ratio=sr_ratios[0])])  
        
        self.block0 = nn.ModuleList([UCMBlock(
            dim=embed_dims[2], num_heads=num_heads[0], mlp_ratio=1, qkv_bias=qkv_bias, qk_scale=qk_scale,
            drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[0], norm_layer=norm_layer,
            sr_ratio=sr_ratios[0])])        
        
        self.block1 = nn.ModuleList([UCMBlock(
            dim=embed_dims[3], num_heads=num_heads[0], mlp_ratio=1, qkv_bias=qkv_bias, qk_scale=qk_scale,
            drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[0], norm_layer=norm_layer,
            sr_ratio=sr_ratios[0])])

        self.block2 = nn.ModuleList([UCMBlock(
            dim=embed_dims[4], num_heads=num_heads[0], mlp_ratio=1, qkv_bias=qkv_bias, qk_scale=qk_scale,
            drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[1], norm_layer=norm_layer,
            sr_ratio=sr_ratios[0])])
        
        self.block3 = nn.ModuleList([UCMBlock(
            dim=embed_dims[5], num_heads=num_heads[0], mlp_ratio=1, qkv_bias=qkv_bias, qk_scale=qk_scale,
            drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[1], norm_layer=norm_layer,
            sr_ratio=sr_ratios[0])])
        self.dblock0 = nn.ModuleList([UCMBlock(
            dim=embed_dims[4], num_heads=num_heads[0], mlp_ratio=1, qkv_bias=qkv_bias, qk_scale=qk_scale,
            drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[0], norm_layer=norm_layer,
            sr_ratio=sr_ratios[0])])

        self.dblock1 = nn.ModuleList([UCMBlock(
            dim=embed_dims[3], num_heads=num_heads[0], mlp_ratio=1, qkv_bias=qkv_bias, qk_scale=qk_scale,
            drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[0], norm_layer=norm_layer,
            sr_ratio=sr_ratios[0])])

        self.dblock2 = nn.ModuleList([UCMBlock(
            dim=embed_dims[2], num_heads=num_heads[0], mlp_ratio=1, qkv_bias=qkv_bias, qk_scale=qk_scale,
            drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[1], norm_layer=norm_layer,
            sr_ratio=sr_ratios[0])])
        self.dblock3 = nn.ModuleList([UCMBlock(
            dim=embed_dims[1], num_heads=num_heads[0], mlp_ratio=1, qkv_bias=qkv_bias, qk_scale=qk_scale,
            drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[1], norm_layer=norm_layer,
            sr_ratio=sr_ratios[0])])
        self.dblock4 = nn.ModuleList([UCMBlock(
            dim=embed_dims[0], num_heads=num_heads[0], mlp_ratio=1, qkv_bias=qkv_bias, qk_scale=qk_scale,
            drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[1], norm_layer=norm_layer,
            sr_ratio=sr_ratios[0])])
        

        
        self.patch_embed1 = OverlapPatchEmbed(img_size=img_size , patch_size=3, stride=2, in_chans=embed_dims[0],
                                              embed_dim=embed_dims[1])
        
        self.patch_embed2 = OverlapPatchEmbed(img_size=img_size // 2, patch_size=3, stride=2, in_chans=embed_dims[1],
                                              embed_dim=embed_dims[2])

        self.patch_embed3 = OverlapPatchEmbed(img_size=img_size // 4, patch_size=3, stride=2, in_chans=embed_dims[2],
                                              embed_dim=embed_dims[3])
        self.patch_embed4 = OverlapPatchEmbed(img_size=img_size // 8, patch_size=3, stride=2, in_chans=embed_dims[3],
                                              embed_dim=embed_dims[4])
        
        self.patch_embed5 = OverlapPatchEmbed(img_size=img_size // 16, patch_size=3, stride=2, in_chans=embed_dims[4],
                                              embed_dim=embed_dims[5])
        
      
        self.decoder0 = nn.Conv2d(embed_dims[5], embed_dims[4], 1, stride=1,padding=0)  
        self.decoder1 = nn.Conv2d(embed_dims[4], embed_dims[3], 1, stride=1,padding=0)  
      #  self.decoder1_1 =   nn.Conv2d(48, 32, 1, stride=1, padding=0)  
        self.decoder2 =   nn.Conv2d(embed_dims[3], embed_dims[2], 1, stride=1, padding=0)  
     #   self.decoder2_1 =   nn.Conv2d(32, 24, 1, stride=1, padding=0)  
        self.decoder3 =   nn.Conv2d(embed_dims[2], embed_dims[1],  1, stride=1, padding=0) 
    #    self.decoder3_1 =   nn.Conv2d(24, 16, 1, stride=1, padding=0) 
        self.decoder4 =   nn.Conv2d(embed_dims[1], embed_dims[0], 1, stride=1, padding=0)
       # self.decoder4_1 =   nn.Conv2d(16, 6, 1, stride=1, padding=0)
        self.decoder5 =   nn.Conv2d(embed_dims[0], embed_dims[-1], 1, stride=1, padding=0)
 
      #  self.dbn0 = nn.BatchNorm2d(embed_dims[4])
      ##  self.dbn1 = nn.BatchNorm2d(embed_dims[3])
       # self.dbn2 = nn.BatchNorm2d(embed_dims[2])
       # self.dbn3 = nn.BatchNorm2d(embed_dims[1])
       # self.dbn4 = nn.BatchNorm2d(embed_dims[0])
        
        self.dbn0 = nn.GroupNorm(4,embed_dims[4])
        self.dbn1 = nn.GroupNorm(4,embed_dims[3])
        self.dbn2 = nn.GroupNorm(4,embed_dims[2])
        self.dbn3 = nn.GroupNorm(4,embed_dims[1])
        self.dbn4 = nn.GroupNorm(4,embed_dims[0])
    
        
      
        self.finalpre0 = nn.Conv2d(embed_dims[4], num_classes, kernel_size=1)
        self.finalpre1 = nn.Conv2d(embed_dims[3], num_classes, kernel_size=1)
        self.finalpre2 = nn.Conv2d(embed_dims[2], num_classes, kernel_size=1)
        self.finalpre3 = nn.Conv2d(embed_dims[1], num_classes, kernel_size=1)
        self.finalpre4 = nn.Conv2d(embed_dims[0], num_classes, kernel_size=1)
        
        self.final = nn.Conv2d(embed_dims[-1], num_classes, kernel_size=1)

       

    def forward(self, x,inference_mode=False):
        
        B = x.shape[0]
        ### Encoder
        ### Conv Stage
        out = self.encoder1(x)

        ### Stage 1
        out = F.relu(F.max_pool2d(self.ebn1(out),2,2))
        t1 = out
       
      #  out,H,W = self.patch_embed5(x)
      #  for i, blk in enumerate(self.block_0_2):
      #      out = blk(out, H, W)
      #  out = self.norm0(out)
      #  out = out.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
      #  t1 = out
      
        ### Stage 2
       # out = F.relu(F.max_pool2d(self.ebn2(self.encoder2(out)),2,2))
        #t2 = out
        out,H,W = self.patch_embed1(out)
        for i, blk in enumerate(self.block_0_1):
            out = blk(out, H, W)
        out = self.norm1(out)
        out = out.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
        t2 = out
        ### Stage 3
       

     #   out = F.relu(F.max_pool2d(self.ebn3(self.encoder3(out)),2,2))
       # t3 = out
        out,H,W = self.patch_embed2(out)
        for i, blk in enumerate(self.block0):
            out = blk(out, H, W)
        out = self.norm2(out)
        out = out.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
        t3 = out

        ### Tokenized MLP Stage
        ### Stage 4

        out,H,W = self.patch_embed3(out)
        for i, blk in enumerate(self.block1):
            out = blk(out, H, W)
        out = self.norm3(out)
        out = out.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
        t4 = out

        ### Bottleneck

        out ,H,W= self.patch_embed4(out)
        for i, blk in enumerate(self.block2):
            out = blk(out, H, W)
        out = self.norm4(out)
        out = out.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
        t5 = out
        
        ### Bottleneck
        out ,H,W= self.patch_embed5(out)
        for i, blk in enumerate(self.block3):
            out = blk(out, H, W)
        out = self.norm5(out)
        out = out.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
        
       # outtpre0 = F.interpolate(out, scale_factor=32, mode ='bilinear', align_corners=True)
       # outtpre0 =self.finalpre0(outtpre0)
        ### Stage 4
        out = F.relu(F.interpolate(self.dbn0(self.decoder0(out)),scale_factor=(2,2),mode ='bilinear'))
        out = torch.add(out,t5)
        if not inference_mode:
            outtpre0 = F.interpolate(out, scale_factor=32, mode ='bilinear', align_corners=True)
            outtpre0 =self.finalpre0(outtpre0)
        #print('outtpre1',torch.sigmoid(outtpre1).size())
        
        _,_,H,W = out.shape
        out = out.flatten(2).transpose(1,2)
        for i, blk in enumerate(self.dblock0):
            out = blk(out, H, W)

        ### Stage 3
        
        out = self.dnorm2(out)
        out = out.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()

        out = F.relu(F.interpolate(self.dbn1(self.decoder1(out)),scale_factor=(2,2),mode ='bilinear'))
   
        
        out = torch.add(out,t4)
        if not inference_mode:
            outtpre1 = F.interpolate(out, scale_factor=16, mode ='bilinear', align_corners=True)
            outtpre1 =self.finalpre1(outtpre1)
        #print('outtpre1',torch.sigmoid(outtpre1).size())
        
        _,_,H,W = out.shape
        out = out.flatten(2).transpose(1,2)
        for i, blk in enumerate(self.dblock1):
            out = blk(out, H, W)

        ### Stage 3
        
        out = self.dnorm3(out)
        out = out.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
        out = F.relu(F.interpolate(self.dbn2(self.decoder2(out)),scale_factor=(2,2),mode ='bilinear'))
      #  t41=self.decoder2_1(F.upsample(t4, scale_factor=(2,2),mode ='bilinear'))
        out = torch.add(out,t3)
     #   out = torch.add(out,t41)
        if not inference_mode:
            outtpre2 = F.interpolate(out, scale_factor=8, mode ='bilinear', align_corners=True)
        
            outtpre2 =self.finalpre2(outtpre2)
        #print('outtpre2',outtpre2.size())
        
        _,_,H,W = out.shape
        out = out.flatten(2).transpose(1,2)
        
        for i, blk in enumerate(self.dblock2):
            out = blk(out, H, W)

        out = self.dnorm4(out)
        out = out.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
     #   t31=self.decoder3_1(F.upsample(t3, scale_factor=(2,2),mode ='bilinear'))
        out = F.relu(F.interpolate(self.dbn3(self.decoder3(out)),scale_factor=(2,2),mode ='bilinear'))
        out = torch.add(out,t2)
    #    out = torch.add(out,t31)
        #print(out.size())
        if not inference_mode:
            outtpre3 = F.interpolate(out, scale_factor=4, mode ='bilinear', align_corners=True)
        
            outtpre3 =self.finalpre3(outtpre3)
        #print('outtpre3',outtpre3.size())
        _,_,H,W = out.shape
        out = out.flatten(2).transpose(1,2)
        
        for i, blk in enumerate(self.dblock3):
            out = blk(out, H, W)

        out = self.dnorm5(out)
        out = out.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()       
     #   print(out.size())
        
     #   t21=self.decoder4_1(F.upsample(t2, scale_factor=(2,2),mode ='bilinear'))
       
        out = F.relu(F.interpolate(self.dbn4(self.decoder4(out)),scale_factor=(2,2),mode ='bilinear'))
        out = torch.add(out,t1)
    #    out = torch.add(out,t21)
        
        if not inference_mode:
            outtpre4 = F.interpolate(out, scale_factor=2, mode ='bilinear', align_corners=True)
        
            outtpre4 =self.finalpre4(outtpre4)
        #print('outtpre4',outtpre4.size())        
        
        _,_,H,W = out.shape
        out = out.flatten(2).transpose(1,2)
        
        for i, blk in enumerate(self.dblock4):
            out = blk(out, H, W) 
        out = self.dnorm6(out)
        out = out.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
        
        out = F.relu(F.interpolate(self.decoder5(out),scale_factor=(2,2),mode ='bilinear'))

        out =  self.final(out)
        if not inference_mode:
            return ( outtpre0,outtpre1, outtpre2, outtpre3, outtpre4), out
        else:
            return out
#EOF


In [None]:
import warnings
warnings.filterwarnings("ignore", category=UserWarning, module="mmcv")

In [None]:
#losses file 
import torch
import torch.nn as nn
import torch.nn.functional as F
#from pytorch_zoo.loss import lovasz_hinge
from torch.nn.modules.loss import CrossEntropyLoss




__all__ = ['GT_BceDiceLoss','GT_BceDiceLoss_new','GT_BceDiceLoss_new1']



##############medt##################
import torch
from torch.nn.functional import cross_entropy
from torch.nn.modules.loss import _WeightedLoss


class BCEDiceLoss_newversion(nn.Module):
    def __init__(self):
        super().__init__()
        self.bceloss = nn.BCELoss()

    def forward(self, input, target):
        
        input = torch.sigmoid(input)
        
        
        smooth = 1e-5
        
        num = target.size(0)

      
        input = input.view(num, -1)
        target = target.view(num, -1)
        bce = self.bceloss(input,target)
        intersection = (input * target)
        dice = (2. * intersection.sum(1).pow(2) + smooth) / (input.sum(1).pow(2) + target.sum(1).pow(2) + smooth)
        
      
        
        dice_loss = 1 - dice.sum() / num
  


        return bce +dice_loss
         
    
 
   
class GT_BceDiceLoss(nn.Module):
    def __init__(self):
        super(GT_BceDiceLoss, self).__init__()
        self.bcedice = BCEDiceLoss()

    def forward(self, pre,out, target):
        bcediceloss = self.bcedice(out, target)
        #print(len(out[0]))
        gt_pre4, gt_pre3, gt_pre2, gt_pre1,gt_pre0 = pre
        gt_loss =  self.bcedice(gt_pre4, target) * 0.1 + self.bcedice(gt_pre3, target) * 0.2 + self.bcedice(gt_pre2, target) * 0.3 + self.bcedice(gt_pre1, target) * 0.4 +self.bcedice(gt_pre0, target) * 0.5
        return bcediceloss + gt_loss
    
class GT_BceDiceLoss_new(nn.Module):
    def __init__(self):
        super(GT_BceDiceLoss_new, self).__init__()
        self.bcedice = BCEDiceLoss_newversion()

    def forward(self, pre,out, target):
        bcediceloss = self.bcedice(out, target)
        #print(len(out[0]))
        gt_pre4, gt_pre3, gt_pre2, gt_pre1,gt_pre0 = pre
        
        gt_loss =  self.bcedice(gt_pre4, target) * 0.1 + self.bcedice(gt_pre3, target) * 0.2 + self.bcedice(gt_pre2, target) * 0.3 + self.bcedice(gt_pre1, target) * 0.4 +self.bcedice(gt_pre0, target) * 0.5
        return bcediceloss + gt_loss
class GT_BceDiceLoss_new1(nn.Module):
    def __init__(self):
        super(GT_BceDiceLoss_new1, self).__init__()
        self.bcedice = BCEDiceLoss_newversion()

    def forward(self, pre,out, target):
        bcediceloss = self.bcedice(out, target)
        #print(len(out[0]))
        gt_pre4, gt_pre3, gt_pre2, gt_pre1,gt_pre0 = pre
        gt_loss =  self.bcedice(gt_pre4, target) * 0.1 + self.bcedice(gt_pre3, target) * 0.2 + self.bcedice(gt_pre2, target) * 0.3 + self.bcedice(gt_pre1, target) * 0.4 +self.bcedice(gt_pre0, target) * 0.5
        
        return bcediceloss + gt_loss,self.bcedice(gt_pre4, target),self.bcedice(gt_pre3, target),self.bcedice(gt_pre2, target),self.bcedice(gt_pre1, target),self.bcedice(gt_pre0, target)
    


In [None]:
class archs_ucm(nn.Module):
    # define your model here
    pass

In [None]:
from IPython.display import clear_output
# Inside your training loop:
clear_output(wait=True)  # Refresh output

In [None]:
#train file 
import argparse
import os
from collections import OrderedDict
from glob import glob

import pandas as pd
import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.optim as optim
import yaml
from albumentations.augmentations import transforms
import albumentations as A
from albumentations.core.composition import Compose, OneOf
from sklearn.model_selection import train_test_split
from torch.optim import lr_scheduler
from tqdm import tqdm
from albumentations import RandomRotate90,Resize,Rotate, VerticalFlip,HorizontalFlip, ElasticTransform
import archs_ucm
import losses
from dataset1 import Dataset
#from metrics import iou_score,iou_score1
#from utils import AverageMeter, str2bool


import numpy as np
from sklearn.metrics import confusion_matrix

from torch.nn.modules.loss import CrossEntropyLoss

import argparse
import logging
import os
import random
import numpy as np
import torch
import torch.backends.cudnn as cudnn
ARCH_NAMES =['UCM_Net']
LOSS_NAMES = ['GT_BceDiceLoss','GT_BceDiceLoss_new','GT_BceDiceLoss_new1']
LOSS_NAMES.append('BCEWithLogitsLoss')



def parse_args():
    parser = argparse.ArgumentParser()

    parser.add_argument('--name', default=None,
                        help='model name: (default: arch+timestamp)')
    parser.add_argument('--epochs', default=50, type=int, metavar='N',
                        help='number of total epochs to run')
    parser.add_argument('-b', '--batch_size', default=8, type=int,
                        metavar='N', help='mini-batch size (default: 8)')
    
    # model
    parser.add_argument('--arch', '-a', metavar='ARCH', default='UCM_Net')
    parser.add_argument('--deep_supervision', default=False, type=str2bool)
    parser.add_argument('--input_channels', default=3, type=int,
                        help='input channels')
    parser.add_argument('--num_classes', default=1, type=int,
                        help='number of classes')
    parser.add_argument('--input_w', default=256, type=int,
                        help='image width')
    parser.add_argument('--input_h', default=256, type=int,
                        help='image height')
    
    # loss
    parser.add_argument('--loss', default='GT_BceDiceLoss_new',
                        choices=LOSS_NAMES,
                        help='loss: ' +
                        ' | '.join(LOSS_NAMES) +
                        ' (default: GT_BceDiceLoss_new)')
    
    # dataset
    parser.add_argument('--dataset', default='/kaggle/input/isisc-2018/isic2018',
                        help='dataset name')
    parser.add_argument('--img_ext', default='.png',
                        help='image file extension')
    parser.add_argument('--mask_ext', default='.png',
                        help='mask file extension')

    # optimizer
    parser.add_argument('--optimizer', default='AdamW',
                        choices=['AdamW', 'SGD'],
                        help='loss: ' +
                        ' | '.join(['AdamW', 'SGD']) +
                        ' (default: SGD)')
    parser.add_argument('--lr', '--learning_rate', default=1e-4, type=float,
                        metavar='LR', help='initial learning rate')
    parser.add_argument('--momentum', default=0.9, type=float,
                        help='momentum')
    parser.add_argument('--weight_decay', default=0.01, type=float,
                        help='weight decay')
    parser.add_argument('--nesterov', default=False, type=str2bool,
                        help='nesterov')
    parser.add_argument('--nrand', default=44, type=int,
                        help='rand state')
    # scheduler
    parser.add_argument('--scheduler', default='CosineAnnealingLR',
                        choices=['CosineAnnealingLR', 'ReduceLROnPlateau', 'MultiStepLR', 'ConstantLR'])
    parser.add_argument('--min_lr', default=1e-5, type=float,
                        help='minimum learning rate')
    parser.add_argument('--factor', default=0.1, type=float)
    parser.add_argument('--patience', default=2, type=int)
    parser.add_argument('--milestones', default='1,2', type=str)
    parser.add_argument('--gamma', default=2/3, type=float)
    parser.add_argument('--early_stopping', default=15, type=int,
                        metavar='N', help='early stopping (default: 15)')
    parser.add_argument('--cfg', type=str, metavar="FILE", help='path to config file', )
    parser.add_argument(
        "--opts",
        help="Modify config options by adding 'KEY VALUE' pairs. ",
        default=None,
        nargs='+',
    )
    parser.add_argument('--num_workers', default=4, type=int)
  #  parser.add_argument('--cfg', type=str, required=True, metavar="FILE", help='path to config file' )
    parser.add_argument('--zip', action='store_true', help='use zipped dataset instead of folder dataset')
    parser.add_argument('--cache-mode', type=str, default='part', choices=['no', 'full', 'part'],
                    help='no: no cache, '
                            'full: cache all data, '
                            'part: sharding the dataset into nonoverlapping pieces and only cache one piece')
    
    parser.add_argument('--resume', help='resume from checkpoint')
    parser.add_argument('--accumulation-steps', type=int, help="gradient accumulation steps")
    parser.add_argument('--use-checkpoint', action='store_true',
                        help="whether to use gradient checkpointing to save memory")
    parser.add_argument('--amp-opt-level', type=str, default='O1', choices=['O0', 'O1', 'O2'],
                        help='mixed precision opt level, if O0, no amp is used')
    parser.add_argument('--tag', help='tag of experiment')
    parser.add_argument('--eval', action='store_true', help='Perform evaluation only')
    parser.add_argument('--throughput', action='store_true', help='Test throughput only')

    config, unknown = parser.parse_known_args()
    #args, unknown = parser.parse_known_args()
    #config = parser.parse_args()

    return config


def train(config, train_loader, model, criterion, optimizer):
    avg_meters = {'loss': AverageMeter(),
                  'iou': AverageMeter(),
                   'new_iou':AverageMeter()}

    model.train()
    ce_loss = CrossEntropyLoss()
   

    pbar = tqdm(total=len(train_loader))
    for input, target, _ in train_loader:
        input = input.cuda()
        target = target.cuda()

        # compute output
        if config['arch']== 'TransUNet':
            outputs = model(input)

            
            loss_dice = dice_loss(outputs, target)
            
            
            loss_ce = ce_loss(outputs, target)
            loss = 0.5 * loss_ce + 0.5 * loss_dice
            loss = loss.mean()
            iou,dice = iou_score(outputs, target)
            
        elif config['deep_supervision']:
            outputs = model(input)
            loss = 0
            for output in outputs:
                loss += criterion(output, target)
            loss /= len(outputs)
            iou,dice = iou_score(outputs[-1], target)
        elif config['loss'] =='LogNLLLoss':
            #gts.append(target.squeeze(1).cpu().detach().numpy())
            output = model(input)
            tmp2 = target.detach().cpu().numpy()
            tmp = output.detach().cpu().numpy()
            tmp[tmp>=0.5] = 1
            tmp[tmp<0.5] = 0
            tmp2[tmp2>0] = 1
            tmp2[tmp2<=0] = 0
            tmp2 = tmp2.astype(int)
            tmp = tmp.astype(int)

            yHaT = tmp
            yval = tmp2



            loss = criterion(output, target)
            loss =loss.mean()

            iou,dice = iou_score(output, target)
         #   output1 = torch.sigmoid(output)
         #   output1 = output1.squeeze(1).cpu().detach().numpy()
         #   preds.append(output1)
        elif config['loss'] =='GT_BceDiceLoss':
            gt_pre, out = model(input)
         
            loss = 0
    

            loss = criterion(gt_pre, out, target)
            loss = loss.mean()

         
            iou,dice = iou_score(out, target)
        elif config['loss'] =='GT_BceDiceLoss_new':
            gt_pre, out = model(input)
         
            loss = 0
    

            loss = criterion(gt_pre, out, target)
            loss = loss.mean()

         
            iou,dice = iou_score(out, target)
        elif config['loss'] =='GT_BceDiceLoss_new1':
            gt_pre, out = model(input)
         
            loss = 0
    

            loss,loss1,loss2,los3,loss4,loss5 = criterion(gt_pre, out, target)
 
            loss = loss.mean()
          #  loss1 = loss.mean()
           # loss2 = loss.mean()
           # loss3 = loss.mean()
           # loss4 = loss.mean()
           # loss5 = loss.mean()

         
            iou,dice = iou_score(out, target)
        elif config['loss'] =='BCEDiceLossMAL':
            out = model(input)
    
            loss = 0
    

            loss = criterion(out, target)
            loss = loss.mean()

   
            iou,dice = iou_score1(out, target)    
        elif config['loss'] =='BCEDiceLossUNEXT'or config['arch'] =='AttU_Net'or config['arch'] == 'R2U_Net' or config['arch'] =='U_Net':
            out = model(input)
      
            loss = 0
   

            loss = criterion(out, target)
            loss = loss.mean()

   
            iou,dice = iou_score(out, target)
        elif config['loss'] =='BCEDiceLossSWIN':
            out = model(input)
      
            loss = 0
   

            loss = criterion(out, target)
            loss = loss.mean()

   
            iou,dice = iou_score(out, target)
    
        elif config['loss'] =='LossTransFuse':
            lateral_map_4, lateral_map_3, lateral_map_2 = model(input)

            # ---- loss function ----
            loss4 = criterion(lateral_map_4, target)
            loss3 = criterion(lateral_map_3, target)
            loss2 = criterion(lateral_map_2, target)

            loss = 0.5 * loss2 + 0.3 * loss3 + 0.2 * loss4

   
            loss =loss.mean()
            iou,dice = iou_score(lateral_map_2, target)

    
        elif config['loss'] =='GT_BceDiceLossEGE':
            gt_pre, out = model(input)
            loss = 0
    

            loss = criterion(gt_pre, out, target)
            loss =loss.mean()

          
            iou,dice = iou_score1(out, target)

        elif config['arch']== 'CONVUNext':
            output = model(input)
            loss = 0
    

            loss = criterion(output['out'], target)
            loss =loss.mean()
          
            iou,dice = iou_score(output['out'], target)
        elif config['loss'] =='BCEDiceLossMEDT':
            output = model(input)
   



            loss = criterion(output, target)
            loss =loss.mean()
          
            iou,dice = iou_score(output, target)
        else:
            output = model(input)
            loss = criterion(output, target)
            iou,dice = iou_score(output, target)
            
        if config['loss'] !='GT_BceDiceLoss_new1':
        # compute gradient and do optimizing step
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        else:
           # print('hello')
            # Zero gradients
            optimizer.zero_grad()

            # Backward pass for the final output loss
            loss.backward(retain_graph=True)

            # Backward pass for each internal layer loss
            
            
            
            
           # loss1.backward(retain_graph=True)
            #loss2.backward(retain_graph=True)
           # loss3.backward(retain_graph=True)
           # loss4.backward(retain_graph=True)
          #  loss5.backward(retain_graph=True)
            # Optional: Modify gradients here if needed

            # Update parameters
            optimizer.step()
            

        avg_meters['loss'].update(loss.item(), input.size(0))
        avg_meters['iou'].update(iou, input.size(0))

        postfix = OrderedDict([
            ('loss', avg_meters['loss'].avg),
            ('iou', avg_meters['iou'].avg),
        ])
        pbar.set_postfix(postfix)
        pbar.update(1)
    pbar.close()

    return OrderedDict([('loss', avg_meters['loss'].avg),
                        ('iou', avg_meters['iou'].avg)])

best_miou1 = 0
best_dice1 = 0
def validate(config, val_loader, model, criterion):
    avg_meters = {'loss': AverageMeter(),
                  'iou': AverageMeter(),
                   'dice': AverageMeter()}

    # switch to evaluate mode
    model.eval()
    
    preds = []
    gts = []
    ce_loss = CrossEntropyLoss()
 
    with torch.no_grad():

        pbar = tqdm(total=len(val_loader))
        for input, target, _ in val_loader:
            
            
            input = input.cuda()
            target = target.cuda()

    

            if config['loss'] =='GT_BceDiceLoss' or config['loss'] =='GT_BceDiceLoss_new':
                gts.append(target.squeeze(1).cpu().detach().numpy())
                
                gt_pre, out = model(input)
            
                loss = 0
              

                loss = criterion(gt_pre, out, target)
                loss = loss.mean()

      
                iou,dice = iou_score(out, target)
                
                output1 = torch.sigmoid(out)
                output1 = output1.squeeze(1).cpu().detach().numpy()

                preds.append(output1) 
            elif config['loss'] =='GT_BceDiceLoss_new1':
                gts.append(target.squeeze(1).cpu().detach().numpy())
                
                
                gt_pre, out = model(input)

                loss = 0


                loss,loss1,loss2,los3,loss4,loss5 = criterion(gt_pre, out, target)
                iou,dice = iou_score(out, target)
                
                output1 = torch.sigmoid(out)
                output1 = output1.squeeze(1).cpu().detach().numpy()

                preds.append(output1) 
            
           
            
            elif config['loss'] =='BCEDiceLossUNEXT' or config['arch'] =='AttU_Net'or config['arch'] == 'R2U_Net' or config['arch'] =='U_Net':
                gts.append(target.squeeze(1).cpu().detach().numpy())
                out = model(input)
           
                loss = 0
              

                loss = criterion(out, target)
                loss = loss.mean()

         
                iou,dice = iou_score(out, target)
                output1 = torch.sigmoid(out)
                output1 = output1.squeeze(1).cpu().detach().numpy()
                preds.append(output1)



            elif config['loss'] =='GT_BceDiceLossEGE':
                gts.append(target.squeeze(1).cpu().detach().numpy())
                gt_pre, out = model(input)
                loss = 0
          

                loss = criterion(gt_pre, out, target)
                loss =loss.mean()
              
   
                iou,dice = iou_score1(out, target)
                out= out.squeeze(1).cpu().detach().numpy()
                preds.append(out) 

            else:
                output = model(input)
                loss = criterion(output, target)
                iou,dice = iou_score(output, target)

            avg_meters['loss'].update(loss.item(), input.size(0))
            avg_meters['iou'].update(iou, input.size(0))
            avg_meters['dice'].update(dice, input.size(0))

            postfix = OrderedDict([
                ('loss', avg_meters['loss'].avg),
                ('iou', avg_meters['iou'].avg),
                ('dice', avg_meters['dice'].avg)
            ])
            pbar.set_postfix(postfix)
            pbar.update(1)
        preds = np.array(preds).reshape(-1)
        gts = np.array(gts).reshape(-1)
        #print(preds)

        y_pre = np.where(preds>=0.5, 1, 0)
        y_true = np.where(gts>=0.5, 1, 0)

        confusion = confusion_matrix(y_true, y_pre)
        TN, FP, FN, TP = confusion[0,0], confusion[0,1], confusion[1,0], confusion[1,1] 

        accuracy = float(TN + TP) / float(np.sum(confusion)) if float(np.sum(confusion)) != 0 else 0
        sensitivity = float(TP) / float(TP + FN) if float(TP + FN) != 0 else 0
        specificity = float(TN) / float(TN + FP) if float(TN + FP) != 0 else 0
        f1_or_dsc = float(2 * TP) / float(2 * TP + FP + FN) if float(2 * TP + FP + FN) != 0 else 0
        miou = float(TP) / float(TP + FP + FN) if float(TP + FP + FN) != 0 else 0
        global best_miou1
        global best_dice1
        if best_miou1<miou:
            torch.save(model, 'models/%s/modelmiou1.pth' %
                       config['name'])
            best_miou1 = miou
            best_dice1 = f1_or_dsc
            
        print('miou',best_miou1)
        print('f1_or_dsc',best_dice1)
        pbar.close()

    
    return OrderedDict([('loss', avg_meters['loss'].avg),
                        ('iou', avg_meters['iou'].avg),
                        ('dice', avg_meters['dice'].avg)])
def model_summary(model):
    print(model)  # This will print the architecture of the model

    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

    print(f'Total Parameters: {total_params}')
    print(f'Trainable Parameters: {trainable_params}')

import torch
from thop import profile

def compute_gflops(model, input_size):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    config = vars(parse_args())
    input = torch.randn(1, 3, input_size, input_size)
    if config['arch'] =='TransFuse_S':
        input = torch.randn(1, 3, 192,256)
    if config['arch'] =='TransUNet':
        input = torch.randn(1, 3, 224,224)
    input = input.to(device)
    macs, params = profile(model, inputs=(input, ))
    gflops = macs / (10**9)
    return gflops


def main():
    config = vars(parse_args())

    if config['name'] is None:
        if config['deep_supervision']:
            config['name'] = '%s_%s_wDS' % (config['dataset'], config['arch'])
        else:
            config['name'] = '%s_%s_woDS' % (config['dataset'], config['arch'])
    
    os.makedirs('models/%s' % config['name'], exist_ok=True)

    # create model
    model = archs_ucm.__dict__[config['arch']](config['num_classes'],
                                       config['input_channels'],
                                       config['deep_supervision'])   
    config['optimizer'] == 'AdamW'
    weight_decay=0.01
    config['scheduler'] == 'CosineAnnealingLR'
    T_max=50#config['epochs']
        
        

    print('-' * 20)
    for key in config:
        print('%s: %s' % (key, config[key]))
    print('-' * 20)

    with open('models/%s/config.yml' % config['name'], 'w') as f:
        yaml.dump(config, f)

    # define loss function (criterion)
    if config['loss'] == 'BCEWithLogitsLoss':
        criterion = nn.BCEWithLogitsLoss().cuda()
    else:
        criterion = losses.__dict__[config['loss']]().cuda()

    cudnn.benchmark = True


    model_summary(model)#

    input_size = config['input_h']  # Set this to the height/width of your input images
    gflops = compute_gflops(model, input_size)
    print(f'GigaFLOPs: {gflops}')
   # model = torch.load('models/%s/model.pth' %
                     #  config['name'])
    model = model.cuda()
    if config['arch'] =='SWIN':
        model.load_from(config1)
        
    params = filter(lambda p: p.requires_grad, model.parameters())

    if config['optimizer'] == 'Adam':
        optimizer = optim.Adam(
            params, lr=config['lr'], weight_decay=config['weight_decay'])
    elif config['optimizer'] == 'SGD':
        optimizer = optim.SGD(params, lr=config['lr'], momentum=config['momentum'],weight_decay=config['weight_decay'])
                           #   nesterov=config['nesterov'], weight_decay=config['weight_decay'])
    elif config['optimizer'] == 'AdamW':
        optimizer = optim.AdamW(
            params, lr=config['lr'], weight_decay=config['weight_decay'])
    else:
        raise NotImplementedError

    if config['scheduler'] == 'CosineAnnealingLR':
        scheduler = lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=T_max, eta_min=config['min_lr'])
    elif config['scheduler'] == 'ReduceLROnPlateau':
        scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, factor=config['factor'], patience=config['patience'],
                                                   verbose=1, min_lr=config['min_lr'])
    elif config['scheduler'] == 'MultiStepLR':
        scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[int(e) for e in config['milestones'].split(',')], gamma=config['gamma'])
    elif config['scheduler'] == 'ConstantLR':
        scheduler = None
    else:
        raise NotImplementedError


    train_img_ids = glob(os.path.join('', config['dataset'], 'train/','images/', '*' + config['img_ext']))

    val_img_ids = glob(os.path.join('', config['dataset'], 'val/','images/', '*' + config['img_ext']))

 
    train_img_ids = [os.path.splitext(os.path.basename(p))[0] for p in train_img_ids]
    val_img_ids = [os.path.splitext(os.path.basename(p))[0] for p in val_img_ids]
 

    
    
    train_transform = Compose([
        Rotate(limit=180,p=0.5),
        VerticalFlip(p =0.5),HorizontalFlip(p =0.5),
        Resize(config['input_h'], config['input_w']),

    ])

    val_transform = Compose([
        Resize(config['input_h'], config['input_w']),

    ])
    print('train_img_ids',len(train_img_ids))
    print('val_img_ids',len(val_img_ids))

    train_dataset = Dataset(
        img_ids=train_img_ids,
        img_dir=os.path.join('', config['dataset'], 'train/','images/'),
        mask_dir=os.path.join('', config['dataset'], 'train/','masks/'),
        img_ext=config['img_ext'],
        mask_ext=config['mask_ext'],
        num_classes=config['num_classes'],
        transform=train_transform, train = True)
    val_dataset = Dataset(
        img_ids=val_img_ids,
        img_dir=os.path.join('', config['dataset'], 'val/','images/'),
        mask_dir=os.path.join('', config['dataset'], 'val/','masks/'),
        img_ext=config['img_ext'],
        mask_ext=config['mask_ext'],
        num_classes=config['num_classes'],
        transform=val_transform, train = False)

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=config['batch_size'],
        shuffle=True,
        num_workers=config['num_workers'],
        drop_last=False)
    val_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=1,

        shuffle=False,
        num_workers=config['num_workers'],
        drop_last=False)

    log = OrderedDict([
        ('epoch', []),
        ('lr', []),
        ('loss', []),
        ('iou', []),
        ('val_loss', []),
        ('val_iou', []),
        ('val_dice', []),
    ])

    best_iou = 0
    best_dice = 0
    trigger = 0
    for epoch in range(config['epochs']):
        print('Epoch [%d/%d]' % (epoch, config['epochs']))

        # train for one epoch
        train_log = train(config, train_loader, model, criterion, optimizer)
        # evaluate on validation set
        val_log = validate(config, val_loader, model, criterion)
        
       
        scheduler.step()


        print('loss %.4f - iou %.4f - val_loss %.4f - val_iou %.4f'
              % (train_log['loss'], train_log['iou'], val_log['loss'], val_log['iou']))

        log['epoch'].append(epoch)
        log['lr'].append(config['lr'])
        log['loss'].append(train_log['loss'])
        log['iou'].append(train_log['iou'])
        log['val_loss'].append(val_log['loss'])
        log['val_iou'].append(val_log['iou'])
        log['val_dice'].append(val_log['dice'])

        pd.DataFrame(log).to_csv('models/%s/log.csv' %
                                 config['name'], index=False)

        trigger += 1

        if val_log['iou'] > best_iou:
            torch.save(model, 'models/%s/model.pth' %
                       config['name'])
            best_iou = val_log['iou']
            best_dice = val_log['dice']
            print("=> saved best model")
            trigger = 0
        print('best_iou', best_iou)
        print('best_dice', best_dice)

        # early stopping
        if config['early_stopping'] >= 0 and trigger >= config['early_stopping']:
            print("=> early stopping")
            break

        torch.cuda.empty_cache()


if __name__ == '__main__':
    main()

In [None]:
import shutil

shutil.copy2('/kaggle/input/losses-dataset/losses.py', '/kaggle/working/')
#COpying file from input dir to working dir 

In [None]:
import shutil
import os

folder_path = '/kaggle/working/models'

if os.path.exists(folder_path):
    shutil.rmtree(folder_path)
    print(f"Deleted: {folder_path}")
else:
    print(f"Folder not found: {folder_path}")


In [None]:
import os

# Replace 'filename.txt' with your file name
file_path = '/kaggle/working/losses_s.py'

if os.path.exists(file_path):
    os.remove(file_path)
    print(f"Deleted: {file_path}")
else:
    print(f"File not found: {file_path}")#deleting all file from kaggle working directory

In [None]:
import argparse #evaluate file 
import os
from glob import glob

import cv2
import torch
import torch.backends.cudnn as cudnn
import yaml
from albumentations.augmentations import transforms
from albumentations.core.composition import Compose
from sklearn.model_selection import train_test_split
from tqdm import tqdm

import archs_ucm

from dataset1 import Dataset
#from metrics import iou_score1,iou_score
#from utils import AverageMeter
from albumentations import RandomRotate90,Resize
import time

import numpy as np
from tqdm import tqdm
import torch
from torch.cuda.amp import autocast as autocast
from sklearn.metrics import confusion_matrix
import torch
from thop import profile

import shutil
import os


import argparse
import logging
import os
import random
import numpy as np
import torch
import torch.backends.cudnn as cudnn


import torch
import torch.nn as nn
def estimate_model_inference_memory_usage(model, val_loader, name ='UCM_Net',device='cpu'):
    model.to(device)
    parameter_memory = 0

    # Calculate memory used by model parameters
    for param in model.parameters():
        parameter_memory += param.nelement() * param.element_size()

    # Estimate input tensor memory
   
    

    # Perform a forward pass to estimate output memory (without gradients)
    with torch.no_grad():
        for input, target, meta in tqdm(val_loader, total=1):
            model.eval()
            input_memory = input.nelement() * input.element_size()
            if name== 'EGEUNet':
                pre,output = model(input)
                output_memory1 = sum([o.nelement() * o.element_size() for o in pre])
                output_memory = output.nelement() * output.element_size() +output_memory1
              #  iou,dice = iou_score1(output, target)
          
            else: 
                output = model(input,inference_mode=True)
               # iou,dice = iou_score(output, target)
                output = torch.sigmoid(output)
            break
        
       # output_tensor = model(input_tensor.to(device))
    output_memory = output.nelement() * output.element_size() 

    # Convert bytes to megabytes
    total_memory_MB = (parameter_memory + input_memory + output_memory) / (1024 ** 2)

    print(f"Estimated total memory usage during inference: {total_memory_MB:.4f} MB")

def fuse_conv_bn(conv, bn):
    """
    This function fuses a convolution layer with a batch normalization layer.
    
    Parameters:
    - conv (nn.Conv2d): The convolutional layer.
    - bn (nn.BatchNorm2d): The batch normalization layer.
    
    Returns:
    - nn.Conv2d: The fused convolutional layer.
    """
    # Step 1: Extract the parameters from BatchNorm
    bn_mean = bn.running_mean
    bn_var_sqrt = torch.sqrt(bn.running_var + bn.eps)
    bn_weight = bn.weight
    bn_bias = bn.bias
    
    # Step 2: Adjust the Conv2D weight and bias
    conv_weight = conv.weight.clone().view(conv.out_channels, -1)
    conv_weight = bn_weight / bn_var_sqrt.view(-1, 1) * conv_weight
    conv_weight = conv_weight.view(conv.weight.size())
    conv_bias = bn_bias - bn_weight * bn_mean / bn_var_sqrt
    
    if conv.bias is not None:
        conv_bias += conv.bias
        
    # Step 3: Create a new Conv2D layer with the fused parameters
    fused_conv = nn.Conv2d(in_channels=conv.in_channels,
                           out_channels=conv.out_channels,
                           kernel_size=conv.kernel_size,
                           stride=conv.stride,
                           padding=conv.padding,
                           dilation=conv.dilation,
                           groups=conv.groups,
                           bias=True)
    fused_conv.weight = nn.Parameter(conv_weight)
    fused_conv.bias = nn.Parameter(conv_bias)
    
    return fused_conv

def fuse_model(model):
    """
    This function recursively fuses Conv2D and BatchNorm2D layers in the model.
    
    Parameters:
    - model (torch.nn.Module): The PyTorch model.
    """
    for child_name, child in model.named_children():
        if isinstance(child, nn.Conv2d):
            # Check if the next layer is BatchNorm2D
            successor = next(model.named_children())[1]
            if isinstance(successor, nn.BatchNorm2d):
                # Fuse Conv2D and BatchNorm2D
                fused_conv = fuse_conv_bn(child, successor)
                setattr(model, child_name, fused_conv)
                # You might want to remove or replace the successor layer, e.g., with nn.Identity()
        else:
            # Recursively apply to children
            fuse_model(child)

def parse_args():
    parser = argparse.ArgumentParser()

    parser.add_argument('--name', default=UCM_Net,
                        help='UCM_Net')
    # dataset
    parser.add_argument('--dataset', default='/kaggle/input/isisc-2018/isic2018',
                        help='dataset name')
    parser.add_argument('--img_ext', default='.png',
                        help='image file extension')
    parser.add_argument('--mask_ext', default='.png',
                        help='mask file extension')

    

    parser.add_argument('--input_channels', default=3, type=int,
                        help='input channels')
    parser.add_argument('--num_classes', default=1, type=int,
                        help='number of classes')
    parser.add_argument('--input_w', default=256, type=int,
                        help='image width')
    parser.add_argument('--input_h', default=256, type=int,
                        help='image height')
    parser.add_argument('--arch', default='transfuse', type=str,
                        help='model')
    parser.add_argument('-b', '--batch_size', default=1, type=int,
                        metavar='N', help='mini-batch size (default: 8)')
    parser.add_argument('--path',default ='no', type = str)
    parser.add_argument('--cfg', type=str, metavar="FILE", help='path to config file', )
    parser.add_argument(
        "--opts",
        help="Modify config options by adding 'KEY VALUE' pairs. ",
        default=None,
        nargs='+',
    )
    parser.add_argument('--num_workers', default=4, type=int)
  #  parser.add_argument('--cfg', type=str, required=True, metavar="FILE", help='path to config file' )
    parser.add_argument('--zip', action='store_true', help='use zipped dataset instead of folder dataset')
    parser.add_argument('--cache-mode', type=str, default='part', choices=['no', 'full', 'part'],
                    help='no: no cache, '
                            'full: cache all data, '
                            'part: sharding the dataset into nonoverlapping pieces and only cache one piece')
    
    parser.add_argument('--resume', help='resume from checkpoint')
    parser.add_argument('--accumulation-steps', type=int, help="gradient accumulation steps")
    parser.add_argument('--use-checkpoint', action='store_true',
                        help="whether to use gradient checkpointing to save memory")
    parser.add_argument('--amp-opt-level', type=str, default='O1', choices=['O0', 'O1', 'O2'],
                        help='mixed precision opt level, if O0, no amp is used')
    parser.add_argument('--tag', help='tag of experiment')
    parser.add_argument('--eval', action='store_true', help='Perform evaluation only')
    parser.add_argument('--throughput', action='store_true', help='Test throughput only')    
    #args = parser.parse_args()
    args, unknown = parser.parse_known_args()

    return args



def compute_gflops(model, input_size):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    config = vars(parse_args())
    input = torch.randn(1, 3, input_size, input_size)
    if config['arch'].strip() =='TransFuse_S':
        input = torch.randn(1, 3, 192,256)
    if config['arch'] =='TransUNet':
        input = torch.randn(1, 3, 224,224)
    #input = torch.randn(1, 3, 192,256) ## TransFuse_S
    print(input.shape,config['arch'] =='TransFuse_S',config['arch'])
    input = input.to(device)
    macs, params = profile(model, inputs=(input, ))
    gflops = macs / (10**9)
    return gflops


def main():
    config = vars(parse_args())
 
    print(config)

    if config['arch'] !='TransFuse_S':

        config_path = '/kaggle/working/models/kaggle/input/isisc-2018/isic2018_UCM_Net_woDS/config.yml'   # <-- write your config file path here
        with open(config_path, 'r') as f:
        #with open('models/%s/config.yml' % config['name'], 'r') as f:
            config = yaml.load(f, Loader=yaml.FullLoader)
        

    print('-'*20)
    for key in config.keys():
        print('%s: %s' % (key, str(config[key])))
    print('-'*20)

    cudnn.benchmark = True

    print("=> creating model %s" % config['arch'])
    if config['arch']== 'EGEUNet':
        model = egeunet.EGEUNet()

        
    else:
        model = archs_ucm.__dict__[config['arch']](config['num_classes'],
                                           config['input_channels'],
                                           config['deep_supervision'])
     
    
    gflops = compute_gflops(model,config['input_h'])
    print(f'GigaFLOPs: {gflops}')
    pytorch_total_params = sum(p.numel() for p in model.parameters())
    print(pytorch_total_params)
   
    if config['arch'] =='TransFuse_S':
        model = TransFuse_S(pretrained=True).cuda()

    
    


    
    
    
   

    
    val_img_ids = glob(os.path.join('', config['dataset'], 'val/','images/', '*' + config['img_ext']))

    val_img_ids = [os.path.splitext(os.path.basename(p))[0] for p in val_img_ids]
   
 

    
    
 

    val_transform = Compose([
        Resize(config['input_h'], config['input_w']),
      
       
    ])
  



    val_dataset = Dataset(
        img_ids=val_img_ids,
        img_dir=os.path.join('', config['dataset'], 'val/','images/'),
        mask_dir=os.path.join('', config['dataset'], 'val/','masks/'),
        img_ext=config['img_ext'],
        mask_ext=config['mask_ext'],
        num_classes=config['num_classes'],
        transform=val_transform, train = False)


    val_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=1,

        shuffle=False,
        num_workers=config['num_workers'],
        drop_last=False,
    pin_memory=True)

  
    #model=torch.load('models/%s/model.pth' % config['name'])
    model_path = '/kaggle/working/models/kaggle/input/isisc-2018/isic2018_UCM_Net_woDS/model.pth'    # <-- write your model path here
    model = torch.load(model_path)
    model.eval()
    model = model.cuda()


    
    iou_avg_meter = AverageMeter()
    dice_avg_meter = AverageMeter()
    f1_avg_meter = AverageMeter()
    gput = AverageMeter()
    cput = AverageMeter()

    count = 0

    preds = []
    gts = []
    with torch.no_grad():
        for input, target, meta in tqdm(val_loader, total=len(val_loader)):
            gts.append(target.squeeze(1).cpu().detach().numpy())
           # input, target = input.cuda(non_blocking=True).float(), target.cuda(non_blocking=True).float()
            input = input.cuda()
            target = target.cuda()
            
            # compute output
            

            if config['arch']== 'EGEUNet':
                pre,output = model(input)
                iou,dice = iou_score1(output, target)

            else: 
                pre,output = model(input)
                iou,dice = iou_score(output, target)
                output = torch.sigmoid(output)



            iou_avg_meter.update(iou, input.size(0))
           
            dice_avg_meter.update(dice, input.size(0))
           
            
            output1 = output.squeeze(1).cpu().detach().numpy()
            preds.append(output1) 
            output = output.cpu().numpy()
            output[output>=0.5]=1
            output[output<0.5]=0
            


    print('IoU: %.8f' % iou_avg_meter.avg)
    print('Dice: %.8f' % dice_avg_meter.avg)
    


    #model=torch.load('models/%s/modelmiou1.pth' %config['name'])
    model_path = '/kaggle/working/models/kaggle/input/isisc-2018/isic2018_UCM_Net_woDS/modelmiou1.pth'    # <-- write your model path here
    model = torch.load(model_path)
    

    model.eval()
    model = model.cuda()



    count = 0

    preds = []
    gts = []
    with torch.no_grad():
        for input, target, meta in tqdm(val_loader, total=len(val_loader)):
            gts.append(target.squeeze(1).cpu().detach().numpy())
           # input, target = input.cuda(non_blocking=True).float(), target.cuda(non_blocking=True).float()
            input = input.cuda()
            target = target.cuda()
            
            # compute output
            

            if config['arch']== 'EGEUNet':
                pre,output = model(input)
                iou,dice = iou_score1(output, target)
           
            else: 
                pre,output = model(input)
                iou,dice = iou_score(output, target)
                output = torch.sigmoid(output)
            '''
            iou_avg_meter.update(iou, input.size(0))
           
            dice_avg_meter.update(dice, input.size(0))
            '''
           
            
            output1 = output.squeeze(1).cpu().detach().numpy()
            preds.append(output1) 

                    
    preds = np.array(preds).reshape(-1)
    gts = np.array(gts).reshape(-1)
    #print(preds)
    
    y_pre = np.where(preds>=0.5, 1, 0)
    y_true = np.where(gts>=0.5, 1, 0)

    confusion = confusion_matrix(y_true, y_pre)
    TN, FP, FN, TP = confusion[0,0], confusion[0,1], confusion[1,0], confusion[1,1] 

    accuracy = float(TN + TP) / float(np.sum(confusion)) if float(np.sum(confusion)) != 0 else 0
    sensitivity = float(TP) / float(TP + FN) if float(TP + FN) != 0 else 0
    specificity = float(TN) / float(TN + FP) if float(TN + FP) != 0 else 0
    f1_or_dsc = float(2 * TP) / float(2 * TP + FP + FN) if float(2 * TP + FP + FN) != 0 else 0
    miou = float(TP) / float(TP + FP + FN) if float(TP + FP + FN) != 0 else 0
    print('miou*',miou)
    print('f1_or_dsc*',f1_or_dsc)
    print("accuracy",accuracy)
    
    
    
    

    model.eval()  # Set the model to evaluation mode
    #fuse_model(model) 

    # Measure the FPS
    #val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=4, pin_memory=True)
    start_time = time.time()

    with torch.no_grad():
        for input, target, meta in tqdm(val_loader, total=len(val_loader)):
            #gts.append(target.squeeze(1).cpu().detach().numpy())
           # input, target = input.cuda(non_blocking=True).float(), target.cuda(non_blocking=True).float()
            input = input.cuda()
            target = target.cuda()


            if config['arch']== 'EGEUNet':
                pre,output = model(input)
              #  iou,dice = iou_score1(output, target)

            else: 
                output = model(input,inference_mode=True)
               # iou,dice = iou_score(output, target)
                output = torch.sigmoid(output)


    end_time = time.time()
    elapsed_time = end_time - start_time
    fps = len(val_loader) / elapsed_time

    print(f"FPS: {fps:.4f}")

    
    torch.cuda.empty_cache()
    estimate_model_inference_memory_usage(model,  val_loader,name = config['arch'],device='cpu')

    return iou_avg_meter.avg,dice_avg_meter.avg, miou,f1_or_dsc,accuracy
if __name__ == '__main__':
    main()