In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import sys
import os
import math
import time 

from einops import rearrange
import torch.optim as optim
from operator import truediv
# import torchvision
# from torchvision import datasets, transforms
from scipy import io
import torch.utils.data
import scipy.io as sio
# import mat73
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchsummary import summary
import copy
from sklearn.metrics import confusion_matrix, accuracy_score, classification_report, cohen_kappa_score
from sklearn.model_selection import train_test_split
import record


In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
datasetname = 'UHouston'


cuda:0


In [3]:
hsi_2013_data=sio.loadmat('./'+str(datasetname)+'/2013_IEEE_GRSS_DF_Contest_CASI.mat')['HSI_data']
print('hsi_2013_data shape:', hsi_2013_data.shape)

# Loader Lidar  data
# import mat73
lidar_2013_data = sio.loadmat('./'+str(datasetname)+'/2013_IEEE_GRSS_DF_Contest_LiDAR.mat')['LiDAR_data']

print('Lidar_2013_data shape:', lidar_2013_data.shape)

#Load ground truth labels
gt_2013_data=sio.loadmat('./'+str(datasetname)+'/GRSS2013.mat')['name']
print('gt_2013_data.shape:', gt_2013_data.shape)

hsi_2013_data shape: (349, 1905, 144)
Lidar_2013_data shape: (349, 1905, 1)
gt_2013_data.shape: (349, 1905)


In [4]:
class_info = [(1, "Healthy grass", 'training_sample', 198, 'test_sample', 1053,  'total', 1251),
    (2, "Stressed grass",'training_sample', 190, 'test_sample', 1064,  'total', 1254),
    (3, "Synthetic grass", 'training_sample', 192, 'test_sample', 505,  'total', 697),
    (4, "Trees", 'training_sample', 188, 'test_sample', 1058,  'total', 1244),
    (5, "Soil",'training_sample', 186, 'test_sample', 1056,  'total', 1242),
    (6, "Water", 'training_sample', 182, 'test_sample', 141,  'total', 325),
    (7, "Residential", 'training_sample', 196, 'test_sample', 1072,  'total', 1268),
    (8, "Commercial", 'training_sample', 191, 'test_sample', 1053,  'total', 1244),
    (9, "Road", 'training_sample', 193, 'test_sample', 1059,  'total', 1252),
    (10, "Highway", 'training_sample', 191, 'test_sample', 1036,  'total', 1227),
    (11, "Railway", 'training_sample', 181, 'test_sample', 1054,  'total', 1235),
    (12, "Parking lot 1", 'training_sample', 192, 'test_sample', 1041,  'total', 1233),
    (13, "Parking lot 2", 'training_sample', 184, 'test_sample',285,  'total', 469),
    (14, "Tennis court",'training_sample', 181, 'test_sample', 247,  'total', 428),
    (15, "Running track", 'training_sample', 187, 'test_sample', 473,  'total', 660)]

# Create a dictionary to store class number, class name, and class samples
class_dict = {class_number: {"class_name": class_name,
                             'training_sample': training_sample,
                             'test_sample': test_sample,
                             "total_samples": total}
              for class_number, class_name, _, training_sample, _, test_sample, _, total in class_info}

print(class_dict)

{1: {'class_name': 'Healthy grass', 'training_sample': 198, 'test_sample': 1053, 'total_samples': 1251}, 2: {'class_name': 'Stressed grass', 'training_sample': 190, 'test_sample': 1064, 'total_samples': 1254}, 3: {'class_name': 'Synthetic grass', 'training_sample': 192, 'test_sample': 505, 'total_samples': 697}, 4: {'class_name': 'Trees', 'training_sample': 188, 'test_sample': 1058, 'total_samples': 1244}, 5: {'class_name': 'Soil', 'training_sample': 186, 'test_sample': 1056, 'total_samples': 1242}, 6: {'class_name': 'Water', 'training_sample': 182, 'test_sample': 141, 'total_samples': 325}, 7: {'class_name': 'Residential', 'training_sample': 196, 'test_sample': 1072, 'total_samples': 1268}, 8: {'class_name': 'Commercial', 'training_sample': 191, 'test_sample': 1053, 'total_samples': 1244}, 9: {'class_name': 'Road', 'training_sample': 193, 'test_sample': 1059, 'total_samples': 1252}, 10: {'class_name': 'Highway', 'training_sample': 191, 'test_sample': 1036, 'total_samples': 1227}, 11: 

In [5]:
# Define patch size and stride
patch_size = 9
stride = 1

# Create an empty list to store patches and labels
hsi_samples = []
lidar_samples = []
labels = []

# Initialize a dictionary to store class count
class_count = {i: 0 for i in class_dict.keys()}

# Function to check if all classes have the required number of samples
def all_classes_completed(class_count, class_dict):
    return all(class_count[class_num] == class_dict[class_num]["total_samples"] for class_num in class_dict.keys())

while not all_classes_completed(class_count, class_dict):
    # Loop through the ground truth data
    for label in class_dict.keys():
        # Get the coordinates of the ground truth pixels
        #coords = np.argwhere((gt_2013_data == label) & (mask > 0))
        coords = np.argwhere(gt_2013_data == label)

        # Shuffle the coordinates to randomize the patch extraction
        np.random.shuffle(coords)

        for coord in coords:
            i, j = coord
            # Calculate the patch indices
            i_start, i_end = i - patch_size // 2, i + patch_size // 2 + 1
            j_start, j_end = j - patch_size // 2, j + patch_size // 2 + 1

            # Check if the indices are within the bounds of the HSI data
            if i_start >= 0 and i_end <= hsi_2013_data.shape[0] and j_start >= 0 and j_end <= hsi_2013_data.shape[1]:
                # Extract the patch
                hsi_patch = hsi_2013_data[i_start:i_end, j_start:j_end, :]

                # Extract the LiDAR patch
                lidar_patch = lidar_2013_data[i_start:i_end, j_start:j_end, :]

                # If the class count is less than the required samples
                if class_count[label] < class_dict[label]["total_samples"]:
                    # Append the patch and its label to the list
                    hsi_samples.append(hsi_patch)
                    lidar_samples.append(lidar_patch)
                    labels.append(label)
                    class_count[label] += 1

                    # If all classes have the required number of samples, exit the loop
                    if all_classes_completed(class_count, class_dict):
                        break

# Convert the list of patches and labels into arrays
hsi_samples = np.array(hsi_samples)
lidar_samples = np.array(lidar_samples)
labels = np.array(labels)
print('hsi_samples shape:', hsi_samples.shape)
print('lidar_samples shape:', lidar_samples.shape)
print('labels shape:', labels.shape)

hsi_samples shape: (15029, 9, 9, 144)
lidar_samples shape: (15029, 9, 9, 1)
labels shape: (15029,)


In [6]:
#Avoid overlap of train and test
# Extracting training samples
hsi_training_samples, lidar_training_samples, training_labels = [], [], []
used_indices = []  # To keep track of indices already taken for training samples

for label, class_data in class_dict.items():
    # Get indices of the current class
    class_indices = np.where(labels == label)[0]

    # Randomly shuffle the indices
    np.random.shuffle(class_indices)

    # Take the required number of training samples
    train_indices = class_indices[:class_data["training_sample"]]
    used_indices.extend(train_indices)  # Add these to the used_indices list

    # Append training samples
    hsi_training_samples.extend(hsi_samples[train_indices])
    lidar_training_samples.extend(lidar_samples[train_indices])
    training_labels.extend(labels[train_indices])

# Extracting test samples
hsi_test_samples, lidar_test_samples, test_labels = [], [], []

for label, class_data in class_dict.items():
    class_indices = np.where(labels == label)[0]

    # Exclude indices which were used for training
    test_indices = np.setdiff1d(class_indices, used_indices)

    # Append test samples
    hsi_test_samples.extend(hsi_samples[test_indices])
    lidar_test_samples.extend(lidar_samples[test_indices])
    test_labels.extend(labels[test_indices])

# Convert lists back to numpy arrays
hsi_training_samples = np.array(hsi_training_samples)
lidar_training_samples = np.array(lidar_training_samples)
training_labels = np.array(training_labels)

hsi_test_samples = np.array(hsi_test_samples)
lidar_test_samples = np.array(lidar_test_samples)
test_labels = np.array(test_labels)

# Print shapes to verify
print('hsi_training_samples shape:', hsi_training_samples.shape)
print('lidar_training_samples shape:', lidar_training_samples.shape)
print('training_labels shape:', training_labels.shape)

print('hsi_test_samples shape:', hsi_test_samples.shape)
print('lidar_test_samples shape:', lidar_test_samples.shape)
print('test_labels shape:', test_labels.shape)

hsi_training_samples shape: (2832, 9, 9, 144)
lidar_training_samples shape: (2832, 9, 9, 1)
training_labels shape: (2832,)
hsi_test_samples shape: (12197, 9, 9, 144)
lidar_test_samples shape: (12197, 9, 9, 1)
test_labels shape: (12197,)


In [7]:
hsi_train=hsi_training_samples
lidar_train=lidar_training_samples
y_train=training_labels-1
print('hsi_train_samples shape:', hsi_train.shape)
print('lidar_train_samples shape:', lidar_train.shape)
print('train_labels shape:', y_train.shape)
hsi_test=hsi_test_samples
lidar_test=lidar_test_samples
y_test=test_labels-1
print('hsi_test_samples shape:', hsi_test.shape)
print('lidar_test_samples shape:', lidar_test.shape)
print('y_test shape:', y_test.shape)

hsi_train_samples shape: (2832, 9, 9, 144)
lidar_train_samples shape: (2832, 9, 9, 1)
train_labels shape: (2832,)
hsi_test_samples shape: (12197, 9, 9, 144)
lidar_test_samples shape: (12197, 9, 9, 1)
y_test shape: (12197,)


In [8]:
hsi_train = hsi_train.reshape(hsi_train.shape[0], patch_size*patch_size, hsi_train.shape[-1], 1)
lidar_train = lidar_train.reshape(lidar_train.shape[0], patch_size*patch_size, lidar_train.shape[-1], 1)

hsi_test = hsi_test.reshape(hsi_test.shape[0], patch_size*patch_size, hsi_test.shape[-1], 1)
lidar_test = lidar_test.reshape(lidar_test.shape[0], patch_size*patch_size, lidar_test.shape[-1], 1)

print(hsi_train.shape)
print(lidar_train.shape)

print(hsi_test.shape)
print(lidar_test.shape)

(2832, 81, 144, 1)
(2832, 81, 1, 1)
(12197, 81, 144, 1)
(12197, 81, 1, 1)


In [9]:
filename = 'UH_ViT_CA_EF_CLS_HSI_LiDAR'
NC = hsi_train.shape[2]
NCLiDAR = lidar_train.shape[2]
n_classes = len(np.unique(training_labels))
print(NC)
print(NCLiDAR)
print(n_classes)

144
1
15


In [10]:
def get_confusion_matrix(y_test, y_pred, name, c, plt_name):
    # print(plt_name)
    df_cm = pd.DataFrame(confusion_matrix(y_test, y_pred), range(c), range(c))
        
    if name == 'UTrento':
        df_cm.columns = ['Buildings', 'Woods', 'Roads', 'Apples', 'Ground', 'Vineyard']
        df_cm = df_cm.rename(index={0: 'Buildings', 1: 'Woods', 2: 'Roads', 3: 'Apples', 4: 'Ground', 5: 'Vineyard'})
    
    elif name == 'UHouston':
        df_cm.columns = ['Healthy grass', 'Stressed grass', 'Synthetic grass', 'Trees', 'Soil', 'Water',  'Residential', 'Commercial', 'Road', 'Highway',
                        'Railway', 'Parking Lot 1', 'Parking Lot 2', 'Tennis Court', 'Running Track']    
        df_cm = df_cm.rename(index={0:'Healthy grass', 1:'Stressed grass', 2:'Synthetic grass', 3:'Trees', 4:'Soil', 5:'Water',  6:'Residential', 7:'Commercial', 8:'Road', 9:'Highway',
                        10:'Railway', 11:'Parking Lot 1', 12:'Parking Lot 2', 13:'Tennis Court', 14:'Running Track'})
    
    df_cm.index.name = 'Actual'
    df_cm.columns.name = 'Predicted'

    plt.figure(figsize=(10, 8))
    sns.set(font_scale=1.2)  # for label size
    heatmap = sns.heatmap(df_cm, cmap="Blues", annot=True, annot_kws={"size": 14}, fmt='g', cbar=False)

    heatmap.set_xticklabels(heatmap.get_xticklabels(), rotation=45, ha='right', fontsize=12)
    heatmap.set_yticklabels(heatmap.get_yticklabels(), rotation=0, fontsize=12)

    plt.tight_layout()
    plt.savefig(f'{plt_name}.png', format='png')
    plt.close()

def AA_andEachClassAccuracy(confusion_matrix):
    counter = confusion_matrix.shape[0]
    list_diag = np.diag(confusion_matrix)
    list_raw_sum = np.sum(confusion_matrix, axis=1)
    each_acc = np.nan_to_num(truediv(list_diag, list_raw_sum))
    average_acc = np.mean(each_acc)
    return each_acc, average_acc
    

In [11]:
class PatchEmbedding(nn.Module):
    def __init__(self, in_channels, num_patches, dim, dropout):
        super(PatchEmbedding, self).__init__()

        self.in_channels = in_channels
        self.num_patches = num_patches
        self.dim = dim
        self.dropout = dropout

        self.proj = nn.Conv2d(
            in_channels=in_channels,
            out_channels=in_channels,  # Using num_patches instead of in_channels
            kernel_size=(1,1),
            stride=1,
        )

        self.linear_proj = nn.Linear(in_channels, dim)
        self.norm = nn.LayerNorm(dim)  # Adjusted to apply to each embedding vector

        self.pos_embedding = nn.Parameter(torch.zeros(1, num_patches + 1, dim))
        self.cls_token = nn.Parameter(torch.zeros(1, 1, dim))

        self.dropout_layer = nn.Dropout(dropout)  # Renamed to avoid conflict with attribute name
    
    def forward(self, x):
        B, C, H, W = x.shape
        
        x = self.proj(x)  # Apply convolution
        
        x = x.view(B, C, -1)  # Shape: [B, C, H*W]
        x = x.permute(0, 2, 1)  # Change to [B, spatial_dim, C]
        
        x = self.linear_proj(x)  # Apply linear projection correctly
        
        x = self.norm(x)  # Normalize each embedding vector

        x = torch.cat((self.cls_token.expand(B, -1, -1), x), dim=1)
        x = x + self.pos_embedding
        x = self.dropout_layer(x)
        
        return x

class HSISelfAttention(nn.Module):
    def __init__(self, dim, n_heads, dim_head, dropout):
        super(HSISelfAttention, self).__init__()

        self.dim = dim
        self.dim_head = dim_head
        self.num_heads = n_heads
        self.sqrt_dim_head = math.sqrt(self.dim_head)
        self.dropout_layer = nn.Dropout(dropout)  # Renamed to avoid conflict with attribute name

        # Layer normalization before Q, K, V projections
        self.norm = nn.LayerNorm(dim)

        self.to_q = nn.Linear(dim, dim_head * n_heads, bias=False)
        self.to_k = nn.Linear(dim, dim_head * n_heads, bias=False)
        self.to_v = nn.Linear(dim, dim_head * n_heads, bias=False)

        self.to_out = nn.Linear(n_heads * dim_head, dim)
        self.dropout_layer = nn.Dropout(dropout)  # This line seems duplicated, corrected below

    def forward(self, x1, x2):
        B, N1, _ = x1.size()
        B, N2, _ = x2.size()

        x2 = self.norm(x2)  # Apply normalization before attention

        Q = self.to_q(x1)
        K = self.to_k(x2)
        V = self.to_v(x2)
        
        Q = Q.view(B, N1, self.num_heads, self.dim_head).transpose(1, 2)  # [B, num_heads, 144, dim_head]
        K = K.view(B, N2, self.num_heads, self.dim_head).transpose(1, 2)  # [B, num_heads, 144, dim_head]
        V = V.view(B, N2, self.num_heads, self.dim_head).transpose(1, 2)  # [B, num_heads, 144, dim_head]
        
        # Calculate attention scores
        attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / self.sqrt_dim_head
        attn_weights = F.softmax(attn_scores, dim=-1)  # [B, num_heads, 144, 144]
        attn_weights = attn_weights.mean(dim=2, keepdim=True)        
        
        # Apply attention to V
        attn_output = torch.matmul(attn_weights, V)  # [B, num_heads, 144, dim_head]
        attn_output = attn_output.transpose(1, 2).reshape(B, -1, self.num_heads * self.dim_head)  # [B, 144, num_heads * dim_head]

        attn_output = self.dropout_layer(attn_output)
        
        # Project back to the original embedding space
        attn_output = self.to_out(attn_output)  # [B, 144, num_patches]
        
        return attn_output

# MLP Module in the Transformer Encoder 
class Mlp(nn.Module):
    def __init__(self, dim, mlp_dim, dropout):
        super(Mlp, self).__init__()
        self.dim = dim
        self.mlp_dim = mlp_dim
        
        self.fc1 = nn.Linear(dim, mlp_dim)
        self.fc2 = nn.Linear(mlp_dim, dim)
        self.act_fn = nn.GELU()
        self.dropout_layer = nn.Dropout(dropout)
    
    def forward(self, x):
        x = self.fc1(x)
        x = self.act_fn(x)
        x = self.dropout_layer(x)
        x = self.fc2(x)
        x = self.dropout_layer(x)
        
        return x

# Single Transformer Encoder Block that combines the Self-Attention and MLP layers     
class Block(nn.Module):
    def __init__(self, dim, n_heads, dim_head, mlp_dim, dropout):
        super(Block, self).__init__()
        self.hidden_size = dim
        self.hidden_dim_size = mlp_dim
        self.attention_norm = nn.LayerNorm(dim)
        self.ffn_norm = nn.LayerNorm(dim)
        self.ffn = Mlp(dim, mlp_dim, dropout)
        self.attn = HSISelfAttention(dim, n_heads, dim_head, dropout)

    def forward(self, x1, x2):
        
        h = x2
        x2 = self.attention_norm(x2)
        x2= self.attn(x1, x2)
        x2 = x2 + h

        h = x2
        x2 = self.ffn_norm(x2)
        x2 = self.ffn(x2)
        x2 = x2 + h
        
        return x2
        
# Transformer Encoder Block with Multi-head Self-Attention repetition    
class TransformerEncoder(nn.Module):
    def __init__(self, dim, n_heads, dim_head, mlp_dim, dropout, depth):
        super(TransformerEncoder, self).__init__()
        self.layer = nn.ModuleList()
        self.encoder_norm = nn.LayerNorm(dim)
        for _ in range(depth):
            layer = Block(dim, n_heads, dim_head, mlp_dim, dropout)
            self.layer.append(copy.deepcopy(layer))       

    def forward(self, x1, x2):
        for layer_block in self.layer:
            x = layer_block(x1, x2)
            
        encoded = self.encoder_norm(x)

        return encoded

# The Final ViT Implementation with cls token from other modalities
class ViT_CA(nn.Module):
    def __init__(self, in_channels=81, num_patches1=144, num_patches2=1, dim=128, n_heads=8, dim_head=64, mlp_dim=256,  depth=3, num_classes=15, dropout=0.4):
        super(ViT_CA, self).__init__()
        
        self.num_patches1 = num_patches1
        self.patch_embedding1 = PatchEmbedding(in_channels, num_patches1, dim, dropout)
        self.patch_embedding2 = PatchEmbedding(in_channels, num_patches2, dim, dropout)
        self.transformer = TransformerEncoder(dim, n_heads, dim_head, mlp_dim, dropout, depth)
        self.linear_head = nn.Linear(dim, num_classes)
        
    def forward(self, x1, x2):
        
        x1 = self.patch_embedding1(x1)
        
        x2 = self.patch_embedding2(x2)
        
        x1_1 = self.transformer(x2, x1)
        
        x2_1 = self.transformer(x1, x2)
        
        # x = x.mean(dim=1)

        logits_hsi = self.linear_head(x1_1[:,0])
        logits_lid = self.linear_head(x2_1[:,0])
       
        out = (logits_hsi + logits_lid) / 2
        
        return out


In [12]:
# Convert HSI and label data to PyTorch tensors
hsi_train_tensor = torch.from_numpy(hsi_train.astype(np.float32))  # .permute(0, 1, 2, 3) Shape: (batch, channels, height, width)
lidar_train_tensor = torch.from_numpy(lidar_train.astype(np.float32))  # .permute(0, 1, 2, 3) Shape: (batch, channels, height, width)
training_labels_tensor = torch.from_numpy(y_train.astype(np.int64))

hsi_test_tensor = torch.from_numpy(hsi_test.astype(np.float32))  # Shape: (batch, channels, height, width)
lidar_test_tensor = torch.from_numpy(lidar_test.astype(np.float32))  # Shape: (batch, channels, height, width)
test_labels_tensor = torch.from_numpy(y_test.astype(np.int64))

# Define a dataset class for HSI data
class HyperspectralDataset(Dataset):
    def __init__(self, hsi_data, lidar_data, labels):
        self.hsi_data = hsi_data
        self.lidar_data = lidar_data
        self.labels = labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return self.hsi_data[idx], self.lidar_data[idx], self.labels[idx]

# Create datasets
train_dataset = HyperspectralDataset(hsi_train_tensor, lidar_train_tensor, training_labels_tensor)
# val_dataset = HyperspectralDataset(hsi_samples_val, labels_val)
test_dataset = HyperspectralDataset(hsi_test_tensor, lidar_test_tensor, test_labels_tensor)

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
# val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=512, shuffle=False)

print('DataLoaders for training, validation, and testing are ready.')

DataLoaders for training, validation, and testing are ready.


In [13]:
# model.to(device)

num_epochs = 200  # Set the number of epochs
    
KAPPA = []
OA = []
AA = []                
ELEMENT_ACC = np.zeros((3, n_classes))

for iterNum in range(3):
    model = ViT_CA(in_channels=81, num_patches1=NC, num_patches2=NCLiDAR, dim=128, n_heads=8, dim_head=64, mlp_dim=256,  depth=3, num_classes=n_classes, dropout=0.1).cuda()
    summary(model, [(81, NC, 1), (81, NCLiDAR, 1)])
    model.to(device)
    
    # Define the optimizer
    optimizer = optim.Adam(model.parameters(), lr=5e-4)  # Adjust learning rate as needed

    # Define learning rate scheduler
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.5)  # Adjust parameters as needed

    # Instantiate the loss function
    criterion = nn.CrossEntropyLoss()
    
    # Training loop
    best_test_acc = float('-inf')
    best_model_wts = copy.deepcopy(model.state_dict())
    
    torch.cuda.synchronize()
    
    start = time.time()
    for epoch in range(num_epochs):
        model.train()  # Set model to training mode
        running_loss = 0.0

        # Training phase
        for batch_idx, (hsi_batch, lidar_batch, label_batch) in enumerate(train_loader):
            hsi_batch = hsi_batch.to(device)
            lidar_batch = lidar_batch.to(device)
            label_batch = label_batch.to(device)

            # Forward pass
            outputs = model(hsi_batch, lidar_batch)
        
            # Loss calculation
            loss = criterion(outputs, label_batch)
        
            # Backward and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        avg_train_loss = running_loss / len(train_loader)

        # Compute training accuracy
        model.eval()
        with torch.no_grad():
            correct_train = 0
            total_train = 0
            for hsi_batch, lidar_batch, label_batch in train_loader:
                hsi_batch = hsi_batch.to(device)
                lidar_batch = lidar_batch.to(device)
                label_batch = label_batch.to(device)
                outputs = model(hsi_batch, lidar_batch)
                _, predicted = torch.max(outputs, 1)
                total_train += label_batch.size(0)
                correct_train += (predicted == label_batch).sum().item()

        train_accuracy = correct_train / total_train * 100

        # Validation phase
        model.eval()  # Set model to evaluation mode
        test_loss = 0.0
        correct_test = 0
        total_test = 0
        with torch.no_grad():  # No need to track gradients during validation
            for hsi_batch, lidar_batch, label_batch in test_loader:
                hsi_batch = hsi_batch.to(device)
                lidar_batch = lidar_batch.to(device)
                label_batch = label_batch.to(device)

                outputs = model(hsi_batch, lidar_batch)
                loss = criterion(outputs, label_batch)
                test_loss += loss.item()            

                _, predicted = torch.max(outputs, 1)
                total_test += label_batch.size(0)
                correct_test += (predicted == label_batch).sum().item()

        avg_test_loss = test_loss / len(test_loader)
        test_accuracy = correct_test / total_test * 100

        print(f'Epoch [{epoch+1}/{num_epochs}], Training Loss: {avg_train_loss:.4f}, Training Accuracy: {train_accuracy:.2f}%, Test Loss: {avg_test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%')

        # Check if this is the best model so far
        if test_accuracy > best_test_acc:
            print('\n')
            print(f'Test accuracy increased ({best_test_acc:.4f} --> {test_accuracy:.4f}). Saving model ...')
            best_test_acc = test_accuracy
            best_model_wts = copy.deepcopy(model.state_dict())
            # Save model weights
            torch.save(model.state_dict(), 'best_model_'+str(filename)+'_Iter_'+str(iterNum)+'.pth')

    end = time.time()
    print("\n")
    print('\nThe train time (in seconds) is:', end - start) 
                            
    # Load the best model weights
    model.load_state_dict(best_model_wts)
    # Test the model
    model.eval()
    test_loss = 0.0
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for hsi_batch, lidar_batch, label_batch in test_loader:
            hsi_batch = hsi_batch.to(device)
            lidar_batch = lidar_batch.to(device)
            label_batch = label_batch.to(device)

            outputs = model(hsi_batch, lidar_batch)
            loss = criterion(outputs, label_batch)
            test_loss += loss.item()

            preds = torch.max(outputs, 1)[1].cpu().numpy()
            all_preds.extend(preds)
            all_labels.extend(label_batch.cpu().numpy())

    # Calculate average test loss
    avg_test_loss = test_loss / len(test_loader)
    print(f'Test Loss: {avg_test_loss:.4f}')

    # Convert lists to numpy arrays for evaluation metrics
    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)

    get_confusion_matrix(all_labels, all_preds, datasetname, n_classes, filename + str('_') + str(iterNum))
    
    # Evaluate metrics
    oa = accuracy_score(all_labels, all_preds)
    confusion = confusion_matrix(all_labels, all_preds)
    each_acc, aa = AA_andEachClassAccuracy(confusion)
    kappa = cohen_kappa_score(all_labels, all_preds)    
                    
    KAPPA.append(kappa*100)
    OA.append(oa*100)
    AA.append(aa*100)
    ELEMENT_ACC[iterNum, :] = each_acc*100
    # torch.save(model, datasetname+'/best_model_'+filename+'_Iter'+str(iterNum)+'.pt')
print("\n")
record.record_output(OA, AA, KAPPA, ELEMENT_ACC,'./' + datasetname+'/'+filename+'_Report_' + datasetname +'.txt')


----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 81, 144, 1]           6,642
            Linear-2             [-1, 144, 128]          10,496
         LayerNorm-3             [-1, 144, 128]             256
           Dropout-4             [-1, 145, 128]               0
    PatchEmbedding-5             [-1, 145, 128]               0
            Conv2d-6             [-1, 81, 1, 1]           6,642
            Linear-7               [-1, 1, 128]          10,496
         LayerNorm-8               [-1, 1, 128]             256
           Dropout-9               [-1, 2, 128]               0
   PatchEmbedding-10               [-1, 2, 128]               0
        LayerNorm-11             [-1, 145, 128]             256
        LayerNorm-12             [-1, 145, 128]             256
           Linear-13               [-1, 2, 512]          65,536
           Linear-14             [-1, 1