In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

from utils import data_utils


# Define transformations to apply to the images (e.g., convert to tensor, normalize)
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)


# train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
# test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

In [2]:
!tree ./data/

[01;34m./data/[0m
└── [01;34mMNIST[0m
    └── [01;34mraw[0m
        ├── t10k-images-idx3-ubyte
        ├── [01;31mt10k-images-idx3-ubyte.gz[0m
        ├── t10k-labels-idx1-ubyte
        ├── [01;31mt10k-labels-idx1-ubyte.gz[0m
        ├── train-images-idx3-ubyte
        ├── [01;31mtrain-images-idx3-ubyte.gz[0m
        ├── train-labels-idx1-ubyte
        └── [01;31mtrain-labels-idx1-ubyte.gz[0m

3 directories, 8 files


In [3]:
print("Applying transforms to create final tensors...")
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=len(train_dataset))
train_data, train_labels = next(iter(train_loader))

test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=len(test_dataset))
test_data, test_labels = next(iter(test_loader))

print(f"Train tensors: {train_data.shape}, {train_labels.shape}")
print(f"Test tensors: {test_data.shape}, {test_labels.shape}")

Applying transforms to create final tensors...
Train tensors: torch.Size([60000, 1, 28, 28]), torch.Size([60000])
Test tensors: torch.Size([10000, 1, 28, 28]), torch.Size([10000])


In [4]:
data_utils.split_and_distribute?

[31mSignature:[39m
data_utils.split_and_distribute(
    train_data,
    train_labels,
    test_data,
    test_labels,
    inventory_path: str,
    split_method: str,
    remote_dest_path: str,
)
[31mDocstring:[39m
Splits, saves, and distributes any dataset (provided as tensors or arrays) 
to Ansible clients.

*** This will DELETE and REPLACE the remote_dest_path on all clients. ***

Args:
    train_data: A PyTorch Tensor or NumPy array of training data (X_train).
    train_labels: A PyTorch Tensor or NumPy array of training labels (y_train).
    test_data: A PyTorch Tensor or NumPy array of testing data (X_test).
    test_labels: A PyTorch Tensor or NumPy array of testing labels (y_test).
    inventory_path: Path to the inventory.ini file.
    split_method: 'uniform', 'exponential', 'square', or 'linear'.
    remote_dest_path: Absolute path on clients (e.g., "/tmp/my_data").
[31mFile:[39m      ~/federated_learning/notebooks/utils/data_utils.py
[31mType:[39m      function

In [5]:
data_utils.split_and_distribute(
    train_data=train_data,
    train_labels=train_labels,
    test_data=test_data,
    test_labels=test_labels,
    inventory_path="/home/k3s-server-07/federated_learning/ansible/inventory.ini",
    split_method="square",
    remote_dest_path="/tmp/mnist_data"
)

--- Starting Data Split and Distribution ---
Validating inputs...
Querying Ansible inventory '/home/k3s-server-07/federated_learning/ansible/inventory.ini' for client list...
Found 5 clients: ['k3s-client-09', 'k3s-client-08', 'k3s-client-17', 'k3s-client-18', 'k3s-client-06']

Creating local splits in temporary directory: /tmp/tmppxc0tqid
Splitting 60000 samples into 5 sites (square): [1090, 4363, 9818, 17454, 27275]
  Saved 1090 items to /tmp/tmppxc0tqid/k3s-client-09_train.pt
  Saved 4363 items to /tmp/tmppxc0tqid/k3s-client-08_train.pt
  Saved 9818 items to /tmp/tmppxc0tqid/k3s-client-17_train.pt
  Saved 17454 items to /tmp/tmppxc0tqid/k3s-client-18_train.pt
  Saved 27275 items to /tmp/tmppxc0tqid/k3s-client-06_train.pt

Saving full test dataset...
  Saved 10000 items to /tmp/tmppxc0tqid/test_data.pt

Starting distribution to clients via Ansible...
  Re-creating remote directory '/tmp/mnist_data'...
  Distributing client-specific files...
    Sending files to k3s-client-09...
    S