In [21]:
import torch 
import torch.nn as nn
import torch.functional as f
from torchsummary import summary

In [90]:
device = torch.device('mps')
torch.set_default_device(device=device)
torch.get_default_device()

device(type='mps', index=0)

# Loss functions
- For this code to reach optimal performance we need 2 losses that we will combine 
- One of them is Dice Loss and the other is Focal loss 
- Dice loss handles class imbalance 
- Focal loss does other stuff 


## Focal Loss 

In [91]:
class FocalLoss(nn.Module):
    def __init__(self, gamma=2, alpha=None, reduction='mean', task_type='mutli-class', num_classes=5):
        """
        Unified Focal Loss class for binary, multi-class, and multi-label classification tasks.
        :param gamma: Focusing parameter, controls the strength of the modulating factor (1 - p_t)^gamma
        :param alpha: Balancing factor, can be a scalar or a tensor for class-wise weights. If None, no class balancing is used.
        :param reduction: Specifies the reduction method: 'none' | 'mean' | 'sum'
        :param task_type: Specifies the type of task: 'binary', 'multi-class', or 'multi-label'
        :param num_classes: Number of classes (only required for multi-class classification)
        """
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.alpha = alpha
        self.reduction = reduction
        self.task_type = task_type
        self.num_classes = num_classes

        # Handle alpha for class balancing in multi-class tasks
        if task_type == 'multi-class' and alpha is not None and isinstance(alpha, (list, torch.Tensor)):
            assert num_classes is not None, "num_classes must be specified for multi-class classification"
            if isinstance(alpha, list):
                self.alpha = torch.Tensor(alpha)
            else:
                self.alpha = alpha

    def forward(self, inputs, targets):
        """
        Forward pass to compute the Focal Loss based on the specified task type.
        :param inputs: Predictions (logits) from the model.
                       Shape:
                         - binary/multi-label: (batch_size, num_classes)
                         - multi-class: (batch_size, num_classes)
        :param targets: Ground truth labels.
                        Shape:
                         - binary: (batch_size,)
                         - multi-label: (batch_size, num_classes)
                         - multi-class: (batch_size,)
        """
        if self.task_type == 'binary':
            return self.binary_focal_loss(inputs, targets)
        elif self.task_type == 'multi-class':
            return self.multi_class_focal_loss(inputs, targets)
        elif self.task_type == 'multi-label':
            return self.multi_label_focal_loss(inputs, targets)
        else:
            raise ValueError(
                f"Unsupported task_type '{self.task_type}'. Use 'binary', 'multi-class', or 'multi-label'.")

    def binary_focal_loss(self, inputs, targets):
        """ Focal loss for binary classification. """
        probs = torch.sigmoid(inputs)
        targets = targets.float()

        # Compute binary cross entropy
        bce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')

        # Compute focal weight
        p_t = probs * targets + (1 - probs) * (1 - targets)
        focal_weight = (1 - p_t) ** self.gamma

        # Apply alpha if provided
        if self.alpha is not None:
            alpha_t = self.alpha * targets + (1 - self.alpha) * (1 - targets)
            bce_loss = alpha_t * bce_loss

        # Apply focal loss weighting
        loss = focal_weight * bce_loss

        if self.reduction == 'mean':
            return loss.mean()
        elif self.reduction == 'sum':
            return loss.sum()
        return loss

    def multi_class_focal_loss(self, inputs, targets):
        """ Focal loss for multi-class classification. """
        if self.alpha is not None:
            alpha = self.alpha.to(inputs.device)

        # Convert logits to probabilities with softmax
        probs = F.softmax(inputs, dim=1)

        # One-hot encode the targets
        targets_one_hot = F.one_hot(targets, num_classes=self.num_classes).float()

        # Compute cross-entropy for each class
        ce_loss = -targets_one_hot * torch.log(probs)

        # Compute focal weight
        p_t = torch.sum(probs * targets_one_hot, dim=1)  # p_t for each sample
        focal_weight = (1 - p_t) ** self.gamma

        # Apply alpha if provided (per-class weighting)
        if self.alpha is not None:
            alpha_t = alpha.gather(0, targets)
            ce_loss = alpha_t.unsqueeze(1) * ce_loss

        # Apply focal loss weight
        loss = focal_weight.unsqueeze(1) * ce_loss

        if self.reduction == 'mean':
            return loss.mean()
        elif self.reduction == 'sum':
            return loss.sum()
        return loss

    def multi_label_focal_loss(self, inputs, targets):
        """ Focal loss for multi-label classification. """
        probs = torch.sigmoid(inputs)

        # Compute binary cross entropy
        bce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')

        # Compute focal weight
        p_t = probs * targets + (1 - probs) * (1 - targets)
        focal_weight = (1 - p_t) ** self.gamma

        # Apply alpha if provided
        if self.alpha is not None:
            alpha_t = self.alpha * targets + (1 - self.alpha) * (1 - targets)
            bce_loss = alpha_t * bce_loss

        # Apply focal loss weight
        loss = focal_weight * bce_loss

        if self.reduction == 'mean':
            return loss.mean()
        elif self.reduction == 'sum':
            return loss.sum()
        return loss

## Dice Loss 

In [92]:
def DiceLoss(pred, target):
    """This definition generalize to real valued pred and target vector.
    This should be differentiable.
    pred: tensor with first dimension as batch
    target: tensor with first dimension as batch
    """

    smooth = 1.

    # have to use contiguous since they may from a torch.view op
    iflat = pred.contiguous().view(-1)
    tflat = target.contiguous().view(-1)
    intersection = (iflat * tflat).sum()

    A_sum = torch.sum(tflat * iflat)
    B_sum = torch.sum(tflat * tflat)

    return 1 - ((2. * intersection + smooth) / (A_sum + B_sum + smooth) )

# Modules 
- We create all the necessary modules for unet to work 
- these include an Attention block (class) to be used in between the big decoder functions 
- a self attention block (class) that will be used to calculate the self attention 
- This will be used only for the final layer in the unet
- a double convolutional block - consisting of 2 conv encoders 
- an output conv block, that takes in the n_dim, h, w image and converts it to n_class, h, w image (final conv layer) 
- a down class that dictates the flow of the downward convolution (in this case the encoder section)
- an up class that dictates the flow of the upward convolution (or conv transpose) (in this case the decoder section)  

## Class Double conv 
- A class that has the double convolutional section that is instrumental to the unet encoder and decoder architecture 
- we set bias to false because batchnorm has its own bias so double the bias is useless
- `nn.Conv2d(in, out, kernel = 3, padding = 1, bias=False)` lets say the input is in x h x w 
- The output would be out x h x w (in this case)... 
- The general formula = `(h - kernel + 2p) / stride + 1 = h when stride = 1 (default)` 
- In this case the output and input shape are the same but the number of filters keeps changing. 
- Usually we want to reduce the filter count for every conv block  

In [86]:
class DoubleConv(nn.Module): 
    def __init__(self, in_channel, out_channel, mid_channels = None, bias = False):
        super().__init__()
        if not mid_channels: 
            mid_channels = out_channel
        self.doubleconv = nn.Sequential(
            # first convolutional layer
            nn.Conv2d(in_channels= in_channel, out_channels= mid_channels, padding= 1, kernel_size= 3, bias= bias),
            nn.BatchNorm2d(mid_channels), 
            nn.ReLU(inplace=True), 
            
            # Second convolutional layer
            nn.Conv2d(in_channels=mid_channels, out_channels=out_channel, padding=1, kernel_size=3, bias= bias),
            nn.BatchNorm2d(out_channel),
            nn.ReLU(inplace=True),
        ) 
        
    def forward(self, x):
        return self.doubleconv(x)
    

# Class Down 
- This is the class used to facilitate the encoder class
- At the moment, after the encoder we only need to run the maxpool2d after it  
- So we run the doubleconv layer in the down class and then run the maxpool2D after this 
- Return the maxpool output for the next down layer 
- Return the encoder output for the attention layer 

In [31]:
class Down(nn.Module): 
    def __init__(self, in_channel, out_channel):
        super().__init__()
        self.in_channel = in_channel
        self.out_channel = out_channel
        self.maxpool_layer = nn.MaxPool2d(kernel_size= 2)
        
        
    def forward(self, x):
        feat = DoubleConv(in_channel=self.in_channel, out_channel=self.out_channel)(x)
        return self.maxpool_layer(feat) # return the maxpool output and the encoder output  
    

## Self Attention block 
- Self attention of each encoder level 
- A lot of computation required 

### Query value and key calculations 
- First to calculate attention 
- Attention = query x key 
- Our query is of the shape -> `batch_size, Channels, W, H` 
- Flatten it to -> `Batch_size, Channels, H*W` (channels is in_dim // 8)  
- Key is also of the shape -> `batch_size, Channels, H*W` (Channels is in_dim // 8) 
- Flatten Key -> `batch_size, channels, H*W`
- Torch.batch matrix multiplication -> `Query * key` = attention at batch level 
- To perform this batch level multiplication convert query dim to Batch_size, N, Channels 
- Attention -> `Batch_size, N, N`   
- REMEMBER THAT ATTENTION NEEDS A SOFTMAX AT THE END
- Now sum (attention * value) is what we need (this is essentially a V * A^T
- Value -> `B x 64 x N` Attention permute to -> `B x N x N` (This is essentially a transpose)    
- In matrix form that becomes `torch.bmm(value, attention.T)` (The T is only for N x N not the batch)
- You can use `nn.permute(0,2,1)` to do that

In [55]:
class SelfAttention(nn.Module): 
    def __init__(self, in_dim):
        super().__init__()
        self.channel = in_dim
        
        self.query_conv = nn.Conv2d(in_channels= in_dim, out_channels= in_dim // 8, kernel_size= 1)
        self.key_conv = nn.Conv2d(in_channels= in_dim, out_channels= in_dim // 8, kernel_size= 1)
        self.value_conv = nn.Conv2d(in_channels= in_dim, out_channels= in_dim, kernel_size= 1)
        
        self.gamma = nn.Parameter(torch.zeros(1))
        
    def forward(self, x):
        batch_size, C, width, height = x.size()
        # this is essentially query transpose without touching the batch 
        query_projection = self.query_conv(x).view(batch_size, -1, width* height).permute(0,2,1)
        key_projection = self.key_conv(x).view(batch_size, -1, width* height)
        
        # energy calculation -> energy is just attention before softmax
        energy = torch.bmm(query_projection, key_projection)
        attention = torch.softmax(energy, dim= -1)
        
        # Calculate Value projection 
        value_projection = self.value_conv(x).view(batch_size, -1, width* height)
        
        # calculate self attention mask 
        out = torch.bmm(value_projection, attention.permute(0,2,1))
        
        # reshape mask to fit width, height 
        out = out.view(batch_size, C, width, height)
        
        # gamma stuff 
        out = self.gamma * out + x
        
        return out, attention
        

## Attention Block 
- gate (g) = decoder feature - f_g channel count of decoder  
- x (skip) = encoder level - f_x channel count of skip level encoder 
- F_int = reduced channel count which is f_x // 2
- weight of g => a conv block f_g to f_int with kernel, stride = 1, padding =0, bias = False (because batchnorm is true)
- Same for x
- Add the conv x and conv g with a relu  
- a conv layer f_int to 1 
- batch norm then sigmoid 

In [56]:
class Attention_block(nn.Module): 
    def __init__(self, encoder_channels, decoder_channels, intermediate_channels = None):
        super().__init__()
        if not intermediate_channels: 
            intermediate_channels = decoder_channels // 2
        
        self.W_gate = nn.Sequential(
            nn.Conv2d(in_channels= decoder_channels, out_channels= intermediate_channels, kernel_size= 1, stride=1, padding= 0),
            nn.BatchNorm2d(intermediate_channels),
        )
        
        self.W_encoder = nn.Sequential(
            nn.Conv2d(in_channels= encoder_channels, out_channels= intermediate_channels, kernel_size= 1, stride= 1, padding= 0),
            nn.BatchNorm2d(intermediate_channels),
        )
        
        self.psi = nn.Sequential(
            nn.Conv2d(in_channels= intermediate_channels, out_channels= 1, kernel_size= 1, stride = 1, padding= 0),
            nn.Sigmoid(),
        )
        
        self.relu = nn.ReLU(inplace=True)
        
    def forward(self, enc, dec):
        dec_conv = self.W_gate(dec)
        enc_conv = self.W_encoder(enc)
        
        psi = self.relu(dec_conv + enc_conv)
        psi = self.psi(psi)
        
        return enc * psi


## Up function for self attention 
- This is the class used to facilitate the decoder class 
- The decoder class uses `nn.Convtranspose2D(inchannel, outchannel, kernel_size = 2, stride = 2)` which ensures that it goes from in x h x w -> out x 128 x 128 
- There are 2 versions, one with x1 and x2. We experiment 
- Then we need to concat the encoder and the decoder output 
- Uses SELF attention modules 
- Can be used for ablation studies
- Cannot be used for the hybrid model 

In [57]:
class self_attention_Up(nn.Module): 
    def __init__(self, in_channel, out_channel, bilinear= False):
        super().__init__()
        if bilinear: 
            self.up = nn.Upsample(scale_factor= 2, mode= 'bilinear', align_corners= True)
            self.conv = DoubleConv(in_channel= in_channel, out_channel= out_channel, mid_channels=in_channel // 2)
        else:
            self.up = nn.ConvTranspose2d(in_channels= in_channel, out_channels= in_channel // 2, kernel_size= 2, 
               stride= 2)
            self.conv = DoubleConv(in_channel= in_channel, out_channel= out_channel) # NOTE here you pass the in channel as the original in channel which might work if decoder has like 64 channel and encoder has 32 channel 
        
    def forward(self, decoder, encoder):
        # here encoder is the output of the encoder at that level of the unet 
        # decoder is the output of that we are supposed to feed into the decoder,
        # usually from the previous decoder segment or the code layer
        decoder = self.up(decoder)
        decoder = torch.cat([decoder, encoder], dim=1)
        
        return self.conv(decoder) 
        

## Up function for attenion block unet 
- This is the up function for the attention block unet 
- The reason for writing this seperately is because the previous code concated the attention masked decoder 
- Directly Considering bilinear = True

In [79]:
class Attention_block_up_class(nn.Module): 
    def __init__(self, in_channel, out_channel):
        super().__init__()
        self.up = nn.Sequential(
            nn.Upsample(scale_factor=2, mode= 'bilinear', align_corners= True),
            DoubleConv(in_channel=in_channel, out_channel=out_channel),
        )
    
    def forward(self, x):
        x = self.up(x)
        return x
        

## Out convolutional Layer 
- THe final convolutional layer 

In [58]:
class outConv(nn.Module): 
    def __init__(self, in_channel, out_channel):
        super().__init__()
        self.conv = nn.Conv2d(in_channels= in_channel, out_channels= out_channel, kernel_size=1)
        
    def forward(self, x):
        return self.conv(x)

# UNET class 
- Use the above functions to code a unet with attention 
- Down class for the encoder
- Up class for the decoder
- attention block class to calculate attention 
- worry about self attention later  

### Encoder Class
- Unet has 4 encoder channels, 4 double conv blocks that is (in our code the class called down)
- first block = 3 -> 64 (3 channels to 64 channels {64 filters})
- Second block = 64 -> 128
- Third block = 128 -> 256 
- Fourth block = 256 -> 512
- Fifth block = 512 -> 1024 `CODE BLOCK`

### Self Attention block - NOT IMPLEMENTED YET 
- Here we use the self attention block to calculate self attention at the code layer  
- We can technically expand self attention to all the layers but it is extremely memory intensive and only needed for like really clean datasets 

### Decoder Layer 
- This layer is also known as the class up in our code. 
- Takes the encoded image and decodes. 
- While it decodes we will feed it attention masked skip encoder values 
- This is done using the attention block function, returns the mask 
- Each block reduces the filter count 
- upconv1 => 512 in channels -> 256 out channels 
- upconv2 => 256 in channels -> 128 out channels
- upconv3 => 128 in channels -> 64 out channels 
- upconv4 => 64 out channels -> `n_classes` out channels 


#### NOTE 
taking an argument for the mid class layers might make this code more changeable 

In [67]:
class Self_Attention_Unet(nn.Module): 
    def __init__(
            self,
            image_channels,
            n_classes,
            mid_layers = [64,128, 256, 512, 1024], 
            bilinear= True,
            ):
        super().__init__()
        
        
        
        # Encoder section of the network
        factor = 2 if bilinear else 1
        
        self.down1 = Down(image_channels, mid_layers[0]) # first layer -> input channel = 3, output channel = 64
        self.down2 = Down(mid_layers[0], mid_layers[1]) # second layer -> input channel = 64, output channel = 128
        self.down3 = Down(mid_layers[1], mid_layers[2])
        self.down4 = Down(mid_layers[2], mid_layers[3])
        
        # calculate all the self attention values 
        self.attention1 = SelfAttention(mid_layers[0])
        self.attention2 = SelfAttention(mid_layers[1])
        self.attention3 = SelfAttention(mid_layers[2])
        self.attention4 = SelfAttention(mid_layers[3])
        
        
        # code block 
        self.down_code = Down(mid_layers[3], mid_layers[4] // factor)
        
        # self attention 
        # ADD SELF ATTENTION for code layer 
        
        
        
        # decoder block
        # if you use bilinear section, the conv2DTranspose is not invoked in favor of upsampling 
        # this means that you need the conv block to reduce the number of filters 
        # here is where we introduce a term called factor which is equal to the scaling factor 
        # in this case it is 2 
        # that is the factor variable defined in the beginning
        
        self.up1 = self_attention_Up(mid_layers[4], mid_layers[3] // factor, bilinear)
        self.up2 = self_attention_Up(mid_layers[3], mid_layers[2] // factor, bilinear)
        self.up3 = self_attention_Up(mid_layers[2], mid_layers[1] // factor, bilinear)
        self.up4 = self_attention_Up(mid_layers[1], mid_layers[0] // factor, bilinear)
        
        # outputlayers
        self.outc = outConv(mid_layers[0] // factor, n_classes)
        
    def forward(self, x):
        # you cant create a seperate encoder, decoder function because you need the x1, x2 and so on values for attention calculation   
        # encoder calculations 
        x1 = self.down1(x)
        x2 = self.down2(x1)
        x3 = self.down3(x2)
        x4 = self.down4(x3)
        x5 = self.down_code(x4)
        
        # x5 is the code layer
        
        # calculate attention
        
        v1, _ = self.attention1(x1) 
        v2, _ = self.attention2(x2)
        v3, _ = self.attention3(x3)
        v4, _ = self.attention4(x4)
        
    
        # Decoding 
        x = self.up1(x5, v4) # code layer and the final layer of attention 
        x = self.up2(x, v3) # decoded layer 1 (which is x) and 2nd to final layer of attention
        x = self.up3(x, v2)
        x = self.up4(x, v1)
        
        # calculate output
        logits = self.outc(x)
        return logits
        
        

In [72]:
model = Self_Attention_Unet(3,5)

In [69]:
model

Self_Attention_Unet(
  (down1): Down(
    (maxpool_layer): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (down2): Down(
    (maxpool_layer): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (down3): Down(
    (maxpool_layer): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (down4): Down(
    (maxpool_layer): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (attention1): SelfAttention(
    (query_conv): Conv2d(64, 8, kernel_size=(1, 1), stride=(1, 1))
    (key_conv): Conv2d(64, 8, kernel_size=(1, 1), stride=(1, 1))
    (value_conv): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
  )
  (attention2): SelfAttention(
    (query_conv): Conv2d(128, 16, kernel_size=(1, 1), stride=(1, 1))
    (key_conv): Conv2d(128, 16, kernel_size=(1, 1), stride=(1, 1))
    (value_conv): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1))
  )
  (attention3): SelfAttention(
    

In [87]:
class Unet_with_attention_block(nn.Module): 
    def __init__(
            self, 
            image_channels, 
            n_classes, 
            mid_layers = [64, 128, 256, 512, 1024],
            bilinear = True      
            ):
        super().__init__()
        # we dont need factor here because of how the attention layer is coded 
    
        # Encoder 
        self.Conv1 = Down(in_channel= image_channels, out_channel= mid_layers[0])
        self.Conv2 = Down(in_channel= mid_layers[0], out_channel= mid_layers[1])
        self.Conv3 = Down(in_channel= mid_layers[1], out_channel= mid_layers[2])
        self.Conv4 = Down(in_channel= mid_layers[2], out_channel= mid_layers[3])
        self.Conv_code = Down(in_channel= mid_layers[3], out_channel= mid_layers[4])
        
        # decoder 
        # Upsampling including attention, then convolutional layer to reduce the size 
        # this is one block 
        self.up5 = Attention_block_up_class(in_channel= mid_layers[4], out_channel= mid_layers[3])
        self.att5 = Attention_block(encoder_channels= mid_layers[3], decoder_channels= mid_layers[3])
        self.up_conv5 = DoubleConv(in_channel= mid_layers[4], out_channel= mid_layers[3])
        
        # second upsampling 
        self.up4 = Attention_block_up_class(in_channel= mid_layers[3], out_channel= mid_layers[2])
        self.att4 = Attention_block(encoder_channels= mid_layers[2], decoder_channels=mid_layers[2])
        self.up_conv4 = DoubleConv(in_channel= mid_layers[3], out_channel=mid_layers[2])
         
        # Third upsampling 
        self.up3 = Attention_block_up_class(in_channel= mid_layers[2], out_channel= mid_layers[1])
        self.att3 = Attention_block(encoder_channels= mid_layers[1], decoder_channels=mid_layers[1])
        self.up_conv3 = DoubleConv(in_channel= mid_layers[2], out_channel= mid_layers[1])
        
        # Fourth upsampling 
        self.up2 = Attention_block_up_class(in_channel= mid_layers[1], out_channel= mid_layers[0])
        self.att2 = Attention_block(encoder_channels= mid_layers[0], decoder_channels= mid_layers[0])
        self.up_conv2 = DoubleConv(in_channel= mid_layers[1], out_channel= mid_layers[0])
        
        # final output convolutional block 
        self.outc= outConv(mid_layers[0], n_classes)
        
        
    def forward(self, x):
        # Encoder traversal
        x1 = self.Conv1(x)
        x2 = self.Conv2(x1)
        x3 = self.Conv3(x2)
        x4 = self.Conv4(x3)
        x5 = self.Conv_code(x4)
        
        # Decoding and attention concat
        # Upsample x5, attention w respect to x4 and d5 
        # concat 
        
        # first block
        d5 = self.up5(x5)
        x4 = self.att5(d5, x4) # calculate attention block 
        d5 = torch.cat((d5, x4), dim=1)
        d5 = self.up_conv5(d5)
        
        # second block 
        d4 = self.up4(d5)
        x3 = self.att4(d4, x3) # calculate attention block 
        d4 = torch.cat((d4, x3), dim=1)
        d4 = self.up_conv4(d4)

        # third block
        d3 = self.up3(d4)
        x2 = self.att3(d3, x2) # calculate attention block 
        d3 = torch.cat((d3, x2), dim=1)
        d3 = self.up_conv3(d3)
        
        # second block 
        d2 = self.up2(d3)
        x1 = self.att2(d2, x1) # calculate attention block 
        d2 = torch.cat((d2, x1), dim=1)
        d2 = self.up_conv2(d2)

        logits = self.outc(d2)
        
        return logits 
        


In [88]:
model = Unet_with_attention_block(image_channels=6, n_classes=5)
model 

Unet_with_attention_block(
  (Conv1): Down(
    (maxpool_layer): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (Conv2): Down(
    (maxpool_layer): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (Conv3): Down(
    (maxpool_layer): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (Conv4): Down(
    (maxpool_layer): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (Conv_code): Down(
    (maxpool_layer): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (up5): Attention_block_up_class(
    (up): Sequential(
      (0): Upsample(scale_factor=2.0, mode='bilinear')
      (1): DoubleConv(
        (doubleconv): Sequential(
          (0): Conv2d(1024, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inp

In [89]:
summary(model, input_size=(6, 256,256))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
         MaxPool2d-1         [-1, 64, 128, 128]               0
              Down-2         [-1, 64, 128, 128]               0
         MaxPool2d-3          [-1, 128, 64, 64]               0
              Down-4          [-1, 128, 64, 64]               0
         MaxPool2d-5          [-1, 256, 32, 32]               0
              Down-6          [-1, 256, 32, 32]               0
         MaxPool2d-7          [-1, 512, 16, 16]               0
              Down-8          [-1, 512, 16, 16]               0
         MaxPool2d-9           [-1, 1024, 8, 8]               0
             Down-10           [-1, 1024, 8, 8]               0
         Upsample-11         [-1, 1024, 16, 16]               0
           Conv2d-12          [-1, 512, 16, 16]       4,718,592
      BatchNorm2d-13          [-1, 512, 16, 16]           1,024
             ReLU-14          [-1, 512,