In [1]:
import numpy as np
import os


from fault_tolerant_ml.data.mnist import MNist

%reload_ext autoreload
%autoreload 2

In [2]:
data_dir = "../data"
filepaths = {
    "train": {
        "images": os.path.join(data_dir, "train-images-idx3-ubyte.gz"), "labels": os.path.join(data_dir, "train-labels-idx1-ubyte.gz")
    },
    "test": {
        "images": os.path.join(data_dir, "t10k-images-idx3-ubyte.gz"), "labels": os.path.join(data_dir, "t10k-labels-idx1-ubyte.gz")
    }
}
mnist = MNist(filepaths)

In [9]:
n_partitions = 10
batch_size = mnist.X_train.shape[0] // n_partitions
for i in np.arange(n_partitions):
    start = i * batch_size
    end = start + batch_size
    X_batch = mnist.X_train.data[start:end]
    print(X_batch.shape, "start idx=", start, "end_idx=", end)
    fro_norm = np.linalg.norm(X_batch, )

(6000, 784) start idx= 0 end_idx= 6000
(6000, 784) start idx= 6000 end_idx= 12000
(6000, 784) start idx= 12000 end_idx= 18000
(6000, 784) start idx= 18000 end_idx= 24000
(6000, 784) start idx= 24000 end_idx= 30000
(6000, 784) start idx= 30000 end_idx= 36000
(6000, 784) start idx= 36000 end_idx= 42000
(6000, 784) start idx= 42000 end_idx= 48000
(6000, 784) start idx= 48000 end_idx= 54000
(6000, 784) start idx= 54000 end_idx= 60000


In [11]:
n_workers = 23
n_samples = 60000
batch_size = int(np.ceil(n_samples / n_workers))

In [12]:
overlap_per = 0.0
n_overlap = int((1 + overlap_per) * batch_size)
print(f"Overlap={n_overlap}")

Overlap=2609


In [77]:
# Iterate through each worker
end = 0
for i in np.arange(0, n_workers):
    if i > 0:
        start = i * batch_size
    else:
        start = 0
    end = start + n_overlap
    # We need to circle back to beginning worker
    X_batch = mnist.X_train.data[start:end]
    if end > n_samples and overlap_per != 0.0:
        # end = n_samples
        # I would need to circle back and stack the points from worker 1 onto these points
        end = end - n_samples
        X_batch = np.vstack([X_batch, mnist.X_train.data[0:end]])
    print("worker %s %s %s X_batch.shape=%s" % (i, start, end, X_batch.shape))

worker 0 0 2609 X_batch.shape=(2609, 784)
worker 1 2609 5218 X_batch.shape=(2609, 784)
worker 2 5218 7827 X_batch.shape=(2609, 784)
worker 3 7827 10436 X_batch.shape=(2609, 784)
worker 4 10436 13045 X_batch.shape=(2609, 784)
worker 5 13045 15654 X_batch.shape=(2609, 784)
worker 6 15654 18263 X_batch.shape=(2609, 784)
worker 7 18263 20872 X_batch.shape=(2609, 784)
worker 8 20872 23481 X_batch.shape=(2609, 784)
worker 9 23481 26090 X_batch.shape=(2609, 784)
worker 10 26090 28699 X_batch.shape=(2609, 784)
worker 11 28699 31308 X_batch.shape=(2609, 784)
worker 12 31308 33917 X_batch.shape=(2609, 784)
worker 13 33917 36526 X_batch.shape=(2609, 784)
worker 14 36526 39135 X_batch.shape=(2609, 784)
worker 15 39135 41744 X_batch.shape=(2609, 784)
worker 16 41744 44353 X_batch.shape=(2609, 784)
worker 17 44353 46962 X_batch.shape=(2609, 784)
worker 18 46962 49571 X_batch.shape=(2609, 784)
worker 19 49571 52180 X_batch.shape=(2609, 784)
worker 20 52180 54789 X_batch.shape=(2609, 784)
worker 21 54

In [78]:
def next_batch():
    # Iterate through each worker
    end = 0
    for worker, i in enumerate(np.arange(0, n_samples, batch_size)):
        if i > 0:
            start = i
        else:
            start = 0
        end = start + n_overlap
        X_batch = mnist.X_train.data[start:end]
        if end > n_samples and overlap_per != 0.0:
            # end = n_samples
            # I would need to circle back and stack the points from worker 1 onto these points
            end = end - n_samples
            X_batch = np.vstack([X_batch, mnist.X_train.data[0:end]])
        print("worker %s %s %s X_batch.shape=%s" % (worker, start, end, X_batch.shape))
        yield X_batch

In [79]:
for batch in next_batch():
    print(batch.shape)

worker 0 0 2609 X_batch.shape=(2609, 784)
(2609, 784)
worker 1 2609 5218 X_batch.shape=(2609, 784)
(2609, 784)
worker 2 5218 7827 X_batch.shape=(2609, 784)
(2609, 784)
worker 3 7827 10436 X_batch.shape=(2609, 784)
(2609, 784)
worker 4 10436 13045 X_batch.shape=(2609, 784)
(2609, 784)
worker 5 13045 15654 X_batch.shape=(2609, 784)
(2609, 784)
worker 6 15654 18263 X_batch.shape=(2609, 784)
(2609, 784)
worker 7 18263 20872 X_batch.shape=(2609, 784)
(2609, 784)
worker 8 20872 23481 X_batch.shape=(2609, 784)
(2609, 784)
worker 9 23481 26090 X_batch.shape=(2609, 784)
(2609, 784)
worker 10 26090 28699 X_batch.shape=(2609, 784)
(2609, 784)
worker 11 28699 31308 X_batch.shape=(2609, 784)
(2609, 784)
worker 12 31308 33917 X_batch.shape=(2609, 784)
(2609, 784)
worker 13 33917 36526 X_batch.shape=(2609, 784)
(2609, 784)
worker 14 36526 39135 X_batch.shape=(2609, 784)
(2609, 784)
worker 15 39135 41744 X_batch.shape=(2609, 784)
(2609, 784)
worker 16 41744 44353 X_batch.shape=(2609, 784)
(2609, 784)
