In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
import os
import scipy.io
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader, random_split
import torch.nn as nn
import torch.optim as optim

In [None]:
class CustomMatDataset(Dataset):
    def __init__(self, ch_dir, f_dir):
        self.ch_files = [os.path.join(ch_dir, file) for file in os.listdir(ch_dir) if file.endswith('.mat')]
        self.f_files = [os.path.join(f_dir, file) for file in os.listdir(f_dir) if file.endswith('.mat')]

    def __len__(self):
        return len(self.ch_files)

    def __getitem__(self, idx):
        ch_data = scipy.io.loadmat(self.ch_files[idx])['H']
        ch_data = np.abs(ch_data)
        f_data = scipy.io.loadmat(self.f_files[idx])['Fopt_final']
        f_data = np.abs(f_data)
        ch_data = np.transpose(ch_data, (2, 0, 1))*1e6  # Change to [5, 36, 144]
        f_data = np.transpose(f_data, (2, 0, 1)) # Change to [5, 144, 2]
        return torch.tensor(ch_data, dtype=torch.float32), torch.tensor(f_data, dtype=torch.float32)


In [None]:
# Define directories
ch_dir = '/kaggle/input/dataset-hfw/Combined Dataset/Input'
f_dir = '/kaggle/input/dataset-hfw/Combined Dataset/Output1'

# Create dataset
dataset = CustomMatDataset(ch_dir, f_dir)

# Split dataset into train and test
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False)


In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(in_channels)
        self.relu = nn.LeakyReLU()
        self.conv2 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(in_channels)
    
    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out += residual
        out = self.relu(out)
        return out
class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()
        
        # Define the encoder (contracting path)
        self.enc_conv0 = self.double_conv(5, 16)
        self.res0 = ResidualBlock(16)
        self.enc_conv1 = self.double_conv(16, 32)
        self.res1 = ResidualBlock(32)
        self.enc_conv2 = self.double_conv(32, 64)
        self.res2 = ResidualBlock(64)
        self.enc_conv3 = self.double_conv(64, 128)
        self.res3 = ResidualBlock(128)
        self.pool  = nn.MaxPool2d(2)
        self.lin = nn.Linear(9216, 1440)
        
    def double_conv(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.LeakyReLU(),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.LeakyReLU()
        )

    
    def forward(self, x):
        # Encoder
        enc0 = self.enc_conv0(x)
        enc0 = self.res0(enc0)
        enc1 = self.enc_conv1(self.pool(enc0))
        enc1 = self.res1(enc1)
        enc2 = self.enc_conv2(self.pool(enc1))
        enc2 = self.res2(enc2)
        enc3 = self.enc_conv3(self.pool(enc2))
        enc3 = self.res3(enc3)
        enc3 = torch.flatten(enc3, 1)
        enc3 = self.lin(enc3)
        enc3 = enc3.view(-1, 5, 144, 2)
        
        return enc3


In [None]:
device = torch.cuda.is_available()

# Initialize model, loss function, and optimizer
model = UNet()

criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.00000001)


In [None]:
x = torch.rand((1, 5, 36,144)).to(device).float()
model(x).shape

In [None]:
from tqdm import tqdm

In [None]:
import matplotlib.pyplot as plt

# Extract one batch from the train loader
inputs, targets = next(iter(train_loader))
targets = torch.real(targets)
# Select one example (e.g., the first one in the batch)
example_fopt = targets[0]  # Shape: [5, 144, 2]

# Select one channel (e.g., the first channel)
selected_channel = example_fopt[0]  # Shape: [144, 2]
selected_channel = selected_channel.reshape((12,24))
print(selected_channel.shape)
# Plot the selected channel
fig, axs = plt.subplots(1, 2, figsize=(6, 4))
axs[0].imshow(selected_channel, aspect='auto', origin='lower', cmap='turbo')
axs[0].set_title('Input')


axs[1].imshow(selected_channel, aspect='auto', origin='lower', cmap='turbo')
axs[1].set_title('Predicted')


# Plot clean STFT magnitude

plt.tight_layout()
plt.show()

In [None]:
# Training loop
num_epochs = 20
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for inputs, targets in tqdm(train_loader):
        inputs = inputs#.to(device).float()
        targets = targets#.to(device).float()
        optimizer.zero_grad()
        outputs = model(inputs*10)
        #print(outputs)
        loss = criterion(outputs*10, targets*10)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

    print(f'Epoch {epoch + 1}/{num_epochs}, Loss: {running_loss/len(train_loader)}')

    # Evaluate the model
    model.eval()
    test_loss = 0.0
    with torch.no_grad():
        for inputs, targets in tqdm(test_loader):
            inputs = inputs#.to(device).float()
            targets = targets#.to(device).float()
            outputs = model(inputs*10)
            loss = criterion(outputs, targets)
            test_loss += loss.item()
            
        if epoch%5 == 0:
            ii = targets[0][0].reshape((12, 24)).cpu().numpy()
            jj = outputs[0][0].reshape((12, 24)).cpu().numpy()

            fig, axs = plt.subplots(1, 2, figsize=(6, 4))
            axs[0].imshow(ii, aspect='auto', origin='lower', cmap='turbo')
            axs[0].set_title('Target')


            axs[1].imshow(jj, aspect='auto', origin='lower', cmap='turbo')
            axs[1].set_title('Predicted')

            plt.tight_layout()
            plt.show()
                
                
                
                

    print(f'Test Loss: {test_loss/len(test_loader)}')