## MNIST Metadata
This notebook will be used to generate the metadata required for the MNIST dataset. This is done to gather dataset statistics and stratify the folds.

In [None]:
import sys
sys.path.append("/Users/ishaanroy/Projects/advanced-neural-networks")

In [None]:
import numpy as np
import pandas as pd
import os
import torch
from tqdm import tqdm
import matplotlib.pyplot as plt

import advanced_neural_networks
from advanced_neural_networks.dataloader.mnist import MNISTDataset

import warnings
warnings.filterwarnings("ignore")

In [None]:
np.random.seed(42)

In [None]:
def partition_array(arr: np.ndarray, n_partitions):
    partitions = np.array_split(arr, n_partitions)
    return partitions


def stratify_kfolds(metadata_df: pd.DataFrame, kfolds: int):
    label_indices = {}
    labels = sorted((metadata_df["label"].unique().tolist()))
    for label in labels:
        indices = np.array(metadata_df.loc[metadata_df["label"] == label].index)
        label_indices[label] = indices

    metadata_df["fold"] = -1
    
    for label in labels:
        lb_indices = label_indices[label]
        np.random.shuffle(lb_indices)
        label_partitions = partition_array(lb_indices, kfolds)

        for itr, partition in enumerate((label_partitions)):
            metadata_df.loc[partition, "fold"] = itr

    return metadata_df

In [None]:
module_dir = advanced_neural_networks.__path__[0]
dataloader_path = os.path.join(module_dir, "dataloader")
dataset_config = os.path.join(dataloader_path, "mnist_config.yaml")

In [None]:
mnist_train = MNISTDataset(config_file = dataset_config,
                          location = "cloud",
                          train = True,
                          transforms = [],
                          one_hot=False)

mnist_test = MNISTDataset(config_file = dataset_config,
                          location = "cloud",
                          train = False,
                          transforms = [],
                          one_hot=False)

In [None]:
train_metadata_df = pd.DataFrame()

for data_point in tqdm(mnist_train):
    x_image, label = data_point
    channels, height, width = list(x_image.shape)
    metadata = {
        "img_channels": channels,
        "height": height,
        "width": width,
        "label": label
    }

    train_metadata_df = pd.concat([train_metadata_df, pd.DataFrame([metadata])], ignore_index=True)
    

In [None]:
train_metadata_df.head()

In [None]:
# summary df
train_size_summary = train_metadata_df.groupby("label").size()
train_size_summary.plot(kind = "bar")
plt.show()

In [None]:
## check for corruptions
condition = (train_metadata_df["img_channels"] != 1) | (train_metadata_df["height"] != 28) | (train_metadata_df["width"] != 28)
corrupt_df = train_metadata_df.loc[condition]
corrupt_df

In [None]:
test_metadata_df = pd.DataFrame()

for data_point in tqdm(mnist_test):
    x_image, label = data_point
    channels, height, width = list(x_image.shape)
    metadata = {
        "img_channels": channels,
        "height": height,
        "width": width,
        "label": label
    }

    test_metadata_df = pd.concat([test_metadata_df, pd.DataFrame([metadata])], ignore_index=True)

In [None]:
test_metadata_df.head()

In [None]:
# summary df
test_size_summary = test_metadata_df.groupby("label").size()
test_size_summary.plot(kind = "bar")
plt.show()

In [None]:
## check for corruptions
condition = (test_metadata_df["img_channels"] != 1) | (test_metadata_df["height"] != 28) | (test_metadata_df["width"] != 28)
corrupt_df_test = test_metadata_df.loc[condition]
corrupt_df_test

In [None]:
## save metadata
save_dir = os.path.join(module_dir, "metadata")
train_df_path = os.path.join(save_dir, "mnist_train_metadata.csv")
# train_metadata_df.to_csv(train_df_path, index=False)

test_df_path = os.path.join(save_dir, "mnist_test_metadata.csv")
# test_metadata_df.to_csv(test_df_path, index=False)

In [None]:
## load saved metadata
train_metadata_df = pd.read_csv(train_df_path)
test_metadata_df = pd.read_csv(test_df_path)

In [None]:
train_metadata_df.head()

In [None]:
train_metadata_df = stratify_kfolds(train_metadata_df, kfolds = 10)
train_metadata_df.head()

In [None]:
fold_summary = train_metadata_df.groupby(["label", "fold"]).size()
fold_summary.unstack().plot(kind = "bar", stacked = True)
plt.show()