# FuseNet: Self-Supervised Dual-Path Network for Medical Image Segmentation <br>

In [None]:
!git clone https://github.com/mindflow-institue/FuseNet.git
%cd ./FuseNet

In [None]:
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as T

import cv2
import sys
import os
import numpy as np
import random
import glob
from matplotlib import pyplot as plt

from utils import read_image, dice_metric, xor_metric, hm_metric, create_mask, cross_entropy
from model_utils import Encoder, ProjectionHead, MixFFN_skip, CrossAttentionBlock

from einops import rearrange
from einops.layers.torch import Rearrange

In [None]:
use_cuda = torch.cuda.is_available()

parser = argparse.ArgumentParser(description='FuseNet: Self-Supervised Dual-Path Network for Medical Image Segmentation')
parser.add_argument('--nChannel', metavar='N', default=64, type=int, 
                    help='number of channels')
parser.add_argument('--maxIter', metavar='T', default=50, type=int, 
                    help='number of maximum iterations')
parser.add_argument('--minLabels', metavar='minL', default=3, type=int, 
                    help='minimum number of labels')
parser.add_argument('--lr', metavar='LR', default=0.005, type=float, 
                    help='learning rate')

parser.add_argument('--input_path', metavar='INPUT', default='./input_images/', 
                    help='input image folder path')
parser.add_argument('--save_output', metavar='SAVE', default=True, 
                    help='whether to save output ot not')
parser.add_argument('--output_path', metavar='OUTPUT', default='./output/', 
                    help='output folder path')

parser.add_argument('--loss_ce_coef', metavar='CE', default=2.5, type=float, 
                    help='Cross entropy loss weighting factor')
parser.add_argument('--loss_clip_coef', metavar='AT', default=0.5, type=float, 
                    help='Clip loss weighting factor')
parser.add_argument('--loss_b_coef', metavar='Spatial', default=0.5, type=float, 
                    help='Boundary loss weighting factor')

args = parser.parse_args(args=[])

In [None]:
if args.save_output:
    SAVE_PATH = args.output_path
    os.makedirs(SAVE_PATH, exist_ok=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Loading Data

In [None]:
IMG_PATH = args.input_path
img_data = sorted(glob.glob(IMG_PATH + 'image/*'))
lbl_data = sorted(glob.glob(IMG_PATH + 'GT/*'))

In [None]:
len(img_data), len(lbl_data)

# Model

In [None]:
class Model(nn.Module):
    """
    Args:
        input_dim (int): Dimension of the input data.
        image_embed (int): Dimension of the image embeddings.
        augmented_embed (int): Dimension of the augmented image embeddings.
        input_size (tuple): Tuple representing the input size of the images (height, width).
        temperature (float): Temperature parameter to scale CLIP matrix.
        dropout (float): Dropout rate applied in the projection heads.
        beta (int): Downsampling factor.
        alpha (int): Scaling factor applied to the main path in the cross-attention block.
    """
    def __init__(self, input_dim, image_embed, augmented_embed, input_size=(256, 256),
                 temperature=5.0, dropout=0.1, beta=16, alpha=3):
        super(Model, self).__init__()
        
        input_H, input_W = input_size
        self.H = input_H
        
        self.beta = 16  # Downsampling factor
        self.alpha = 3  # Main path scaling factor
        self.img_enc = Encoder(input_dim, image_embed)
        self.aug_enc = Encoder(input_dim, image_embed)
        
        self.image_projection = ProjectionHead(embedding_dim=image_embed, projection_dim=image_embed, dropout=dropout)
        self.aug_projection = ProjectionHead(embedding_dim=augmented_embed, projection_dim=augmented_embed, dropout=dropout)
        self.temperature = temperature
        
        self.cross_attn = CrossAttentionBlock(in_channels=image_embed, key_channels=image_embed,
                                              value_channels=image_embed, height=input_H, width=input_W)
        
        
        self.patch_size = self.H//8 #32
        self.dim = image_embed
        patch_dim = self.dim * self.patch_size * self.patch_size
        
        self.to_patch_embedding_img = nn.Sequential(
            Rearrange('b (h p1) (w p2) c -> b (h w) (p1 p2 c)', p1 = self.patch_size, p2 = self.patch_size),
            nn.Linear(patch_dim, self.dim))
        
        self.to_patch_embedding_aug = nn.Sequential(
            Rearrange('b (h p1) (w p2) c -> b (h w) (p1 p2 c)', p1 = self.patch_size, p2 = self.patch_size),
            nn.Linear(patch_dim, self.dim))    
        
        self.bn1 = nn.BatchNorm2d(image_embed)
        self.bn2 = nn.BatchNorm2d(image_embed)
        
        
    def forward(self, x, augmented_x):

        # extract feature representations of each modality
        img_f = self.img_enc(x)
        aug_f = self.img_enc(augmented_x) 

        img_f = rearrange(img_f, 'b c h w -> b (h w) c')
        aug_f = rearrange(aug_f, 'b c h w -> b (h w) c')

        # Getting Image and augmented image Embeddings (with same dimension)
        img_e = self.image_projection(img_f)
        aug_e = self.aug_projection(aug_f)
                
        # Calculating CLIP
        img_e_r = self.bn1(rearrange(img_e, 'b (h w) c -> b c h w', h=self.H)).permute(0, 2, 3, 1)
        aug_e_r = self.bn2(rearrange(aug_e, 'b (h w) c -> b c h w', h=self.H)).permute(0, 2, 3, 1)
        
        img_e_patch = self.to_patch_embedding_img(img_e_r) 
        aug_e_patch = self.to_patch_embedding_aug(aug_e_r) 
        
        img_e_norm = img_e_patch / img_e_patch.norm(dim=-1, keepdim=True)        
        aug_e_norm = aug_e_patch / aug_e_patch.norm(dim=-1, keepdim=True)
        
        clip_sim = (img_e_norm @ aug_e_norm.mT) / self.temperature
        img_e_sim = img_e_norm @ img_e_norm.mT
        aug_e_sim = aug_e_norm @ aug_e_norm.mT
        clip_targets = F.softmax((img_e_sim + aug_e_sim) / 2 * self.temperature, dim=-1)
        
        # Cross attention
        attn_1 = self.cross_attn(img_e*self.alpha, aug_e*0.8)
        attn_2 = self.cross_attn(aug_e*0.8, img_e*self.alpha)
        
        attn = attn_1 + attn_2
        
        _, edge1 = torch.max(attn, 1)
        attn_down = torchvision.transforms.functional.resize(attn, 256//self.beta, antialias=True)
        attn_up = torchvision.transforms.functional.resize(attn_down, 256, antialias=True)
        _, edge2 = torch.max(attn_up, 1)
        edge = edge1 - edge2

        return edge, attn, clip_sim, clip_targets


# Training

In [None]:
img_size = 256

In [None]:
for img_num, img_file in enumerate(img_data):
    
    ##### Read image #####
    image = read_image(img_file, img_size).to(device)

    ##### Laod Model #####
    model = Model(input_dim=3, image_embed=64, augmented_embed=64,
                  input_size=(img_size, img_size), temperature=5.0, dropout=0.1,
                  beta=16, alpha=3).to(device)
    model.train()

    ##### Setteings #####
    zero_img = torch.zeros(image.shape[2], image.shape[3]).to(device)
    
    loss_ce = torch.nn.CrossEntropyLoss()
    loss_s = torch.nn.L1Loss()
    
    optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9)
    label_colours = np.random.randint(255, size=(128, 3))
    
    
    jitter = T.ColorJitter(brightness=[1.4, 1.4], hue=[-0.06, -0.06])
    aug_img = jitter(image)
    aug_img = T.GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 5))(aug_img)
    aug_img = aug_img.to(device)
    
    ##### Training #####
    for batch_idx in range(args.maxIter):

        optimizer.zero_grad()
        edge, output, clip_logits, clip_targets = model(image, aug_img)
        
        ### Output
        output, clip_logits, clip_targets = output[0], clip_logits[0], clip_targets[0]        
        output = output.permute(1, 2, 0).contiguous().view(-1, args.nChannel*2)
                
        _, target = torch.max(output, 1)
        img_target = target.data.cpu().numpy()
        img_target_rgb = np.array([label_colours[c % args.nChannel] for c in img_target])
        img_target_rgb = img_target_rgb.reshape(image.shape[2], image.shape[3], image.shape[1]).astype(np.uint8)
        
        ### Cross-entropy loss function         
        loss_ce_value = args.loss_ce_coef * loss_ce(output, target)
        
        ### Boundary Loss
        loss_edge = args.loss_b_coef * loss_s(edge[0], zero_img)  
        
        ### CLIP loss 
        aug_loss = cross_entropy(clip_logits, clip_targets, 'mean')
        img_loss = cross_entropy(clip_logits.T, clip_targets.T, 'mean')
        loss_clip = args.loss_clip_coef * ((img_loss + aug_loss) / 2.0)
        
        ### Optimization        
        loss = loss_ce_value + loss_clip + loss_edge
        loss.backward()
        optimizer.step()
        
        
        nLabels = len(np.unique(img_target))
        print(batch_idx, '/', args.maxIter, '|', ' label num:', nLabels, ' | loss:', round(loss.item(), 4),
                '| CE:', round(loss_ce_value.item(), 4), '| CLIP:', round(loss_clip.item(), 4),
                '| B:', round(loss_edge.item(), 4))
            
        if nLabels <= args.minLabels and batch_idx>=5:
            print (f"Number of labels have reached {nLabels}")
            break
        

    ##### Evaluate #####
    edge, output, _, _ = model(image, aug_img)
    output = output[0].permute(1, 2, 0).contiguous().view(-1, args.nChannel*2)
    _, target = torch.max(output, 1)
    img_target = target.data.cpu().numpy()
    img_eval_output = np.array([label_colours[c % args.nChannel] for c in img_target])
    img_eval_output = img_eval_output.reshape(image.shape[2], image.shape[3], image.shape[1]).astype(np.uint8)
    
    
    ##### Visualization #####
    fig, axes = plt.subplots(1, 4, figsize=(8, 8))
    axes[0].imshow(img_eval_output)
    axes[1].imshow(image[0].permute(1, 2, 0).cpu().detach().numpy()[..., ::-1])
    axes[2].imshow(aug_img[0].permute(1, 2, 0).cpu().detach().numpy()[...,::-1])
    axes[3].imshow(edge[0].cpu().detach().numpy())
    axes[0].set_title('Prediction')
    axes[1].set_title('Input Image')
    axes[2].set_title('Augmented Image')
    axes[3].set_title('Edge SR')    
    axes[0].axis('off')
    axes[1].axis('off')
    axes[2].axis('off')
    axes[3].axis('off')
    plt.show()
    
    if args.save_output:
        name = os.path.basename(img_file).split('.')[0]
        cv2.imwrite(SAVE_PATH + '/FuseNet_mask_' + name + '.png', img_eval_output)
        cv2.imwrite(SAVE_PATH + '/FuseNet_img_' + name + '.png', image[0].permute(1, 2, 0).cpu().detach().numpy()*255)
        cv2.imwrite(SAVE_PATH + '/FuseNet_aug_' + name + '.png', aug_img[0].permute(1, 2, 0).cpu().detach().numpy()*255)
        
    print('-------------------------------', '\n')