In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
!pip install torch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 --index-url https://download.pytorch.org/whl/cu121
!pip install -U openmim
!mim install mmcv==1.4.0
!pip install pytorch_msssim
!pip install tensorboardX
!pip install thop

Looking in indexes: https://download.pytorch.org/whl/cu121
Looking in links: https://download.openmmlab.com/mmcv/dist/cu121/torch2.1.0/index.html


In [None]:
import torch
import torch.nn as nn
import torch.utils.checkpoint as checkpoint
import torch.nn.init as init

device = 'mps' if torch.backends.mps.is_available() else ('cuda' if torch.cuda.is_available() else 'cpu')
print("Using device = " + device)
if device == 'cpu':
    print("WARNING: Using CPU will cause slower train times")

Using device = cuda


In [None]:
%cd /content/drive/MyDrive/Denoiser/

[Errno 2] No such file or directory: '/content/drive/MyDrive/Denoiser/'
/content


In [None]:
def kaiming_init(module, mode='fan_in', nonlinearity='relu'):
    if isinstance(module, nn.Conv2d):
        init.kaiming_uniform_(module.weight, mode=mode, nonlinearity=nonlinearity)
        if module.bias is not None:
            init.zeros_(module.bias)

def constant_init(module, val=0):
    if hasattr(module, 'weight') and module.weight is not None:
        init.constant_(module.weight, val)
    if hasattr(module, 'bias') and module.bias is not None:
        init.constant_(module.bias, val)

def last_zero_init(module):
    if isinstance(module, nn.Sequential) and len(module) > 0:
        constant_init(module[-1], val=0)
    elif module is not None:
        constant_init(module, val=0)


In [None]:
class ContextBlock(nn.Module):

    def __init__(self,
                 inplanes,
                 ratio,
                 pooling_type='att',
                 fusion_types=('channel_add', )):
        super(ContextBlock, self).__init__()
        assert pooling_type in ['avg', 'att']
        assert isinstance(fusion_types, (list, tuple))
        valid_fusion_types = ['channel_add', 'channel_mul']
        assert all([f in valid_fusion_types for f in fusion_types])
        assert len(fusion_types) > 0, 'at least one fusion should be used'
        # inplanes = 96                 # TEMP HARDCODING INPLANES
        self.inplanes = inplanes
        self.ratio = ratio
        self.planes = int(inplanes * ratio)
        self.pooling_type = pooling_type
        self.fusion_types = fusion_types
        if pooling_type == 'att':
            print('INPLANES:', inplanes)
            self.conv_mask = nn.Conv2d(inplanes, 1, kernel_size=1)
            self.softmax = nn.Softmax(dim=2)
        else:
            self.avg_pool = nn.AdaptiveAvgPool2d(1)
        if 'channel_add' in fusion_types:
            self.channel_add_conv = nn.Sequential(
                nn.Conv2d(self.inplanes, self.planes, kernel_size=1),
                nn.LayerNorm([self.planes, 1, 1]),
                nn.ReLU(inplace=True),  # yapf: disable
                nn.Conv2d(self.planes, self.inplanes, kernel_size=1))
        else:
            self.channel_add_conv = None
        if 'channel_mul' in fusion_types:
            self.channel_mul_conv = nn.Sequential(
                nn.Conv2d(self.inplanes, self.planes, kernel_size=1),
                nn.LayerNorm([self.planes, 1, 1]),
                nn.ReLU(inplace=True),  # yapf: disable
                nn.Conv2d(self.planes, self.inplanes, kernel_size=1))
        else:
            self.channel_mul_conv = None
        self.reset_parameters()

    def reset_parameters(self):
        if self.pooling_type == 'att':
            kaiming_init(self.conv_mask, mode='fan_in')
            self.conv_mask.inited = True

        if self.channel_add_conv is not None:
            last_zero_init(self.channel_add_conv)
        if self.channel_mul_conv is not None:
            last_zero_init(self.channel_mul_conv)

    def spatial_pool(self, x):
        print('gc spatial pooling. xsize:', x.size())
        #x = x.permute(2,0,1)
        batch, channel, height, width = x.size()
        L = height * width
        #batch, channel, L = x.size()
        if self.pooling_type == 'att':
            print('self.pooling_type==att. inside if statement')
            input_x = x
            # [N, C, H * W]
            input_x = input_x.view(batch, channel, L)
            # [N, 1, C, H * W]
            input_x = input_x.unsqueeze(1)
            # [N, 1, H, W]
            context_mask = self.conv_mask(x)
            # [N, 1, H * W]
            context_mask = context_mask.view(batch, 1, L)
            # [N, 1, H * W]
            context_mask = self.softmax(context_mask)
            # [N, 1, H * W, 1]
            context_mask = context_mask.unsqueeze(-1)
            # [N, 1, C, 1]
            context = torch.matmul(input_x, context_mask)
            # [N, C, 1, 1]
            context = context.view(batch, channel, 1, 1)
        else:
            # [N, C, 1, 1]
            context = self.avg_pool(x)

        return context

    def forward(self, x):
        print('gc forward.')
        # [N, C, 1, 1]
        context = self.spatial_pool(x)

        out = x
        if self.channel_mul_conv is not None:
            # [N, C, 1, 1]
            channel_mul_term = torch.sigmoid(self.channel_mul_conv(context))
            out = out * channel_mul_term
        if self.channel_add_conv is not None:
            # [N, C, 1, 1]
            channel_add_term = self.channel_add_conv(context)
            out = out + channel_add_term

        return out

In [None]:
# from model.SUNet_detail_gcn import ContextBlock
context_block = ContextBlock(
    inplanes= 64,
    ratio=0.25,
    pooling_type='att',
    fusion_types=('channel_add', 'channel_mul')
)
batch_size = 4
channels = 64
height, width = 32, 32
input_tensor = torch.rand(batch_size, channels, height, width)
output_tensor = context_block(input_tensor)
print("Input shape:", input_tensor.shape)
print("Output shape:", output_tensor.shape)
# print("Diff:", output_tensor[0] - input_tensor[0])

INPLANES: 64
gc forward.
gc spatial pooling. xsize: torch.Size([4, 64, 32, 32])
self.pooling_type==att. inside if statement
Input shape: torch.Size([4, 64, 32, 32])
Output shape: torch.Size([4, 64, 32, 32])


In [None]:
class StridedConvolutionDownsampling(nn.Module):
    def __init__(self, in_channels, out_channels, norm_layer=nn.BatchNorm2d, stride = 2, kernel_size = 3, padding = 1):
        super().__init__()
        self.conv_down = nn.Conv2d(
            in_channels, out_channels, kernel_size=kernel_size, stride= stride, padding=padding
        )
        self.norm = norm_layer(out_channels)
        self.activation = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv_down(x)
        x = self.norm(x)
        x = self.activation(x)
        return x

In [None]:
# Parameters
in_channels = 30
out_channels = 50
input_height, input_width = 32, 32  # Input dimensions (HxW)

# Create the module
model = StridedConvolutionDownsampling(in_channels, out_channels)

# Create a random input tensor with shape (batch_size, channels, height, width)
batch_size = 8
input_tensor = torch.randn(batch_size, in_channels, input_height, input_width)
print("Input shape:", input_tensor.shape)
# Forward pass
output = model(input_tensor)

# Expected output shape
stride = 2
expected_height = input_height // stride  # Stride=2 halves the spatial dimensions
expected_width = input_width // stride
expected_shape = (batch_size, out_channels, expected_height, expected_width)

# Test the output shape
assert output.shape == expected_shape, f"Expected {expected_shape}, but got {output.shape}"
print("Test passed! Output shape:", output.shape)


Input shape: torch.Size([8, 30, 32, 32])
Test passed! Output shape: torch.Size([8, 50, 16, 16])


In [None]:
# global context basic layer
class GlobalContextBasicLayer(nn.Module):
    def __init__(self, dim, ratio, pooling_type, fusion_types,
                 depth, input_resolution, norm_layer=nn.LayerNorm,
                 downsample= None, use_checkpoint=False, stride = 2, out_channels = 50):
        print('we making a gc basic layer')
        super().__init__()
        self.dim = dim
        self.ratio = ratio
        self.pooling_type = pooling_type
        self.fusion_types = fusion_types
        self.depth = depth
        self.use_checkpoint = use_checkpoint
        self.stride = stride
        self.out_channels = out_channels

        # create a series of ContextBlocks
        self.blocks = nn.ModuleList([
            ContextBlock(inplanes=dim, ratio=ratio, pooling_type=pooling_type,
                         fusion_types=fusion_types) for _ in range(depth)
        ])
        print("The input_resolution is:", input_resolution)
        print("After context blocks, the shape is:", ())
        # downsample/patch merging layer
        if downsample is not None:
            self.downsample = StridedConvolutionDownsampling(in_channels = dim, out_channels= out_channels, norm_layer= nn.BatchNorm2d, stride = stride)
        else:
            self.downsample = None

    def forward(self, x):
        print('we going forward in a gc basic layer')
        for blk in self.blocks:
            if self.use_checkpoint:
                x = checkpoint.checkpoint(blk, x)
            else:
                x = blk(x)
        if self.downsample is not None:
            x = self.downsample(x)
        return x

In [None]:
# global context basic layer test
layer = GlobalContextBasicLayer(
    dim=30,
    ratio=0.25,
    pooling_type='avg',
    fusion_types=('channel_add',),
    depth=3,
    input_resolution=(32, 32),
    norm_layer=nn.LayerNorm,
    downsample = StridedConvolutionDownsampling,
    # downsample= None,
    use_checkpoint=False
)

# Input tensor: batch size = 8, channels = 30, height = 32, width = 32
x = torch.randn(8, 30, 32, 32)
output = layer(x)
# output = layer(output)
# output = layer(output)

print("Output shape:", output.shape)

we making a gc basic layer
The input_resolution is: (32, 32)
After context blocks, the shape is: ()
we going forward in a gc basic layer
gc forward.
gc spatial pooling. xsize: torch.Size([8, 30, 32, 32])
gc forward.
gc spatial pooling. xsize: torch.Size([8, 30, 32, 32])
gc forward.
gc spatial pooling. xsize: torch.Size([8, 30, 32, 32])
Output shape: torch.Size([8, 50, 16, 16])


In [None]:
# # Dual up-sample
# class UpSample(nn.Module):
#     def __init__(self, input_resolution, in_channels, scale_factor):
#         super(UpSample, self).__init__()
#         print('dual upsample')
#         self.input_resolution = input_resolution
#         self.factor = scale_factor


#         if self.factor == 2:
#             self.conv = nn.Conv2d(in_channels, in_channels//2, 1, 1, 0, bias=False)
#             self.up_p = nn.Sequential(nn.Conv2d(in_channels, 2*in_channels, 1, 1, 0, bias=False),
#                                       nn.PReLU(),
#                                       nn.PixelShuffle(scale_factor),
#                                       nn.Conv2d(in_channels//2, in_channels//2, 1, stride=1, padding=0, bias=False))

#             self.up_b = nn.Sequential(nn.Conv2d(in_channels, in_channels, 1, 1, 0),
#                                       nn.PReLU(),
#                                       nn.Upsample(scale_factor=scale_factor, mode='bilinear', align_corners=False),
#                                       nn.Conv2d(in_channels, in_channels // 2, 1, stride=1, padding=0, bias=False))
#         elif self.factor == 4:
#             self.conv = nn.Conv2d(2*in_channels, in_channels, 1, 1, 0, bias=False)
#             self.up_p = nn.Sequential(nn.Conv2d(in_channels, 16 * in_channels, 1, 1, 0, bias=False),
#                                       nn.PReLU(),
#                                       nn.PixelShuffle(scale_factor),
#                                       nn.Conv2d(in_channels, in_channels, 1, stride=1, padding=0, bias=False))

#             self.up_b = nn.Sequential(nn.Conv2d(in_channels, in_channels, 1, 1, 0),
#                                       nn.PReLU(),
#                                       nn.Upsample(scale_factor=scale_factor, mode='bilinear', align_corners=False),
#                                       nn.Conv2d(in_channels, in_channels, 1, stride=1, padding=0, bias=False))

#     def forward(self, x):
#         """
#         x: B, L = H*W, C
#         """
#         print('dual upsample forward')
#         if type(self.input_resolution) == int:
#             H = self.input_resolution
#             W = self.input_resolution

#         elif type(self.input_resolution) == tuple:
#             H, W = self.input_resolution

#         B, L, C = x.shape
#         x = x.view(B, H, W, C)  # B, H, W, C
#         x = x.permute(0, 3, 1, 2)  # B, C, H, W
#         x_p = self.up_p(x)  # pixel shuffle
#         x_b = self.up_b(x)  # bilinear
#         out = self.conv(torch.cat([x_p, x_b], dim=1))
#         out = out.permute(0, 2, 3, 1)  # B, H, W, C
#         if self.factor == 2:
#             out = out.view(B, -1, C // 2)

#         return out

In [None]:
class UpSample(nn.Module):
    def __init__(self, input_resolution, in_channels, scale_factor):
        super(UpSample, self).__init__()
        self.scale_factor = scale_factor
        self.upsample = nn.Upsample(scale_factor=self.scale_factor, mode='bilinear', align_corners=False)

    def forward(self, x):
        # Ensure the input is in the (B, C, H, W) format
        if x.dim() != 4:
            raise ValueError(f"Expected input to have 4 dimensions (B, C, H, W), but got {x.shape}")
        # Apply upsampling directly
        return self.upsample(x)


In [None]:
# global context basic up layer
# from model.SUNet_detail_gcn import UpSample
class GlobalContextBasicUpLayer(nn.Module):
    def __init__(self, dim, ratio, pooling_type, fusion_types,
                 depth, input_resolution, norm_layer=nn.LayerNorm,
                 upsample=None, use_checkpoint=False):
        super().__init__()
        self.dim = dim
        self.ratio = ratio
        self.pooling_type = pooling_type
        self.fusion_types = fusion_types
        self.depth = depth
        self.use_checkpoint = use_checkpoint
        self.upsample = upsample

        # create a series of ContextBlocks
        self.blocks = nn.ModuleList([
            ContextBlock(inplanes=dim, ratio=ratio, pooling_type=pooling_type,
                         fusion_types=fusion_types) for _ in range(depth)
        ])

        # upsample
        if upsample is not None:
            self.upsample = UpSample(input_resolution, in_channels=dim, scale_factor=2)
        else:
            self.upsample = None

    def forward(self, x):
        for blk in self.blocks:
            if self.use_checkpoint:
                x = checkpoint.checkpoint(blk, x)
            else:
                x = blk(x)
            if self.upsample is not None:
                x = self.upsample(x)
        return x

In [None]:
# global context basic layer up test
layer = GlobalContextBasicUpLayer(
    dim=50,
    ratio=0.25,
    pooling_type='avg',
    fusion_types=('channel_add',),
    depth=3,
    input_resolution=(16, 16),
    upsample=None,
    use_checkpoint=False
)

# Input tensor: batch size = 8, channels = 50, height = 16, width = 16
x = torch.randn(8, 50, 16, 16)
output = layer(x)

print("Output shape:", output.shape)


gc forward.
gc spatial pooling. xsize: torch.Size([8, 50, 16, 16])
gc forward.
gc spatial pooling. xsize: torch.Size([8, 50, 16, 16])
gc forward.
gc spatial pooling. xsize: torch.Size([8, 50, 16, 16])
Output shape: torch.Size([8, 50, 16, 16])


In [None]:
# if self.final_upsample == "Dual up-sample":
up = UpSample(input_resolution=(16, 16),
                    in_channels= 50, scale_factor=4)
def test_upsample():
    # Parameters
    input_resolution = (16, 16)
    in_channels = 50
    scale_factor = 4
    batch_size = 2

    # Create a random input tensor with shape [batch_size, in_channels, height, width]
    input_tensor = torch.rand(batch_size, in_channels, input_resolution[0], input_resolution[1])

    # Initialize the UpSample layer
    upsample = UpSample(input_resolution=input_resolution, in_channels=in_channels, scale_factor=scale_factor)

    # Pass the input through the UpSample layer
    output_tensor = upsample(input_tensor)

    # Compute the expected output shape
    expected_height = input_resolution[0] * scale_factor
    expected_width = input_resolution[1] * scale_factor
    expected_shape = (batch_size, in_channels, expected_height, expected_width)

    # Assertions to verify correctness
    assert output_tensor.shape == expected_shape, f"Expected shape {expected_shape}, but got {output_tensor.shape}"
    print(f"Test passed! Output shape: {output_tensor.shape}")

# Run the test case
test_upsample()


Test passed! Output shape: torch.Size([2, 50, 64, 64])


In [None]:
# if self.final_upsample == "Dual up-sample":
up = UpSample(input_resolution=(16, 16),
                    in_channels= 50, scale_factor=4)

# Define the Conv2d layer
output_layer = nn.Conv2d(
    in_channels=50, out_channels=30, kernel_size=3, stride=1, padding=1, bias=False
)

# Create a dummy input tensor (batch_size, in_channels, height, width)
input_tensor = torch.randn(8, 50, 16, 16)  # Example shape: (batch size 1, 50 channels, 64x64)

# Pass the input tensor through the Conv2d layer
output_tensor = output_layer(input_tensor)

# Print the shape of the output tensor
print(output_tensor.shape)


dual upsample
torch.Size([8, 30, 16, 16])


In [None]:
class GlobalNet(nn.Module):
    def __init__(self, img_size=224, in_chans=3, out_chans=3,
                 norm_layer=nn.LayerNorm, ape=False,
                 use_checkpoint=False, final_upsample="Dual up-sample",
                 ratio=0.8, pooling_type='avg', gc_layer=3, downsample = False, stride = 2, downsample_out_chans = 6, **kwargs):
        super(GlobalNet, self).__init__()

        self.img_size = img_size
        self.in_chans = in_chans
        self.out_chans = out_chans
        self.ape = ape
        self.final_upsample = final_upsample
        self.use_checkpoint = use_checkpoint
        self.ratio = ratio
        self.pooling_type = pooling_type
        self.gc_layers = gc_layer
        self.downsample = downsample
        self.downsample_out_chans = downsample_out_chans
        self.stride = stride
        # Calculate embedding dimension based on image size and channels
        # self.embed_dim = in_chans * img_size * img_size
        # self.num_features = self.embed_dim

        # Build encoder and bottleneck layers
        self.layers = nn.ModuleList()
        # for i_layer in range(self.gc_layers):
        layer = GlobalContextBasicLayer(
            dim=self.in_chans,
            ratio=ratio,
            pooling_type=pooling_type,
            fusion_types=('channel_add',),
            depth=3,
            input_resolution=(img_size, img_size),
            norm_layer=norm_layer,
            downsample= StridedConvolutionDownsampling,
            use_checkpoint=self.use_checkpoint,
            stride = self.stride,
            out_channels=self.downsample_out_chans
        )
        self.layers.append(layer)
        print("Finish Encoder Part!")

        # Build decoder layers
        self.layers_up = nn.ModuleList()
        # self.concat_back_dim = nn.ModuleList()
        # for i_layer in range(self.gc_layers):
        # if i_layer == 0:
        layer_up = UpSample(
            input_resolution=(img_size // self.stride, img_size // self.stride),
            in_channels= self.downsample_out_chans,
            scale_factor= stride
        )
        self.layers_up.append(layer_up)
        # else:
        layer_up = GlobalContextBasicUpLayer(
            dim= self.downsample_out_chans,
            ratio = ratio,
            pooling_type=pooling_type,
            fusion_types=('channel_add',),
            depth=3,
            input_resolution=(img_size, img_size),
            norm_layer=nn.LayerNorm,
            upsample=None,
            use_checkpoint=self.use_checkpoint
        )
        print("Finish Decoder Part!")

        self.layers_up.append(layer_up)
        # Normalization layers
        self.norm = norm_layer(self.downsample_out_chans)
        self.norm_up = norm_layer(self.downsample_out_chans)

        # # Final upsample layers
        if self.final_upsample == "Dual up-sample":
            self.up = UpSample(
                input_resolution=(img_size, img_size),
                in_channels=self.downsample_out_chans,
                scale_factor=2
            )
            self.output = nn.Conv2d(
                in_channels=self.downsample_out_chans,
                out_channels=self.out_chans,
                kernel_size=3,
                stride=1,
                padding=1,
                bias=False
            )


    def forward(self, x):
        # Encoder part
        for layer in self.layers:
            x = layer(x)  # Pass through each encoder layer
        print("Finish Forward Encoder Part!")

        # Decoder part
        for layer_up in self.layers_up:
            x = layer_up(x)  # Pass through each decoder layer
            print("Finish One round!")
        print("Finish Forward Decoder Part!")

        # Final upsample
        if self.final_upsample == "Dual up-sample":
            x = self.up(x)  # Perform the final upsampling
            x = self.output(x)  # Project to the desired output channels
        print("Finish Forward Upsample Part!")
        return x


In [None]:
# Instantiate the model
model = GlobalNet(img_size=224, in_chans=3, out_chans=3, stride=2)

# Create a dummy input tensor (batch_size, channels, height, width)
input_tensor = torch.randn(1, 3, 224, 224)

# Pass the input through the model
output_tensor = model(input_tensor)

# Print the output shape
print("Output shape:", output_tensor.shape)


we making a gc basic layer
The input_resolution is: (224, 224)
After context blocks, the shape is: ()
Finish Encoder Part!
Finish Decoder Part!
we going forward in a gc basic layer
gc forward.
gc spatial pooling. xsize: torch.Size([1, 3, 224, 224])
gc forward.
gc spatial pooling. xsize: torch.Size([1, 3, 224, 224])
gc forward.
gc spatial pooling. xsize: torch.Size([1, 3, 224, 224])
Finish Forward Encoder Part!
Finish One round!
gc forward.
gc spatial pooling. xsize: torch.Size([1, 6, 224, 224])
gc forward.
gc spatial pooling. xsize: torch.Size([1, 6, 224, 224])
gc forward.
gc spatial pooling. xsize: torch.Size([1, 6, 224, 224])
Finish One round!
Finish Forward Decoder Part!
Finish Forward Upsample Part!
Output shape: torch.Size([1, 3, 448, 448])
