Below are some helper functions:

- **`load_mnist_data()`**: Loads 28x28 images from the MNIST dataset and flattens them into 784-dimensional vectors.
- **`show_image_from_flattened(flattened_data, index, label=None)`**: Displays the corresponding image for a given index from a set of flattened data. 

We define these functions and load the data in the following block.

In [None]:
import matplotlib.pyplot as plt
import torchvision
import torchvision.transforms as transforms

# Load MNIST dataset and flatten it
def load_mnist_data():
    # Download the MNIST dataset
    transform = transforms.ToTensor()  # Converts images to Pytorch tensors (values between 0 and 1)
    train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    test_dataset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)

    # Extract images and labels as tensors
    train_images = train_dataset.data.float() / 255.0  # Normalize to [0, 1]
    train_labels = train_dataset.targets
    test_images = test_dataset.data.float() / 255.0  # Normalize to [0, 1]
    test_labels = test_dataset.targets
    
    # Flatten the images from [28, 28] to [784]
    train_images = train_images.view(train_images.shape[0], -1)  # Shape: [num_samples, 784]
    test_images = test_images.view(test_images.shape[0], -1)      # Shape: [num_samples, 784]

    return train_images, train_labels, test_images, test_labels

# Function to show an image from flattened data
def show_image_from_flattened(flattened_data, index, label=None):
    # Reshape the 784-dimensional vector back to 28x28 to visualize the image
    image = flattened_data[index].view(28, 28)
    
    plt.imshow(image, cmap='gray')
    if label is not None:
        plt.title(f"Label: {label}")
    plt.axis('off') 
    plt.show()


# Load MNIST dataset and flatten the images
train_images, train_labels, test_images, test_labels = load_mnist_data()

Your code begins here