# 🎉 Out-of-Distribution (OOD) with PCA using Deep Features from the Latent Space

The goal of this notebook is to understand the depths of using Principal Component Analysis in order to perform OOD tasks using deep features from the latent space

## 📝 Plan of action

### ♻️ Preprocessing phase

In order to achieve our goal, we need to understand how the dataset is structured.

For this notebook, we are going to use the CBIR 15 dataset, that contains images of different places, such as an office, a bedroom, a mountain, etc. Note that there are some places that are similar one to another, i.e. a bedroom and a living room.

Thus, in order to extract the features of the images we have to preprocess those images:

- Get the images that are located in data/CBIR_15-scene and fit them to a dataframe using Pandas
  - Locate the "Labels.txt" file: it shows where the indexes of the images from each category starts
- Create the dataset with this information with two columns: the path to the image and its category
- Transform all of the images in the same size (in this case, we are going with 256x256)
  
Now, in order to extract the features, it's necessary to divide the reshaped images into patches of 32x32 pixels. This is good to perform processing tasks to avoid waiting long periods of time.

After all the preprocess, we should separate the images into two different foldes: one contains the patches of the training images that is going to give us their principal components and dimensions, and the other is the patches of the test images, that is going to be tested to fit into those dimensions and we'll get an OOD score afterwards.

### 🏋🏽‍♂️ Training phase

With the images that are stored inside the "patches_train" folder, the first thing we are going to do is _normalize_ all of the images to find the correct maximum covariance and transforming all the variables into the same scale.

Next, we should then apply the PCA with all the components. As we have patches of 32x32, we'll be having 1024 features, hence components. Then we plot a graph to see how many components truly contributes for the most variance of the data - and give us more information about it. We're going to take the threshold of 95% of variance in this notebook.

After getting the PCA with components that describe 95% of the variance, it's time to test our images and see how far of the residual space their data can be found.

### ⚗️ Test phase and results

In this phase, we take the test images and normalize then with the same scale of each PCA. This is important to maintain consistency throughout the final results and measure the norms in the new dimension properly.

After that, we calculate the norm of the projection of the given data into the orthogonal space of the principal component and divide it by the norm of the data in relation to the origin. This is the OOD score.

We calculate the mean of the score for each category and get the minimal one. The current environment is the smallest.


--------------------------

First of all, we need to understand which libraries we are going to use:

- os: Deals with the operation system interface such as finding the relative and absolute path of files inside a project and reading/writing files for example.
- sys: This module provides access to some variables used or maintained by the interpreter and to functions that interact strongly with the interpreter.
- numpy: NumPy is the fundamental package for scientific computing in Python. It is a Python library that provides a multidimensional array object, various derived objects (such as masked arrays and matrices), and an assortment of routines for fast operations on arrays, including mathematical, logical, shape manipulation, sorting, selecting, I/O, discrete Fourier transforms, basic linear algebra, basic statistical operations, random simulation and much more.
- pandas: Pandas is an open source, BSD-licensed library providing high-performance, easy-to-use data structures and data analysis tools for the Python programming language.
- matplotlib: Deals with plotting graphs to visualize data in a graphical way.
- sklearn: Scikit-learn provides dozens of built-in machine learning algorithms and models, called estimators.

In [None]:
import os
import sys
import cv2
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA


I'd suggest to use a conda virtual environment in order to avoid messing up your base kernel environment and causing dependency errors in the future.

After you successfully installed all the modules, it's time to import our custom modules that are going to deal with:

- Creation of our dataframe using pandas
- Separation of our dataset into patches of 32x32 in folders of training and test

In [None]:

sys.path.append(os.path.abspath('..'))

from dataframe_generator import *
from images_standardizing import *

In [None]:
import tarfile

def extract_tgz(tgz_path, extract_to):
    if not os.path.exists(extract_to):
        os.makedirs(extract_to)
    
    with tarfile.open(tgz_path, 'r:gz') as tar:
        tar.extractall(path=extract_to)
        print(f"Arquivos extraídos para {extract_to}")

tgz_path = '../CBIR_15-Scene.tgz'
extract_to = '../data/'

extract_tgz(tgz_path, extract_to)

In [None]:
df = create_dataframe()
df

## ☝️ Part I: Comparing two different environments

### ♻️ Preprocessing phase

Now we start our experiments to understand if our idea work, however this time we are going to understand what happens with our approach using two different environments.

In our case, I'm going to take the **Coast** and **Office** environments arbitrarily.


In [None]:
train_categories = ['Coast', 'Office']

df_different = df[df['category'].isin(train_categories)]
df_different

It's time to separate our dataset into train and test. We should use the built-in function of sklearn to do this:

In [None]:
X = df_different['image_path'].tolist()
y = df_different['category'].tolist()
unique_categories = list(df_different['category'].unique())
print(f"Unique categories: {unique_categories}")

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=10)

standard_size = (256, 256)

Making sure that everything went well, we plot the grid of all the patches from the first image of our training set

This is exactly what the module that's inside our "image_patching.py" do. So we now, need to save everything into the subfolders by calling that function:

In [None]:
create_images_set(X_train, X_test, y_train, y_test, output_dir_train='images_train', output_dir_test='images_test', standard_size=standard_size)

Now, we should load our patches for training:

In [None]:
training_images_by_category = load_images_by_category('images_train', unique_categories, image_size=(256, 256))
print(training_images_by_category['Coast'].shape)

### Centering images

Now, we need to center the images to make the neural network more efficient. We are not normalizing it to avoid information loss.

In [None]:
def center_images(images):
    num_images, height, width = images.shape
    flattened_images = images.reshape((num_images, -1))
    
    mean = np.mean(flattened_images, axis=0)
    
    centered_flattened_images = flattened_images - mean
    
    centered_images = centered_flattened_images.reshape((num_images, height, width))
    return centered_images


centralized_images_by_category = {}
scalers_by_category = {}
for category, images in training_images_by_category.items():
    print(images.shape)
    centralized_images = center_images(images)
    centralized_images_by_category[category] = centralized_images
    print(f"Category {category}, images shape: {centralized_images.shape}")


In [None]:
def check_centralization(images):
    mean = np.mean(images, axis=(0, 1, 2))
    return mean

for category, images in centralized_images_by_category.items():
    mean = check_centralization(images)
    print(f"Mean pixel values after centralization for category {category}: {mean}")


Given the values close to zero, it means that the pixels for each color channel are correctly centralized.

### 🏋🏽‍♂️ Training phase

With everything preprocessed, we now need to train our neural network. In this notebook, I chose the VGG16 because it's a well-known neural network that is often used por computer vision applications.

I'm using no weights, because the underlining goal of this research is to use the results from this work in a unsupervised environment.

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, TensorDataset
from sklearn.decomposition import PCA

# Verifica o dispositivo disponível (GPU ou CPU)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

# Define o modelo UNet para codificar e decodificar
class UNet(nn.Module):
    def __init__(self, in_channels=1, out_channels=1):
        super(UNet, self).__init__()

        # Encoder
        self.encoder1 = self.contracting_block(in_channels, 64)
        self.encoder2 = self.contracting_block(64, 128)
        self.encoder3 = self.contracting_block(128, 256)
        self.encoder4 = self.contracting_block(256, 512)

        # Decoder
        self.upconv4 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.decoder4 = self.expansive_block(512, 256)
        self.upconv3 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.decoder3 = self.expansive_block(256, 128)
        self.upconv2 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.decoder2 = self.expansive_block(128, 64)
        self.decoder1 = self.final_block(64, out_channels)

    def contracting_block(self, in_channels, out_channels):
        block = nn.Sequential(
            nn.Conv2d(kernel_size=3, in_channels=in_channels, out_channels=out_channels, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Conv2d(kernel_size=3, in_channels=out_channels, out_channels=out_channels, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Dropout(0.5)
        )
        return block

    def expansive_block(self, in_channels, out_channels):
        block = nn.Sequential(
            nn.Conv2d(kernel_size=3, in_channels=in_channels, out_channels=out_channels, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Conv2d(kernel_size=3, in_channels=out_channels, out_channels=out_channels, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Dropout(0.5)
        )
        return block

    def final_block(self, in_channels, out_channels):
        block = nn.Sequential(
            nn.Conv2d(kernel_size=3, in_channels=in_channels, out_channels=out_channels, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Conv2d(kernel_size=3, in_channels=out_channels, out_channels=out_channels, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Conv2d(kernel_size=1, in_channels=out_channels, out_channels=out_channels),
            nn.Sigmoid()  # Adiciona ativação sigmoid para garantir que a saída fique entre [0, 1]
        )
        return block

    def crop_and_concat(self, upsampled, bypass):
        _, _, H, W = upsampled.size()
        _, _, H_b, W_b = bypass.size()
        if H_b != H or W_b != W:
            bypass = nn.functional.interpolate(bypass, size=(H, W), mode='bilinear', align_corners=True)
        return torch.cat((upsampled, bypass), 1)

    def forward(self, x):
        # Encoder path
        enc1 = self.encoder1(x)
        enc2 = self.encoder2(nn.functional.max_pool2d(enc1, kernel_size=2, stride=2))
        enc3 = self.encoder3(nn.functional.max_pool2d(enc2, kernel_size=2, stride=2))
        enc4 = self.encoder4(nn.functional.max_pool2d(enc3, kernel_size=2, stride=2))

        # Decoder path
        dec4 = self.crop_and_concat(self.upconv4(enc4), enc3)
        dec4 = self.decoder4(dec4)
        dec3 = self.crop_and_concat(self.upconv3(dec4), enc2)
        dec3 = self.decoder3(dec3)
        dec2 = self.crop_and_concat(self.upconv2(dec3), enc1)
        dec2 = self.decoder2(dec2)
        dec1 = self.decoder1(dec2)

        return dec1, enc1, enc2, enc3, enc4  # Retorna tanto a imagem reconstruída quanto as features do encoder

def apply_pca_and_reconstruct(features, pca):
    # Flatten the features while keeping track of the original shape
    batch_size, channels, height, width = features.size()
    flattened_features = features.view(batch_size, -1).cpu().numpy()

    # Verifica as dimensões originais das features
    original_num_features = flattened_features.shape[1]
    print(f"Original number of features before PCA: {original_num_features}")

    # Apply PCA projection and reconstruction
    projected_features = pca.transform(flattened_features)
    num_pca_components = projected_features.shape[1]
    print(f"Number of components after PCA: {num_pca_components}")

    reconstructed_features = pca.inverse_transform(projected_features)
    #print(f"Reconstructed feature dimensions (after inverse PCA): {reconstructed_features.shape}")

    # Convert back to tensor and reshape to the original feature shape
    reconstructed_features = torch.tensor(reconstructed_features, dtype=torch.float32).view(batch_size, channels, height, width).to(device)

    return reconstructed_features

# Função para treinar e aplicar PCA nas features extraídas da U-Net com verificação de variância explicada
def train_unet_and_apply_pca(unet, data_loader, num_epochs=5):
    criterion = nn.MSELoss()
    optimizer = torch.optim.Adam(unet.parameters(), lr=0.0001)

    global pca_train_features

    for epoch in range(num_epochs):
        unet.train()
        epoch_loss = 0
        for images, in data_loader:
            images = images.to(device).float()
            images /= 255.0

            optimizer.zero_grad()

            # Forward pass through U-Net to get features
            reconstructed_images, _, _, _, unet_features = unet(images)

            # Flatten and collect features for PCA training (in the first epoch)
            if epoch == 0:
                flattened_features = unet_features.view(images.shape[0], -1).detach().cpu().numpy()
                if 'pca_train_features' not in globals():
                    pca_train_features = flattened_features
                else:
                    pca_train_features = np.vstack((pca_train_features, flattened_features))

            loss = criterion(reconstructed_images, images)
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()

        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss/len(data_loader):.4f}')

    # Calcula e exibe o número de features antes do PCA
    original_feature_count = pca_train_features.shape[1]
    print(f"Original number of features before PCA: {original_feature_count}")

    # Aplicando o PCA e mostrando a variância explicada
    pca = PCA(n_components=0.95)  # Retain 95% variance
    pca.fit(pca_train_features)
    print(f"Explained variance by PCA components: {pca.explained_variance_ratio_}")
    print(f"Number of components chosen by PCA: {pca.n_components_}")

    return pca

# Inicialização do modelo UNet
unet = UNet().to(device)

# Dicionário para armazenar o PCA de cada categoria
pca_dict = {}

# Defina o DataLoader aqui
categories = centralized_images_by_category.keys()

for category in categories:
    images = centralized_images_by_category[category]

    # Adiciona a dimensão do canal e cria o DataLoader
    images = np.expand_dims(images, axis=1)  # Adiciona dimensão do canal
    dataset = TensorDataset(torch.tensor(images, dtype=torch.float32))
    loader = DataLoader(dataset, batch_size=1, shuffle=True)

    # Treina o modelo U-Net e coleta features para o PCA
    pca = train_unet_and_apply_pca(unet, loader, num_epochs=5)

    # Armazena o PCA no dicionário para a categoria atual
    pca_dict[category] = pca

    # Avalia e visualiza algumas imagens reconstruídas com PCA aplicado nas features
    unet.eval()
    with torch.no_grad():
        for i in range(3):  # Exibe 3 imagens de cada categoria
            image = centralized_images_by_category[category][i]
            image_tensor = torch.tensor(image / 255.0).unsqueeze(0).unsqueeze(0).float().to(device)

            # Processa com UNet e extrai features
            reconstructed_image, enc1, enc2, enc3, unet_features = unet(image_tensor)

            # Aplica PCA nas features e reconstrói as features
            reconstructed_features = apply_pca_and_reconstruct(unet_features, pca)

            # Passa as features reconstruídas pelo caminho de decodificação da U-Net
            dec4 = unet.crop_and_concat(unet.upconv4(reconstructed_features), enc3)
            dec4 = unet.decoder4(dec4)
            dec3 = unet.crop_and_concat(unet.upconv3(dec4), enc2)
            dec3 = unet.decoder3(dec3)
            dec2 = unet.crop_and_concat(unet.upconv2(dec3), enc1)
            dec2 = unet.decoder2(dec2)

            # Corrige o número de canais esperados pela última camada de convolução
            final_reconstructed_image = unet.decoder1(dec2)

            # Converte o tensor para numpy para visualização
            final_reconstructed_image_np = final_reconstructed_image.squeeze().cpu().numpy()

            # Visualiza as imagens
            plt.figure(figsize=(10, 5))
            plt.subplot(1, 2, 1)
            plt.imshow(image, cmap='gray')
            plt.title(f'Original {category} Image {i+1}')
            plt.axis('off')

            plt.subplot(1, 2, 2)
            plt.imshow(final_reconstructed_image_np, cmap='gray')
            plt.title(f'Reconstructed {category} Image {i+1} with PCA')
            plt.axis('off')

            plt.show()

# Salva o dicionário de PCA para uso posterior
print("PCA models saved for each category.")

Now, we get the before last layer's output to extract our latent features from the neural network.

This result means that we extracted 348 images with 4096 features each of the Coast category and 209 images with 4096 features each of the Office category.

Now we have to reduce the dimensonality. In order to do that, we should use PCA techniques. But before that, we should centralize the features now.

The components_ matrix has the shape (n_components, n_features), but when you project the original data into this new principal components space, the data is transformed into a shape matrix (n_samples, n_components).

### Testing phase


In [None]:
def load_and_preprocess_test_images(test_dir, categories, image_size, input_size=(224,224)):
    test_images_by_category = load_images_by_category(test_dir, categories, image_size)
    test_centralized_images_by_category = {}

    for category, images in test_images_by_category.items():
        test_centralized_images = center_images(images)
        test_centralized_images_by_category[category] = test_centralized_images

    return test_centralized_images_by_category

test_preprocessed_images_by_category = load_and_preprocess_test_images('images_test', y, image_size=(256,256), input_size=(256,256))

In [None]:
for category, images in test_preprocessed_images_by_category.items():
    mean = check_centralization(images)
    print(f"Mean pixel values after centralization for category {category}: {mean}")

In [None]:
import torch.nn.functional as F

# Function to apply UNet with different PCA models, compute OOD scores, and visualize the first 3 results
def compute_ood_and_visualize(unet, pca_dict, test_images_by_category):
    unet.eval()  # Set the model to evaluation mode

    # Dictionary to store OOD scores for each test category and PCA category
    ood_scores_by_category = {}

    with torch.no_grad():  # Disable gradient calculation for evaluation
        for test_category, test_images in test_images_by_category.items():
            print(f"Processing test category: {test_category}")

            # Store OOD scores for this test category
            ood_scores_by_category[test_category] = {}

            # Iterate over PCA models from different categories
            for pca_category in pca_dict.keys():
                ood_scores_by_category[test_category][pca_category] = []

                # Process all test images in this category
                for i, image in enumerate(test_images):
                    image_tensor = torch.tensor(image / 255.0).unsqueeze(0).unsqueeze(0).float().to(device)

                    # Forward pass through UNet to get features
                    reconstructed_image, enc1, enc2, enc3, unet_features = unet(image_tensor)

                    # Apply PCA to the features and reconstruct them
                    reconstructed_features = apply_pca_and_reconstruct(unet_features, pca_dict[pca_category])

                    # Calculate residuals
                    residuals = unet_features - reconstructed_features

                    # Calculate norms for OOD score
                    norm_residuals = torch.norm(residuals)
                    norm_original = torch.norm(unet_features)

                    # Compute OOD score
                    ood_score = (norm_residuals / norm_original).item()
                    ood_scores_by_category[test_category][pca_category].append(ood_score)

                    # Only plot the reconstruction for the first 3 images in each category
                    if i < 3:
                        # Pass the reconstructed features through the decoder of the UNet
                        dec4 = unet.crop_and_concat(unet.upconv4(reconstructed_features), enc3)
                        dec4 = unet.decoder4(dec4)
                        dec3 = unet.crop_and_concat(unet.upconv3(dec4), enc2)
                        dec3 = unet.decoder3(dec3)
                        dec2 = unet.crop_and_concat(unet.upconv2(dec3), enc1)
                        dec2 = unet.decoder2(dec2)
                        final_reconstructed_image = unet.decoder1(dec2)

                        # Convert the tensor to numpy for visualization
                        final_reconstructed_image_np = final_reconstructed_image.squeeze().cpu().numpy()

                        # Plot original and reconstructed images
                        plt.figure(figsize=(10, 5))
                        plt.subplot(1, 2, 1)
                        plt.imshow(image, cmap='gray')
                        plt.title(f'Original {test_category} Image {i+1}')
                        plt.axis('off')

                        plt.subplot(1, 2, 2)
                        plt.imshow(final_reconstructed_image_np, cmap='gray')
                        plt.title(f'Reconstructed {test_category} Image {i+1} with {pca_category} PCA')
                        plt.axis('off')

                        plt.show()

    # Calculate average OOD score for each category with each PCA
    avg_ood_scores_by_category = {}

    for test_category, pca_ood_scores in ood_scores_by_category.items():
        avg_ood_scores_by_category[test_category] = {}
        for pca_category, ood_scores in pca_ood_scores.items():
            avg_ood_score = np.mean(ood_scores)
            avg_ood_scores_by_category[test_category][pca_category] = avg_ood_score
            print(f"Average OOD Score for {test_category} with {pca_category} PCA: {avg_ood_score}")

    return avg_ood_scores_by_category

# Assuming you have already trained your U-Net and saved the PCA models for each category:
avg_ood_scores_by_category = compute_ood_and_visualize(unet, pca_dict, test_preprocessed_images_by_category)

# Print the average OOD scores for all categories
for test_category, pca_ood_scores in avg_ood_scores_by_category.items():
    print(f"\nAverage OOD Scores for test category '{test_category}':")
    for pca_category, avg_ood in pca_ood_scores.items():
        print(f"  PCA from '{pca_category}' => Average OOD Score: {avg_ood}")


# Agnostic Space

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch

import torch

# Function to project features onto selected PCA components and transform back
def project_and_transform_back(data, pca, specific_indices):
    """
    Project the data onto specific PCA components, reconstruct the data, and calculate residuals and OOD score.
    """
    # Flatten the features for PCA projection
    flattened_data = data.view(1, -1).cpu().numpy()
    
    # Project onto the PCA components
    projected = pca.transform(flattened_data)
    
    # Use only the specific components
    projected_specific = projected[:, specific_indices]
    
    # Reconstruct the data from the selected components
    specific_components = pca.components_[specific_indices]
    reconstructed_data = np.dot(projected_specific, specific_components)
    
    # Calculate residuals
    residuals = flattened_data - reconstructed_data
    
    # Calculate norms for OOD score
    norm_residuals = np.linalg.norm(residuals)
    norm_original = np.linalg.norm(flattened_data)
    
    # Calculate OOD score as the ratio of the norms
    ood_score = norm_residuals / norm_original
    
    return reconstructed_data, residuals, ood_score
def compute_ood_with_selected_components(unet, pca_dict, test_images_by_category):
    unet.eval()  # Set UNet to evaluation mode
    
    # Dictionary to store OOD scores for each test category and PCA category
    ood_scores_by_category = {}

    with torch.no_grad():
        for test_category, test_images in test_images_by_category.items():
            print(f"Processing test category: {test_category}")

            ood_scores_by_category[test_category] = {}

            # Iterate over PCA models from different categories
            for pca_category, pca in pca_dict.items():
                print(f"Processing with PCA from category: {pca_category}")
                ood_scores_by_category[test_category][pca_category] = []

                # Retrieve explained variance ratio and components from the PCA object
                explained_variance_ratio = pca.explained_variance_ratio_
                components = pca.components_

                all_selected_indices = []

                # Process each image in the test category
                for i, image in enumerate(test_images):
                    # Prepare image tensor and extract UNet features
                    image_tensor = torch.tensor(image / 255.0).unsqueeze(0).unsqueeze(0).float().to(device)
                    reconstructed_image, enc1, enc2, enc3, unet_features = unet(image_tensor)

                    # Flatten UNet features for PCA projection
                    flattened_features = unet_features.view(1, -1).cpu().numpy()

                    # Project and select PCA components
                    norms_category = np.linalg.norm(components, axis=1)
                    selected_indices = np.where(np.cumsum(explained_variance_ratio) <= 0.9)[0]
                    all_selected_indices.append(selected_indices)

                # Aggregate selected components across all images
                aggregated_selected_indices = np.unique(np.concatenate(all_selected_indices))

                # Skip processing if no valid components are selected
                if len(aggregated_selected_indices) == 0:
                    print(f"Warning: No valid components selected for {test_category} with {pca_category} PCA. Skipping this combination.")
                    continue

                # Process each image again using selected components for projection, reconstruction, and OOD calculation
                for i, image in enumerate(test_images):
                    image_tensor = torch.tensor(image / 255.0).unsqueeze(0).unsqueeze(0).float().to(device)
                    reconstructed_image, enc1, enc2, enc3, unet_features = unet(image_tensor)

                    # Project and reconstruct features using only the selected PCA components
                    reconstructed_features, residuals, ood_score = project_and_transform_back(unet_features, pca, aggregated_selected_indices)

                    # Reshape the residuals to the same shape as `unet_features`
                    residuals_reshaped = residuals.reshape(unet_features.shape)  # Reshape to (B, C, H, W)

                    # Append the OOD score
                    ood_scores_by_category[test_category][pca_category].append(ood_score)

                    # Visualize the first 3 images with their reconstructions and residuals
                    if i < 3:
                        # Pass the reconstructed features through the decoder of the UNet
                        dec4 = unet.crop_and_concat(unet.upconv4(torch.tensor(reconstructed_features, dtype=torch.float32).view_as(unet_features).to(device)), enc3)
                        dec4 = unet.decoder4(dec4)
                        dec3 = unet.crop_and_concat(unet.upconv3(dec4), enc2)
                        dec3 = unet.decoder3(dec3)
                        dec2 = unet.crop_and_concat(unet.upconv2(dec3), enc1)
                        dec2 = unet.decoder2(dec2)
                        final_reconstructed_image = unet.decoder1(dec2)

                        # Convert to numpy for visualization
                        final_reconstructed_image_np = final_reconstructed_image.squeeze().cpu().numpy()

                        # Convert residuals to numpy for visualization
                        residuals_np = residuals_reshaped.squeeze()  # Residuals are already in NumPy format

                        # Plot original, reconstructed, and residual images
                        plt.figure(figsize=(15, 5))
                        plt.subplot(1, 3, 1)
                        plt.imshow(image, cmap='gray')
                        plt.title(f'Original {test_category} Image {i+1}')
                        plt.axis('off')

                        plt.subplot(1, 3, 2)
                        plt.imshow(final_reconstructed_image_np, cmap='gray')
                        plt.title(f'Reconstructed {test_category} Image {i+1} with {pca_category} PCA')
                        plt.axis('off')

                        plt.subplot(1, 3, 3)
                        plt.imshow(residuals_np[0], cmap='gray')  # Select the first channel for visualization
                        plt.title(f'Residuals {test_category} Image {i+1} with {pca_category} PCA')
                        plt.axis('off')

                        plt.show()

    # Calculate average OOD score for each category and PCA
    avg_ood_scores_by_category = {}
    for test_category, pca_ood_scores in ood_scores_by_category.items():
        avg_ood_scores_by_category[test_category] = {}
        for pca_category, ood_scores in pca_ood_scores.items():
            avg_ood_score = np.mean(ood_scores)
            avg_ood_scores_by_category[test_category][pca_category] = avg_ood_score
            print(f"Average OOD Score for {test_category} with {pca_category} PCA: {avg_ood_score}")

    return avg_ood_scores_by_category

# Assuming you have already trained your U-Net and saved the PCA models for each category
avg_ood_scores_by_category = compute_ood_with_selected_components(unet, pca_dict, test_preprocessed_images_by_category)

# Print the average OOD scores for all categories
for test_category, pca_ood_scores in avg_ood_scores_by_category.items():
    print(f"\nAverage OOD Scores for test category '{test_category}':")
    for pca_category, avg_ood in pca_ood_scores.items():
        print(f"  PCA from '{pca_category}' => Average OOD Score: {avg_ood}")



## ✌️ Part II: Comparing two similar environments

In [None]:
train_categories = ['Bedroom', 'LivingRoom']

df_different = df[df['category'].isin(train_categories)]
df_different

In [None]:
X = df_different['image_path']
y = df_different['category']
(X_train, X_test, y_train, y_test) = train_test_split(X, y, test_size=0.2, random_state=10)

image_size = (256, 256)
unique_categories = list(df_different['category'].unique())
print(f"Unique categories: {unique_categories}")


In [None]:
create_images_set(X_train, X_test, y_train, y_test, output_dir_train='images_train', output_dir_test='images_test', standard_size=standard_size)

In [None]:
training_images_by_category = load_images_by_category('images_train', unique_categories, image_size=(256, 256))

In [None]:
centralized_images_by_category = {}
scalers_by_category = {}
for category, images in training_images_by_category.items():
    centralized_images = center_images(images)
    centralized_images_by_category[category] = centralized_images


In [None]:
def check_centralization(images):
    mean = np.mean(images, axis=(0, 1, 2))
    return mean

for category, images in centralized_images_by_category.items():
    mean = check_centralization(images)
    print(f"Mean pixel values after centralization for category {category}: {mean}")


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, TensorDataset
from sklearn.decomposition import PCA

# Verifica o dispositivo disponível (GPU ou CPU)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

# Define o modelo UNet para codificar e decodificar
class UNet(nn.Module):
    def __init__(self, in_channels=1, out_channels=1):
        super(UNet, self).__init__()

        # Encoder
        self.encoder1 = self.contracting_block(in_channels, 64)
        self.encoder2 = self.contracting_block(64, 128)
        self.encoder3 = self.contracting_block(128, 256)
        self.encoder4 = self.contracting_block(256, 512)

        # Decoder
        self.upconv4 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.decoder4 = self.expansive_block(512, 256)
        self.upconv3 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.decoder3 = self.expansive_block(256, 128)
        self.upconv2 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.decoder2 = self.expansive_block(128, 64)
        self.decoder1 = self.final_block(64, out_channels)

    def contracting_block(self, in_channels, out_channels):
        block = nn.Sequential(
            nn.Conv2d(kernel_size=3, in_channels=in_channels, out_channels=out_channels, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Conv2d(kernel_size=3, in_channels=out_channels, out_channels=out_channels, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Dropout(0.5)
        )
        return block

    def expansive_block(self, in_channels, out_channels):
        block = nn.Sequential(
            nn.Conv2d(kernel_size=3, in_channels=in_channels, out_channels=out_channels, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Conv2d(kernel_size=3, in_channels=out_channels, out_channels=out_channels, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Dropout(0.5)
        )
        return block

    def final_block(self, in_channels, out_channels):
        block = nn.Sequential(
            nn.Conv2d(kernel_size=3, in_channels=in_channels, out_channels=out_channels, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Conv2d(kernel_size=3, in_channels=out_channels, out_channels=out_channels, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Conv2d(kernel_size=1, in_channels=out_channels, out_channels=out_channels),
            nn.Sigmoid()  # Adiciona ativação sigmoid para garantir que a saída fique entre [0, 1]
        )
        return block

    def crop_and_concat(self, upsampled, bypass):
        _, _, H, W = upsampled.size()
        _, _, H_b, W_b = bypass.size()
        if H_b != H or W_b != W:
            bypass = nn.functional.interpolate(bypass, size=(H, W), mode='bilinear', align_corners=True)
        return torch.cat((upsampled, bypass), 1)

    def forward(self, x):
        # Encoder path
        enc1 = self.encoder1(x)
        enc2 = self.encoder2(nn.functional.max_pool2d(enc1, kernel_size=2, stride=2))
        enc3 = self.encoder3(nn.functional.max_pool2d(enc2, kernel_size=2, stride=2))
        enc4 = self.encoder4(nn.functional.max_pool2d(enc3, kernel_size=2, stride=2))

        # Decoder path
        dec4 = self.crop_and_concat(self.upconv4(enc4), enc3)
        dec4 = self.decoder4(dec4)
        dec3 = self.crop_and_concat(self.upconv3(dec4), enc2)
        dec3 = self.decoder3(dec3)
        dec2 = self.crop_and_concat(self.upconv2(dec3), enc1)
        dec2 = self.decoder2(dec2)
        dec1 = self.decoder1(dec2)

        return dec1, enc1, enc2, enc3, enc4  # Retorna tanto a imagem reconstruída quanto as features do encoder

def apply_pca_and_reconstruct(features, pca):
    # Flatten the features while keeping track of the original shape
    batch_size, channels, height, width = features.size()
    flattened_features = features.view(batch_size, -1).cpu().numpy()

    # Verifica as dimensões originais das features
    original_num_features = flattened_features.shape[1]
    print(f"Original number of features before PCA: {original_num_features}")

    # Apply PCA projection and reconstruction
    projected_features = pca.transform(flattened_features)
    num_pca_components = projected_features.shape[1]
    print(f"Number of components after PCA: {num_pca_components}")

    reconstructed_features = pca.inverse_transform(projected_features)
    #print(f"Reconstructed feature dimensions (after inverse PCA): {reconstructed_features.shape}")

    # Convert back to tensor and reshape to the original feature shape
    reconstructed_features = torch.tensor(reconstructed_features, dtype=torch.float32).view(batch_size, channels, height, width).to(device)

    return reconstructed_features

# Função para treinar e aplicar PCA nas features extraídas da U-Net com verificação de variância explicada
def train_unet_and_apply_pca(unet, data_loader, num_epochs=5):
    criterion = nn.MSELoss()
    optimizer = torch.optim.Adam(unet.parameters(), lr=0.0001)

    global pca_train_features

    for epoch in range(num_epochs):
        unet.train()
        epoch_loss = 0
        for images, in data_loader:
            images = images.to(device).float()
            images /= 255.0

            optimizer.zero_grad()

            # Forward pass through U-Net to get features
            reconstructed_images, _, _, _, unet_features = unet(images)

            # Flatten and collect features for PCA training (in the first epoch)
            if epoch == 0:
                flattened_features = unet_features.view(images.shape[0], -1).detach().cpu().numpy()
                if 'pca_train_features' not in globals():
                    pca_train_features = flattened_features
                else:
                    pca_train_features = np.vstack((pca_train_features, flattened_features))

            loss = criterion(reconstructed_images, images)
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()

        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss/len(data_loader):.4f}')

    # Calcula e exibe o número de features antes do PCA
    original_feature_count = pca_train_features.shape[1]
    print(f"Original number of features before PCA: {original_feature_count}")

    # Aplicando o PCA e mostrando a variância explicada
    pca = PCA(n_components=0.95)  # Retain 95% variance
    pca.fit(pca_train_features)
    print(f"Explained variance by PCA components: {pca.explained_variance_ratio_}")
    print(f"Number of components chosen by PCA: {pca.n_components_}")

    return pca

# Inicialização do modelo UNet
unet = UNet().to(device)

# Dicionário para armazenar o PCA de cada categoria
pca_dict = {}

# Defina o DataLoader aqui
categories = centralized_images_by_category.keys()

for category in categories:
    images = centralized_images_by_category[category]

    # Adiciona a dimensão do canal e cria o DataLoader
    images = np.expand_dims(images, axis=1)  # Adiciona dimensão do canal
    dataset = TensorDataset(torch.tensor(images, dtype=torch.float32))
    loader = DataLoader(dataset, batch_size=1, shuffle=True)

    # Treina o modelo U-Net e coleta features para o PCA
    pca = train_unet_and_apply_pca(unet, loader, num_epochs=5)

    # Armazena o PCA no dicionário para a categoria atual
    pca_dict[category] = pca

    # Avalia e visualiza algumas imagens reconstruídas com PCA aplicado nas features
    unet.eval()
    with torch.no_grad():
        for i in range(3):  # Exibe 3 imagens de cada categoria
            image = centralized_images_by_category[category][i]
            image_tensor = torch.tensor(image / 255.0).unsqueeze(0).unsqueeze(0).float().to(device)

            # Processa com UNet e extrai features
            reconstructed_image, enc1, enc2, enc3, unet_features = unet(image_tensor)

            # Aplica PCA nas features e reconstrói as features
            reconstructed_features = apply_pca_and_reconstruct(unet_features, pca)

            # Passa as features reconstruídas pelo caminho de decodificação da U-Net
            dec4 = unet.crop_and_concat(unet.upconv4(reconstructed_features), enc3)
            dec4 = unet.decoder4(dec4)
            dec3 = unet.crop_and_concat(unet.upconv3(dec4), enc2)
            dec3 = unet.decoder3(dec3)
            dec2 = unet.crop_and_concat(unet.upconv2(dec3), enc1)
            dec2 = unet.decoder2(dec2)

            # Corrige o número de canais esperados pela última camada de convolução
            final_reconstructed_image = unet.decoder1(dec2)

            # Converte o tensor para numpy para visualização
            final_reconstructed_image_np = final_reconstructed_image.squeeze().cpu().numpy()

            # Visualiza as imagens
            plt.figure(figsize=(10, 5))
            plt.subplot(1, 2, 1)
            plt.imshow(image, cmap='gray')
            plt.title(f'Original {category} Image {i+1}')
            plt.axis('off')

            plt.subplot(1, 2, 2)
            plt.imshow(final_reconstructed_image_np, cmap='gray')
            plt.title(f'Reconstructed {category} Image {i+1} with PCA')
            plt.axis('off')

            plt.show()

# Salva o dicionário de PCA para uso posterior
print("PCA models saved for each category.")

# Test

In [None]:
def load_and_preprocess_test_images(test_dir, categories, image_size, input_size=(256,256)):
    test_images_by_category = load_images_by_category(test_dir, categories, image_size)
    test_centralized_images_by_category = {}

    for category, images in test_images_by_category.items():
        test_centralized_images = center_images(images)
        test_centralized_images_by_category[category] = test_centralized_images

    return test_centralized_images_by_category

test_preprocessed_images_by_category = load_and_preprocess_test_images('images_test', y, image_size=(256,256), input_size=(256,256))

In [None]:
for category, images in centralized_images_by_category.items():
    mean = check_centralization(images)
    print(f"Mean pixel values after centralization for category {category}: {mean}")

In [None]:
import torch.nn.functional as F

# Function to apply UNet with different PCA models, compute OOD scores, and visualize the first 3 results
def compute_ood_and_visualize(unet, pca_dict, test_images_by_category):
    unet.eval()  # Set the model to evaluation mode

    # Dictionary to store OOD scores for each test category and PCA category
    ood_scores_by_category = {}

    with torch.no_grad():  # Disable gradient calculation for evaluation
        for test_category, test_images in test_images_by_category.items():
            print(f"Processing test category: {test_category}")

            # Store OOD scores for this test category
            ood_scores_by_category[test_category] = {}

            # Iterate over PCA models from different categories
            for pca_category in pca_dict.keys():
                ood_scores_by_category[test_category][pca_category] = []

                # Process all test images in this category
                for i, image in enumerate(test_images):
                    image_tensor = torch.tensor(image / 255.0).unsqueeze(0).unsqueeze(0).float().to(device)

                    # Forward pass through UNet to get features
                    reconstructed_image, enc1, enc2, enc3, unet_features = unet(image_tensor)

                    # Apply PCA to the features and reconstruct them
                    reconstructed_features = apply_pca_and_reconstruct(unet_features, pca_dict[pca_category])

                    # Calculate residuals
                    residuals = unet_features - reconstructed_features

                    # Calculate norms for OOD score
                    norm_residuals = torch.norm(residuals)
                    norm_original = torch.norm(unet_features)

                    # Compute OOD score
                    ood_score = (norm_residuals / norm_original).item()
                    ood_scores_by_category[test_category][pca_category].append(ood_score)

                    # Only plot the reconstruction for the first 3 images in each category
                    if i < 3:
                        # Pass the reconstructed features through the decoder of the UNet
                        dec4 = unet.crop_and_concat(unet.upconv4(reconstructed_features), enc3)
                        dec4 = unet.decoder4(dec4)
                        dec3 = unet.crop_and_concat(unet.upconv3(dec4), enc2)
                        dec3 = unet.decoder3(dec3)
                        dec2 = unet.crop_and_concat(unet.upconv2(dec3), enc1)
                        dec2 = unet.decoder2(dec2)
                        final_reconstructed_image = unet.decoder1(dec2)

                        # Convert the tensor to numpy for visualization
                        final_reconstructed_image_np = final_reconstructed_image.squeeze().cpu().numpy()

                        # Plot original and reconstructed images
                        plt.figure(figsize=(10, 5))
                        plt.subplot(1, 2, 1)
                        plt.imshow(image, cmap='gray')
                        plt.title(f'Original {test_category} Image {i+1}')
                        plt.axis('off')

                        plt.subplot(1, 2, 2)
                        plt.imshow(final_reconstructed_image_np, cmap='gray')
                        plt.title(f'Reconstructed {test_category} Image {i+1} with {pca_category} PCA')
                        plt.axis('off')

                        plt.show()

    # Calculate average OOD score for each category with each PCA
    avg_ood_scores_by_category = {}

    for test_category, pca_ood_scores in ood_scores_by_category.items():
        avg_ood_scores_by_category[test_category] = {}
        for pca_category, ood_scores in pca_ood_scores.items():
            avg_ood_score = np.mean(ood_scores)
            avg_ood_scores_by_category[test_category][pca_category] = avg_ood_score
            print(f"Average OOD Score for {test_category} with {pca_category} PCA: {avg_ood_score}")

    return avg_ood_scores_by_category

# Assuming you have already trained your U-Net and saved the PCA models for each category:
avg_ood_scores_by_category = compute_ood_and_visualize(unet, pca_dict, test_preprocessed_images_by_category)

# Print the average OOD scores for all categories
for test_category, pca_ood_scores in avg_ood_scores_by_category.items():
    print(f"\nAverage OOD Scores for test category '{test_category}':")
    for pca_category, avg_ood in pca_ood_scores.items():
        print(f"  PCA from '{pca_category}' => Average OOD Score: {avg_ood}")


# Agnostic

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch

import torch

# Function to project features onto selected PCA components and transform back
def project_and_transform_back(data, pca, specific_indices):
    """
    Project the data onto specific PCA components, reconstruct the data, and calculate residuals and OOD score.
    """
    # Flatten the features for PCA projection
    flattened_data = data.view(1, -1).cpu().numpy()
    
    # Project onto the PCA components
    projected = pca.transform(flattened_data)
    
    # Use only the specific components
    projected_specific = projected[:, specific_indices]
    
    # Reconstruct the data from the selected components
    specific_components = pca.components_[specific_indices]
    reconstructed_data = np.dot(projected_specific, specific_components)
    
    # Calculate residuals
    residuals = flattened_data - reconstructed_data
    
    # Calculate norms for OOD score
    norm_residuals = np.linalg.norm(residuals)
    norm_original = np.linalg.norm(flattened_data)
    
    # Calculate OOD score as the ratio of the norms
    ood_score = norm_residuals / norm_original
    
    return reconstructed_data, residuals, ood_score
def compute_ood_with_selected_components(unet, pca_dict, test_images_by_category):
    unet.eval()  # Set UNet to evaluation mode
    
    # Dictionary to store OOD scores for each test category and PCA category
    ood_scores_by_category = {}

    with torch.no_grad():
        for test_category, test_images in test_images_by_category.items():
            print(f"Processing test category: {test_category}")

            ood_scores_by_category[test_category] = {}

            # Iterate over PCA models from different categories
            for pca_category, pca in pca_dict.items():
                print(f"Processing with PCA from category: {pca_category}")
                ood_scores_by_category[test_category][pca_category] = []

                # Retrieve explained variance ratio and components from the PCA object
                explained_variance_ratio = pca.explained_variance_ratio_
                components = pca.components_

                all_selected_indices = []

                # Process each image in the test category
                for i, image in enumerate(test_images):
                    # Prepare image tensor and extract UNet features
                    image_tensor = torch.tensor(image / 255.0).unsqueeze(0).unsqueeze(0).float().to(device)
                    reconstructed_image, enc1, enc2, enc3, unet_features = unet(image_tensor)

                    # Flatten UNet features for PCA projection
                    flattened_features = unet_features.view(1, -1).cpu().numpy()

                    # Project and select PCA components
                    norms_category = np.linalg.norm(components, axis=1)
                    selected_indices = np.where(np.cumsum(explained_variance_ratio) <= 0.9)[0]
                    all_selected_indices.append(selected_indices)

                # Aggregate selected components across all images
                aggregated_selected_indices = np.unique(np.concatenate(all_selected_indices))

                # Skip processing if no valid components are selected
                if len(aggregated_selected_indices) == 0:
                    print(f"Warning: No valid components selected for {test_category} with {pca_category} PCA. Skipping this combination.")
                    continue

                # Process each image again using selected components for projection, reconstruction, and OOD calculation
                for i, image in enumerate(test_images):
                    image_tensor = torch.tensor(image / 255.0).unsqueeze(0).unsqueeze(0).float().to(device)
                    reconstructed_image, enc1, enc2, enc3, unet_features = unet(image_tensor)

                    # Project and reconstruct features using only the selected PCA components
                    reconstructed_features, residuals, ood_score = project_and_transform_back(unet_features, pca, aggregated_selected_indices)

                    # Reshape the residuals to the same shape as `unet_features`
                    residuals_reshaped = residuals.reshape(unet_features.shape)  # Reshape to (B, C, H, W)

                    # Append the OOD score
                    ood_scores_by_category[test_category][pca_category].append(ood_score)

                    # Visualize the first 3 images with their reconstructions and residuals
                    if i < 3:
                        # Pass the reconstructed features through the decoder of the UNet
                        dec4 = unet.crop_and_concat(unet.upconv4(torch.tensor(reconstructed_features, dtype=torch.float32).view_as(unet_features).to(device)), enc3)
                        dec4 = unet.decoder4(dec4)
                        dec3 = unet.crop_and_concat(unet.upconv3(dec4), enc2)
                        dec3 = unet.decoder3(dec3)
                        dec2 = unet.crop_and_concat(unet.upconv2(dec3), enc1)
                        dec2 = unet.decoder2(dec2)
                        final_reconstructed_image = unet.decoder1(dec2)

                        # Convert to numpy for visualization
                        final_reconstructed_image_np = final_reconstructed_image.squeeze().cpu().numpy()

                        # Convert residuals to numpy for visualization
                        residuals_np = residuals_reshaped.squeeze()  # Residuals are already in NumPy format

                        # Plot original, reconstructed, and residual images
                        plt.figure(figsize=(15, 5))
                        plt.subplot(1, 3, 1)
                        plt.imshow(image, cmap='gray')
                        plt.title(f'Original {test_category} Image {i+1}')
                        plt.axis('off')

                        plt.subplot(1, 3, 2)
                        plt.imshow(final_reconstructed_image_np, cmap='gray')
                        plt.title(f'Reconstructed {test_category} Image {i+1} with {pca_category} PCA')
                        plt.axis('off')

                        plt.subplot(1, 3, 3)
                        plt.imshow(residuals_np[0], cmap='gray')  # Select the first channel for visualization
                        plt.title(f'Residuals {test_category} Image {i+1} with {pca_category} PCA')
                        plt.axis('off')

                        plt.show()

    # Calculate average OOD score for each category and PCA
    avg_ood_scores_by_category = {}
    for test_category, pca_ood_scores in ood_scores_by_category.items():
        avg_ood_scores_by_category[test_category] = {}
        for pca_category, ood_scores in pca_ood_scores.items():
            avg_ood_score = np.mean(ood_scores)
            avg_ood_scores_by_category[test_category][pca_category] = avg_ood_score
            print(f"Average OOD Score for {test_category} with {pca_category} PCA: {avg_ood_score}")

    return avg_ood_scores_by_category

# Assuming you have already trained your U-Net and saved the PCA models for each category
avg_ood_scores_by_category = compute_ood_with_selected_components(unet, pca_dict, test_preprocessed_images_by_category)

# Print the average OOD scores for all categories
for test_category, pca_ood_scores in avg_ood_scores_by_category.items():
    print(f"\nAverage OOD Scores for test category '{test_category}':")
    for pca_category, avg_ood in pca_ood_scores.items():
        print(f"  PCA from '{pca_category}' => Average OOD Score: {avg_ood}")



## All environments

In [None]:
X = df['image_path'].tolist()
y = df['category'].tolist()
unique_categories = list(df['category'].unique())
print(f"Unique categories: {unique_categories}")

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=10)

standard_size = (224, 224)

In [None]:
create_images_set(X_train, X_test, y_train, y_test, output_dir_train='images_train', output_dir_test='images_test', standard_size=standard_size)

In [None]:
all_training_preprocessed_images_by_category = load_and_preprocess_test_images('images_train', y, image_size=(256,256), input_size=(256,256))

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, TensorDataset
from sklearn.decomposition import PCA

# Verifica o dispositivo disponível (GPU ou CPU)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

# Define o modelo UNet para codificar e decodificar
class UNet(nn.Module):
    def __init__(self, in_channels=1, out_channels=1):
        super(UNet, self).__init__()

        # Encoder
        self.encoder1 = self.contracting_block(in_channels, 64)
        self.encoder2 = self.contracting_block(64, 128)
        self.encoder3 = self.contracting_block(128, 256)
        self.encoder4 = self.contracting_block(256, 512)

        # Decoder
        self.upconv4 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.decoder4 = self.expansive_block(512, 256)
        self.upconv3 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.decoder3 = self.expansive_block(256, 128)
        self.upconv2 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.decoder2 = self.expansive_block(128, 64)
        self.decoder1 = self.final_block(64, out_channels)

    def contracting_block(self, in_channels, out_channels):
        block = nn.Sequential(
            nn.Conv2d(kernel_size=3, in_channels=in_channels, out_channels=out_channels, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Conv2d(kernel_size=3, in_channels=out_channels, out_channels=out_channels, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Dropout(0.5)
        )
        return block

    def expansive_block(self, in_channels, out_channels):
        block = nn.Sequential(
            nn.Conv2d(kernel_size=3, in_channels=in_channels, out_channels=out_channels, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Conv2d(kernel_size=3, in_channels=out_channels, out_channels=out_channels, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Dropout(0.5)
        )
        return block

    def final_block(self, in_channels, out_channels):
        block = nn.Sequential(
            nn.Conv2d(kernel_size=3, in_channels=in_channels, out_channels=out_channels, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Conv2d(kernel_size=3, in_channels=out_channels, out_channels=out_channels, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Conv2d(kernel_size=1, in_channels=out_channels, out_channels=out_channels),
            nn.Sigmoid()  # Adiciona ativação sigmoid para garantir que a saída fique entre [0, 1]
        )
        return block

    def crop_and_concat(self, upsampled, bypass):
        _, _, H, W = upsampled.size()
        _, _, H_b, W_b = bypass.size()
        if H_b != H or W_b != W:
            bypass = nn.functional.interpolate(bypass, size=(H, W), mode='bilinear', align_corners=True)
        return torch.cat((upsampled, bypass), 1)

    def forward(self, x):
        # Encoder path
        enc1 = self.encoder1(x)
        enc2 = self.encoder2(nn.functional.max_pool2d(enc1, kernel_size=2, stride=2))
        enc3 = self.encoder3(nn.functional.max_pool2d(enc2, kernel_size=2, stride=2))
        enc4 = self.encoder4(nn.functional.max_pool2d(enc3, kernel_size=2, stride=2))

        # Decoder path
        dec4 = self.crop_and_concat(self.upconv4(enc4), enc3)
        dec4 = self.decoder4(dec4)
        dec3 = self.crop_and_concat(self.upconv3(dec4), enc2)
        dec3 = self.decoder3(dec3)
        dec2 = self.crop_and_concat(self.upconv2(dec3), enc1)
        dec2 = self.decoder2(dec2)
        dec1 = self.decoder1(dec2)

        return dec1, enc1, enc2, enc3, enc4  # Retorna tanto a imagem reconstruída quanto as features do encoder

def apply_pca_and_reconstruct(features, pca):
    # Flatten the features while keeping track of the original shape
    batch_size, channels, height, width = features.size()
    flattened_features = features.view(batch_size, -1).cpu().numpy()

    # Verifica as dimensões originais das features
    original_num_features = flattened_features.shape[1]
    print(f"Original number of features before PCA: {original_num_features}")

    # Apply PCA projection and reconstruction
    projected_features = pca.transform(flattened_features)
    num_pca_components = projected_features.shape[1]
    print(f"Number of components after PCA: {num_pca_components}")

    reconstructed_features = pca.inverse_transform(projected_features)
    #print(f"Reconstructed feature dimensions (after inverse PCA): {reconstructed_features.shape}")

    # Convert back to tensor and reshape to the original feature shape
    reconstructed_features = torch.tensor(reconstructed_features, dtype=torch.float32).view(batch_size, channels, height, width).to(device)

    return reconstructed_features

# Função para treinar e aplicar PCA nas features extraídas da U-Net com verificação de variância explicada
def train_unet_and_apply_pca(unet, data_loader, num_epochs=5):
    criterion = nn.MSELoss()
    optimizer = torch.optim.Adam(unet.parameters(), lr=0.0001)

    global pca_train_features

    for epoch in range(num_epochs):
        unet.train()
        epoch_loss = 0
        for images, in data_loader:
            images = images.to(device).float()
            images /= 255.0

            optimizer.zero_grad()

            # Forward pass through U-Net to get features
            reconstructed_images, _, _, _, unet_features = unet(images)

            # Flatten and collect features for PCA training (in the first epoch)
            if epoch == 0:
                flattened_features = unet_features.view(images.shape[0], -1).detach().cpu().numpy()
                if 'pca_train_features' not in globals():
                    pca_train_features = flattened_features
                else:
                    pca_train_features = np.vstack((pca_train_features, flattened_features))

            loss = criterion(reconstructed_images, images)
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()

        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss/len(data_loader):.4f}')

    # Calcula e exibe o número de features antes do PCA
    original_feature_count = pca_train_features.shape[1]
    print(f"Original number of features before PCA: {original_feature_count}")

    # Aplicando o PCA e mostrando a variância explicada
    pca = PCA(n_components=0.95)  # Retain 95% variance
    pca.fit(pca_train_features)
    print(f"Explained variance by PCA components: {pca.explained_variance_ratio_}")
    print(f"Number of components chosen by PCA: {pca.n_components_}")

    return pca

# Inicialização do modelo UNet
unet = UNet().to(device)

# Dicionário para armazenar o PCA de cada categoria
pca_dict = {}

# Defina o DataLoader aqui
categories = centralized_images_by_category.keys()

for category in categories:
    images = centralized_images_by_category[category]

    # Adiciona a dimensão do canal e cria o DataLoader
    images = np.expand_dims(images, axis=1)  # Adiciona dimensão do canal
    dataset = TensorDataset(torch.tensor(images, dtype=torch.float32))
    loader = DataLoader(dataset, batch_size=1, shuffle=True)

    # Treina o modelo U-Net e coleta features para o PCA
    pca = train_unet_and_apply_pca(unet, loader, num_epochs=5)

    # Armazena o PCA no dicionário para a categoria atual
    pca_dict[category] = pca

    # Avalia e visualiza algumas imagens reconstruídas com PCA aplicado nas features
    unet.eval()
    with torch.no_grad():
        for i in range(3):  # Exibe 3 imagens de cada categoria
            image = centralized_images_by_category[category][i]
            image_tensor = torch.tensor(image / 255.0).unsqueeze(0).unsqueeze(0).float().to(device)

            # Processa com UNet e extrai features
            reconstructed_image, enc1, enc2, enc3, unet_features = unet(image_tensor)

            # Aplica PCA nas features e reconstrói as features
            reconstructed_features = apply_pca_and_reconstruct(unet_features, pca)

            # Passa as features reconstruídas pelo caminho de decodificação da U-Net
            dec4 = unet.crop_and_concat(unet.upconv4(reconstructed_features), enc3)
            dec4 = unet.decoder4(dec4)
            dec3 = unet.crop_and_concat(unet.upconv3(dec4), enc2)
            dec3 = unet.decoder3(dec3)
            dec2 = unet.crop_and_concat(unet.upconv2(dec3), enc1)
            dec2 = unet.decoder2(dec2)

            # Corrige o número de canais esperados pela última camada de convolução
            final_reconstructed_image = unet.decoder1(dec2)

            # Converte o tensor para numpy para visualização
            final_reconstructed_image_np = final_reconstructed_image.squeeze().cpu().numpy()

            # Visualiza as imagens
            plt.figure(figsize=(10, 5))
            plt.subplot(1, 2, 1)
            plt.imshow(image, cmap='gray')
            plt.title(f'Original {category} Image {i+1}')
            plt.axis('off')

            plt.subplot(1, 2, 2)
            plt.imshow(final_reconstructed_image_np, cmap='gray')
            plt.title(f'Reconstructed {category} Image {i+1} with PCA')
            plt.axis('off')

            plt.show()

# Salva o dicionário de PCA para uso posterior
print("PCA models saved for each category.")