In [None]:
import torch 
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms.functional as TF
import torchvision.transforms as T
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import CosineAnnealingLR, ReduceLROnPlateau
import math


import os
import glob
from PIL import Image
import numpy as np
import pandas as pd
from natsort import natsorted
from tqdm import tqdm

import matplotlib.pyplot as plt

# Complex Utils

In [None]:
"""
This part of the code is taken from the following GitHub Repo: 
https://github.com/saurabhya/FCCNs

"""

def apply_complex(fr, fi, input, dtype= torch.complex64):
    return (fr(input.real) - fi(input.imag)) + 1j * (fr(input.imag) + fi(input.real))



class ComplexConv2d(nn.Module):
    def __init__(self,
        in_channels,
        out_channels,
        kernel_size,
        stride=1,
        padding=1,
        bias: bool = False,
        complex_axis= 1,
        device= None,
        dtype= None
        ) -> None:
        super().__init__()

        self.conv_real = nn.Conv2d(in_channels, out_channels, kernel_size= kernel_size, stride= stride, padding= padding,  bias= bias)
        self.conv_imag = nn.Conv2d(in_channels, out_channels, kernel_size= kernel_size, stride= stride, padding= padding,  bias= bias)



    def forward(self, x):
        ''' define how the forward prop will take place '''
        # check if the input is of dtype complex
        # for this we can use is_complex() function which will return true if the input is complex dtype
        if not x.is_complex():
            raise ValueError(f"Input should be a complex tensor. Got {x.dtype}")

        return apply_complex(self.conv_real, self.conv_imag, x)
    

class ComplexTranspose2d(nn.Module):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size,
        stride=1,
        padding=0,
        output_padding=0,
        bias: bool= False,
        device= None,
        dtype= None
    ):
        super().__init__()

        self.trans_conv_real = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride= stride, padding= padding, output_padding= output_padding,  bias= bias)
        self.trans_conv_imag = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride= stride, padding= padding, output_padding= output_padding,  bias= bias)

    def forward(self, x):
        ''' define how the forward prop will take place '''
        # check if the input is of dtype complex
        if not x.is_complex():
            raise ValueError(f"Input should be a complex tensor. Got {x.dtype}")

        return apply_complex(self.trans_conv_real, self.trans_conv_imag, x)
    

class ComplexMaxPool2d(nn.Module):

    def __init__(self, kernel_size, stride= 2, padding= 0, dilation=(1,1), return_indices= False, ceil_mode= False):
        super().__init__()

        self.kernel_size= kernel_size
        self.stride = stride
        self.padding= padding
        self.dilation = dilation
        self.ceil_mode= ceil_mode
        self.return_indices= return_indices

        self.max_pool = nn.MaxPool2d(self.kernel_size, self.stride, self.padding, self.dilation, 
                                     self.return_indices, self.ceil_mode)

    def forward(self, x):

        # check if the input is complex
        if not x.is_complex():
            raise ValueError(f"Input should be a complex tensor, Got {x.dtype}")

        return (self.max_pool(x.real)) + 1j * (self.max_pool(x.imag))
    

class CReLU(nn.Module):
    def __init__(self):
        super().__init__()
        self.real_coeff = nn.Parameter(torch.tensor(1.0))
        self.imag_coeff = nn.Parameter(torch.tensor(1.0))

    def forward(self, x):
        
        # real_part = F.relu(self.real_coeff*F.relu(x.real)) # Uncomment for normal use
        # imag_part = F.relu(self.imag_coeff*F.relu(x.imag))
        real_part = F.relu(self.real_coeff*x.real)
        imag_part = F.relu(self.imag_coeff*x.imag)
        return real_part + 1j * imag_part
    


class Naive_ComplexSigmoid(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return F.sigmoid(x.real) + 1j * F.sigmoid(x.imag)
    




class ComplexBatchNorm2d(torch.nn.Module):
    def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True,
            track_running_stats=True, complex_axis=1):
        super().__init__()
        self.num_features        = num_features
        self.eps                 = eps
        self.momentum            = momentum
        self.affine              = affine
        self.track_running_stats = track_running_stats

        self.complex_axis = complex_axis

        if self.affine:
            self.Wrr = torch.nn.Parameter(torch.Tensor(self.num_features))
            self.Wri = torch.nn.Parameter(torch.Tensor(self.num_features))
            self.Wii = torch.nn.Parameter(torch.Tensor(self.num_features))
            self.Br  = torch.nn.Parameter(torch.Tensor(self.num_features))
            self.Bi  = torch.nn.Parameter(torch.Tensor(self.num_features))
        else:
            self.register_parameter('Wrr', None)
            self.register_parameter('Wri', None)
            self.register_parameter('Wii', None)
            self.register_parameter('Br',  None)
            self.register_parameter('Bi',  None)

        if self.track_running_stats:
            self.register_buffer('RMr',  torch.zeros(self.num_features))
            self.register_buffer('RMi',  torch.zeros(self.num_features))
            self.register_buffer('RVrr', torch.ones (self.num_features))
            self.register_buffer('RVri', torch.zeros(self.num_features))
            self.register_buffer('RVii', torch.ones (self.num_features))
            self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long))
        else:
            self.register_parameter('RMr',                 None)
            self.register_parameter('RMi',                 None)
            self.register_parameter('RVrr',                None)
            self.register_parameter('RVri',                None)
            self.register_parameter('RVii',                None)
            self.register_parameter('num_batches_tracked', None)
        self.reset_parameters()

    def reset_running_stats(self):
        if self.track_running_stats:
            self.RMr.zero_()
            self.RMi.zero_()
            self.RVrr.fill_(1)
            self.RVri.zero_()
            self.RVii.fill_(1)
            self.num_batches_tracked.zero_()

    def reset_parameters(self):
        self.reset_running_stats()
        if self.affine:
            self.Br.data.zero_()
            self.Bi.data.zero_()
            self.Wrr.data.fill_(1)
            self.Wri.data.uniform_(-.9, +.9) # W will be positive-definite
            self.Wii.data.fill_(1)

    def _check_input_dim(self, xr, xi):
        assert(xr.shape == xi.shape)
        assert(xr.size(1) == self.num_features)

    def forward(self, inputs):
        #self._check_input_dim(xr, xi)

        # xr, xi = torch.chunk(inputs,2, axis=self.complex_axis)
        xr, xi = inputs.real, inputs.imag
        exponential_average_factor = 0.0

        if self.training and self.track_running_stats:
            self.num_batches_tracked += 1
            if self.momentum is None:  # use cumulative moving average
                exponential_average_factor = 1.0 / self.num_batches_tracked.item()
            else:  # use exponential moving average
                exponential_average_factor = self.momentum

        #
        # NOTE: The precise meaning of the "training flag" is:
        #       True:  Normalize using batch   statistics, update running statistics
        #              if they are being collected.
        #       False: Normalize using running statistics, ignore batch   statistics.
        #
        training = self.training or not self.track_running_stats
        redux = [i for i in reversed(range(xr.dim())) if i!=1]
        vdim  = [1] * xr.dim()
        vdim[1] = xr.size(1)

        #
        # Mean M Computation and Centering
        #
        # Includes running mean update if training and running.
        #
        if training:
            Mr, Mi = xr, xi
            for d in redux:
                Mr = Mr.mean(d, keepdim=True)
                Mi = Mi.mean(d, keepdim=True)
            if self.track_running_stats:
                self.RMr.lerp_(Mr.squeeze(), exponential_average_factor)
                self.RMi.lerp_(Mi.squeeze(), exponential_average_factor)
        else:
            Mr = self.RMr.view(vdim)
            Mi = self.RMi.view(vdim)
        xr, xi = xr-Mr, xi-Mi

        #
        # Variance Matrix V Computation
        #
        # Includes epsilon numerical stabilizer/Tikhonov regularizer.
        # Includes running variance update if training and running.
        #
        if training:
            Vrr = xr * xr
            Vri = xr * xi
            Vii = xi * xi
            for d in redux:
                Vrr = Vrr.mean(d, keepdim=True)
                Vri = Vri.mean(d, keepdim=True)
                Vii = Vii.mean(d, keepdim=True)
            if self.track_running_stats:
                self.RVrr.lerp_(Vrr.squeeze(), exponential_average_factor)
                self.RVri.lerp_(Vri.squeeze(), exponential_average_factor)
                self.RVii.lerp_(Vii.squeeze(), exponential_average_factor)
        else:
            Vrr = self.RVrr.view(vdim)
            Vri = self.RVri.view(vdim)
            Vii = self.RVii.view(vdim)
        Vrr   = Vrr + self.eps
        Vri   = Vri
        Vii   = Vii + self.eps

        #
        # Matrix Inverse Square Root U = V^-0.5
        #
        # sqrt of a 2x2 matrix,
        # - https://en.wikipedia.org/wiki/Square_root_of_a_2_by_2_matrix
        tau   = Vrr + Vii
        # delta = torch.addcmul(Vrr * Vii, -1, Vri, Vri)
        delta = torch.addcmul(Vrr * Vii, Vri, Vri, value= -1)
        s     = delta.sqrt()
        t     = (tau + 2*s).sqrt()

        # matrix inverse, http://mathworld.wolfram.com/MatrixInverse.html
        rst   = (s * t).reciprocal()
        Urr   = (s + Vii) * rst
        Uii   = (s + Vrr) * rst
        Uri   = (  - Vri) * rst

        #
        # Optionally left-multiply U by affine weights W to produce combined
        # weights Z, left-multiply the inputs by Z, then optionally bias them.
        #
        # y = Zx + B
        # y = WUx + B
        # y = [Wrr Wri][Urr Uri] [xr] + [Br]
        #     [Wir Wii][Uir Uii] [xi]   [Bi]
        #
        if self.affine:
            Wrr, Wri, Wii = self.Wrr.view(vdim), self.Wri.view(vdim), self.Wii.view(vdim)
            Zrr = (Wrr * Urr) + (Wri * Uri)
            Zri = (Wrr * Uri) + (Wri * Uii)
            Zir = (Wri * Urr) + (Wii * Uri)
            Zii = (Wri * Uri) + (Wii * Uii)
        else:
            Zrr, Zri, Zir, Zii = Urr, Uri, Uri, Uii

        yr = (Zrr * xr) + (Zri * xi)
        yi = (Zir * xr) + (Zii * xi)

        if self.affine:
            yr = yr + self.Br.view(vdim)
            yi = yi + self.Bi.view(vdim)

        return (yr) + 1j * (yi)

# Complex UNET

In [None]:
"""
Implemented architecture following: https://www.youtube.com/watch?v=IHq1t7NxS8k

"""


class DoubleConv(nn.Module):

    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            ComplexConv2d(in_channels, out_channels, kernel_size=3, bias=False),
            ComplexBatchNorm2d(out_channels),
            CReLU(),
            ComplexConv2d(out_channels, out_channels, kernel_size=3, bias=False),
            ComplexBatchNorm2d(out_channels),
            CReLU(),
        )

    def forward(self, x):
        return self.conv(x)
    

class UNET(nn.Module):
    def __init__(self, in_channels=3, out_channels=3, features=[64, 128, 256, 512]):
        super(UNET, self).__init__()
        self.ups=nn.ModuleList()
        self.downs=nn.ModuleList()
        self.pool=ComplexMaxPool2d(kernel_size=2, stride=2)

        for feature in features:
            self.downs.append(DoubleConv(in_channels, feature))
            in_channels=feature

        
        for feature in reversed(features):
            self.ups.append(
                ComplexTranspose2d(
                feature*2, feature, kernel_size=2, stride=2))
            self.ups.append(DoubleConv(feature*2, feature))
        

        self.bottleneck=DoubleConv(features[-1], features[-1]*2)
        self.final_conv=ComplexConv2d(features[0], out_channels, kernel_size=1, padding=0)

    def forward(self, x):
        skip_connections = []
        for down in self.downs:
            x=down(x)
            skip_connections.append(x)
            x=self.pool(x)

        x = self.bottleneck(x)
        skip_connections = skip_connections[::-1] # reverses order

        for idx in range(0, len(self.ups), 2):
            x=self.ups[idx](x)
            skip_connection=skip_connections[idx//2]

            # if x.shape != skip_connection.shape:
            #     x=TF.resize(x, size=skip_connection.shape[2:])

            concat_skip=torch.cat((skip_connection, x), dim=1)
            x=self.ups[idx+1](concat_skip)
        
        return self.final_conv(x)

# Dataset 

In [None]:
class CamVid_Simple(Dataset):
    def __init__(self, path):
        super().__init__()
        self.path = path
        self.files = glob.glob(path+'/*.pth')
        

    def __len__(self):
        return len(self.files)

    def __getitem__(self, index):
        data = torch.load(self.files[index])
        return data['img'], data['mask']   

# Loss

In [None]:
"""
PyTorch implementation for first loss in the write up

"""
class Complex_CCELoss(nn.Module):
    def __init__(self, lambda_phase=0.2):
        super(Complex_CCELoss, self).__init__()
        self.criterion = nn.CrossEntropyLoss()
        self.lambda_phase = lambda_phase

    def forward(self, z, target):
        

        l_real = self.criterion(torch.abs(z), target)

        phase_select = torch.gather(torch.angle(z), dim=1, index = target.unsqueeze(1)).squeeze(1)
        
        l_phase = (1 - torch.cos(phase_select)).mean()  

        return l_real + self.lambda_phase * l_phase

# Train Loops

## Using CE + Phase Loss

In [None]:
num_classes = 32
batch_size = 8
num_epochs = 50
lr = 1e-3
fpath = 'Dataset/Complex_CamVid_iHSV' # change if required
os.makedirs('Models', exist_ok=True)


if torch.cuda.is_available():
    
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

model  = UNET(in_channels = 3, out_channels=num_classes)
model  = model.to(device)
model.train()

traindataset = CamVid_Simple(path = f'{fpath}/train')

trainloader = DataLoader(traindataset, batch_size=batch_size, shuffle=True)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
scheduler = CosineAnnealingLR(optimizer, num_epochs*len(trainloader), eta_min=1e-5)

criterion = Complex_CCELoss()



for epoch in range(num_epochs):
    loss_avg = 0
    pbar = tqdm(trainloader)
    for batch_idx, (image, mask) in enumerate(pbar):
        image, mask = image.to(device), mask.to(device)
        output = model(image)
        loss = criterion(output, mask)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()

        loss_avg+=loss.item()

        descrip = {
            'Epoch': epoch+1,
            'Loss': loss_avg/(batch_idx+1),
            'lr': optimizer.param_groups[0]['lr']
        }

        pbar.set_postfix(descrip)
    pbar.close()
    torch.save(model.state_dict(), f'complex_model_l1.pth') # Change name accordingly


    

### Evaluate First Model

In [None]:
from torchmetrics.classification import JaccardIndex
testdataset = CamVid_Simple(path = f'{fpath}/test')
num_classes = 32
batch_size = 8
torch.cuda.empty_cache()
metric = JaccardIndex(task='multiclass', num_classes=32)
testloader = DataLoader(testdataset, batch_size=batch_size, shuffle=False)
complex_model.eval()
with torch.no_grad():
    for batch_idx,(img, target) in enumerate(tqdm(testloader)):
        img = img.to(device)
        output  = complex_model(img)
        # output = output.real*output.imag
        output = torch.abs(output)
        output = torch.argmax(output, dim=1)
        metric.update(output.cpu(), target)

    iou = metric.compute()
    print('Jaccard index Score: ', iou)

    metric.reset()

## Using CE only

In [None]:
num_classes = 32
batch_size = 8
num_epochs = 50
lr = 1e-3
device = torch.device('cuda')

fpath = 'Dataset/Complex_CamVid_iHSV' # change if required
os.makedirs('Models', exist_ok=True)

model  = UNET(in_channels = 3, out_channels=num_classes)
model = model.to(device)
model.train()

traindataset = CamVid_Simple(path = f'{fpath}/train')

trainloader = DataLoader(traindataset, batch_size=batch_size, shuffle=True)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
scheduler = CosineAnnealingLR(optimizer, num_epochs*len(trainloader), eta_min=1e-5)

criterion = nn.CrossEntropyLoss()



for epoch in range(num_epochs):
    loss_avg = 0
    pbar = tqdm(trainloader)
    for batch_idx, (image, mask) in enumerate(pbar):
        image, mask = image.to(device), mask.to(device)
        output = model(image)
        output = (output.real*output.imag).float()
        
        loss = criterion(output, mask)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()

        loss_avg+=loss.item()

        descrip = {
            'Epoch': epoch+1,
            'Loss': loss_avg/(batch_idx+1),
            'lr': optimizer.param_groups[0]['lr']
        }

        pbar.set_postfix(descrip)
    pbar.close()
    torch.save(model.state_dict(), f'complex_model.pth')

### Evaluate 2nd Model

In [None]:
from torchmetrics.classification import JaccardIndex
testdataset = CamVid_Simple(path = f'{fpath}/test')
num_classes = 32
batch_size = 8
torch.cuda.empty_cache()
metric = JaccardIndex(task='multiclass', num_classes=32)
testloader = DataLoader(testdataset, batch_size=batch_size, shuffle=False)
device = torch.device('cuda')
model.eval()
with torch.no_grad():
    for batch_idx,(img, target) in enumerate(tqdm(testloader)):
        img = img.to(device)
        output  = model(img)
        output = output.real*output.imag
        output = torch.argmax(output, dim=1)
        metric.update(output.cpu(), target)

    iou = metric.compute()
    print('Jaccard index Score: ', iou)

    metric.reset()