<a href="https://colab.research.google.com/github/samim23/hyperdimensional_computing_playground/blob/main/Hyperdimensional_Computing_Playground.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Hyperdimensional Computing (HDC) Playground

This notebook is a hands-on exploration of [Hyperdimensional Computing (HDC)](https://en.wikipedia.org/wiki/Hyperdimensional_computing) created by [Samim](https://samim.io) for learning and experimentation purposes.

**In this notebook, you'll find:**

1. A very simple HDC toy example: A quick introduction to the basics of hyperdimensional computing.
2. A simple HDC MNIST encoding and reconstruction example: Demonstrating how hypervectors can represent and reconstruct handwritten digits.
3. A simple HDC TinyImageNet encoding and reconstruction example: Scaling up HDC to handle encoding and reconstruction of images, features and labels from a larger, more complex dataset.
4. An intermediate HDC TinyImageNet encoding and reconstruction example, leveraging techniques like circular convolution, permutation, and sparse hypervectors for efficiency

More HDC reading: [Simple HDC Intro](https://journals.plos.org/ploscompbiol/article?id=10.1371/journal.pcbi.1012426) & [Classic HDC Paper](https://www.rctn.org/vs265/kanerva09-hyperdimensional.pdf)

<br/>

![HDC](https://samim.io/static/upload/journal.pcbi.1012426.g001.PNG)
![HDC](https://samim.io/static/upload/journal.pcbi.1012426.g002.PNG)


# 1. Toy Example

This example introduces the basics of Hyperdimensional Computing (HDC). Here, we encode "toys" with specific attributes (e.g., "red," "hero," "vehicle") into high-dimensional vectors (hypervectors). By combining attributes through element-wise multiplication (binding), we create unique hypervectors representing different toys.

Key concepts demonstrated:

- Random High-Dimensional Vectors: Attributes like "red" and "hero" are represented as random binary hypervectors.
- Binding and Encoding: Toys are encoded by combining their attributes using element-wise multiplication, a key operation in HDC.
- Cosine Similarity for Queries: A query hypervector (e.g., "red") is compared to the database of toys to find similar items based on their encoded attributes.
- Visualization with PCA: The high-dimensional hypervectors are reduced to 2D using PCA for visualization, showing how encoded toys cluster in space.

This example shows how HDC can combine, store, and retrieve complex information efficiently while maintaining simplicity and robustness.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA

# Step 1: Create random high-dimensional vectors
def create_random_vector(dimensions=10000):
    return np.random.choice([-1, 1], size=dimensions)

# Cosine similarity function
def cosine_similarity(vec1, vec2):
    return np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))

# Encode a toy by combining attributes
def encode_toy(*attributes_to_combine):
    encoded_vector = np.ones_like(attributes_to_combine[0])  # Start with a neutral vector
    for attr in attributes_to_combine:
        encoded_vector *= attr  # Element-wise multiplication (binding)
    return encoded_vector

# Query the database for similar vectors
def query_database(query_vector, database, threshold=0.5):
    results = []
    for idx, toy in enumerate(database):
        similarity = cosine_similarity(query_vector, toy)
        if similarity > threshold:  # Matches exceed threshold
            results.append((idx, similarity))
    return results

# Step 2: Create attributes as random high-dimensional vectors
dimensions = 10000
attributes = {
    "red": create_random_vector(dimensions),
    "blue": create_random_vector(dimensions),
    "action_figure": create_random_vector(dimensions),
    "hero": create_random_vector(dimensions),
    "vehicle": create_random_vector(dimensions)
}

# Step 3: Encode some toys
red_hero_action_figure = encode_toy(attributes["red"], attributes["hero"], attributes["action_figure"])
blue_vehicle = encode_toy(attributes["blue"], attributes["vehicle"])
red_vehicle = encode_toy(attributes["red"], attributes["vehicle"])

# Store toys in a database
toy_database = [red_hero_action_figure, blue_vehicle, red_vehicle]

# Step 4: Query for red toys
red_query = attributes["red"]
matches = query_database(red_query, toy_database, threshold=0.5)

# Step 5: Visualize the high-dimensional space using PCA
# Reduce dimensionality for visualization
pca = PCA(n_components=2)
data_matrix = np.vstack([red_query] + toy_database)  # Combine query and database
reduced_data = pca.fit_transform(data_matrix)

# Plot the reduced vectors
plt.figure(figsize=(10, 8))
plt.scatter(reduced_data[1:, 0], reduced_data[1:, 1], color='blue', label='Toys')
plt.scatter(reduced_data[0, 0], reduced_data[0, 1], color='red', label='Red Query', s=100)
for i, (x, y) in enumerate(reduced_data[1:], start=1):
    plt.text(x, y, f"Toy {i}", fontsize=10, color="blue")

plt.title("Visualization of High-Dimensional Vectors in 2D Space")
plt.xlabel("PCA Dimension 1")
plt.ylabel("PCA Dimension 2")
plt.legend()
plt.grid(True)
plt.show()

# Print query results
for idx, similarity in matches:
    print(f"Toy {idx + 1} matches the query with similarity: {similarity:.2f}")

## 2. Hyperdimensional Computing for Image Reconstruction with MNIST

This example applies Hyperdimensional Computing (HDC) to encode and reconstruct images from the MNIST dataset. Images are transformed into hypervectors using positional and brightness components, while a neural encoder-decoder network reconstructs the original images from these representations. The quality of reconstruction is assessed using metrics like Mean Squared Error (MSE) and Structural Similarity Index (SSIM). Latent space exploration visualizes smooth transitions between encoded representations, showcasing the interpretability and expressive power of HDC in capturing image structure.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import numpy as np
import matplotlib.pyplot as plt
from skimage.metrics import structural_similarity as ssim
from sklearn.metrics import mean_squared_error
import imageio
from IPython.display import display, HTML

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Hyperparameters
dimensions = 20000  # Dimensionality of hypervectors
image_shape = (28, 28)  # Image size
batch_size = 32
learning_rate = 0.001
epochs = 20

# Step 1: Random High-Dimensional Vectors
def create_random_hv(dimensions):
    """Create a random hypervector."""
    return torch.randint(-1, 2, (dimensions,), device=device, dtype=torch.float32)

def generate_position_hvs(image_shape, dimensions):
    """Generate positional hypervectors for each pixel."""
    return torch.stack([create_random_hv(dimensions) for _ in range(image_shape[0] * image_shape[1])])

def encode_image_with_position_and_brightness(image, position_hvs, brightness_hv):
    """Encode an image using position and brightness hypervectors."""
    image_flat = image.flatten().to(device)  # Flatten the image
    encoded_hv = torch.zeros(dimensions, device=device)
    for i, intensity in enumerate(image_flat):
        encoded_hv += intensity * position_hvs[i] * brightness_hv
    return torch.sign(encoded_hv)

# Generate position and brightness hypervectors
position_hvs = generate_position_hvs(image_shape, dimensions)
brightness_hv = create_random_hv(dimensions)

# Step 2: Dataset Preparation
transform = transforms.Compose([transforms.ToTensor()])

train_dataset = datasets.MNIST(root="./data", train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root="./data", train=False, transform=transform, download=True)

train_dataset_small = torch.utils.data.Subset(train_dataset, range(1000))
test_dataset_small = torch.utils.data.Subset(test_dataset, range(100))

train_loader = DataLoader(train_dataset_small, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset_small, batch_size=batch_size, shuffle=False)

# Step 3: Encoder and Decoder Networks
class Encoder(nn.Module):
    def __init__(self, dimensions):
        super(Encoder, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 512)
        self.fc2 = nn.Linear(512, dimensions)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = torch.tanh(self.fc2(x))
        return x

class Decoder(nn.Module):
    def __init__(self, dimensions):
        super(Decoder, self).__init__()
        self.fc1 = nn.Linear(dimensions, 512)
        self.fc2 = nn.Linear(512, 28 * 28)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = torch.sigmoid(self.fc2(x))
        x = x.view(-1, 1, 28, 28)
        return x

encoder = Encoder(dimensions).to(device)
decoder = Decoder(dimensions).to(device)

# Optimizers
optimizer_encoder = torch.optim.Adam(encoder.parameters(), lr=learning_rate)
optimizer_decoder = torch.optim.Adam(decoder.parameters(), lr=learning_rate)

# Mixed Precision Training
scaler = torch.amp.GradScaler()

# Step 4: Training Loop
def train(encoder, decoder, train_loader, optimizer_encoder, optimizer_decoder, scaler):
    encoder.train()
    decoder.train()

    for epoch in range(epochs):
        epoch_loss = 0
        for batch_idx, (images, _) in enumerate(train_loader):
            images = images.to(device)

            # Forward pass with mixed precision
            with torch.cuda.amp.autocast(enabled=True):  # Explicitly enable autocast
                encoded_hvs = encoder(images)
                reconstructed_images = decoder(encoded_hvs)
                loss = F.mse_loss(reconstructed_images, images)

            # Backpropagation
            optimizer_encoder.zero_grad()
            optimizer_decoder.zero_grad()
            scaler.scale(loss).backward()

            # Ensure gradients are valid before stepping
            try:
                scaler.step(optimizer_encoder)
                scaler.step(optimizer_decoder)
            except AssertionError:
                print("Warning: Skipping optimizer step due to invalid gradients.")
                continue

            scaler.update()

            epoch_loss += loss.item()

            if batch_idx % 10 == 0:  # Feedback every 10 batches
                print(f"Epoch [{epoch+1}/{epochs}], Batch [{batch_idx}/{len(train_loader)}], Loss: {loss.item():.4f}")

        print(f"Epoch [{epoch+1}/{epochs}], Average Loss: {epoch_loss / len(train_loader):.4f}")

train(encoder, decoder, train_loader, optimizer_encoder, optimizer_decoder, scaler)

# Step 5: Testing and Metrics
def test(encoder, decoder, test_loader):
    encoder.eval()
    decoder.eval()

    mse_scores = []
    ssim_scores = []

    with torch.no_grad():
        for images, _ in test_loader:
            images = images.to(device)

            # Encode and decode
            encoded_hvs = encoder(images)
            reconstructed_images = decoder(encoded_hvs)

            # Move to CPU for metrics
            reconstructed_images_np = reconstructed_images.cpu().numpy()
            images_np = images.cpu().numpy()

            # Compute metrics
            for i in range(images_np.shape[0]):
                mse = mean_squared_error(images_np[i].flatten(), reconstructed_images_np[i].flatten())
                mse_scores.append(mse)

                ssim_score = ssim(
                    images_np[i].squeeze(), reconstructed_images_np[i].squeeze(), data_range=1.0
                )
                ssim_scores.append(ssim_score)

    print(f"Average MSE: {np.mean(mse_scores):.4f}")
    print(f"Average SSIM: {np.mean(ssim_scores):.4f}")
    return mse_scores, ssim_scores

mse_scores, ssim_scores = test(encoder, decoder, test_loader)

# Visualization: Original vs Reconstructed
def visualize_reconstruction(encoder, decoder, test_loader):
    encoder.eval()
    decoder.eval()

    images, _ = next(iter(test_loader))
    images = images.to(device)

    with torch.no_grad():
        encoded_hvs = encoder(images)
        reconstructed_images = decoder(encoded_hvs)

    reconstructed_images_np = reconstructed_images.cpu().numpy()
    images_np = images.cpu().numpy()

    # Plot original and reconstructed images
    for i in range(5):
        plt.figure(figsize=(6, 3))
        # Original
        plt.subplot(1, 2, 1)
        plt.imshow(images_np[i].squeeze(), cmap="gray")
        plt.title("Original Image")
        plt.axis("off")

        # Reconstructed
        plt.subplot(1, 2, 2)
        plt.imshow(reconstructed_images_np[i].squeeze(), cmap="gray")
        plt.title("Reconstructed Image")
        plt.axis("off")
        plt.show()

visualize_reconstruction(encoder, decoder, test_loader)

# Latent Space Exploration with GIF Display
def walk_latent_space(encoder, decoder, test_loader):
    encoder.eval()
    decoder.eval()

    with torch.no_grad():
        images, _ = next(iter(test_loader))
        images = images.to(device)

        hv_start = encoder(images[0:1])
        hv_end = encoder(images[1:2])

        steps = 20
        latent_walk = [
            hv_start + t / steps * (hv_end - hv_start)
            for t in range(steps + 1)
        ]

        interpolated_images = [decoder(hv).squeeze(0).cpu().numpy() for hv in latent_walk]

        # Save GIF without scaling
        gif_filename = "latent_walk.gif"
        with imageio.get_writer(gif_filename, mode="I", duration=0.2, loop=0) as writer:  # `loop=0` ensures infinite looping
            for img in interpolated_images:
                img = (img * 255).astype(np.uint8).squeeze()
                writer.append_data(img)

        print(f"Latent walk GIF saved as {gif_filename}")

        # Display the GIF inline in the notebook
        display(Image(filename=gif_filename))


walk_latent_space(encoder, decoder, test_loader)

## 3. Hyperdimensional Computing for Image Reconstruction with TinyImageNet: Simple

This script demonstrates Hyperdimensional Computing (HDC) for encoding and reconstructing images from the TinyImageNet dataset. It encodes images into hypervectors using positional and brightness components, and labels using class hypervectors, capturing hierarchical relationships. A neural decoder reconstructs the images, and quality is evaluated using MSE and SSIM. Latent space exploration and visualizations highlight HDC's interpretability and power.

In [None]:
import os
import requests
import zipfile
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import mean_squared_error
from skimage.metrics import structural_similarity as ssim
from IPython.display import display, Image

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Configuration
use_subset = True  # Toggle this to use a subset for training and testing
dimensions = 10000
image_shape = (64, 64, 3)  # Image dimensions: height, width, channels
batch_size = 32
learning_rate = 0.001
epochs = 150

# Download and Extract TinyImageNet
def download_and_extract_tinyimagenet():
    url = "http://cs231n.stanford.edu/tiny-imagenet-200.zip"
    dataset_folder = "./tiny-imagenet-200"
    zip_filename = "tiny-imagenet-200.zip"

    if not os.path.exists(dataset_folder):
        print("Downloading TinyImageNet...")
        response = requests.get(url, stream=True)
        with open(zip_filename, "wb") as f:
            f.write(response.content)
        print("Download complete. Extracting...")
        with zipfile.ZipFile(zip_filename, "r") as zip_ref:
            zip_ref.extractall("./")
        print("Extraction complete.")
        os.remove(zip_filename)
    else:
        print("TinyImageNet already downloaded and extracted.")

download_and_extract_tinyimagenet()

# Dataset Preparation
transform = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor()
])

# Use TinyImageNet Dataset
train_dataset = datasets.ImageFolder(root='./tiny-imagenet-200/train', transform=transform)
test_dataset = datasets.ImageFolder(root='./tiny-imagenet-200/val', transform=transform)

# Optional Subset
if use_subset:
    train_dataset = torch.utils.data.Subset(train_dataset, range(2000))
    test_dataset = torch.utils.data.Subset(test_dataset, range(500))

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# Hyperdimensional Computing
def create_random_hv(dimensions):
    """Create a random hypervector."""
    return torch.randint(-1, 2, (dimensions,), device=device, dtype=torch.float32)

def generate_position_hvs(image_shape, dimensions):
    """Generate positional hypervectors for each pixel across all channels."""
    total_pixels = image_shape[0] * image_shape[1] * image_shape[2]  # Include channels
    return torch.stack([create_random_hv(dimensions) for _ in range(total_pixels)])

def encode_image_with_position_and_brightness_optimized(image, position_hvs, brightness_hv):
    """Optimized encoding to reduce memory footprint."""
    image_flat = image.flatten().to(device)
    encoded_hv = torch.matmul(image_flat[:, None].T, position_hvs * brightness_hv).squeeze()
    return torch.sign(encoded_hv)

# Generate position and brightness hypervectors
position_hvs = generate_position_hvs(image_shape, dimensions)
brightness_hv = create_random_hv(dimensions)

# Label Encoding
num_classes = len(train_dataset.dataset.classes)
label_hvs = torch.stack([create_random_hv(dimensions) for _ in range(num_classes)], dim=0)

def encode_label_vectorized(labels):
    """Encode labels using precomputed label hypervectors (vectorized)."""
    return label_hvs[labels]

# Decoder
class Decoder(nn.Module):
    def __init__(self, dimensions):
        super(Decoder, self).__init__()
        self.fc1 = nn.Linear(dimensions, 512)
        self.fc2 = nn.Linear(512, 64 * 64 * 3)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = torch.sigmoid(self.fc2(x))
        x = x.view(-1, 3, 64, 64)  # Reshape to image
        return x

decoder = Decoder(dimensions).to(device)
optimizer_decoder = torch.optim.Adam(decoder.parameters(), lr=learning_rate)

# train_and_test function
def train_and_test(train_loader, test_loader):
    for epoch in range(epochs):
        decoder.train()
        epoch_loss = 0

        for batch_idx, (images, labels) in enumerate(train_loader):
            images, labels = images.to(device), labels.to(device)

            # Encode images and labels
            encoded_images = torch.stack([
                encode_image_with_position_and_brightness_optimized(image, position_hvs, brightness_hv)
                for image in images
            ])
            encoded_labels = encode_label_vectorized(labels)
            encoded_hvs = encoded_images + encoded_labels

            reconstructed_images = decoder(encoded_hvs)
            loss = F.mse_loss(reconstructed_images, images)

            optimizer_decoder.zero_grad()
            loss.backward()
            optimizer_decoder.step()

            epoch_loss += loss.item()
            if batch_idx % 10 == 0:
                print(f"Epoch [{epoch+1}/{epochs}], Batch [{batch_idx}/{len(train_loader)}], Loss: {loss.item():.4f}")

            # Free memory
            torch.cuda.empty_cache()

        print(f"Epoch [{epoch+1}/{epochs}], Average Loss: {epoch_loss / len(train_loader):.4f}")

    # Enhanced Testing
    decoder.eval()
    mse_scores = []
    ssim_scores = []
    latent_walk_gif_filename = "latent_space_exploration.gif"
    gif_images = []

    with torch.no_grad():
        for batch_idx, (images, labels) in enumerate(test_loader):
            images, labels = images.to(device), labels.to(device)

            encoded_images = torch.stack([
                encode_image_with_position_and_brightness_optimized(image, position_hvs, brightness_hv)
                for image in images
            ])
            encoded_labels = encode_label_vectorized(labels)
            encoded_hvs = encoded_images + encoded_labels

            reconstructed_images = decoder(encoded_hvs)

            # Metrics and visualization
            for i in range(images.size(0)):
                original_image = images[i].cpu().numpy().transpose(1, 2, 0)
                reconstructed_image = reconstructed_images[i].cpu().numpy().transpose(1, 2, 0)

                mse = mean_squared_error(original_image.flatten(), reconstructed_image.flatten())
                mse_scores.append(mse)

                # Specify win_size and channel_axis for SSIM
                similarity = ssim(
                    original_image,
                    reconstructed_image,
                    multichannel=True,
                    data_range=1.0,
                    win_size=3,
                    channel_axis=-1
                )
                ssim_scores.append(similarity)

                # Save first few examples for GIF creation
                if batch_idx == 0 and i < 5:
                    gif_images.append((reconstructed_image * 255).astype(np.uint8))

                # Display the first batch for class-wise reconstruction
                if batch_idx == 0 and i < 5:
                    plt.figure(figsize=(6, 3))
                    plt.subplot(1, 2, 1)
                    plt.imshow(original_image)
                    plt.title("Original Image")
                    plt.axis("off")

                    plt.subplot(1, 2, 2)
                    plt.imshow(reconstructed_image)
                    plt.title("Reconstructed Image")
                    plt.axis("off")
                    plt.show()

            if batch_idx == 0:
                break

    print(f"Average MSE: {np.mean(mse_scores):.4f}")
    print(f"Average SSIM: {np.mean(ssim_scores):.4f}")

    # Save latent space exploration as GIF
    try:
        import imageio
        imageio.mimsave(latent_walk_gif_filename, gif_images, fps=2)
        print(f"Latent space exploration GIF saved as {latent_walk_gif_filename}")
    except ImportError:
        print("Install imageio for GIF generation.")

# Run the training and testing
train_and_test(train_loader, test_loader)



## 4. Hyperdimensional Computing for Image Reconstruction with TinyImageNet: Complex

This script implements a Hyperdimensional Computing (HDC)-based autoencoder for image reconstruction using the TinyImageNet dataset. It encodes images into high-dimensional hypervectors by combining positional and brightness information, leveraging techniques like circular convolution, permutation, and sparse hypervectors for efficiency. The encoded hypervectors are then decoded back into images using a neural network. The script includes data augmentation, dropout regularization, and mixed precision training to improve robustness and performance. It evaluates the model using metrics like MSE, SSIM, and PSNR, and visualizes the results by comparing original and reconstructed images, as well as generating a GIF of the reconstructions. The entire pipeline is optimized for speed and scalability, making it suitable for exploring HDC principles in image processing tasks.

In [None]:
import os
import requests
import zipfile
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import mean_squared_error
from skimage.metrics import structural_similarity as ssim
from tqdm import tqdm
import imageio
from torch.cuda.amp import GradScaler, autocast

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Configuration
use_subset = True  # Toggle this to use a subset for training and testing
dimensions = 5000  # Reduced dimensionality for faster computation
image_shape = (64, 64, 3)  # Image dimensions: height, width, channels
batch_size = 16
learning_rate = 0.001
epochs = 30  # Reduced epochs for faster testing
noise_level = 0.1  # Noise for robustness
brightness_levels = 256  # 8-bit grayscale
sparsity = 0.1  # Sparsity level for hypervectors

# Download and Extract TinyImageNet
def download_and_extract_tinyimagenet():
    url = "http://cs231n.stanford.edu/tiny-imagenet-200.zip"
    dataset_folder = "./tiny-imagenet-200"
    zip_filename = "tiny-imagenet-200.zip"

    if not os.path.exists(dataset_folder):
        print("Downloading TinyImageNet...")
        response = requests.get(url, stream=True)
        with open(zip_filename, "wb") as f:
            f.write(response.content)
        print("Download complete. Extracting...")
        with zipfile.ZipFile(zip_filename, "r") as zip_ref:
            zip_ref.extractall("./")
        print("Extraction complete.")
        os.remove(zip_filename)
    else:
        print("TinyImageNet already downloaded and extracted.")

download_and_extract_tinyimagenet()

# Dataset Preparation with Data Augmentation
transform = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.RandomHorizontalFlip(),  # Data augmentation
    transforms.RandomRotation(10),  # Data augmentation
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # Normalize to [-1, 1]
])

# Use TinyImageNet Dataset
train_dataset = datasets.ImageFolder(root='./tiny-imagenet-200/train', transform=transform)
test_dataset = datasets.ImageFolder(root='./tiny-imagenet-200/val', transform=transform)

# Optional Subset
if use_subset:
    train_dataset = torch.utils.data.Subset(train_dataset, range(500))  # Smaller subset
    test_dataset = torch.utils.data.Subset(test_dataset, range(100))

# Increase number of workers in DataLoader
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=8, pin_memory=True)

# Hyperdimensional Computing
def create_random_hv(dimensions):
    """Create a random hypervector with ternary values {-1, 0, 1}."""
    return torch.randint(-1, 2, (dimensions,), device=device, dtype=torch.float32)

def create_sparse_hv(dimensions, sparsity=0.1):
    """Create a sparse hypervector with a given sparsity level."""
    hv = torch.zeros(dimensions, device=device)
    nonzero_indices = torch.randperm(dimensions)[:int(dimensions * sparsity)]
    hv[nonzero_indices] = torch.randint(-1, 2, (int(dimensions * sparsity),), device=device, dtype=torch.float32)
    return hv

def circular_convolution(hv1, hv2):
    """Circular convolution for binding hypervectors."""
    return torch.fft.irfft(torch.fft.rfft(hv1) * torch.fft.rfft(hv2), n=dimensions)

def permute(hv, shift):
    """Cyclic shift (permutation) of a hypervector."""
    return torch.roll(hv, shifts=shift, dims=0)

def generate_position_hvs(image_shape, dimensions):
    """Generate positional hypervectors using random permutations."""
    base_hv = create_sparse_hv(dimensions, sparsity)
    total_pixels = image_shape[0] * image_shape[1] * image_shape[2]  # Include channels
    position_hvs = torch.stack([permute(base_hv, i) for i in range(total_pixels)])
    return position_hvs

def encode_batch_with_position_and_brightness(images, position_hvs, brightness_hvs):
    """Encode a batch of images using sparse matrix operations."""
    batch_size, channels, height, width = images.shape
    images_flat = images.view(batch_size, -1).to(device)  # Flatten images
    intensity_indices = ((images_flat + 1) * (brightness_levels - 1) / 2).long()  # Scale to brightness levels
    brightness_hvs_selected = brightness_hvs[intensity_indices]  # Select brightness hypervectors
    encoded_hvs = torch.sum(brightness_hvs_selected * position_hvs, dim=1)  # Vectorized sum
    return torch.sign(encoded_hvs)  # Quantize to ternary

# Generate position and brightness hypervectors
position_hvs = generate_position_hvs(image_shape, dimensions)
brightness_hvs = torch.stack([create_sparse_hv(dimensions, sparsity) for _ in range(brightness_levels)])

# Label Encoding
num_classes = len(train_dataset.dataset.classes)
label_hvs = torch.stack([create_sparse_hv(dimensions, sparsity) for _ in range(num_classes)], dim=0)

def encode_label_vectorized(labels):
    """Encode labels using precomputed label hypervectors (vectorized)."""
    return label_hvs[labels]

# Hybrid Decoder with Dropout
class Decoder(nn.Module):
    def __init__(self, dimensions):
        super(Decoder, self).__init__()
        self.fc1 = nn.Linear(dimensions, 512)
        self.dropout = nn.Dropout(0.5)  # Add dropout
        self.fc2 = nn.Linear(512, 64 * 64 * 3)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.dropout(x)  # Apply dropout
        x = torch.tanh(self.fc2(x))  # Output in [-1, 1]
        x = x.view(-1, 3, 64, 64)  # Reshape to image
        return x

decoder = Decoder(dimensions).to(device)
optimizer_decoder = torch.optim.Adam(decoder.parameters(), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer_decoder, step_size=10, gamma=0.5)  # Learning rate scheduler
scaler = GradScaler()  # For mixed precision training

# Training and Testing
def train_and_test(train_loader, test_loader):
    for epoch in range(epochs):
        decoder.train()
        epoch_loss = 0

        for batch_idx, (images, labels) in enumerate(tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}")):
            images, labels = images.to(device), labels.to(device)

            # Encode images and labels
            with autocast():
                encoded_images = encode_batch_with_position_and_brightness(images, position_hvs, brightness_hvs)
                encoded_labels = encode_label_vectorized(labels)
                encoded_hvs = circular_convolution(encoded_images, encoded_labels)  # Use circular convolution

                # Add noise for robustness
                encoded_hvs = encoded_hvs + torch.randn_like(encoded_hvs) * noise_level

                # Forward pass
                reconstructed_images = decoder(encoded_hvs)
                loss = F.mse_loss(reconstructed_images, images)

            # Backward pass
            optimizer_decoder.zero_grad()
            scaler.scale(loss).backward()
            scaler.step(optimizer_decoder)
            scaler.update()

            epoch_loss += loss.item()

            # Free memory
            torch.cuda.empty_cache()

        scheduler.step()  # Update learning rate
        print(f"Epoch [{epoch+1}/{epochs}], Average Loss: {epoch_loss / len(train_loader):.4f}")

    # Enhanced Testing
    decoder.eval()
    mse_scores = []
    ssim_scores = []
    psnr_scores = []
    gif_images = []

    with torch.no_grad():
        for batch_idx, (images, labels) in enumerate(tqdm(test_loader, desc="Testing")):
            images, labels = images.to(device), labels.to(device)

            # Encode images and labels
            encoded_images = encode_batch_with_position_and_brightness(images, position_hvs, brightness_hvs)
            encoded_labels = encode_label_vectorized(labels)
            encoded_hvs = circular_convolution(encoded_images, encoded_labels)  # Use circular convolution

            # Reconstruct images
            reconstructed_images = decoder(encoded_hvs)

            # Metrics and visualization
            for i in range(images.size(0)):
                original_image = images[i].cpu().numpy().transpose(1, 2, 0)
                reconstructed_image = reconstructed_images[i].cpu().numpy().transpose(1, 2, 0)

                mse = mean_squared_error(original_image.flatten(), reconstructed_image.flatten())
                mse_scores.append(mse)

                # Specify win_size and channel_axis for SSIM
                similarity = ssim(
                    original_image,
                    reconstructed_image,
                    multichannel=True,
                    data_range=2.0,  # Range is 2 for [-1, 1]
                    win_size=3,
                    channel_axis=-1
                )
                ssim_scores.append(similarity)

                # Compute PSNR
                psnr = -10 * np.log10(mse)
                psnr_scores.append(psnr)

                # Save first few examples for GIF creation
                if batch_idx == 0 and i < 5:
                    gif_images.append(((reconstructed_image + 1) * 127.5).astype(np.uint8))  # Scale to [0, 255]

                # Display the first batch for class-wise reconstruction
                if batch_idx == 0 and i < 5:
                    plt.figure(figsize=(6, 3))
                    plt.subplot(1, 2, 1)
                    plt.imshow((original_image + 1) / 2)  # Scale to [0, 1] for display
                    plt.title("Original Image")
                    plt.axis("off")

                    plt.subplot(1, 2, 2)
                    plt.imshow((reconstructed_image + 1) / 2)  # Scale to [0, 1] for display
                    plt.title("Reconstructed Image")
                    plt.axis("off")
                    plt.show()

            if batch_idx == 0:
                break

    print(f"Average MSE: {np.mean(mse_scores):.4f}")
    print(f"Average SSIM: {np.mean(ssim_scores):.4f}")
    print(f"Average PSNR: {np.mean(psnr_scores):.4f}")

    # Save latent space exploration as GIF
    latent_walk_gif_filename = "latent_space_exploration.gif"
    imageio.mimsave(latent_walk_gif_filename, gif_images, fps=2)
    print(f"Latent space exploration GIF saved as {latent_walk_gif_filename}")

# Run the training and testing
train_and_test(train_loader, test_loader)