In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch

from sklearn.metrics import accuracy_score, precision_score, recall_score
from sklearn.model_selection import train_test_split

import torch
from torchvision import datasets
from torchvision import transforms
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader

import os
import cv2

In [None]:


class SpectralDataset(Dataset):
    def __init__(self, data):
        self.data = torch.tensor(data, dtype=torch.float32)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return len(self.data)
    
    
class Scale(torch.nn.Module):
    def forward(self, input):
        return input * 255
        
# Creating a PyTorch class
# 28*28 ==> 9 ==> 28*28
class spectral_AE(torch.nn.Module):
    def __init__(self):
        super().__init__()

        # Building an linear encoder with Linear
        # layer followed by Relu activation function
        # 140 ==> 3
        self.encoder = torch.nn.Sequential(
            torch.nn.Linear(hyper_2d.shape[1], 64),
            torch.nn.ReLU(),
            torch.nn.Linear(64, 32),
            torch.nn.ReLU(),
            torch.nn.Linear(32, 16),
            torch.nn.ReLU(),
            torch.nn.Linear(16, 8),
            torch.nn.ReLU(),
            torch.nn.Linear(8, 3),
            torch.nn.Hardtanh(min_val=0, max_val=255) # want a known range for visualization and bounding 
        )

        # Building an linear decoder with Linear
        # layer followed by Relu activation function
        # The Sigmoid activation function
        # outputs the value between 0 and 1
        # 3 ==> 140
        self.decoder = torch.nn.Sequential(
            torch.nn.Linear(3, 8),
            torch.nn.ReLU(),
            torch.nn.Linear(8, 16),
            torch.nn.ReLU(),
            torch.nn.Linear(16, 32),
            torch.nn.ReLU(),
            torch.nn.Linear(32, 64),
            torch.nn.ReLU(),
            torch.nn.Linear(64, hyper_2d.shape[1]),
            torch.nn.Sigmoid()
        )

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded    

def load_hsi(file_name):
    # load hyperspectral image
    _, extension = os.path.splitext(file_name)

    if extension == '.tiff':
        #below is a way to load hyperspectral images that are tiff files
        mylist = []
        loaded,mylist = cv2.imreadmulti(mats = mylist, filename = file_name, flags = cv2.IMREAD_ANYCOLOR )
        cube=np.array(mylist)
        cube = cube[:, :, :]
    else :
        print("Error: file type not supported")
        return
    return cube

In [None]:
hyperImage = load_hsi('../../HyperImages/img1.tiff')

print(hyperImage.shape)
hyper_result = np.transpose(hyperImage, (2, 1, 0))
print(hyper_result.shape)
#hyper_2d = hyper_result.reshape(hyper_result[0] * hyper_result[1], hyper_result[2])
new_shape_first_dim = hyper_result.shape[0] * hyper_result.shape[1]
hyper_2d = hyper_result.reshape((new_shape_first_dim, hyper_result.shape[2]))
print(hyper_2d.shape)

In [None]:
spectral_dataset = SpectralDataset(hyper_2d)
spectral_loader = DataLoader(spectral_dataset, batch_size=2048, shuffle=True)

spectral_train_data, spectral_test_data = train_test_split(
    spectral_dataset,  test_size=0.3, random_state=21
)

min_val = torch.min(torch.from_numpy(hyper_2d))
max_val = torch.max(torch.from_numpy(hyper_2d))

# spectral_train_data = ((torch.from_numpy(spectral_train_data) - min_val) / (max_val - min_val)).float()
# spectral_test_data = ((torch.from_numpy(spectral_test_data) - min_val) / (max_val - min_val)).float()

# Normalize each Tensor in the list
spectral_train_data = [(data - min_val) / (max_val - min_val) for data in spectral_train_data]
spectral_test_data = [(data - min_val) / (max_val - min_val) for data in spectral_test_data]

# Convert the list of Tensors to a single Tensor
spectral_train_data = torch.stack(spectral_train_data).float()
spectral_test_data = torch.stack(spectral_test_data).float()




In [None]:
# plt.grid()
# plt.plot(np.arange(hyper_2d.shape[1]), spectral_train_data[0])
# plt.title("A Sample curve")
# plt.show()

In [None]:
spectral_model = spectral_AE()
train_loader = DataLoader(spectral_train_data, batch_size=16384, shuffle=True)
test_loader = DataLoader(spectral_test_data, batch_size=16384)


In [None]:
loss_function = torch.nn.L1Loss()
criterion = torch.nn.MSELoss()
device  = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# print (device)
spectral_model = spectral_model.to(device)
optimizer = torch.optim.Adam(spectral_model.parameters(),
                             lr = 1e-4,
                             weight_decay = 1e-5)

epochs = 20
history_loss_train = []
history_loss_val = []
for epoch in range(epochs):
    spectral_model.train()
    for inputs in train_loader:
        inputs = inputs.to(device)
        outputs = spectral_model(inputs)
        loss = criterion(outputs, inputs)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    spectral_model.eval()
    with torch.no_grad():
        for inputs in test_loader:
            inputs = inputs.to(device)
            outputs = spectral_model(inputs)
            val_loss = criterion(outputs, inputs)
    history_loss_train.append(loss.item())
    history_loss_val.append(val_loss.item())
    print(f'Epoch {epoch+1}/{epochs}, Loss: {loss.item()}, Validation Loss: {val_loss.item()}')

In [None]:
plt.plot(history_loss_train, label="Training Loss")
plt.plot(history_loss_val, label="Validation Loss")
plt.legend()


In [None]:
spectral_test_data = spectral_test_data.to(device)

encoded_data = spectral_model.encoder(spectral_test_data).detach().cpu().numpy()
decoded_data = spectral_model.decoder(torch.from_numpy(encoded_data).to(device)).detach().cpu().numpy()


In [None]:
test_index=0

spectral_test_data_np = spectral_test_data.cpu().numpy()


plt.plot(spectral_test_data_np[test_index], 'b')
plt.plot(decoded_data[test_index], 'r')
plt.fill_between(np.arange(hyper_2d.shape[1]), decoded_data[test_index], spectral_test_data_np[test_index], color='lightcoral')
plt.legend(labels=["Input", "Reconstruction", "Error"])
plt.show()


In [None]:

print(hyperImage.shape)
hyperImage_reshaped = np.transpose(hyperImage, (1, 2, 0))  # Reshape hyperImage before encoding
hyperImage_reshaped = hyperImage_reshaped.reshape((new_shape_first_dim, -1))  # Keep the second dimension flexible
hyper_2d_tensor = torch.from_numpy(hyperImage_reshaped).float().to(device)
encoded_data = spectral_model.encoder(hyper_2d_tensor).detach().cpu().numpy()
print(encoded_data.shape)
print(hyperImage.shape)
encoded_data_reshaped = encoded_data.reshape(1886,1886,3)

plt.figure(figsize=(15, 5))
# print(np.min(encoded_data_reshaped))
# print(np.max(encoded_data_reshaped))
# encoded_data_reshaped2 = encoded_data_reshaped- np.min(encoded_data_reshaped)
# encoded_data_reshaped2 = encoded_data_reshaped2/(np.max(encoded_data_reshaped)-np.min(encoded_data_reshaped))
# print(np.min(encoded_data_reshaped2))
# print(np.max(encoded_data_reshaped2))
encoded_data_reshaped2 = encoded_data_reshaped /255
print(encoded_data.dtype, encoded_data_reshaped2.dtype)


plt.subplot(1, 4, 1)
plt.imshow(encoded_data_reshaped2[:,:,0], cmap='gray')
plt.title('Channel 1')



plt.subplot(1, 4, 2)
plt.imshow(encoded_data_reshaped2[:,:,1], cmap='gray')
plt.title('Channel 2')

plt.subplot(1, 4, 3)
plt.imshow(encoded_data_reshaped2[:,:,2], cmap='gray')
plt.title('Channel 3')

plt.subplot(1, 4, 4)
plt.imshow(encoded_data_reshaped2)
plt.title('Channel 4 RGB')

plt.show()




In [None]:
print(np.min(encoded_data_reshaped2[:,:,0]))
print(np.max(encoded_data_reshaped2[:,:,2]))
# Flatten the data to 1D
data_flattened = encoded_data_reshaped2[:,:,2].flatten()

# Create a histogram
plt.hist(data_flattened, bins=50)

# Show the plot
plt.show()

In [None]:
# torch.save(spectral_model.state_dict(), '/workspaces/HyperTools/spectral_model_norm.pth')