In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torchvision.models import VGG


In [3]:
depth, channels, height, width = 8, 16, 64, 64  # Example input dimensions
input_video = torch.randn(depth, channels, height, width)  # Random low-light video tensor

In [None]:
class Encoder(nn.Module):
    def __init__(self, nf=8):  # nf is the channel multiplier
        super(Encoder, self).__init__()
        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv3d(3, nf, kernel_size=3, stride=2, padding=1),  # Conv Layer 0
            nn.LeakyReLU(0.2, inplace=True),
            nn.InstanceNorm3d(nf),

            nn.Conv3d(nf, nf * 2, kernel_size=3, stride=2, padding=1),  # Conv Layer 1
            nn.LeakyReLU(0.2, inplace=True),
            nn.InstanceNorm3d(nf * 2),

            nn.Conv3d(nf * 2, nf * 4, kernel_size=3, stride=2, padding=1),  # Conv Layer 2
            nn.LeakyReLU(0.2, inplace=True),
            nn.InstanceNorm3d(nf * 4),

            nn.Conv3d(nf * 4, nf * 8, kernel_size=3, stride=2, padding=1),  # Conv Layer 3
            nn.LeakyReLU(0.2, inplace=True),
            nn.InstanceNorm3d(nf * 8),
        )

    def forward(self, x):
        x = self.encoder(x)
        return x


In [6]:

# Define U-Net model for the illumination refinement
class VGG(nn.Module):
    def __init__(self):
        super(VGG, self).__init__()
        self.model = VGG()

    def forward(self, x):
        return self.model(x)

# Custom illumination enhancement module with iterative refinement
class IlluminationEnhancementModule(nn.Module):
    def __init__(self, T, encoder):
        super(IlluminationEnhancementModule, self).__init__()
        self.T = T  # Number of iterations
        self.VGG = VGG()  # U-Net for the iterative refinement
        self.encoder = Encoder()  # Pre-trained encoder for feature extraction

    def forward(self, Xt, Yt):
        # Latent representation z0 (initial illumination representation)
        z0 = self.encoder(Xt)
        zt = z0  # Initialize with z0

        # Iterate T times to refine the illumination
        for k in range(1, self.T + 1):
            alpha_k = k / self.T
            # Mix low-light and normal-light frames based on alpha_k
            zk = alpha_k * Yt + (1 - alpha_k) * Xt
            zt = self.VGG(zk)  # Refining the illumination at step k

        # Final refined illumination component zT
        return zt

In [7]:
# refinement layer
import torch
import torch.nn as nn

# Define the residual block
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        identity = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out += identity  # Skip connection
        out = self.relu(out)
        return out

# Define the function to concatenate and project
class Projector(nn.Module):
    def __init__(self, num_frames, input_channels=3, output_channels=64):
        super(Projector, self).__init__()
        self.num_frames = num_frames
        self.concat_conv = nn.Conv2d(input_channels * num_frames, output_channels, kernel_size=1)
        self.res_block = ResidualBlock(output_channels, output_channels)

    def forward(self, X):
        # X is of shape (batch_size, num_frames, height, width, channels)
        # Concatenate frames along the channel dimension
        B, k, H, W, C = X.shape
        X_concat = X.view(B, k * C, H, W)  # Resulting shape: (batch_size, 3k, H, W)

        # Project through 1x1 convolution to get embedding F0_t
        F0_t = self.concat_conv(X_concat)  # Shape: (batch_size, output_channels, H, W)

        # Pass through residual block
        F0_t = self.res_block(F0_t)  # Shape remains the same

        return F0_t

# Example usage
batch_size = 4
num_frames = 5
height, width = 64, 64
input_channels = 3
output_channels = 64

# Example input: (batch_size, num_frames, height, width, input_channels)
Xt = torch.randn(batch_size, num_frames, height, width, input_channels)

# Initialize the projector and run the forward pass
projector = Projector(num_frames=num_frames, input_channels=input_channels, output_channels=output_channels)
F0_t = projector(Xt)

print(F0_t.shape)  # Expected output: (batch_size, output_channels, height, width)


torch.Size([4, 64, 64, 64])


In [9]:
class FeatureExtractionBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(FeatureExtractionBlock, self).__init__()

        # First path: LayerNorm -> Depthwise Conv -> Channel Attention
        self.layer_norm1 = nn.LayerNorm([in_channels, 1, 1])
        self.depthwise_conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1, groups=in_channels)
        self.channel_attention = nn.Conv2d(in_channels, out_channels, kernel_size=1)  # Channel Attention

        # Second path: LayerNorm -> Standard Convolution
        self.layer_norm2 = nn.LayerNorm([in_channels, 1, 1])
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)

    def forward(self, x):
        # First path
        x1 = self.layer_norm1(x)
        x1 = self.depthwise_conv(x1)
        x1 = self.channel_attention(x1)
        # out = x1 +
        # Second path
        x2 = self.layer_norm2(x)
        x2 = self.conv(x2)

        # Combine both paths (skip connection)
        out = x1 + x2

        return out

class DownsampleBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DownsampleBlock, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1)

    def forward(self, x):
        return self.conv(x)


# Reflectance Estimation Module
class ReflectanceEstimationModule(nn.Module):
    def __init__(self, in_channels, num_stages=4):
        super(ReflectanceEstimationModule, self).__init__()

        self.num_stages = num_stages
        self.feature_blocks = nn.ModuleList()
        self.downsample_blocks = nn.ModuleList()

        # Define stages
        for i in range(num_stages):
            # Double channels after each stage (as shown in the diagram)
            in_c = in_channels * (2 ** i)
            out_c = in_channels * (2 ** (i + 1))

            # Feature extraction block and downsample block for each stage
            self.feature_blocks.append(FeatureExtractionBlock(in_c, out_c))
            self.downsample_blocks.append(DownsampleBlock(in_c, out_c))

    def forward(self, x):
        for i in range(self.num_stages):
            # Apply feature extraction
            x = self.feature_blocks[i](x)

            # Downsample
            x = self.downsample_blocks[i](x)

        # Final output from the last stage is the reflectance map Rt
        return x

class ReflectanceIlluminationAlignment(nn.Module):
    def __init__(self, reflectance_channels, illumination_channels, out_channels):
        super(ReflectanceIlluminationAlignment, self).__init__()

        # Alignment convolution layer
        self.alignment_conv = nn.Conv2d(reflectance_channels + illumination_channels, out_channels, kernel_size=3, padding=1)
    def forward(self, Rt, It):
      # Concatenate reflectance and illumination maps along the channel dimension
      aligned_input = torch.cat([Rt, It], dim=1)

      # Apply alignment convolution to fuse the inputs
      aligned_output = self.alignment_conv(aligned_input)


In [None]:
class Model(nn.Module):
      def __init__(self, in_channels, out_channels):
        super(Model, self).__init__()
        self.IllEnhance = IlluminationEnhancementModule()
        self.ResidualBlock = Projector()
        self.REM = ReflectanceEstimationModule()
        self.Alignment = ReflectanceIlluminationAlignment()

      def forward(self, x):
        Illcomp = self.IllEnhance(x)
        x1 = ResidualBlock(x)
        x1 = self.REM(x1)
        
        
        out = self.Alignment(x1, Illcomp)

        return out

