In [1]:
import sys
sys.path.append("../src")
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import torch
from data_preprocessing.data_loader import load_sun397_dataset

Load Data:

In [None]:
dataloader = load_sun397_dataset(batch_size=1)

Visulize a few samples:

In [None]:
def display_image_with_label(image_tensor, label):
    # Convert the tensor to a NumPy array and rearrange dimensions
    image_array = image_tensor.permute(1, 2, 0).numpy()
    image_array = (image_array - image_array.min()) / (image_array.max() - image_array.min())  # Normalize for display

    # Create a DataFrame for visualization
    df_image = pd.DataFrame(image_array.reshape(-1, 3), columns=["R", "G", "B"])
    df_image["Pixel Row"] = np.tile(np.arange(image_array.shape[0]), image_array.shape[1])
    df_image["Pixel Col"] = np.repeat(np.arange(image_array.shape[1]), image_array.shape[0])
    
    # Visualizing the intensity of one channel (R - Red) for simplicity
    sns.set(style="whitegrid")
    heatmap_data = image_array[:, :, 0]  # Select the R channel
    ax = sns.heatmap(heatmap_data,cmap="Reds",cbar=True,square=True,xticklabels=False,yticklabels=False,)
    ax.set_title(f"Image Label: {label}")
    sns.despine()
    plt.show()
    
# Iterate through the dataset and visualize images
for i, (images, labels) in enumerate(dataloader):
    if i == 3:  # Limit to first 3 images for demonstration
        break
    print(f"Label: {labels.item()}")
    display_image_with_label(images[0], label=labels.item())

# Basic Training Dataset Summary with Seaborn
def plot_dataset_distribution(dataloader):
    label_counts = {}
    for _, labels in dataloader:
        label = labels.item()
        label_counts[label] = label_counts.get(label, 0) + 1

    label_df = pd.DataFrame.from_dict(label_counts, orient="index", columns=["Count"]).reset_index()
    label_df.columns = ["Label", "Count"]

    sns.barplot(data=label_df, x="Label", y="Count", palette="muted")
    plt.title("Dataset Label Distribution")
    plt.xlabel("Labels")
    plt.ylabel("Count")
    sns.despine()
    plt.show()


# Plot label distribution
plot_dataset_distribution(dataloader)