In [3]:
import torch
import torchvision
from torchvision import transforms

import numpy as np 
import matplotlib.pyplot as plt
from PIL import Image


In [5]:
from torch.utils.data import Dataset
import numpy as np
import random
import cv2
import h5py

class TrainDataset(Dataset):
    def __init__(self, data_root, target_size=(512, 512), bgr2rgb=True):
        self.hypers = []
        self.bgrs = []
        hyper_data_path = f'{data_root}/Train_spectral/'
        bgr_data_path = f'{data_root}/Train_RGB/'
        with open(f'{data_root}/split_txt/Train_list.txt', 'r') as fin:
            hyper_list = [line.replace('\n', '.mat') for line in fin]
            bgr_list = [line.replace('mat','jpg') for line in hyper_list]
        hyper_list.sort()
        bgr_list.sort()
        print(f'len(hyper_Train) of ntire2022 dataset:{len(hyper_list)}')
        print(f'len(bgr_Train) of ntire2022 dataset:{len(bgr_list)}')
        for i in range(len(hyper_list)):
            hyper_path = hyper_data_path + hyper_list[i]
            if 'mat' not in hyper_path:
                continue
            with h5py.File(hyper_path, 'r') as mat:
                hyper = np.float32(np.array(mat['cube']))
            hyper = np.transpose(hyper, [0, 2, 1])
            bgr_path = bgr_data_path + bgr_list[i]
            assert hyper_list[i].split('.')[0] == bgr_list[i].split('.')[0], 'Hyper and RGB come from different scenes.'
            bgr = cv2.imread(bgr_path)
            if bgr2rgb:
                bgr = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
            bgr = np.float32(bgr)
            bgr = (bgr - bgr.min()) / (bgr.max() - bgr.min())
            bgr = np.transpose(bgr, [2, 0, 1])
            self.hypers.append(hyper)
            self.bgrs.append(bgr)
            mat.close()
            print(f'Ntire2022 scene {i} is loaded.')

    def __getitem__(self, idx):
        hyper = self.hypers[idx]
        bgr = self.bgrs[idx]
        target_size = (512,512)
        # Resize both hyper and bgr images to the target size
        hyper = cv2.resize(hyper.transpose(1, 2, 0), target_size).transpose(2, 0, 1)
        bgr = cv2.resize(bgr.transpose(1, 2, 0), target_size).transpose(2, 0, 1)
        
        return np.ascontiguousarray(bgr), np.ascontiguousarray(hyper)

    def __len__(self):
        return len(self.hypers)

class ValidDataset(Dataset):
    def __init__(self, data_root, target_size=(512, 512), bgr2rgb=True):
        self.hypers = []
        self.bgrs = []
        hyper_data_path = f'{data_root}/Valid_spectral/'
        bgr_data_path = f'{data_root}/Valid_RGB/'
        with open(f'{data_root}/split_txt/valid_list.txt', 'r') as fin:
            hyper_list = [line.replace('\n', '.mat') for line in fin]
            bgr_list = [line.replace('mat','jpg') for line in hyper_list]
        hyper_list.sort()
        bgr_list.sort()
        print(f'len(hyper_valid) of ntire2022 dataset:{len(hyper_list)}')
        print(f'len(bgr_valid) of ntire2022 dataset:{len(bgr_list)}')
        for i in range(len(hyper_list)):
            hyper_path = hyper_data_path + hyper_list[i]
            if 'mat' not in hyper_path:
                continue
            with h5py.File(hyper_path, 'r') as mat:
                hyper = np.float32(np.array(mat['cube']))
            hyper = np.transpose(hyper, [0, 2, 1])
            bgr_path = bgr_data_path + bgr_list[i]
            assert hyper_list[i].split('.')[0] == bgr_list[i].split('.')[0], 'Hyper and RGB come from different scenes.'
            bgr = cv2.imread(bgr_path)
            if bgr2rgb:
                bgr = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
            bgr = np.float32(bgr)
            bgr = (bgr - bgr.min()) / (bgr.max() - bgr.min())
            bgr = np.transpose(bgr, [2, 0, 1])
            self.hypers.append(hyper)
            self.bgrs.append(bgr)
            mat.close()
            print(f'Ntire2022 scene {i} is loaded.')

    def __getitem__(self, idx):
        hyper = self.hypers[idx]
        bgr = self.bgrs[idx]
#         print(hyper.shape)
        
        target_size = (512,512)
        # Resize both hyper and bgr images to the target size
        hyper = cv2.resize(hyper.transpose(1, 2, 0), target_size).transpose(2, 0, 1)
        bgr = cv2.resize(bgr.transpose(1, 2, 0), target_size).transpose(2, 0, 1)
#         print(hyper.shape)
        return np.ascontiguousarray(bgr), np.ascontiguousarray(hyper)

    def __len__(self):
        return len(self.hypers)

In [6]:
from torch.utils.data import DataLoader

# Define the dataset path
data_root = "C:/Users/user/Desktop/Research/Datasets/Manual_split_NTIRE_2022"

# Create instances of TrainDataset and ValidDataset
train_dataset = TrainDataset(data_root=data_root, bgr2rgb=True)
valid_dataset = ValidDataset(data_root=data_root, bgr2rgb=True)

# Create DataLoader instances for both datasets
train_loader = DataLoader(dataset=train_dataset, batch_size=1, shuffle=True)
valid_loader = DataLoader(dataset=valid_dataset, batch_size=1, shuffle=False)

# Load one batch of data from each DataLoader
train_images, train_labels = next(iter(train_loader))
valid_images, valid_labels = next(iter(valid_loader))

# Print the shapes of the loaded data
print("Train Images Shape:", train_images.shape)
print("Train Labels Shape:", train_labels.shape)
print("Valid Images Shape:", valid_images.shape)
print("Valid Labels Shape:", valid_labels.shape)


FileNotFoundError: [Errno 2] No such file or directory: 'C:/Users/user/Desktop/Research/Datasets/Manual_split_NTIRE_2022/split_txt/Train_list.txt'

In [None]:
x, y = next(iter(train_loader))
x.shape, y.shape

In [None]:
x, y = next(iter(valid_loader))
x.shape, y.shape

## Visualize the data

In [None]:
import matplotlib.pyplot as plt

sample = train_dataset[0]
x, y = sample[0], sample[1]

# Create subplots for the 3-channel input (RGB)
plt.figure(figsize=(12, 4))
for i in range(x.shape[0]):  # Loop through channels
    plt.subplot(1, 3, i + 1)  # Adjust for the number of input channels (3)
    channel_data = x[i]
    plt.imshow(channel_data, cmap='gray')  # Display each channel in grayscale
    plt.title(f'Input Channel {i + 1}')
    plt.axis('off')

# Create subplots for the 31-channel output (Hyperspectral)
plt.figure(figsize=(18, 18))  # Adjust size as needed
for i in range(y.shape[0]):  # Loop through channels
    plt.subplot(6, 6, i + 1)  # Adjust subplot grid for 31 channels
    channel_data = y[i]
    plt.imshow(channel_data, cmap='gray')  # Display each channel in grayscale
    plt.title(f'Output Channel {i + 1}')
    plt.axis('off')

plt.tight_layout()
plt.show()


## custom SAM LOSS

In [None]:
import torch
import torch.nn as nn

class SAMLoss(nn.Module):
    def __init__(self):
        super(SAMLoss, self).__init__()

    def forward(self, predicted_hsi, ground_truth_hsi):
        # Ensure the input tensors have the same shape
        assert predicted_hsi.shape == ground_truth_hsi.shape, "Input tensor shapes must match."

        # Reshape tensors if necessary
        if predicted_hsi.dim() == 4:  # If the tensors are 4D, flatten them
            predicted_hsi = predicted_hsi.view(predicted_hsi.size(0), -1)
            ground_truth_hsi = ground_truth_hsi.view(ground_truth_hsi.size(0), -1)

        # Calculate the dot product and magnitudes
        dot_product = torch.sum(predicted_hsi * ground_truth_hsi, dim=1)
        magnitude_pred = torch.norm(predicted_hsi, dim=1)
        magnitude_gt = torch.norm(ground_truth_hsi, dim=1)

        # Calculate the cosine of the spectral angle
        cosine_theta = dot_product / (magnitude_pred * magnitude_gt)

        # Calculate the spectral angle in radians
        spectral_angle = torch.acos(cosine_theta)

        # Return the mean spectral angle as the loss
        return torch.mean(spectral_angle)
    
class Loss_MRAE(nn.Module):
    def __init__(self):
        super(Loss_MRAE, self).__init__()

    def forward(self, outputs, label):
        assert outputs.shape == label.shape
        error = torch.abs(outputs - label) / label
        mrae = torch.mean(error.view(-1))
        return mrae

class Loss_RMSE(nn.Module):
    def __init__(self):
        super(Loss_RMSE, self).__init__()

    def forward(self, outputs, label):
        assert outputs.shape == label.shape
        error = outputs-label
        sqrt_error = torch.pow(error,2)
        rmse = torch.sqrt(torch.mean(sqrt_error.view(-1)))
        return rmse

class Loss_PSNR(nn.Module):
    def __init__(self):
        super(Loss_PSNR, self).__init__()

    def forward(self, im_true, im_fake, data_range=255):
        N = im_true.size()[0]
        C = im_true.size()[1]
        H = im_true.size()[2]
        W = im_true.size()[3]
        Itrue = im_true.clamp(0., 1.).mul_(data_range).resize_(N, C * H * W)
        Ifake = im_fake.clamp(0., 1.).mul_(data_range).resize_(N, C * H * W)
        mse = nn.MSELoss(reduce=False)
        err = mse(Itrue, Ifake).sum(dim=1, keepdim=True).div_(C * H * W)
        psnr = 10. * torch.log((data_range ** 2) / err) / np.log(10.)
        return torch.mean(psnr)


## combined loss L1loss and SAM

In [None]:
import torch.nn as nn

class CombinedLoss(nn.Module):
    def __init__(self, alpha=0.5, beta=0.5):
        super(CombinedLoss, self).__init__()
        self.alpha = alpha
        self.beta = beta
        self.l1_loss = nn.L1Loss()
        self.sam_loss = SAMLoss()

    def forward(self, outputs, labels):
        loss_l1 = self.l1_loss(outputs, labels)
        loss_sam = self.sam_loss(outputs, labels)
        combined_loss = self.alpha * loss_l1 + self.beta * loss_sam
        return combined_loss
    
class SRLoss(nn.Module):
    def __init__(self, alpha=0.5, beta=0.5):
        super(SRLoss, self).__init__()
        self.alpha = alpha  # Weight for SAM loss
        self.beta = beta    # Weight for RMSE loss

        # Define the SAM and RMSE loss functions
        self.sam_loss = SAMLoss()
        self.rmse_loss = Loss_RMSE()

    def forward(self, outputs, labels):
        # Calculate the SAM loss
        sam_loss = self.sam_loss(outputs, labels)

        # Calculate the RMSE loss
        rmse_loss = self.rmse_loss(outputs, labels)

        # Combine the SAM and RMSE losses using weights (alpha and beta)
        combined_loss = self.alpha * sam_loss + self.beta * rmse_loss

        return combined_loss



## HyperAttention model with spacial, spectral attention and fft block

In [None]:
import torch.nn.functional as F
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
from torch.cuda.amp import GradScaler, autocast  

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class SpectralAttention(nn.Module):
    def __init__(self, in_channels):
        super(SpectralAttention, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, 1, kernel_size=1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_pool = F.avg_pool2d(x, (x.size(2), x.size(3)))
        channel_weights = self.sigmoid(self.conv1(avg_pool))
        x_att = x * channel_weights
        return x_att


class SpatialAttention(nn.Module):
    def __init__(self, in_channels):
        super(SpatialAttention, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, 1, kernel_size=1)
        self.sigmoid = nn.Sigmoid()
        self.batch_norm = nn.BatchNorm2d(in_channels)  # Batch Normalization

    def forward(self, x):
        x = self.batch_norm(x)  # Apply Batch Normalization
        spatial_weights = self.sigmoid(self.conv1(x))
        x_att = x * spatial_weights
        return x_att
    
class OutputNormalization(nn.Module):
    def forward(self, x):
        min_val = torch.min(x)
        range_val = torch.max(x) - min_val
        return (x - min_val) / (range_val + 1e-8)

class UNet(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UNet, self).__init__()

        # Encoder
        self.enc1 = nn.Sequential(
            nn.Conv2d(in_channels, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        
        self.spacial_enc1 = SpatialAttention(64)
        self.spectral_enc1 = SpectralAttention(64)
        
        self.enc2 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        
        self.spacial_enc2 = SpatialAttention(128)
        self.spectral_enc2 = SpectralAttention(128)
        
        self.enc3 = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        
        self.spacial_enc3 = SpatialAttention(256)
        self.spectral_enc3 = SpectralAttention(256)

        # Bottleneck
        self.bottleneck = nn.Sequential(
            nn.Conv2d(256, 512, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
        )
        
        self.spacial_bot = SpatialAttention(256)
        self.spectral_bot = SpectralAttention(256)

        # Decoder
        self.dec1 = nn.Sequential(
            nn.Conv2d(512, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(128, 128, kernel_size=2, stride=2)  # Upsampling
        )
        
        self.spacial_dec1 = SpatialAttention(128)
        self.spectral_dec1 = SpectralAttention(128)
        
        self.dec2 = nn.Sequential(
            nn.Conv2d(256, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            SpectralAttention(128),  # Spectral attention mechanism
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(64, 64, kernel_size=2, stride=2)  # Upsampling
        )
        
        self.spacial_dec2 = SpatialAttention(64)
        self.spectral_dec2 = SpectralAttention(64)
        
        self.output_norm = OutputNormalization()
        
        self.dec3 = nn.Sequential(
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(64, out_channels, kernel_size=2, stride=2),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        # Encoder
        enc1 = self.enc1(x)
#         print(enc1.shape)
        enc2 = self.enc2(enc1)
#         print(enc2.shape)
        enc3 = self.enc3(enc2)
#         print(enc3.shape)

        # Bottleneck
        bottleneck = self.bottleneck(enc3)
#         print(bottleneck.shape)

        # Decoder
        dec1 = self.dec1(torch.cat([enc3, bottleneck], dim=1))
#         print(dec1.shape)
        dec2 = self.dec2(torch.cat([enc2, dec1], dim=1))
#         print(dec2.shape)
        enc1 = self.spacial_enc1(enc1)
    
        dec3 = self.dec3(torch.cat([enc1, dec2], dim=1))
#         print(dec3.shape)
        dec3 = self.output_norm(dec3)
    
        return dec3
    
    def adjust_channels(self, x, target_channels):
        if x.size(1) == target_channels:
            return x
        else:
            # Move the convolutional layer to the same device as the input tensor
            conv_layer = nn.Conv2d(x.size(1), target_channels, kernel_size=1).to(x.device)
            return conv_layer(x)

# Hyperparameters
learning_rate = 0.001

# Initialize the U-Net model with skip connections, loss function, and optimizer
model= UNet(in_channels=3, out_channels=31).to(device).float()
model = model.float()
# criterion = nn.L1Loss()  # You can use a suitable loss function for your task
criterion =SAMLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)



## training

In [None]:
import torch
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR
from tqdm import tqdm
import matplotlib.pyplot as plt
from IPython.display import clear_output

# Model, criterion, and other components initialization code remains the same

# Lists to store losses for plotting
train_losses = []
val_losses = []
verbose_logs = []

# Set hyperparameters
batch_size = 1
learning_rate = 0.001  # Define the learning rate before the optimizer
num_epochs = 2000  # Number of epochs
checkpoint_interval = 200

# Initialize the model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = HyperAttention(in_channels=3, out_channels=31).to(device).float()

# Initialize optimizer with beta parameters
beta1 = 0.9
beta2 = 0.999
optimizer = optim.Adam(model.parameters(), lr=learning_rate, betas=(beta1, beta2))

# Define and initialize the cosine annealing scheduler
T_max = 60  # The number of epochs for cosine annealing
scheduler = CosineAnnealingLR(optimizer, T_max)
# criterion = SAMLoss()
criterion = SRLoss()

# Training and validation functions remain the same

best_val_loss = float('inf')
best_model_epoch = 0

def train_one_epoch(model, train_loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    for inputs, labels in tqdm(train_loader):
        inputs, labels = inputs.to(device), labels.to(device)

        # Forward pass
        outputs = model(inputs)
        loss = criterion(outputs, labels)

        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
    return running_loss / len(train_loader)

# Function to validate the model
from piq import ssim, psnr

def validate_model(model, valid_loader, criterion, device):
    model.eval()
    running_loss = 0.0
    running_ssim = 0.0
    running_psnr = 0.0

    with torch.no_grad():
        for inputs, labels in tqdm(valid_loader):
            inputs, labels = inputs.to(device), labels.to(device)

            outputs = model(inputs)
            loss = criterion(outputs, labels)
            running_loss += loss.item()

            # Normalize outputs and labels to [0, 1] for SSIM and PSNR calculations
            outputs = torch.clamp(outputs, 0, 1)
            labels = torch.clamp(labels, 0, 1)

            # Calculate SSIM and PSNR for each batch
            batch_ssim = ssim(outputs, labels, data_range=1.0).item()
            batch_psnr = psnr(outputs, labels, data_range=1.0).item()
            running_ssim += batch_ssim
            running_psnr += batch_psnr

        avg_loss = running_loss / len(valid_loader)
        avg_ssim = running_ssim / len(valid_loader)
        avg_psnr = running_psnr / len(valid_loader)
        
    return avg_loss, avg_ssim, avg_psnr

# Main training loop with plotting
for epoch in range(num_epochs):
    train_loss = train_one_epoch(model, train_loader, criterion, optimizer, device)
    valid_loss, valid_ssim, valid_psnr = validate_model(model, valid_loader, criterion, device)

    train_losses.append(train_loss)
    val_losses.append(valid_loss)

    # Displaying the metrics
    epoch_log = f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_loss:.4f}, Valid Loss: {valid_loss:.4f}, Valid SSIM: {valid_ssim:.4f}, Valid PSNR: {valid_psnr:.4f}'

    # Step the scheduler
    scheduler.step()

    # Check and save the best model
    if valid_loss < best_val_loss:
        best_val_loss = valid_loss
        best_model_path = f'../Models/best_FocusNet_norm.pth'
        torch.save(model.state_dict(), best_model_path)
        epoch_log += f", Best model saved with validation loss: {best_val_loss:.4f} at {best_model_path}"
        print(f"Best model saved with validation loss: {best_val_loss:.4f}.......")
        best_model_epoch = epoch  # Update the best model epoch


    # Save checkpoint at specific intervals
    if (epoch + 1) % checkpoint_interval == 0:
        checkpoint_filename = f'../Models/checkpoints/FocusNet_norm_epoch_{epoch+1}.pth'
        torch.save(model.state_dict(), checkpoint_filename)
        print(f"Checkpoint saved at epoch {epoch+1}")
        
    # Update verbose logs
    verbose_logs.append(epoch_log)

    # Live plot
    clear_output(wait=True)
    plt.figure(figsize=(10, 5))
    plt.plot(train_losses, label='Train Loss')
    plt.plot(val_losses, label='Validation Loss')

    # Add a marker for the best model
    if best_model_epoch > 0:
        plt.scatter(best_model_epoch, val_losses[best_model_epoch], color='red', label='Best Model')

    plt.title(f'Epoch [{epoch+1}/{num_epochs}]')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.show()
    # Print all verbose logs
    for log in verbose_logs:
        print(log)

# Save the final model
final_model_filename = '../Models/final_FocusNet_norm.pth'
torch.save(model.state_dict(), final_model_filename)
print("Final model saved & training complete!")


In [None]:
# # Save the trained model
# torch.save(model.state_dict(), '../../Models/model.pth')

In [None]:
break

## Prediction of cubes and evaluation

## load the trained model

In [None]:
# model = HyperAttention(in_channels=3, out_channels=31).to(device).float()  # Adjust out_channels to match the target label size
# model.load_state_dict(torch.load('../Models/best_HyperAttention_fft.pth', map_location=device))
model.eval() 

## predict on a single file with the model

In [None]:
# image_path = "C:/Users/user/Desktop/HSI/dataset/ICASSP2024-SPGC/Hyper-Skin(RGB,VIS)/test/RGB_CIE/p012_neutral_front.jpg"
image_path = 'C:/Users/user/Desktop/Research/NTIRE_2022/NTIRE_2022/Test_RGB/ARAD_1K_0951.jpg'
image = Image.open(image_path)
plt.imshow(image)

In [None]:
image.size

In [None]:
# Define transformations to convert and preprocess the image
preprocess = transforms.Compose([
    transforms.Resize((512, 512)),  # Resize the image to match your model's input size
    transforms.ToTensor(),  # Convert the image to a PyTorch tensor
])

# Apply the transformations to the image
input_data = preprocess(image)

# Add an extra dimension to match the batch size (1 image in this case)
input_data = input_data.unsqueeze(0)

with torch.no_grad():
    input_data = input_data.to(device)
    output = model(input_data)
    
print('input data max',input_data.max())
print('input data min',input_data.min())
print(output.shape)
print('output max',output.max())
print('output min',output.min())

In [None]:
output

## save .mat file

In [None]:
import h5py
import matplotlib.pyplot as plt

# Assuming 'output' is the result from your model and has shape [1, 31, 482, 512]
# Remove the batch dimension if it exists
if len(output.shape) > 3:
    output = output.squeeze(0)

num_channels = output.shape[0]
num_rows = 8  # You can adjust this based on your preference
num_cols = (num_channels + num_rows - 1) // num_rows  # Calculate the number of columns

plt.figure(figsize=(16, 12))  # Adjust the figure size as needed

for i in range(num_channels):
    plt.subplot(num_rows, num_cols, i + 1)
    channel_data = output[i].cpu().numpy()  # Convert the channel data to a NumPy array
    plt.imshow(channel_data, cmap='gray')  # Display the image in grayscale
    plt.title(f'Channel {i + 1}')
    plt.axis('off')

plt.tight_layout()
plt.show()

# Define the path to the .mat file where you want to save the output as a cube
output_mat_file = "rgb_2.mat"  # Provide the desired file path

# Assuming 'output' is a PyTorch tensor
output_np = output.cpu().numpy()  # Convert the CUDA tensor to a NumPy array

# Create the .mat file using h5py and save the data without rotation
with h5py.File(output_mat_file, 'w') as mat_file:
    # Save the data without rotation
    mat_file.create_dataset('cube', data=output_np[0])  # Assuming you want to save the first element


## load the output predicted .mat file

In [None]:
import h5py

# Define the path to the .mat file
mat_file_path = "rgb_2.mat"

# Open the .mat file using h5py
mat_contents = h5py.File(mat_file_path, 'r')

# Access the data by the key "cube"
cube_data = mat_contents["cube"]

# Convert the data to a NumPy array
output = cube_data[()]


## visualize the predictions

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Check if the output has a batch dimension and remove it if it exists
if len(output.shape) > 3:
    output = output[0]

num_channels = output.shape[0]
num_rows = 8  # You can adjust this based on your preference
num_cols = (num_channels - 1) // num_rows + 1  # Calculate the number of columns

plt.figure(figsize=(16, 12))  # You can adjust the figure size as needed

for i in range(num_channels):
    plt.subplot(num_rows, num_cols, i + 1)
    channel_data = output[i]  # Use the NumPy array directly
    plt.imshow(channel_data, cmap='gray')
    plt.title(f'Channel {i + 1}')
    plt.axis('off')

plt.tight_layout()
plt.show()


## bigger visulaization

In [None]:
import h5py
import matplotlib.pyplot as plt
import numpy as np

# Define the path to the .mat file
mat_file_path = "rgb_2.mat"

# Open the .mat file using h5py
mat_contents = h5py.File(mat_file_path, 'r')

# Access the data by the key "cube"
cube_data = mat_contents["cube"]

# Convert the data to a NumPy array
output = cube_data[()]

# Check if the output has a batch dimension and remove it if it exists
if len(output.shape) > 3:
    output = output[0]

num_channels = output.shape[0]

# Set the figure size to display images in a single column
plt.figure(figsize=(6, 6 * num_channels))  # Adjust the height based on the number of channels

for i in range(num_channels):
    plt.subplot(num_channels, 1, i + 1)  # Display images in a single column
    channel_data = output[i]  # Convert the tensor to a NumPy array
#     channel_data = np.rot90(channel_data, k=3)  # Rotate counterclockwise by 90 degrees
    plt.imshow(channel_data, cmap='gray', origin='upper')  # Specify origin='upper' to prevent rotation
    plt.title(f'Channel {i + 1}')
    plt.axis('off')

plt.tight_layout()
plt.show()


## predictions on all test images

In [None]:
# import os
# from PIL import Image
# import torch
# import torchvision.transforms as transforms
# import h5py
# import numpy as np

# model = HyperAttention(in_channels=4, out_channels=61).to(device).float()  # Adjust out_channels to match the target label size
# model.load_state_dict(torch.load('../Models/Aug_HyperAttention_fft_470e.pth'))
# model.eval() 

# input_folder = "../Dataset/ICASSP2024-SPGC/Hyper-Skin(MSI,NIR)/test/MSI_CIE" # Replace this with the path the input test folder
# output_folder = "./output/Aug_HyperAttention_fft_470e" # Replace with the path to your output folder

# # Make sure the output folder exists
# os.makedirs(output_folder, exist_ok=True)

# # Define the preprocessing transformation
# preprocess = transforms.Compose([
#     transforms.Resize((1024, 1024)),
#     transforms.ToTensor(),
# ])

# # Loop through all the image files in the input folder
# for filename in os.listdir(input_folder):
#     if filename.endswith(".mat"):  # Assuming you have .mat files for 4-channel input
#         # Load the .mat file
#         mat_file_path = os.path.join(input_folder, filename)

#         # Load the .mat file using h5py
#         mat_file = h5py.File(mat_file_path, 'r')

#         # Access and load the 'cube' dataset
#         input_data = mat_file['cube']  # Use the correct dataset name

#         # Convert the data to a NumPy array
#         input_data = np.array(input_data)

#         # Convert the NumPy array to a PyTorch tensor
#         input_data = torch.from_numpy(input_data)

#         # Add an extra dimension to match the batch size (1 image in this case)
#         input_data = input_data.unsqueeze(0)

#         # Ensure that the input data and model have the same data type
#         input_data = input_data.to(device).float()

#         with torch.no_grad():
#             # Get model predictions
#             output = model(input_data)

#         # Create the output .mat file
#         output_mat_file = os.path.join(output_folder, os.path.splitext(filename)[0] + ".mat")

#         # Convert the output to a NumPy array
#         output_np = output.cpu().numpy()  # Move the output to the CPU and then convert to NumPy

#         # Save the data as a .mat file
#         with h5py.File(output_mat_file, 'w') as mat_file:
#             mat_file.create_dataset('cube', data=output_np[0])

#         print(f"Processed {filename} and saved as {os.path.basename(output_mat_file)}")


## val sam calculation

In [None]:
import torch
import math
from piq import ssim

criterion = SAMLoss()
# Define the RMSE Loss Function
def rmse_loss(outputs, targets):
    mse = torch.mean((outputs - targets) ** 2)
    rmse = torch.sqrt(mse)
    return rmse

# Define the PSNR Loss Function
def psnr_loss(outputs, targets, max_pixel=1.0):  # Assuming normalized data [0, 1]
    mse = torch.mean((outputs - targets) ** 2)
    if mse == 0:
        return float('inf')
    psnr = 20 * math.log10(max_pixel) - 10 * math.log10(mse)
    return psnr

# Assuming the model is already in evaluation mode and loaded
model.eval()

# Initialize variables to store total losses
total_sam_loss = 0.0
total_ssim = 0.0
total_rmse = 0.0
total_psnr = 0.0

with torch.no_grad():
    for batch_idx, (x_val, y_val) in enumerate(valid_loader):
        x_val, y_val = x_val.to(device).float(), y_val.to(device).float()

        # Forward pass
        outputs_val = model(x_val)

        # Ensure output size matches target size
        outputs_val = torch.nn.functional.interpolate(outputs_val, size=y_val.shape[2:], mode='bilinear', align_corners=False)

        # Min-max normalization for outputs and targets
        outputs_val_norm = (outputs_val - outputs_val.min()) / (outputs_val.max() - outputs_val.min() + 1e-8)
        y_val_norm = (y_val - y_val.min()) / (y_val.max() - y_val.min() + 1e-8)

        # Calculate losses
        sam_loss = criterion(outputs_val_norm, y_val_norm)
        ssim_val = ssim(outputs_val_norm, y_val_norm, data_range=1.0)
        rmse_val = rmse_loss(outputs_val_norm, y_val_norm)
        psnr_val = psnr_loss(outputs_val_norm, y_val_norm)

        # Accumulate the total losses
        total_sam_loss += sam_loss.item()
        total_ssim += ssim_val.item()
        total_rmse += rmse_val.item()
        total_psnr += psnr_val

# Calculate the average losses over all batches
avg_sam_loss = total_sam_loss / len(valid_loader)
avg_ssim = total_ssim / len(valid_loader)
avg_rmse = total_rmse / len(valid_loader)
avg_psnr = total_psnr / len(valid_loader)

# Print the average losses
print(f"Average SAM Loss: {avg_sam_loss:.4f}")
print(f"Average SSIM: {avg_ssim:.4f}")
print(f"Average RMSE: {avg_rmse:.4f}")
print(f"Average PSNR: {avg_psnr:.4f}")


In [None]:
# Assuming model is already loaded and in evaluation mode
model.eval()
criterion = SAMLoss()

# Define a variable to store the total loss
total_val_loss = 0.0

# Iterate over the validation loader
with torch.no_grad():
    for batch_idx, (x_val, y_val) in enumerate(valid_loader):
        x_val, y_val = x_val.to(device).float(), y_val.to(device).float()

        # Forward pass
        outputs_val = model(x_val)
        
        # You may need to apply the same interpolation to outputs_val as in the training loop
        outputs_val = torch.nn.functional.interpolate(outputs_val, size=(482, 512), mode='bilinear', align_corners=False)

        # Calculate the loss
        val_loss = criterion(outputs_val, y_val)

        # Accumulate the total loss
        total_val_loss += val_loss.item()

# Calculate the average validation loss
avg_val_loss = total_val_loss / len(valid_loader)

# Print or use the average validation loss as needed
print(f"Average Validation Loss: {avg_val_loss:.4f}")


## train sam loss

In [None]:
# Assuming model is already loaded and in evaluation mode
model.eval()
criterion = SAMLoss()
# Define a variable to store the total loss
total_train_loss = 0.0

# Iterate over the training loader
with torch.no_grad():
    for batch_idx, (x_train, y_train) in enumerate(train_loader):  # Use train_loader instead of val_loader
        x_train, y_train = x_train.to(device).float(), y_train.to(device).float()

        # Forward pass
        outputs_train = model(x_train)
        
        # You may need to apply the same interpolation to outputs_train as in the training loop
        outputs_train = torch.nn.functional.interpolate(outputs_train,size=(482, 512), mode='bilinear', align_corners=False)

        # Calculate the loss
        train_loss = criterion(outputs_train, y_train)

        # Accumulate the total loss
        total_train_loss += train_loss.item()

# Calculate the average training loss
avg_train_loss = total_train_loss / len(train_loader)

# Print or use the average training loss as needed
print(f"Average Training Loss: {avg_train_loss:.4f}")


In [None]:
# Assuming the model is in evaluation mode and loaded
model.eval()

# Initialize variables to store total losses for training data
total_sam_loss_train = 0.0
total_ssim_train = 0.0
total_rmse_train = 0.0
total_psnr_train = 0.0

with torch.no_grad():
    for batch_idx, (x_train, y_train) in enumerate(train_loader):
        x_train, y_train = x_train.to(device).float(), y_train.to(device).float()

        # Forward pass
        outputs_train = model(x_train)

        # Ensure output size matches target size
        outputs_train = torch.nn.functional.interpolate(outputs_train, size=y_train.shape[2:], mode='bilinear', align_corners=False)

        # Min-max normalization for outputs and targets
        outputs_train_norm = (outputs_train - outputs_train.min()) / (outputs_train.max() - outputs_train.min() + 1e-8)
        y_train_norm = (y_train - y_train.min()) / (y_train.max() - y_train.min() + 1e-8)

        # Calculate losses
        sam_loss_train = criterion(outputs_train_norm, y_train_norm)
        ssim_train = ssim(outputs_train_norm, y_train_norm, data_range=1.0)
        rmse_train = rmse_loss(outputs_train_norm, y_train_norm)
        psnr_train = psnr_loss(outputs_train_norm, y_train_norm)

        # Accumulate the total losses for training data
        total_sam_loss_train += sam_loss_train.item()
        total_ssim_train += ssim_train.item()
        total_rmse_train += rmse_train.item()
        total_psnr_train += psnr_train

# Calculate the average losses over all training batches
avg_sam_loss_train = total_sam_loss_train / len(train_loader)
avg_ssim_train = total_ssim_train / len(train_loader)
avg_rmse_train = total_rmse_train / len(train_loader)
avg_psnr_train = total_psnr_train / len(train_loader)

# Print the average losses for training data
print(f"Average SAM Loss on Training Data: {avg_sam_loss_train:.4f}")
print(f"Average SSIM on Training Data: {avg_ssim_train:.4f}")
print(f"Average RMSE on Training Data: {avg_rmse_train:.4f}")
print(f"Average PSNR on Training Data: {avg_psnr_train:.4f}")
