In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from generator_model import Generator
from discriminator_model import Discriminator
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
import numpy as np
from glob import glob
import cv2
import matplotlib.pyplot as plt
import os
from scipy.ndimage import gaussian_filter, map_coordinates

torch.backends.cudnn.benchmark = True





In [2]:
def train_fn(
    disc, gen, loader, opt_disc, opt_gen, l1_loss, bce, g_scaler, d_scaler,
    device='cuda'
):
    loop = tqdm(loader, leave=True)
    
    # Track total losses for averaging
    total_d_loss = 0
    total_g_loss = 0
    total_g_fake_loss = 0
    total_l1_loss = 0

    gen.train()
    disc.train()

    for idx, (x, y) in enumerate(loop):
        x = x.to(device)
        y = y.to(device)

        # ----------- Train Discriminator -----------
        with torch.cuda.amp.autocast():
            y_fake = gen(x)
            D_real = disc(x, y)
            D_real_loss = bce(D_real, torch.ones_like(D_real))
            D_fake = disc(x, y_fake.detach())
            D_fake_loss = bce(D_fake, torch.zeros_like(D_fake))
            D_loss = (D_real_loss + D_fake_loss) / 2

        opt_disc.zero_grad()
        d_scaler.scale(D_loss).backward()
        d_scaler.step(opt_disc)
        d_scaler.update()

        # ----------- Train Generator -----------
        with torch.cuda.amp.autocast():
            D_fake = disc(x, y_fake)
            G_fake_loss = bce(D_fake, torch.ones_like(D_fake))
            L1 = l1_loss(y_fake, y) * 100
            G_loss = G_fake_loss + L1

        opt_gen.zero_grad()
        g_scaler.scale(G_loss).backward()
        g_scaler.step(opt_gen)
        g_scaler.update()

        # ----------- Logging and Metrics -----------
        if idx % 10 == 0:
            loop.set_postfix(
                D_real=torch.sigmoid(D_real).mean().item(),
                D_fake=torch.sigmoid(D_fake).mean().item(),
                D_loss=D_loss.item(),
                G_loss=G_loss.item(),
            )

        # Track for averaging
        total_d_loss += D_loss.item()
        total_g_loss += G_loss.item()
        total_g_fake_loss += G_fake_loss.item()
        total_l1_loss += L1.item()

    # Compute average losses over the epoch
    num_batches = len(loader)
    avg_d_loss = total_d_loss / num_batches
    avg_g_loss = total_g_loss / num_batches
    avg_g_fake_loss = total_g_fake_loss / num_batches
    avg_l1_loss = total_l1_loss / num_batches

    return {
        'D_loss': avg_d_loss,
        'G_loss': avg_g_loss,
        'G_fake_loss': avg_g_fake_loss,
        'L1_loss': avg_l1_loss,
    }

In [3]:
import torch
import torch.nn.functional as F
from tqdm import tqdm
from torchmetrics import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure, MeanAbsoluteError


def eval_fn(gen, disc, val_loader, l1_loss, bce, device='cuda'):
    gen.eval()
    disc.eval()

    total_l1_loss = 0
    total_g_fake_loss = 0
    total_g_loss = 0

    # Image quality metrics
    mae_metric = MeanAbsoluteError().to(device)
    psnr_metric = PeakSignalNoiseRatio().to(device)
    ssim_metric = StructuralSimilarityIndexMeasure(data_range=1.0).to(device)

    loop = tqdm(val_loader, desc="Evaluating", leave=False)

    with torch.no_grad():
        for x, y in loop:
            x = x.to(device)
            y = y.to(device)

            with torch.cuda.amp.autocast():
                y_fake = gen(x)
                D_fake = disc(x, y_fake)
                G_fake_loss = bce(D_fake, torch.ones_like(D_fake))
                L1 = l1_loss(y_fake, y) * 100
                G_loss = G_fake_loss + L1

            total_l1_loss += L1.item()
            total_g_fake_loss += G_fake_loss.item()
            total_g_loss += G_loss.item()

            # Clamp predictions to [0, 1] for metrics
            y_fake_clamped = torch.clamp(y_fake, 0, 1)

            # Update metrics
            mae_metric.update(y_fake_clamped, y)
            psnr_metric.update(y_fake_clamped, y)
            ssim_metric.update(y_fake_clamped, y)

    num_batches = len(val_loader)
    avg_l1_loss = total_l1_loss / num_batches
    avg_g_fake_loss = total_g_fake_loss / num_batches
    avg_g_loss = total_g_loss / num_batches

    # Compute metrics
    mae = mae_metric.compute().item()
    psnr = psnr_metric.compute().item()
    ssim = ssim_metric.compute().item()

    print(f"\n[Validation] G_loss: {avg_g_loss:.4f}, "
          f"G_fake_loss: {avg_g_fake_loss:.4f}, "
          f"L1_loss: {avg_l1_loss:.4f}")
    print(f"[Metrics] MAE: {mae:.4f}, PSNR: {psnr:.2f} dB, SSIM: {ssim:.4f}")

    return {
        'G_loss': avg_g_loss,
        'G_fake_loss': avg_g_fake_loss,
        'L1_loss': avg_l1_loss,
        'MAE': mae,
        'PSNR': psnr,
        'SSIM': ssim,
    }


In [4]:
class MyCustomDataset(Dataset):
    def __init__(self,train_path,test_path,image_size=(224, 224),aug=True):
        self.train_path = train_path
        self.test_path = test_path
        self.image_size = image_size
        self.aug = aug

        self.train_images = sorted(glob(os.path.join(train_path, '*.npy')))
    def __len__(self):
        return len(self.train_images)

    def rotate(self,mr,ct):
        
        h, w = self.image_size
        center = (self.image_size[0] // 2, self.image_size[0] // 2)
        angle = np.random.uniform(-10, 10)
        rotation_matrix = cv2.getRotationMatrix2D(center, angle, scale=1.0)
        mr_rotated = cv2.warpAffine(mr, rotation_matrix, (w, h),borderMode=cv2.BORDER_REFLECT_101)
        ct_rotated = cv2.warpAffine(ct, rotation_matrix, (w, h),borderMode=cv2.BORDER_REFLECT_101)
        return mr_rotated,ct_rotated
        
    def hflip(self,mr,ct):
        mr_flipped = cv2.flip(mr, 1)
        ct_flipped = cv2.flip(ct, 1)
        return mr_flipped,ct_flipped

    def scaled(self, mr, ct):
        h, w = self.image_size
    
        # Random scale factor between 0.9 (zoom out) and 1.1 (zoom in)
        scale = np.random.uniform(0.9, 1.1)
    
        # Resize image
        scaled_mr = cv2.resize(mr, None, fx=scale, fy=scale, interpolation=cv2.INTER_LINEAR)
        scaled_ct = cv2.resize(ct, None, fx=scale, fy=scale, interpolation=cv2.INTER_LINEAR)
    
        # Get new size
        new_h, new_w = scaled_mr.shape[:2]
    
        # Crop or pad to original size (centered)
        top = max((new_h - h) // 2, 0)
        left = max((new_w - w) // 2, 0)
        bottom = top + h
        right = left + w
    
        # If scaled image is larger — crop center
        if scale >= 1.0:
            scaled_mr = scaled_mr[top:bottom, left:right]
            scaled_ct = scaled_ct[top:bottom, left:right]
        else:
            # If scaled image is smaller — pad to original size
            pad_top = (h - new_h) // 2
            pad_bottom = h - new_h - pad_top
            pad_left = (w - new_w) // 2
            pad_right = w - new_w - pad_left
    
            scaled_mr = cv2.copyMakeBorder(scaled_mr, pad_top, pad_bottom, pad_left, pad_right, borderType=cv2.BORDER_REFLECT_101)
            scaled_ct = cv2.copyMakeBorder(scaled_ct, pad_top, pad_bottom, pad_left, pad_right, borderType=cv2.BORDER_REFLECT_101)
    
        return scaled_mr, scaled_ct
        
    def translate(self, mr, ct):
        h, w = self.image_size
    
        # Max shift: 10% of width and height
        max_shift_x = int(0.1 * w)
        max_shift_y = int(0.1 * h)
    
        # Random shifts in x and y directions
        tx = np.random.randint(-max_shift_x, max_shift_x + 1)
        ty = np.random.randint(-max_shift_y, max_shift_y + 1)
    
        # Create translation matrix
        translation_matrix = np.float32([[1, 0, tx], [0, 1, ty]])
    
        # Apply translation
        mr_translated = cv2.warpAffine(mr, translation_matrix, (w, h), borderMode=cv2.BORDER_REFLECT_101)
        ct_translated = cv2.warpAffine(ct, translation_matrix, (w, h), borderMode=cv2.BORDER_REFLECT_101)
    
        return mr_translated, ct_translated

    def elastic_deformation(self,image, dx, dy):
        shape = image.shape[:2]
    
        x, y = np.meshgrid(np.arange(shape[1]), np.arange(shape[0]))
        indices = np.reshape(y + dy, (-1)), np.reshape(x + dx, (-1))
    
        if image.ndim == 3:
            channels = []
            for c in range(image.shape[2]):
                channel = map_coordinates(image[..., c], indices, order=1, mode='reflect').reshape(shape)
                channels.append(channel)
            return np.stack(channels, axis=-1).astype(image.dtype)
        else:
            return map_coordinates(image, indices, order=1, mode='reflect').reshape(shape).astype(image.dtype)
    
    def elastic(self, mr, ct, alpha=30, sigma=4):
        shape = mr.shape[:2]
        random_state = np.random.RandomState(None)
    
        dx = gaussian_filter((random_state.rand(*shape) * 2 - 1), sigma) * alpha
        dy = gaussian_filter((random_state.rand(*shape) * 2 - 1), sigma) * alpha
    
        mr_deformed = self.elastic_deformation(mr, dx, dy)
        ct_deformed = self.elastic_deformation(ct, dx, dy)
    
        return mr_deformed, ct_deformed

    def augment(self, mr, ct):
        # You can tweak these probabilities (0.5 = 50% chance)
        if np.random.rand() < 0.5:
            mr, ct = self.rotate(mr, ct)
        if np.random.rand() < 0.5:
            mr, ct = self.hflip(mr, ct)
        if np.random.rand() < 0.5:
            mr, ct = self.scaled(mr, ct)
        if np.random.rand() < 0.5:
            mr, ct = self.translate(mr, ct)
        # if np.random.rand() < 0.5:
        #     mr, ct = self.elastic(mr, ct)
    
        return mr, ct
    
    def __getitem__(self, idx):
        mr_path = self.train_images[idx]
        filename = os.path.basename(mr_path)

        # CT file (should match filename)
        ct_path = os.path.join(self.test_path, filename)
        # print(mr_path)
        mr_img = np.load(mr_path)
        ct_img = np.load(ct_path)

        if self.image_size is not None:
            mr_img = cv2.resize(mr_img, self.image_size)
            ct_img = cv2.resize(ct_img, self.image_size)
            
        if self.aug:
            mr_img, ct_img = self.augment(mr_img, ct_img)

        if mr_img.ndim == 2:
            mr_img = np.expand_dims(mr_img, axis=-1)
        if ct_img.ndim == 2:
            ct_img = np.expand_dims(ct_img, axis=-1)



        mr_tensor = torch.from_numpy(mr_img).permute(2, 0, 1).float()
        ct_tensor = torch.from_numpy(ct_img).permute(2, 0, 1).float()

        return mr_tensor, ct_tensor


In [5]:
dataTrain = MyCustomDataset('../../ct_mr_stdscale//train/mr/','../../ct_mr_stdscale/train/ct/')
dataVal = MyCustomDataset('../../ct_mr_stdscale/val//mr/','../../ct_mr_stdscale/val//ct/',aug=False)

In [6]:
g_scaler = torch.cuda.amp.GradScaler()
d_scaler = torch.cuda.amp.GradScaler()
val_loader = DataLoader(dataVal, batch_size=1, shuffle=False)
train_loader = DataLoader(dataTrain, batch_size=32, shuffle=True)

  g_scaler = torch.cuda.amp.GradScaler()
  d_scaler = torch.cuda.amp.GradScaler()


In [7]:
disc = Discriminator(in_channels=1).to('cuda')
gen = Generator(in_channels=1, out_channels=1).to('cuda')
opt_disc = optim.Adam(disc.parameters(), lr= 1e-4, betas=(0.5, 0.999),)
opt_gen = optim.Adam(gen.parameters(), lr= 5e-5, betas=(0.5, 0.999))
BCE = nn.BCEWithLogitsLoss()
L1_LOSS = nn.L1Loss()



In [8]:
best_g_loss = float('inf')
best_model_path = 'best_generator_aug.pth'

for epoch in range(50):
    train_fn(
        disc, gen, train_loader, opt_disc, opt_gen, L1_LOSS, BCE, g_scaler, d_scaler,
    )
    
    eval_metrics  = eval_metrics = eval_fn(
        gen, disc, val_loader, L1_LOSS, BCE, device='cuda'
    )

    torch.save(gen.state_dict(), best_model_path)
    # break  # remove this when running full training


  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
100%|██████████| 430/430 [15:54<00:00,  2.22s/it, D_fake=0.448, D_loss=0.613, D_real=0.613, G_loss=3.66]   
  with torch.cuda.amp.autocast():
                                                               


[Validation] G_loss: 3.7701, G_fake_loss: 0.9674, L1_loss: 2.8027
[Metrics] MAE: 0.0280, PSNR: 34.07 dB, SSIM: 0.7776


100%|██████████| 430/430 [10:40<00:00,  1.49s/it, D_fake=0.469, D_loss=0.63, D_real=0.425, G_loss=3.11] 
                                                               


[Validation] G_loss: 9.5015, G_fake_loss: 1.0119, L1_loss: 8.4896
[Metrics] MAE: 0.0849, PSNR: 22.58 dB, SSIM: 0.7028


100%|██████████| 430/430 [06:47<00:00,  1.06it/s, D_fake=0.305, D_loss=0.644, D_real=0.691, G_loss=3.45]
                                                               


[Validation] G_loss: 3.7551, G_fake_loss: 0.8058, L1_loss: 2.9492
[Metrics] MAE: 0.0295, PSNR: 33.59 dB, SSIM: 0.7638


100%|██████████| 430/430 [06:54<00:00,  1.04it/s, D_fake=0.451, D_loss=0.621, D_real=0.544, G_loss=2.88]
                                                               


[Validation] G_loss: 4.1752, G_fake_loss: 1.9310, L1_loss: 2.2442
[Metrics] MAE: 0.0224, PSNR: 35.13 dB, SSIM: 0.8181


100%|██████████| 430/430 [09:20<00:00,  1.30s/it, D_fake=0.574, D_loss=0.627, D_real=0.514, G_loss=2.66]
                                                               


[Validation] G_loss: 2.9296, G_fake_loss: 0.7488, L1_loss: 2.1808
[Metrics] MAE: 0.0218, PSNR: 35.44 dB, SSIM: 0.8171


100%|██████████| 430/430 [11:00<00:00,  1.54s/it, D_fake=0.444, D_loss=0.669, D_real=0.49, G_loss=2.94] 
                                                               


[Validation] G_loss: 4.6220, G_fake_loss: 1.3326, L1_loss: 3.2894
[Metrics] MAE: 0.0329, PSNR: 33.08 dB, SSIM: 0.7309


100%|██████████| 430/430 [07:43<00:00,  1.08s/it, D_fake=0.477, D_loss=0.705, D_real=0.566, G_loss=2.7] 
                                                               


[Validation] G_loss: 11.2182, G_fake_loss: 0.9594, L1_loss: 10.2589
[Metrics] MAE: 0.1026, PSNR: 25.39 dB, SSIM: 0.5511


100%|██████████| 430/430 [06:59<00:00,  1.03it/s, D_fake=0.484, D_loss=0.592, D_real=0.568, G_loss=2.82]
                                                               


[Validation] G_loss: 3.3345, G_fake_loss: 1.1470, L1_loss: 2.1875
[Metrics] MAE: 0.0219, PSNR: 34.95 dB, SSIM: 0.8205


100%|██████████| 430/430 [06:41<00:00,  1.07it/s, D_fake=0.422, D_loss=0.638, D_real=0.399, G_loss=2.94]
                                                               


[Validation] G_loss: 3.2546, G_fake_loss: 0.6008, L1_loss: 2.6538
[Metrics] MAE: 0.0265, PSNR: 34.63 dB, SSIM: 0.7143


100%|██████████| 430/430 [05:42<00:00,  1.26it/s, D_fake=0.368, D_loss=0.498, D_real=0.59, G_loss=3.35] 
                                                               


[Validation] G_loss: 4.3317, G_fake_loss: 0.7134, L1_loss: 3.6183
[Metrics] MAE: 0.0362, PSNR: 31.70 dB, SSIM: 0.7299


100%|██████████| 430/430 [05:47<00:00,  1.24it/s, D_fake=0.214, D_loss=0.562, D_real=0.565, G_loss=3.97] 
                                                               


[Validation] G_loss: 3.9204, G_fake_loss: 1.6322, L1_loss: 2.2882
[Metrics] MAE: 0.0229, PSNR: 34.97 dB, SSIM: 0.8144


100%|██████████| 430/430 [09:07<00:00,  1.27s/it, D_fake=0.0352, D_loss=0.747, D_real=0.821, G_loss=5.68]
                                                               


[Validation] G_loss: 5.7600, G_fake_loss: 1.6788, L1_loss: 4.0812
[Metrics] MAE: 0.0408, PSNR: 31.43 dB, SSIM: 0.6891


100%|██████████| 430/430 [07:34<00:00,  1.06s/it, D_fake=0.554, D_loss=0.593, D_real=0.417, G_loss=2.73]  
                                                               


[Validation] G_loss: 3.9159, G_fake_loss: 1.7016, L1_loss: 2.2143
[Metrics] MAE: 0.0221, PSNR: 35.07 dB, SSIM: 0.8183


100%|██████████| 430/430 [05:47<00:00,  1.24it/s, D_fake=0.0705, D_loss=0.812, D_real=0.767, G_loss=5.01] 
                                                               


[Validation] G_loss: 3.3536, G_fake_loss: 0.8736, L1_loss: 2.4800
[Metrics] MAE: 0.0248, PSNR: 34.49 dB, SSIM: 0.7857


100%|██████████| 430/430 [05:48<00:00,  1.23it/s, D_fake=0.059, D_loss=0.0855, D_real=0.943, G_loss=6.66]
                                                               


[Validation] G_loss: 4.2425, G_fake_loss: 1.7634, L1_loss: 2.4792
[Metrics] MAE: 0.0248, PSNR: 35.12 dB, SSIM: 0.7542


100%|██████████| 430/430 [05:40<00:00,  1.26it/s, D_fake=0.0383, D_loss=0.0664, D_real=0.967, G_loss=6.44]
                                                               


[Validation] G_loss: 3.8281, G_fake_loss: 0.9087, L1_loss: 2.9194
[Metrics] MAE: 0.0292, PSNR: 33.23 dB, SSIM: 0.7743


100%|██████████| 430/430 [05:42<00:00,  1.25it/s, D_fake=0.388, D_loss=1.08, D_real=0.483, G_loss=3.49]   
                                                               


[Validation] G_loss: 3.8189, G_fake_loss: 1.6271, L1_loss: 2.1917
[Metrics] MAE: 0.0219, PSNR: 35.25 dB, SSIM: 0.8216


100%|██████████| 430/430 [05:40<00:00,  1.26it/s, D_fake=0.0307, D_loss=0.0931, D_real=0.85, G_loss=5.9]   
                                                               


[Validation] G_loss: 3.3748, G_fake_loss: 1.0630, L1_loss: 2.3117
[Metrics] MAE: 0.0231, PSNR: 34.95 dB, SSIM: 0.8022


100%|██████████| 430/430 [05:31<00:00,  1.30it/s, D_fake=0.0231, D_loss=0.0297, D_real=0.977, G_loss=6.59]  
                                                               


[Validation] G_loss: 4.1343, G_fake_loss: 1.7017, L1_loss: 2.4326
[Metrics] MAE: 0.0243, PSNR: 34.97 dB, SSIM: 0.7785


100%|██████████| 430/430 [05:31<00:00,  1.30it/s, D_fake=0.0122, D_loss=0.129, D_real=0.99, G_loss=7.6]     
                                                               


[Validation] G_loss: 4.2074, G_fake_loss: 1.1584, L1_loss: 3.0490
[Metrics] MAE: 0.0305, PSNR: 33.45 dB, SSIM: 0.7671


100%|██████████| 430/430 [05:31<00:00,  1.30it/s, D_fake=0.0201, D_loss=0.858, D_real=0.744, G_loss=6.43]  
                                                               


[Validation] G_loss: 5.1553, G_fake_loss: 2.8377, L1_loss: 2.3176
[Metrics] MAE: 0.0232, PSNR: 34.68 dB, SSIM: 0.8084


100%|██████████| 430/430 [05:30<00:00,  1.30it/s, D_fake=0.0842, D_loss=0.277, D_real=0.986, G_loss=5.02] 
                                                               


[Validation] G_loss: 4.1976, G_fake_loss: 1.6644, L1_loss: 2.5332
[Metrics] MAE: 0.0253, PSNR: 34.47 dB, SSIM: 0.7715


100%|██████████| 430/430 [05:37<00:00,  1.27it/s, D_fake=0.375, D_loss=0.479, D_real=0.468, G_loss=3.1]  
                                                               


[Validation] G_loss: 4.1658, G_fake_loss: 1.9360, L1_loss: 2.2298
[Metrics] MAE: 0.0223, PSNR: 35.01 dB, SSIM: 0.8134


100%|██████████| 430/430 [05:43<00:00,  1.25it/s, D_fake=0.0616, D_loss=0.418, D_real=0.742, G_loss=6.56] 
                                                               


[Validation] G_loss: 5.9422, G_fake_loss: 3.1336, L1_loss: 2.8087
[Metrics] MAE: 0.0281, PSNR: 34.02 dB, SSIM: 0.7216


100%|██████████| 430/430 [05:43<00:00,  1.25it/s, D_fake=0.153, D_loss=0.258, D_real=0.703, G_loss=4.27]    
                                                               


[Validation] G_loss: 5.0900, G_fake_loss: 2.6855, L1_loss: 2.4045
[Metrics] MAE: 0.0240, PSNR: 34.46 dB, SSIM: 0.7956


100%|██████████| 430/430 [05:42<00:00,  1.26it/s, D_fake=0.0449, D_loss=0.0647, D_real=0.932, G_loss=5.75] 
                                                               


[Validation] G_loss: 4.9861, G_fake_loss: 2.4404, L1_loss: 2.5457
[Metrics] MAE: 0.0255, PSNR: 34.12 dB, SSIM: 0.7783


100%|██████████| 430/430 [05:42<00:00,  1.26it/s, D_fake=0.0466, D_loss=0.0309, D_real=0.995, G_loss=5.92]  
                                                               


[Validation] G_loss: 4.7694, G_fake_loss: 2.1044, L1_loss: 2.6650
[Metrics] MAE: 0.0267, PSNR: 33.97 dB, SSIM: 0.7701


100%|██████████| 430/430 [05:43<00:00,  1.25it/s, D_fake=0.019, D_loss=0.0403, D_real=0.952, G_loss=7.12]   
                                                               


[Validation] G_loss: 7.1961, G_fake_loss: 2.9455, L1_loss: 4.2506
[Metrics] MAE: 0.0425, PSNR: 30.65 dB, SSIM: 0.7035


100%|██████████| 430/430 [05:42<00:00,  1.26it/s, D_fake=0.76, D_loss=1.44, D_real=0.103, G_loss=2.62]      
                                                               


[Validation] G_loss: 5.6677, G_fake_loss: 3.5977, L1_loss: 2.0701
[Metrics] MAE: 0.0207, PSNR: 35.38 dB, SSIM: 0.8278


100%|██████████| 430/430 [05:44<00:00,  1.25it/s, D_fake=0.0214, D_loss=0.0902, D_real=0.902, G_loss=6.58]  
                                                               


[Validation] G_loss: 6.8609, G_fake_loss: 4.3018, L1_loss: 2.5591
[Metrics] MAE: 0.0256, PSNR: 34.28 dB, SSIM: 0.7928


100%|██████████| 430/430 [05:44<00:00,  1.25it/s, D_fake=0.0175, D_loss=1.52, D_real=0.0626, G_loss=6.96]  
                                                               


[Validation] G_loss: 3.1110, G_fake_loss: 0.6179, L1_loss: 2.4931
[Metrics] MAE: 0.0249, PSNR: 34.52 dB, SSIM: 0.7794


100%|██████████| 430/430 [05:44<00:00,  1.25it/s, D_fake=0.0269, D_loss=0.0607, D_real=0.909, G_loss=6.12]  
                                                               


[Validation] G_loss: 3.1383, G_fake_loss: 0.8635, L1_loss: 2.2749
[Metrics] MAE: 0.0227, PSNR: 34.81 dB, SSIM: 0.8109


100%|██████████| 430/430 [05:43<00:00,  1.25it/s, D_fake=0.00642, D_loss=0.0183, D_real=0.972, G_loss=8.42] 
                                                               


[Validation] G_loss: 6.1836, G_fake_loss: 3.6296, L1_loss: 2.5540
[Metrics] MAE: 0.0255, PSNR: 34.53 dB, SSIM: 0.7745


100%|██████████| 430/430 [05:44<00:00,  1.25it/s, D_fake=0.16, D_loss=0.109, D_real=0.884, G_loss=4.11]     
                                                               


[Validation] G_loss: 6.0770, G_fake_loss: 3.3533, L1_loss: 2.7237
[Metrics] MAE: 0.0272, PSNR: 33.68 dB, SSIM: 0.7916


100%|██████████| 430/430 [05:42<00:00,  1.26it/s, D_fake=0.0244, D_loss=0.303, D_real=0.588, G_loss=6.29] 
                                                               


[Validation] G_loss: 3.5303, G_fake_loss: 1.4567, L1_loss: 2.0736
[Metrics] MAE: 0.0207, PSNR: 35.44 dB, SSIM: 0.8280


100%|██████████| 430/430 [05:42<00:00,  1.26it/s, D_fake=0.477, D_loss=0.395, D_real=0.536, G_loss=2.85]  
                                                               


[Validation] G_loss: 3.0831, G_fake_loss: 0.9782, L1_loss: 2.1049
[Metrics] MAE: 0.0210, PSNR: 35.35 dB, SSIM: 0.8220


100%|██████████| 430/430 [05:43<00:00,  1.25it/s, D_fake=0.0181, D_loss=0.0179, D_real=0.984, G_loss=6.69]  
                                                               


[Validation] G_loss: 4.9039, G_fake_loss: 2.1941, L1_loss: 2.7098
[Metrics] MAE: 0.0271, PSNR: 33.91 dB, SSIM: 0.7950


100%|██████████| 430/430 [05:43<00:00,  1.25it/s, D_fake=0.0601, D_loss=0.695, D_real=0.967, G_loss=5.01]   
                                                               


[Validation] G_loss: 5.1096, G_fake_loss: 2.9890, L1_loss: 2.1206
[Metrics] MAE: 0.0212, PSNR: 35.19 dB, SSIM: 0.8201


100%|██████████| 430/430 [05:44<00:00,  1.25it/s, D_fake=0.0963, D_loss=0.0565, D_real=0.972, G_loss=4.86] 
                                                               


[Validation] G_loss: 4.1933, G_fake_loss: 1.9382, L1_loss: 2.2551
[Metrics] MAE: 0.0226, PSNR: 34.76 dB, SSIM: 0.8170


100%|██████████| 430/430 [05:44<00:00,  1.25it/s, D_fake=0.0358, D_loss=0.172, D_real=0.922, G_loss=5.65]  
                                                               


[Validation] G_loss: 4.7058, G_fake_loss: 2.2860, L1_loss: 2.4198
[Metrics] MAE: 0.0242, PSNR: 34.59 dB, SSIM: 0.7826


100%|██████████| 430/430 [05:45<00:00,  1.24it/s, D_fake=0.0172, D_loss=0.0134, D_real=0.996, G_loss=6.98]  
                                                               


[Validation] G_loss: 3.8224, G_fake_loss: 1.2860, L1_loss: 2.5363
[Metrics] MAE: 0.0254, PSNR: 34.53 dB, SSIM: 0.7804


100%|██████████| 430/430 [05:44<00:00,  1.25it/s, D_fake=0.0036, D_loss=0.00621, D_real=0.992, G_loss=9.67]  
                                                               


[Validation] G_loss: 7.0993, G_fake_loss: 3.4802, L1_loss: 3.6191
[Metrics] MAE: 0.0362, PSNR: 33.38 dB, SSIM: 0.6164


100%|██████████| 430/430 [05:44<00:00,  1.25it/s, D_fake=0.00528, D_loss=0.188, D_real=0.966, G_loss=8.62]    
                                                               


[Validation] G_loss: 5.2826, G_fake_loss: 2.1931, L1_loss: 3.0895
[Metrics] MAE: 0.0309, PSNR: 33.41 dB, SSIM: 0.7311


100%|██████████| 430/430 [05:44<00:00,  1.25it/s, D_fake=0.00428, D_loss=0.0107, D_real=0.985, G_loss=8.72]  
                                                               


[Validation] G_loss: 4.8473, G_fake_loss: 2.3484, L1_loss: 2.4989
[Metrics] MAE: 0.0250, PSNR: 34.50 dB, SSIM: 0.7915


100%|██████████| 430/430 [05:44<00:00,  1.25it/s, D_fake=0.0278, D_loss=0.0448, D_real=0.995, G_loss=6.51]  
                                                               


[Validation] G_loss: 3.9061, G_fake_loss: 1.3362, L1_loss: 2.5699
[Metrics] MAE: 0.0257, PSNR: 34.24 dB, SSIM: 0.7842


100%|██████████| 430/430 [05:42<00:00,  1.26it/s, D_fake=0.00594, D_loss=0.0032, D_real=0.999, G_loss=8.08]  
                                                               


[Validation] G_loss: 3.5704, G_fake_loss: 0.5239, L1_loss: 3.0465
[Metrics] MAE: 0.0305, PSNR: 33.02 dB, SSIM: 0.7359


100%|██████████| 430/430 [05:44<00:00,  1.25it/s, D_fake=0.0187, D_loss=0.0649, D_real=0.898, G_loss=7.34]  
                                                               


[Validation] G_loss: 8.8809, G_fake_loss: 6.1008, L1_loss: 2.7801
[Metrics] MAE: 0.0278, PSNR: 33.72 dB, SSIM: 0.7516


100%|██████████| 430/430 [05:44<00:00,  1.25it/s, D_fake=0.00537, D_loss=0.00417, D_real=0.998, G_loss=7.87]
                                                               


[Validation] G_loss: 3.2867, G_fake_loss: 1.0636, L1_loss: 2.2231
[Metrics] MAE: 0.0222, PSNR: 35.12 dB, SSIM: 0.8233


100%|██████████| 430/430 [05:44<00:00,  1.25it/s, D_fake=0.0303, D_loss=0.0957, D_real=0.868, G_loss=6.09]  
                                                               


[Validation] G_loss: 3.5087, G_fake_loss: 1.3261, L1_loss: 2.1826
[Metrics] MAE: 0.0218, PSNR: 34.95 dB, SSIM: 0.8163


100%|██████████| 430/430 [05:44<00:00,  1.25it/s, D_fake=0.221, D_loss=0.855, D_real=0.223, G_loss=4.1]    
                                                               


[Validation] G_loss: 4.5785, G_fake_loss: 2.4392, L1_loss: 2.1392
[Metrics] MAE: 0.0214, PSNR: 35.20 dB, SSIM: 0.8237
