In [4]:
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision import transforms
#from model import PRIDNet  # Assuming PRIDNet is defined in model.py
from PIL import Image
import os

In [6]:
class ChannelAttentionModule(nn.Module):
  def __init__(self):
    super(ChannelAttentionModule, self).__init__()
    self.avg_pool = nn.AdaptiveAvgPool2d(1)
    self.max_pool = nn.AdaptiveMaxPool2d(1)
    self.fc = nn.Sequential(
        nn.Conv2d(2, 1, 1, bias=False),
        nn.Sigmoid()
    )

  def forward(self, x):
    avg_out = self.avg_pool(x)
    max_out = self.max_pool(x)
    out = torch.cat([avg_out, max_out], dim=1)
    out = self.fc(out)
    return x * out

class KernelSelectingModule(nn.Module):
  def __init__(self, in_channels, channels):
    super(KernelSelectingModule, self).__init__()
    self.conv = nn.Conv2d(in_channels, channels, kernel_size=1, bias=False)
    self.softmax = nn.Softmax(dim=1)

  def forward(self, x):
    out = self.conv(x)
    out = self.softmax(out)
    return out

class PRIDNet(nn.Module):
  def __init__(self, in_channels, out_channels):
    super(PRIDNet, self).__init__()

    # U-Net
    self.unet = nn.Sequential(
      # Encoder
      nn.Conv2d(in_channels, 32, kernel_size=3, padding=1),
      nn.ReLU(inplace=True),
      nn.Conv2d(32, 32, kernel_size=3, padding=1),
      nn.MaxPool2d(2, stride=2),

      nn.Conv2d(32, 64, kernel_size=3, padding=1),
      nn.ReLU(inplace=True),
      nn.Conv2d(64, 64, kernel_size=3, padding=1),
      nn.MaxPool2d(2, stride=2),

      nn.Conv2d(64, 128, kernel_size=3, padding=1),
      nn.ReLU(inplace=True),
      nn.Conv2d(128, 128, kernel_size=3, padding=1),
      nn.MaxPool2d(2, stride=2),

      nn.Conv2d(128, 256, kernel_size=3, padding=1),
      nn.ReLU(inplace=True),
      nn.Conv2d(256, 256, kernel_size=3, padding=1),

      # Decoder
      nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2),
      nn.ReLU(inplace=True),
      nn.Conv2d(128, 128, kernel_size=3, padding=1),
      nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2),
      nn.ReLU(inplace=True),
      nn.Conv2d(64, 64, kernel_size=3, padding=1),
      nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2),
      nn.ReLU(inplace=True),
      nn.Conv2d(32, 32, kernel_size=3, padding=1),
      nn.ConvTranspose2d(32, out_channels, kernel_size=2, stride=2),
      nn.ReLU(inplace=True),
      nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
    )

    # Channel Attention Module
    self.cam = ChannelAttentionModule()

    # Kernel Selecting Module
    self.ksm = KernelSelectingModule(32, 8)

  def forward(self, x):
    out = self.unet(x)

    # Kernel Selection
    # 1. Extract bottleneck features from the encoder (modify if needed)
    bottleneck = out  # Assuming bottleneck features are at the end of the encoder

    # 2. Apply Kernel Selecting Module
    kernels = self.ksm(bottleneck)  # Get weights for each kernel

    # 3. Expand kernels (assuming kernels are defined elsewhere)
    expanded_kernels = []  # Placeholder for expanded kernels
    for i in range(len(kernels[0])):  # Iterate through kernel weights
      # Assuming you have pre-defined kernels (kernel1, kernel2, ...)
      # Expand the weights from a 1x1 kernel to match the actual kernel size
      expanded_kernels.append(nn.functional.conv2d(kernel[i].unsqueeze(1), weight=torch.ones(1, 1, *kernel_size).cuda()))

    # 4. Apply weighted kernels (replace with your desired operation)
    weighted_features = torch.zeros_like(out)  # Placeholder for weighted features
    for i in range(len(expanded_kernels)):
      weighted_features += kernels[:, i, 0, 0].unsqueeze(1) * nn.functional.conv2d(out, expanded_kernels[i])

    # 5. Combine with original features (replace with your desired operation)
    out = self.cam(weighted_features + out)  # Apply CAM and combine

    # Rest of the decoder (same as before)
    # ... your decoder architecture here ...

    return out

In [20]:
def calculate_psnr(denoised, target):
  """
  Calculates Peak Signal-to-Noise Ratio (PSNR) between two images.
  Args:
    denoised: Denoised image tensor.
    target: Clean image tensor.
  Returns:
    PSNR value in dB.
  """
  mse = nn.functional.mse_loss(denoised, target)
  max_pixel = 1.0  # Assuming pixel values are normalized between 0 and 1
  psnr = 10 * torch.log10(max_pixel**2 / mse)
  return psnr.item()

def load_data(noisy_dir, good_dir, size=128):
  """
  Loads and preprocesses noisy and clean image data from directories.
  Args:
    noisy_dir: Path to the directory containing noisy images.
    good_dir: Path to the directory containing clean (good) images.
    size: Target size for image resizing (default: 128).
  Returns:
    A list of tuples (noisy_image, clean_image).
  """
  data = []
  transform = transforms.Compose([
      transforms.Resize(size),  # Resize to desired size
      transforms.ToTensor(),
      transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # Normalize to [0, 1]
  ])
  for filename in os.listdir(noisy_dir):
    noisy_path = os.path.join(noisy_dir, filename)
    clean_path = os.path.join(good_dir, filename)
    if os.path.isfile(noisy_path) and os.path.isfile(clean_path):
      noisy_image = transform(Image.open(noisy_path).resize((size, size)).convert('RGB'))
      clean_image = transform(Image.open(clean_path).resize((size, size)).convert('RGB'))
      data.append((noisy_image, clean_image))
    if (len(data) >= 50):
      break
  return data


In [25]:
train_data = load_data("data\\train_noisy", "data\\train", 256)

train_loader = DataLoader(train_data, batch_size=4, shuffle=True)

In [26]:
for n, c in train_loader:
    print(n.shape, c.shape)
    break

torch.Size([4, 3, 256, 256]) torch.Size([4, 3, 256, 256])


In [27]:
def train_model(epochs, learning_rate, device="cuda" if torch.cuda.is_available() else "cpu"):
  """
  Trains the PRIDNet model for denoising images.
  Args:
    epochs: Number of training epochs.
    learning_rate: Learning rate for the optimizer.
    noisy_dir: Path to the directory containing noisy images.
    good_dir: Path to the directory containing clean (good) images.
    device: Device to use for training (CPU or GPU).
  """

  # Model and optimizer
  model = PRIDNet(in_channels=3, out_channels=3)  # Assuming 3 channels for RGB
  model.to(device)
  optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
  criterion = nn.MSELoss()

  # Training loop
  for epoch in range(epochs):
    model.train()
    train_loss = 0.0
    for noisy_image, clean_image in train_loader:
      noisy_image, clean_image = noisy_image.to(device), clean_image.to(device)

      # Forward pass
      denoised_image = model(noisy_image)
      loss = criterion(denoised_image, clean_image)

      # Backward pass and update weights
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()

      train_loss += loss.item()

    model.eval()
    with torch.no_grad():
      # Optional: Calculate validation PSNR on a separate validation set
      # ... your validation code here ...
      pass

    print(f"Epoch: {epoch+1}/{epochs} | Train Loss: {train_loss/len(train_loader):.4f}")

# Example usage
train_model(epochs=100, learning_rate=0.001)


RuntimeError: Given groups=1, weight of size [8, 32, 1, 1], expected input[4, 3, 512, 512] to have 32 channels, but got 3 channels instead