In [7]:
import numpy as np
import torch
from torchvision import datasets, transforms
import os
from concurrent.futures import ThreadPoolExecutor
import struct
from pathlib import Path

In [5]:
#Need to get the 70000x28x28 array for the mnist dataset 
#Also get the 70000x1 array for the labels

# Set up the transform for converting the images to Tensor and normalizing them
transform = transforms.Compose([transforms.ToTensor()])

# Download and load the MNIST dataset using PyTorch
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

# Extract the images and labels from the datasets
train_images = train_dataset.data.numpy()  # Shape: (60000, 28, 28)
train_labels = train_dataset.targets.numpy()  # Shape: (60000,)
test_images = test_dataset.data.numpy()  # Shape: (10000, 28, 28)
test_labels = test_dataset.targets.numpy()  # Shape: (10000,)

# Combine the train and test data
all_images = np.concatenate((train_images, test_images), axis=0)  # Shape: (70000, 28, 28)
all_labels = np.concatenate((train_labels, test_labels), axis=0)  # Shape: (70000,)

# Print the shape of the arrays to confirm
print("All images shape:", all_images.shape)  # (70000, 28, 28)
print("All labels shape:", all_labels.shape)  # (70000,)

np.save('mnist_files/mnist_img.npy', all_images)
np.save('mnist_files/mnist_labels.npy', all_labels) 

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 32002979.90it/s]


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 1232657.26it/s]


Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 7352708.45it/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 3627981.10it/s]

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw

All images shape: (70000, 28, 28)
All labels shape: (70000,)





In [8]:
#now, let's combine the masks for each datalevel into a single 70000x28x28 array
save_path = './mask/'
for i in range(10, 71, 5): 
    #make the folder if it doesn't exist
    dir_path = Path('./mask/' + str(i) + '/')

    # Get list of .npy files
    file_list = sorted([f for f in os.listdir(dir_path) if f.endswith('.npy')])

    # Function to load one file
    def load_one(idx_file):
        idx, filename = idx_file
        filepath = os.path.join(dir_path, filename)
        array = np.load(filepath)
        return idx, array

    # Preallocate big array
    big_array = np.empty((len(file_list), 28, 28), dtype=np.uint8)

    # Parallel loading
    with ThreadPoolExecutor(max_workers=8) as executor:  # You can adjust max_workers
        for idx, array in executor.map(load_one, enumerate(file_list)):
            big_array[idx] = array

    print(big_array.shape)  # (70000, 28, 28)
    np.save(save_path + 'big_mask_' + str(i), big_array) 

(70000, 28, 28)
(70000, 28, 28)
(70000, 28, 28)
(70000, 28, 28)
(70000, 28, 28)
(70000, 28, 28)
(70000, 28, 28)
(70000, 28, 28)
(70000, 28, 28)
(70000, 28, 28)
(70000, 28, 28)
(70000, 28, 28)
(70000, 28, 28)
