A DCLGAN is a Dual Contrastive Learning Generative Adversarial Network that is built on a generator and discriminator, and has a PatchNCE loss at its core that learns to generate synthetic images by dividing the dataset images into patches, treating them as positive and negative pairs - where positive pairs are pushed together in a vector space while negative pairs are pushed apart.

# Importing Libraries

In [None]:
import torch
from torch import nn, optim
from torchvision import datasets, transforms, models, utils
from torchvision.models import inception_v3
from torchvision.datasets import ImageFolder, DatasetFolder
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from torch.autograd import Function
import torchvision.models as models
import torchvision.transforms as transforms
import shutil
import cv2
import random
from tqdm.notebook import tqdm
import cv2
import torch.nn as nn
import torch.nn.functional as F
import torchvision.utils as vutils
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import os
import warnings

# Generator

In [None]:
class ResnetBlock(nn.Module):
    def __init__(self, dim):
        super(ResnetBlock, self).__init__()
        self.block = nn.Sequential(
            nn.ReflectionPad2d(1),
            nn.Conv2d(dim, dim, kernel_size=3, padding=0),
            nn.InstanceNorm2d(dim),
            nn.ReLU(True),
            nn.ReflectionPad2d(1),
            nn.Conv2d(dim, dim, kernel_size=3, padding=0),
            nn.InstanceNorm2d(dim)
        )

    def forward(self, x):
        return x + self.block(x)

In [None]:
class Generator(nn.Module):
    def __init__(self, input_nc=1, output_nc=1, n_blocks=6):
        super(Generator, self).__init__()

        # Initial convolution
        model = [
            nn.ReflectionPad2d(3),
            nn.Conv2d(input_nc, 64, kernel_size=7, padding=0),
            nn.InstanceNorm2d(64),
            nn.ReLU(True)
        ]

        # Downsampling
        n_downsampling = 2
        for i in range(n_downsampling):
            mult = 2**i
            model += [
                nn.Conv2d(64 * mult, 64 * mult * 2, kernel_size=3, stride=2, padding=1),
                nn.InstanceNorm2d(64 * mult * 2),
                nn.ReLU(True)
            ]

        # ResNet blocks
        mult = 2**n_downsampling
        for i in range(n_blocks):
            model += [ResnetBlock(64 * mult)]

        # Upsampling
        for i in range(n_downsampling):
            mult = 2**(n_downsampling - i)
            model += [
                nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
                nn.Conv2d(64 * mult, int(64 * mult / 2), kernel_size=3, stride=1, padding=1),
                nn.InstanceNorm2d(int(64 * mult / 2)),
                nn.ReLU(True)
            ]

        # Output layer
        model += [
            nn.ReflectionPad2d(3),
            nn.Conv2d(64, output_nc, kernel_size=7, padding=0),
            nn.Tanh()
        ]

        self.model = nn.Sequential(*model)

    def forward(self, x):
        # Ensure output has same spatial dimensions as input
        out = self.model(x)
        if out.shape[2:] != x.shape[2:]:
            out = F.interpolate(out, size=x.shape[2:], mode='bilinear', align_corners=False)
        return out

# Discriminator

In [None]:
class Discriminator(nn.Module):
    def __init__(self, input_nc=1):
        super(Discriminator, self).__init__()
        model = [
            nn.Conv2d(input_nc, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(128),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(256),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(256, 512, kernel_size=4, padding=1),
            nn.InstanceNorm2d(512),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(512, 1, kernel_size=4, padding=1)  # output as a feature map [batch_size, 1, H, W]
        ]
        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)

# Losses

PatchNCE Loss:

In [None]:
class PatchNCELoss(nn.Module):
    """Enhanced PatchNCE Loss with feature normalization"""
    def __init__(self, temperature=0.1):
        super().__init__()
        self.temperature = temperature
        self.cross_entropy_loss = nn.CrossEntropyLoss()

    def forward(self, feat_q, feat_k):
        if feat_q.shape[2:] != feat_k.shape[2:]:
            feat_q = F.interpolate(feat_q, size=feat_k.shape[2:], mode='bilinear', align_corners=False)

        batch_size = feat_q.shape[0]
        dim = feat_q.shape[1]

        feat_q = feat_q.view(batch_size, dim, -1)
        feat_k = feat_k.view(batch_size, dim, -1)

        feat_q = F.normalize(feat_q, dim=1)
        feat_k = F.normalize(feat_k, dim=1)

        num_patches = feat_q.shape[2]

        loss = 0
        for i in range(batch_size):
            q = feat_q[i].permute(1, 0)
            k = feat_k[i].permute(1, 0)

            # Positive logits: num_patches x 1
            l_pos = torch.bmm(q.view(num_patches, 1, dim),
                             k.view(num_patches, dim, 1)).view(num_patches, 1)

            # Negative logits: num_patches x (num_patches-1)
            l_neg = torch.mm(q, k.t())

            # self-similarity
            identity_mask = torch.eye(num_patches, device=l_neg.device)
            l_neg = l_neg.masked_fill(identity_mask.bool(), -float('inf'))

            # combine positive and negative logits
            logits = torch.cat([l_pos, l_neg], dim=1) / self.temperature

            # positive pair is the first entry
            labels = torch.zeros(num_patches, dtype=torch.long, device=logits.device)

            loss += self.cross_entropy_loss(logits, labels)

        return loss / batch_size

Feature Matching Loss (to stabilize GAN training):

In [None]:
class FeatureMatchingLoss(nn.Module):
    """Feature Matching Loss for stabilizing GAN training"""
    def __init__(self):
        super(FeatureMatchingLoss, self).__init__()
        self.l1_loss = nn.L1Loss()

    def forward(self, real_features, fake_features):
        loss = 0
        for real_feat, fake_feat in zip(real_features, fake_features):
            loss += self.l1_loss(fake_feat, real_feat.detach())
        return loss

# Hounsfield Units Loss (New Introduction)

Hounsfield Units (HU) are a standardized scale used in CT imaging that quantifies radiodensity. Different tissues have characteristic HU ranges:

- Air: approximately -1000 HU
- Lung tissue: -700 to -600 HU
- Fat: -100 to -50 HU
- Water: 0 HU
- Soft tissue: +20 to +70 HU
- Bone: +700 to +3000 HU

This quantitative nature of CT scans makes them different from regular photographs - specific pixel intensity values have medical meaning.

In [None]:
class HULoss(nn.Module):
    """Hounsfield Unit distribution preservation loss for CT scans"""
    def __init__(self, bins=100, min_value=-1, max_value=1, reduction='mean'):
        super().__init__()
        self.bins = bins
        self.min_value = min_value
        self.max_value = max_value
        self.reduction = reduction

    def forward(self, real, fake):
        losses = []

        # process each image in the batch individually for better histogram matching
        for i in range(real.size(0)):
            # calculate histograms of pixel values
            real_hist = torch.histc(real[i].flatten(), bins=self.bins,
                                  min=self.min_value, max=self.max_value)
            fake_hist = torch.histc(fake[i].flatten(), bins=self.bins,
                                  min=self.min_value, max=self.max_value)

            # normalize histograms to make them probability distributions
            real_hist = real_hist / (real_hist.sum() + 1e-10)
            fake_hist = fake_hist / (fake_hist.sum() + 1e-10)

            # KL divergence for distribution matching
            # Adding small epsilon to avoid log(0)
            eps = 1e-10
            kl_div = (real_hist * torch.log((real_hist + eps) / (fake_hist + eps))).sum()
            losses.append(kl_div)

        if self.reduction == 'mean':
            return torch.stack(losses).mean()
        else:
            return torch.stack(losses).sum()