In [None]:

dataset_path = '/kaggle/input/seismic-dataset/'


In [None]:
import h5py
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import torch
import random
import torch.nn as nn
import torch.optim as optim
import os
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import ToTensor
from torchsummary import summary
from torch.utils.tensorboard import SummaryWriter
import torchvision

In [None]:
import torch
torch.cuda.is_available()

In [None]:
!export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True

In [None]:
class dataset(Dataset):
    def __init__(self, data_path):
        super().__init__()
        self.path = data_path
        self.high_res_path = os.path.join(self.path, 'high')
        self.low_res_path = os.path.join(self.path, 'low')
        self.high_res = [x for x in os.listdir(self.high_res_path)]
        self.low_res = [x for x in os.listdir(self.low_res_path)]
        self.low_res_file_path = [os.path.join(self.low_res_path, x) for x in self.low_res]
        self.high_res_file_path = [os.path.join(self.high_res_path, x) for x in self.high_res] 

    def __getitem__(self, index):
        high_h5_data = self._read_h5(self.high_res_file_path[index])
        low_h5_data = self._read_h5(self.low_res_file_path[index])

        low_h5_data,high_h5_data = self._agumentation(low_h5_data,high_h5_data)

        low_h5_data = self._normalize_image(low_h5_data)
        high_h5_data = self._normalize_image(high_h5_data)
        
        return (torch.tensor(low_h5_data, dtype=torch.float32).unsqueeze(0),
                torch.tensor(high_h5_data, dtype=torch.float32).unsqueeze(0))

    def __len__(self):
        return len(self.low_res_file_path)      

    def _agumentation(self,low_data,high_data):
        if random.random() > 0.5:
            high_data = np.flip(high_data, axis=0)
            low_data = np.flip(low_data, axis=0)
        if random.random() > 0.5:
            high_data = np.flip(high_data, axis=1)
            low_data = np.flip(low_data, axis=1)
        if random.random() > 0.5:
            high_data = np.rot90(high_data, k=2)
            low_data = np.rot90(low_data, k=2)
        return low_data,high_data
    
    @staticmethod
    def _read_h5(file_path):
        import h5py
        with h5py.File(file_path, 'r') as f:
            return f['data'][:]

    @staticmethod
    def _normalize_image(image):
    # Assuming image is a NumPy array or torch tensor
        min_val = image.min()
        max_val = image.max()
        normalized_image = 2 * ((image - min_val) / (max_val - min_val)) - 1
        return normalized_image
            
        

In [None]:
train_dataset=dataset(os.path.join(dataset_path,'SRF_2/train'))
val_dataset=dataset(os.path.join(dataset_path,'SRF_2/val'))
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)

In [None]:
low,high=train_dataset[200]
print(low.mean())
print(low.max(),low.min())
print(high.max(),high.min())
print(low.shape,high.shape)


In [None]:
print(len(train_loader))
print(len(train_dataset))

In [None]:
train_features, train_labels = next(iter(train_loader))
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")

In [None]:
!pip install pytorch_msssim

### Evaluation Metrices

In [None]:

from pytorch_msssim import ssim as msssim

# PSNR function (with proper handling)

def norm(image):
    max=image.max()
    min=image.min()
    normal = (image - min) / (max-min)
    return normal
   
def psnr(img1, img2):
    img1=norm(img1)
    img2=norm(img2)
    mse = F.mse_loss(img1, img2)
    if mse == 0:
        return 100  # If no error, PSNR is infinite
    return 10 * torch.log10(1 / mse)

def ssim(img1, img2):
    img1=norm(img1)
    img2=norm(img2)
    return msssim(img1, img2, data_range=1.0, size_average=True).item()

def frequency_distance(img1, img2):
    img1=norm(img1)
    img2=norm(img2)
    # Perform FFT on both images
    fft1 = torch.fft.fft2(img1.squeeze(1))  # img1 should have shape [batch_size, 1, H, W]
    fft2 = torch.fft.fft2(img2.squeeze(1))  # img2 should have shape [batch_size, 1, H, W]
    
    # Compute the magnitude of the FFT (absolute value of the complex numbers)
    mag1 = torch.abs(fft1)
    mag2 = torch.abs(fft2)
    
    # Compute the L2 distance between the magnitude spectra
    return F.mse_loss(mag1, mag2)

### Loss Functions

In [None]:
## Wasserstein Loss
def wasserstein_loss(real_preds, fake_preds):
    return -(torch.mean(real_preds) - torch.mean(fake_preds))

## Gradient Penalty
def gradient_penalty(critic, real_samples, fake_samples):
    alph = torch.rand(real_samples.size(0), 1, 1, 1, device=real_samples.device)
    interpolates = (alph * real_samples + (1 - alph) * fake_samples).requires_grad_(True)
    d_interpolates = critic(interpolates)
    gradients = torch.autograd.grad(
        outputs=d_interpolates,
        inputs=interpolates,
        grad_outputs=torch.ones(d_interpolates.size(), device=real_samples.device),
        create_graph=True,
        retain_graph=True
    )[0]
    gradients = gradients.view(gradients.size(0), -1)
    gradient_norm = gradients.norm(2, dim=1)
    penalty = torch.mean((gradient_norm - 1) ** 2)
    return penalty

# TV Loss
def total_variation_loss(img):
    diff_x = img[:, :, 1:, :] - img[:, :, :-1, :]
    diff_y = img[:, :, :, 1:] - img[:, :, :, :-1]
    return torch.sum(torch.abs(diff_x)) + torch.sum(torch.abs(diff_y))

# Loss functions
class VGGLoss(nn.Module):
    def __init__(self):
        super(VGGLoss, self).__init__()
        self.vgg = torchvision.models.vgg16(pretrained=True)
        self.feature_extractor = nn.Sequential(*list(self.vgg.features)[:31]).eval()
        for param in self.feature_extractor.parameters():
            param.requires_grad = False

    def forward(self, pred, target):
        pred_features = self.feature_extractor(pred)
        target_features = self.feature_extractor(target)
        return F.mse_loss(pred_features, target_features)

In [None]:
# Directory to save models
output_dir = "/kaggle/working/models"
if not os.path.exists(output_dir):
    os.mkdir(output_dir)

### Tensorboard Logging

In [None]:
import os
from torch.utils.tensorboard import SummaryWriter

# Define your main logging directory and subdirectories for train and validation
tensorboard_dir = "/kaggle/working/tensorboard_logs"

# Create main directory if it doesn't exist
if not os.path.exists(tensorboard_dir):
    os.mkdir(tensorboard_dir)

# Create subdirectories for train and validation
train_log_dir = os.path.join(tensorboard_dir, 'train')
val_log_dir = os.path.join(tensorboard_dir, 'validation')

if not os.path.exists(train_log_dir):
    os.mkdir(train_log_dir)
if not os.path.exists(val_log_dir):
    os.mkdir(val_log_dir)

# Create TensorBoard SummaryWriter for training and validation logs
train_writer = SummaryWriter(log_dir=train_log_dir)
val_writer = SummaryWriter(log_dir=val_log_dir)


### Converting all output to single zip file

In [None]:
import zipfile
def save_and_compress_outputs(dirs_to_save, zip_filename):
    zip_path = os.path.join("/kaggle/working", zip_filename)
    with zipfile.ZipFile(zip_path, 'w') as zf:
        for save_dir in dirs_to_save:
            for root, _, files in os.walk(save_dir):
                for file in files:
                    zf.write(
                        os.path.join(root, file),
                        arcname=os.path.relpath(os.path.join(root, file), save_dir),
                    )
    print(f"Outputs compressed and saved to {zip_path}")

### Generator

In [None]:
# Define the Residual Block for the Generator
class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(channels)
        self.prelu = nn.PReLU()
        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(channels)

    def forward(self, x):
        residual = self.conv1(x)
        residual = self.bn1(residual)
        residual = self.prelu(residual)
        residual = self.conv2(residual)
        residual = self.bn2(residual)
        return x + residual

# Define the Generator
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.conv1 = nn.Conv2d(1, 64, kernel_size=9, stride=1, padding=4)
        self.prelu = nn.PReLU()
        
        # Residual blocks
        self.residual_blocks = nn.Sequential(*[ResidualBlock(64) for _ in range(12)])
        
        self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(64)

        # Upsampling blocks
        self.upsample1 = nn.Conv2d(64, 256, kernel_size=3, stride=1, padding=1)
       
        self.pixel_shuffle = nn.PixelShuffle(2)
    
        self.conv3 = nn.Conv2d(64, 1, kernel_size=9, stride=1, padding=4)
        self.tanh = nn.Tanh()

    def forward(self, x):
        x = self.conv1(x)
        x = self.prelu(x)
        residual = x
        x = self.residual_blocks(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = x + residual
        x = self.prelu(x)
        x = self.upsample1(x)
        x = self.pixel_shuffle(x)
        x = self.prelu(x)
        x = self.conv3(x)
        
        return self.tanh(x)

### Discriminator

In [None]:
# Discriminator Model (PatchGAN)
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        def discriminator_block(in_channels, out_channels, stride):
            block = [
                nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1),
                nn.BatchNorm2d(out_channels),
                nn.LeakyReLU(0.2, inplace=True)
            ]
            return block

        # PatchGAN structure
        self.model = nn.Sequential(
            *discriminator_block(1, 64, 1),
            *discriminator_block(64, 64, 2),
            *discriminator_block(64, 128, 1),
            *discriminator_block(128, 128, 2),
            *discriminator_block(128, 256, 1),
            *discriminator_block(256, 256, 2),
            *discriminator_block(256, 512, 1),
            *discriminator_block(512, 512, 2),
        )
        
        # Final layers
        self.pool = nn.AdaptiveAvgPool2d(2)
        self.fc1 = nn.Conv2d(512,1024, kernel_size=1)
        self.leaky_relu = nn.LeakyReLU(0.2, inplace=True)
        self.fc2 = nn.Conv2d(1024,1, kernel_size=1)
        
    def forward(self, x):
        x = self.model(x)
        x = self.pool(x)
        x = self.fc1(x)
        x = self.leaky_relu(x)
        x = self.fc2(x)
        return x  # Output will be a matrix of size [batch_size, 1, H/16, W/16]


### Checkpoint Loading

In [None]:
#For training
mse_loss = nn.MSELoss()
vgg_loss = VGGLoss().cuda()

checkpoint_path = '/kaggle/input/epoch120/pytorch/default/1/checkpoint_epoch_120.pth'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

### Initialization

In [None]:
epochs = 200
alpha = 2e-2
beta = 6e-2
lambda_gp = 12  #Gradient penalty coefficient
critic_steps = 1
lamda=2e-7


#Initialize models
generator = Generator().cuda()
discriminator = Discriminator().cuda()

# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), 
                 lr=3e-5,              # Learning rate
                 betas=(0.9, 0.999),   # Momentum parameters
                 weight_decay=1e-5)    # L2 regularization

# Discriminator Optimizer
optimizer_D = torch.optim.Adam(discriminator.parameters(),
                 lr=1e-4,              # Learning rate
                 betas=(0.9, 0.999),   # Momentum parameters
                 weight_decay=1e-5)    # L2 regularization


if os.path.exists(checkpoint_path):
    checkpoint = torch.load(checkpoint_path, map_location=device)
    generator.load_state_dict(checkpoint['generator_state_dict'])
    discriminator.load_state_dict(checkpoint['discriminator_state_dict'])
    optimizer_G.load_state_dict(checkpoint['optimizer_G_state_dict'])
    optimizer_D.load_state_dict(checkpoint['optimizer_D_state_dict'])

    start_epoch = checkpoint['epoch']  # Resume from last saved epoch
    print(f"Resuming training from epoch {start_epoch}")

else:
    print("No checkpoint found. Training from scratch.")
    start_epoch = 0  # Start from the beginning

### Training

In [None]:
# Training loop
for epoch in range(start_epoch,epochs):
    generator.train()
    discriminator.train()
    
    for i, (low_res, high_res) in enumerate(train_loader):
        low_res, high_res = low_res.cuda(), high_res.cuda()

        # Train Discriminator
        for _ in range (critic_steps):
            optimizer_D.zero_grad()
            fake_high_img=generator(low_res).detach()
            real_pred=discriminator(high_res)
            fake_pred=discriminator(fake_high_img)
            
            gp=gradient_penalty(discriminator, high_res, fake_high_img)
            d_loss = wasserstein_loss(real_pred,fake_pred) + lambda_gp * gp
            d_loss.backward()
            optimizer_D.step()

        # Train Generator
        optimizer_G.zero_grad()
        gen_high_res = generator(low_res)
        gen_high_res_3channel=gen_high_res.repeat(1,3,1,1)
        high_res_3channel=high_res.repeat(1,3,1,1)
        
        g_loss_vgg = vgg_loss(gen_high_res_3channel, high_res_3channel) #Perseptual Loss
        
        g_loss_mse = mse_loss(gen_high_res, high_res)
        g_loss_TV = total_variation_loss(gen_high_res)
        gen_pred=discriminator(gen_high_res)
        g_loss_adv = -torch.mean(gen_pred) #Perseptual Loss
        g_loss = g_loss_mse + beta * g_loss_vgg + alpha * g_loss_adv + lamda*g_loss_TV
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                            
        g_loss.backward()
        optimizer_G.step()

         # Log losses to TensorBoard
        pred = generator(low_res).detach()
        # Compute Metrics
        psnr_value = psnr(pred, high_res)
        ssim_value = ssim(pred, high_res)

        # Log losses to TensorBoard
        train_writer.add_scalar("PSNR/train", psnr_value.item(), epoch * len(train_loader) + i)
        train_writer.add_scalar("SSIM/train", ssim_value, epoch * len(train_loader) + i)
        train_writer.add_scalar('Loss/Discriminator', d_loss.item(), epoch * len(train_loader) + i)
        train_writer.add_scalar('Loss/Generator', g_loss.item(), epoch * len(train_loader) + i)
        train_writer.add_scalar('Loss/MSE', g_loss_mse.item(), epoch * len(train_loader) + i)
        train_writer.add_scalar('Loss/VGG', g_loss_vgg.item(), epoch * len(train_loader) + i)
        train_writer.add_scalar('Loss/Adversarial', g_loss_adv.item(), epoch * len(train_loader) + i)

        if i % 10 == 0:
            print(f"Epoch [{epoch}/{epochs}] Batch {i}/{len(train_loader)} "
                  f"Loss D: {d_loss.item():.4f}, Loss G: {g_loss.item():.4f}, "
                  f"MSE: {g_loss_mse.item():.4f}, VGG: {g_loss_vgg.item():.4f}, Adv: {g_loss_adv.item():.4f},PSNR:{psnr_value.item():.4f},SSIM:{ssim_value:.4f}")

    
    # Validation phase
    generator.eval()
    discriminator.eval()
    val_mse_loss, val_psnr,val_fft,val_ssim= 0.0, 0.0,0.0,0.0
    with torch.no_grad():
        val_g_loss=0.0
        val_d_loss=0.0
        for low_res, high_res in val_loader:
            low_res, high_res = low_res.cuda(), high_res.cuda()
            
            gen_high_res = generator(low_res)
            real_pred=discriminator(high_res)
            fake_pred=discriminator(gen_high_res)
            val_d_loss = wasserstein_loss(real_pred,fake_pred)
            
            gen_high_res_3channel=gen_high_res.repeat(1,3,1,1)
            high_res_3channel=high_res.repeat(1,3,1,1)
        
            val_g_loss_vgg = vgg_loss(gen_high_res_3channel, high_res_3channel) #Perseptual Loss
        
            val_g_loss_mse = mse_loss(gen_high_res, high_res)
            val_g_loss_TV = total_variation_loss(gen_high_res)
            val_gen_pred= discriminator(gen_high_res)
            val_g_loss_adv = -torch.mean(val_gen_pred) 
            val_g_loss =  val_g_loss_mse + beta * val_g_loss_vgg + alpha * val_g_loss_adv + lamda*val_g_loss_TV

            val_writer.add_scalar('VAL/Discriminator', val_d_loss.item(), epoch * len(train_loader) + i)
            val_writer.add_scalar('VAL/Generator', val_g_loss.item(), epoch * len(train_loader) + i)
            val_writer.add_scalar('VAL/VGG',val_g_loss_vgg.item(), epoch * len(train_loader) + i)
            val_writer.add_scalar('VAL/Adversarial', val_g_loss_adv.item(), epoch * len(train_loader) + i)
            
            val_mse_loss += val_g_loss_mse.item()
            val_psnr += psnr(gen_high_res, high_res).item()
            val_ssim += ssim(gen_high_res,high_res)
            val_fft  += frequency_distance(gen_high_res,high_res).item()
    val_mse_loss /= len(val_loader)
    val_psnr /= len(val_loader)
    val_ssim /= len(val_loader)
    val_fft /= len(val_loader)
    # Log validation metrics to TensorBoard
    val_writer.add_scalar('Validation/MSE', val_mse_loss, epoch)
    val_writer.add_scalar('Validation/PSNR', val_psnr, epoch)
    val_writer.add_scalar('Validation/ssim',val_ssim,epoch)
    val_writer.add_scalar('Validation/fft',val_fft,epoch)
    print(f"Validation - Epoch [{epoch}/{epochs}]: MSE: {val_mse_loss:.4f}, PSNR: {val_psnr:.4f}, SSIM: {val_ssim:.4f},FFT :{val_fft:.4f}")

    # Save the model every 10 epochs
    if (epoch + 1) % 10 == 0:
            torch.save({
                'epoch': epoch + 1,
                'generator_state_dict': generator.state_dict(),
                'discriminator_state_dict': discriminator.state_dict(),
                'optimizer_G_state_dict': optimizer_G.state_dict(),
                'optimizer_D_state_dict': optimizer_D.state_dict(),
            },os.path.join(output_dir, f"checkpoint_epoch_{epoch+1}.pth"))  
            print(f"Models saved at epoch {epoch+1}")
save_and_compress_outputs([output_dir, tensorboard_dir], "trained_models and logs.zip")
        