In [None]:
import os
import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
import matplotlib.pyplot as plt
from tqdm import tqdm
from PIL import Image

In [None]:
""""
Stage-by-stage training with:
- Stage 1: DnCNN (denoising) + color/brightness losses
- Stage 2: Background Removal
- Stage 3: Simple MPRNet-based Deblurring
Then end-to-end fine-tuning (StarEnhancementNet).

Datasets assumed in:
  ../stage1_denoising/{train,val}/{input,target}
  ../stage2_bg_removed/{train,val}/{input,target}
  ../stage3_final/{train,val}/{input,target}
"""

# Stage1: DnCNN for Denoising
class DnCNN(nn.Module):
    def __init__(self, image_channels=3, num_features=64, num_layers=17):
        super(DnCNN, self).__init__()
        layers = []
        layers.append(nn.Conv2d(image_channels, num_features, kernel_size=3, padding=1, bias=False))
        layers.append(nn.ReLU(inplace=True))
        for _ in range(num_layers - 2):
            layers.append(nn.Conv2d(num_features, num_features, kernel_size=3, padding=1, bias=False))
            layers.append(nn.BatchNorm2d(num_features))
            layers.append(nn.ReLU(inplace=True))
        layers.append(nn.Conv2d(num_features, image_channels, kernel_size=3, padding=1, bias=False))
        self.dncnn = nn.Sequential(*layers)

    def forward(self, x):
        noise = self.dncnn(x)
        return x - noise

# Stage2: Background Removal
class BackgroundRemoval(nn.Module):
    def __init__(self, channels=3):
        super().__init__()
        self.conv1 = nn.Conv2d(channels, channels, 3, 1, 1)
        self.relu = nn.ReLU(True)
        self.conv2 = nn.Conv2d(channels, channels, 3, 1, 1)
    def forward(self, x):
        bg = self.conv1(x)
        bg = self.relu(bg)
        bg = self.conv2(bg)
        return x - bg
    
# Stage 3: Simple MPRNet-inspired Deblurring
def conv_block(in_channels, out_channels, kernel_size=3, stride=1, padding=1):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=True),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True)
    )

class SimpleMPRNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=3, base_channels=64):
        super(SimpleMPRNet, self).__init__()
        self.conv1 = conv_block(in_channels, base_channels)
        self.conv2 = conv_block(base_channels, base_channels*2)
        self.pool = nn.MaxPool2d(2)
        self.bottleneck = conv_block(base_channels*2, base_channels*4)
        self.up1 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
        self.deconv1 = conv_block(base_channels*4, base_channels*2)
        self.up2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
        self.deconv2 = conv_block(base_channels*2, base_channels)
        self.conv_last = nn.Conv2d(base_channels, out_channels, kernel_size=3, stride=1, padding=1)
        
    def forward(self, x):
        x1 = self.conv1(x)
        x2 = self.conv2(self.pool(x1))
        x_b = self.bottleneck(self.pool(x2))
        x_up1 = self.up1(x_b)
        x_d1 = self.deconv1(x_up1)
        if x_d1.shape == x2.shape:
            x_d1 = x_d1 + x2
        x_up2 = self.up2(x_d1)
        x_d2 = self.deconv2(x_up2)
        if x_d2.shape == x1.shape:
            x_d2 = x_d2 + x1
        out = self.conv_last(x_d2)
        return out

# Combined Pipeline
class StarEnhancementNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.stage1 = DnCNN(image_channels=3, num_features=64, num_layers=17)
        self.stage2 = BackgroundRemoval(3)
        self.stage3 = SimpleMPRNet(3,3,64)
    def forward(self, x):
        x = self.stage1(x)
        x = self.stage2(x)
        x = self.stage3(x)
        return x


In [None]:
from tqdm import tqdm
import os
import cv2
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
import torch.nn as nn
import matplotlib.pyplot as plt
from torchvision import transforms
from PIL import Image

# PairedDataset for input-target pairs
class PairedDataset(Dataset):
    def __init__(self, inp_dir, tgt_dir, transform=None):
        self.inp_dir = inp_dir
        self.tgt_dir = tgt_dir
        self.transform = transform
        self.files = sorted(os.listdir(self.inp_dir))
    
    def __len__(self):
        return len(self.files)
    
    def __getitem__(self, idx):
        inp_filename = self.files[idx]
        inp_path = os.path.join(self.inp_dir, inp_filename)
        tgt_path = os.path.join(self.tgt_dir, inp_filename)

        inp_bgr = cv2.imread(inp_path)
        if inp_bgr is None:
            raise ValueError(f"Could not read input image {inp_path}")
        tgt_bgr = cv2.imread(tgt_path)
        if tgt_bgr is None:
            raise ValueError(f"Could not read target image {tgt_path}")

        inp_rgb = cv2.cvtColor(inp_bgr, cv2.COLOR_BGR2RGB)
        tgt_rgb = cv2.cvtColor(tgt_bgr, cv2.COLOR_BGR2RGB)

        if self.transform:
            inp_tensor = self.transform(Image.fromarray(inp_rgb))
            tgt_tensor = self.transform(Image.fromarray(tgt_rgb))
        else:
            inp_tensor = torch.from_numpy(inp_rgb).permute(2,0,1).float()/255.
            tgt_tensor = torch.from_numpy(tgt_rgb).permute(2,0,1).float()/255.
        return inp_tensor, tgt_tensor

# Loss: L1 + brightness + color
def brightness_loss(output, target):
    out_mean = torch.mean(output, dim=[1,2,3])
    tgt_mean = torch.mean(target, dim=[1,2,3])
    return torch.mean(torch.abs(out_mean - tgt_mean))

def color_consistency_loss(output, target):
    out_mean = output.mean(dim=[2,3])
    tgt_mean = target.mean(dim=[2,3])
    return torch.mean(torch.abs(out_mean - tgt_mean))

l1_criterion = nn.L1Loss()
brightness_weight = 0.05
color_weight = 0.05

def compute_loss_with_color_brightness(out, tgt):
    l1 = l1_criterion(out, tgt)
    b_loss = brightness_loss(out, tgt)
    c_loss = color_consistency_loss(out, tgt)
    return l1 + brightness_weight*b_loss + color_weight*c_loss

def calculate_psnr_batch(output, target):
    output = torch.clamp(output,0,1)
    target = torch.clamp(target,0,1)
    out_np = output.detach().cpu().permute(0,2,3,1).numpy()
    tgt_np = target.detach().cpu().permute(0,2,3,1).numpy()
    psnr_vals=[]
    for b in range(out_np.shape[0]):
        psnr_vals.append(psnr_single(out_np[b], tgt_np[b]))
    return np.mean(psnr_vals)

def psnr_single(img1, img2):
    mse = np.mean((img1 - img2)**2)
    if mse < 1e-10:
        return 100
    return 20*np.log10(1.0/np.sqrt(mse))

def show_final_val_examples(model, loader, stage_name, epoch_num, n_examples=3):
    model.eval()
    count = 0
    with torch.no_grad():
        for inp, tgt in loader:
            inp, tgt = inp.to(device), tgt.to(device)
            out = model(inp)
            out = torch.clamp(out, 0, 1)

            inp_np = inp[0].permute(1,2,0).cpu().numpy()
            out_np = out[0].permute(1,2,0).cpu().numpy()
            tgt_np = tgt[0].permute(1,2,0).cpu().numpy()

            fig, axes = plt.subplots(1, 3, figsize=(12, 4))
            axes[0].imshow(inp_np)
            axes[0].set_title("Input")
            axes[0].axis("off")

            axes[1].imshow(out_np)
            axes[1].set_title("Output")
            axes[1].axis("off")

            axes[2].imshow(tgt_np)
            axes[2].set_title("GroundTruth")
            axes[2].axis("off")
            fig.suptitle(f"{stage_name} - Val Example (Epoch {epoch_num}) - {count}")
            out_name = f"{stage_name}_e{epoch_num}_val_{count}.png"
            plt.savefig(out_name)
            plt.show()

            count += 1
            if count >= n_examples:
                break

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def train_stage(
    model, 
    stage_name,
    train_inp_dir, train_tgt_dir,
    val_inp_dir,   val_tgt_dir,
    epochs=10, 
    batch_size=4,
    lr=1e-4,
    checkpoint=None
):
    print(f"Training {stage_name} for {epochs} epochs...")

    transform = transforms.ToTensor()
    train_ds = PairedDataset(train_inp_dir, train_tgt_dir, transform)
    val_ds   = PairedDataset(val_inp_dir,   val_tgt_dir,   transform)

    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
    val_loader   = DataLoader(val_ds,   batch_size=batch_size, shuffle=False)

    model = model.to(device)
    if checkpoint:
        model.load_state_dict(torch.load(checkpoint))
        print(f"Loaded checkpoint {checkpoint}")

    optimizer = optim.Adam(model.parameters(), lr=lr)

    train_loss_arr = []
    val_loss_arr = []
    val_psnr_arr = []

    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        # train
        for inp, tgt in tqdm(train_loader, desc=f"Train Epoch {epoch+1}/{epochs}", leave=False):
            inp, tgt = inp.to(device), tgt.to(device)
            out = model(inp)
            loss = compute_loss_with_color_brightness(out, tgt)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        avg_tr_loss = running_loss / len(train_loader)
        train_loss_arr.append(avg_tr_loss)

        # validation
        model.eval()
        running_val_loss = 0.0
        running_val_psnr = 0.0
        with torch.no_grad():
            for inp, tgt in tqdm(val_loader, desc=f"Validation Epoch {epoch+1}/{epochs}", leave=False):
                inp, tgt = inp.to(device), tgt.to(device)
                out = model(inp)
                vloss = compute_loss_with_color_brightness(out, tgt)
                running_val_loss += vloss.item()

                batch_psnr = calculate_psnr_batch(out, tgt)
                running_val_psnr += batch_psnr

        avg_val_loss = running_val_loss / len(val_loader)
        avg_val_psnr = running_val_psnr / len(val_loader)
        val_loss_arr.append(avg_val_loss)
        val_psnr_arr.append(avg_val_psnr)

        if ((epoch+1) % 50 == 0) or (epoch+1 == epochs):
            print(f"[{stage_name}] Epoch {epoch+1}/{epochs} => "
                  f"TrainLoss={avg_tr_loss:.4f}, ValLoss={avg_val_loss:.4f}, ValPSNR={avg_val_psnr:.2f} dB")
            # save checkpoint
            torch.save(model.state_dict(), f"{stage_name.lower()}_e{epoch+1}.pth")
            print(f"Saved checkpoint => {stage_name.lower()}_e{epoch+1}.pth")
            show_final_val_examples(model, val_loader, stage_name, epoch+1)

    final_model_path = f"{stage_name.lower()}_e{epochs}.pth"
    if checkpoint is not None:
        final_model_path = f"{stage_name.lower()}_finetuned_e{epochs}.pth"
    torch.save(model.state_dict(), final_model_path)
    print(f"Saved final model => {final_model_path}")

    # Loss & PSNR curves
    fig, axes = plt.subplots(1, 2, figsize=(12, 5))
    # loss
    axes[0].plot(range(1, epochs+1), train_loss_arr, label="TrainLoss")
    axes[0].plot(range(1, epochs+1), val_loss_arr,   label="ValLoss")
    axes[0].set_xlabel("Epoch")
    axes[0].set_ylabel("Loss")
    axes[0].set_title(f"{stage_name} Loss")
    axes[0].legend()

    # psnr
    axes[1].plot(range(1, epochs+1), val_psnr_arr, label="ValPSNR", color='orange')
    axes[1].set_xlabel("Epoch")
    axes[1].set_ylabel("PSNR (dB)")
    axes[1].set_title(f"{stage_name} Validation PSNR")
    axes[1].legend()

    fig.suptitle(f"{stage_name} - Loss & PSNR")
    plt.tight_layout()
    save_fig_path = f"{stage_name.lower()}_loss_psnr_e{epochs}.png"
    plt.savefig(save_fig_path)
    plt.show()

    show_final_val_examples(model, val_loader, stage_name, epochs)

    return model

# Fine-tune
def fine_tune_entire_pipeline(net, epochs=5, lr=1e-5):
    print(f"Fine-tuning entire pipeline for {epochs} epochs...")
    final_inp_dir = "../stage3_final-mid/train/input"
    final_tgt_dir = "../stage3_final-mid/train/target"

    transform = transforms.ToTensor()
    train_ds = PairedDataset(final_inp_dir, final_tgt_dir, transform)
    train_loader = DataLoader(train_ds, batch_size=8, shuffle=True)

    optimizer = optim.Adam(net.parameters(), lr=lr)

    train_loss_arr = []

    for epoch in range(epochs):
        net.train()
        running = 0.0
        for inp, tgt in tqdm(train_loader, desc=f"FineTune E{epoch+1}/{epochs}", leave=False):
            inp, tgt = inp.to(device), tgt.to(device)
            out = net(inp)
            loss = compute_loss_with_color_brightness(out, tgt)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            running += loss.item()
        avg_tr_loss = running / len(train_loader)
        train_loss_arr.append(avg_tr_loss)

        # Print only every 50 epochs or final
        if ((epoch+1) % 50 == 0) or (epoch+1 == epochs):
            avg_psnr = calculate_psnr_batch(out, tgt)
            print(f"[FineTune] Epoch {epoch+1}/{epochs} => TrainLoss={avg_tr_loss:.4f}, PSNR={avg_psnr:.2f} dB")
            torch.save(net.state_dict(), f"final_pipeline_finetuned_e{epoch+1}.pth")
            print(f"Saved checkpoint => final_pipeline_finetuned_e{epoch+1}.pth")
            show_final_val_examples(net, train_loader, "FinalPipeline", epoch+1)

    final_model_path = f"final_pipeline_large_finetuned_e{epochs}.pth"
    torch.save(net.state_dict(), final_model_path)
    print(f"Saved final pipeline => {final_model_path}")

    plt.figure()
    plt.plot(range(1, epochs+1), train_loss_arr, label="TrainLoss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title("Fine-tune End-to-End Loss")
    plt.legend()
    plt.savefig(f"finetune_loss_e{epochs}.png")
    plt.show()

    show_final_val_examples(net, train_loader, "FinalPipeline", epochs)

if __name__ == "__main__":
    epoch_num = 800
    # Stage1: DnCNN for Denoising
    from_stage1 = DnCNN(image_channels=3, num_features=64, num_layers=17)
    from_stage1 = train_stage(
        from_stage1, 
        stage_name="Stage1_DnCNN",
        train_inp_dir="../stage1_denoising-mid/train/input",
        train_tgt_dir="../stage1_denoising-mid/train/target",
        val_inp_dir="../stage1_denoising-mid/val/input",
        val_tgt_dir="../stage1_denoising-mid/val/target",
        epochs=epoch_num,
        batch_size=8,
        lr=1e-4
    )

    # Stage2: BG Removal
    from_stage2 = BackgroundRemoval(channels=3)
    from_stage2 = train_stage(
        from_stage2,
        stage_name="Stage2_BGRemoval",
        train_inp_dir="../stage2_bg_removed-mid/train/input",
        train_tgt_dir="../stage2_bg_removed-mid/train/target",
        val_inp_dir="../stage2_bg_removed-mid/val/input",
        val_tgt_dir="../stage2_bg_removed-mid/val/target",
        epochs=epoch_num,
        batch_size=8,
        lr=1e-4
    )

    # Stage3: Deblurring
    from_stage3 = SimpleMPRNet(in_channels=3, out_channels=3, base_channels=64)
    from_stage3 = train_stage(
        from_stage3,
        stage_name="Stage3_Deblurring",
        train_inp_dir="../stage3_final-mid/train/input",
        train_tgt_dir="../stage3_final-mid/train/target",
        val_inp_dir="../stage3_final-mid/val/input",
        val_tgt_dir="../stage3_final-mid/val/target",
        epochs=1000,
        batch_size=8,
        lr=1e-5,
        # checkpoint="stage3_deblurring_e1000.pth"
    )

    # Fine-tune entire pipeline
    net = StarEnhancementNet()
    net.stage1.load_state_dict(torch.load(f"stage1_dncnn_e{500}.pth"))
    net.stage2.load_state_dict(torch.load(f"stage2_bgremoval_e{800}.pth"))
    net.stage3.load_state_dict(torch.load(f"./good-saved/stage3_deblurring_finetuned_e{500}.pth"))
    net = net.to(device)
    fine_tune_entire_pipeline(net, epochs=100, lr=1e-4)
    print("All training done!")
