In [None]:
"""
Download CIFAR-10 dataset, and splits it among clients
"""
import argparse
import os
import pickle

from sklearn.model_selection import train_test_split
from torch.utils.data import ConcatDataset
from torchvision.datasets import CIFAR10
from torchvision.transforms import Compose, ToTensor, Normalize

from utils import split_dataset_by_labels, pathological_non_iid_split, split_and_reform_dataset

ALPHA = .4
N_CLASSES = 10
N_COMPONENTS = 3
SEED = 12345
RAW_DATA_PATH = "raw_data/"
PATH = "all_data/"


def save_data(l, path_):
    with open(path_, 'wb') as f:
        pickle.dump(l, f)


def parse_args():
    parser = argparse.ArgumentParser()

    parser.add_argument(
        '--n_tasks',
        help='number of tasks/clients;',
        type=int,
        required=True)
    parser.add_argument(
        '--pathological_split',
        help='if selected, the dataset will be split as in'
             '"Communication-Efficient Learning of Deep Networks from Decentralized Data";'
             'i.e., each client will receive `n_shards` of dataset, where each shard contains at most two classes',
        action='store_true'
    )
    parser.add_argument(
        '--n_shards',
        help='number of shards given to each clients/task; ignored if `--pathological_split` is not used;'
             'default is 2',
        type=int,
        default=2
    )
    parser.add_argument(
        '--n_components',
        help='number of components/clusters;',
        type=int,
        default=N_COMPONENTS
    )
    parser.add_argument(
        '--alpha',
        help='parameter controlling tasks dissimilarity, the smaller alpha is the more tasks are dissimilar;',
        type=float,
        default=ALPHA
    )
    parser.add_argument(
        '--s_frac',
        help='fraction of the dataset to be used; default: 1.0;',
        type=float,
        default=1.0
    )
    parser.add_argument(
        '--tr_frac',
        help='fraction in training set; default: 0.8;',
        type=float,
        default=0.8
    )
    parser.add_argument(
        '--val_frac',
        help='fraction of validation set (from train set); default: 0.0;',
        type=float,
        default=0.0
    )
    parser.add_argument(
        '--test_tasks_frac',
        help='fraction of tasks / clients not participating to the training; default is 0.0',
        type=float,
        default=0.0
    )
    parser.add_argument(
        '--seed',
        help='seed for the random processes;',
        type=int,
        default=SEED
    )
    parser.add_argument(
        '--distribution_shift',
        help='if selected, the dataset will be split such that each component',
        action='store_true'
    )

    return parser.parse_args()

# python generate_data.py \
#     --n_tasks 80 \
#     --n_components -1 \
#     --alpha 0.5 \
#     --s_frac 1.0 \
#     --tr_frac 0.8 \
#     --val_frac 0.25 \
#     --seed 12345 --distribution_shift  
def main():
    args = parse_args("python generate_data.py \
    --n_tasks 80 \
    --n_components -1 \
    --alpha 0.5 \
    --s_frac 1.0 \
    --tr_frac 0.8 \
    --val_frac 0.25 \
    --seed 12345 --distribution_shift  ")

    transform = Compose([
        ToTensor(),
        Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
    ])

    dataset =\
        ConcatDataset([
            CIFAR10(root=RAW_DATA_PATH, download=True, train=True, transform=transform),
            CIFAR10(root=RAW_DATA_PATH, download=False, train=False, transform=transform)
        ])

    if args.pathological_split:
        clients_indices = \
            pathological_non_iid_split(
                dataset=dataset,
                n_classes=N_CLASSES,
                n_clients=args.n_tasks,
                n_classes_per_client=args.n_shards,
                frac=args.s_frac,
                seed=args.seed
            )
    elif args.distribution_shift:
        clients_indices, rotation_idx = \
            split_and_reform_dataset(
                dataset=dataset,
                n_classes=N_CLASSES,
                n_clients=args.n_tasks,
                n_clusters=args.n_components,
                alpha=args.alpha,
                frac=args.s_frac,
                seed=args.seed,
                rotation_ratio=0.7
            )
    else:
        clients_indices = \
            split_dataset_by_labels(
                dataset=dataset,
                n_classes=N_CLASSES,
                n_clients=args.n_tasks,
                n_clusters=args.n_components,
                alpha=args.alpha,
                frac=args.s_frac,
                seed=args.seed
            )

    if args.test_tasks_frac > 0:
        train_clients_indices, test_clients_indices = \
            train_test_split(clients_indices, test_size=args.test_tasks_frac, random_state=args.seed)
    else:
        train_clients_indices, test_clients_indices = clients_indices, []

    os.makedirs(os.path.join(PATH, "train"), exist_ok=True)
    os.makedirs(os.path.join(PATH, "test"), exist_ok=True)
    """
    " require significant shifting
    """
    if args.distribution_shift:
        save_data(rotation_idx, os.path.join(PATH, "rotations.pkl"))

    for mode, clients_indices in [('train', train_clients_indices), ('test', test_clients_indices)]:
        for client_id, indices in enumerate(clients_indices):
            if len(indices) == 0:
                continue

            client_path = os.path.join(PATH, mode, "task_{}".format(client_id))
            os.makedirs(client_path, exist_ok=True)

            train_indices, test_indices = \
                train_test_split(
                    indices,
                    train_size=args.tr_frac,
                    random_state=args.seed
                )

            if args.val_frac > 0:
                train_indices, val_indices = \
                    train_test_split(
                        train_indices,
                        train_size=1.-args.val_frac,
                        random_state=args.seed
                    )

                save_data(val_indices, os.path.join(client_path, "val.pkl"))

            print(len(train_indices), len(test_indices))
            save_data(train_indices, os.path.join(client_path, "train.pkl"))
            save_data(test_indices, os.path.join(client_path, "test.pkl"))


if __name__ == "__main__":
    main()


In [4]:
import argparse

ALPHA = .4
N_CLASSES = 10
N_COMPONENTS = 3
SEED = 12345
RAW_DATA_PATH = "raw_data/"
PATH = "all_data/"
def parse_args():
    parser = argparse.ArgumentParser()

    parser.add_argument(
        '--n_tasks',
        help='number of tasks/clients;',
        type=int,
        required=True)
    parser.add_argument(
        '--pathological_split',
        help='if selected, the dataset will be split as in'
             '"Communication-Efficient Learning of Deep Networks from Decentralized Data";'
             'i.e., each client will receive `n_shards` of dataset, where each shard contains at most two classes',
        action='store_true'
    )
    parser.add_argument(
        '--n_shards',
        help='number of shards given to each clients/task; ignored if `--pathological_split` is not used;'
             'default is 2',
        type=int,
        default=2
    )
    parser.add_argument(
        '--n_components',
        help='number of components/clusters;',
        type=int,
        default=N_COMPONENTS
    )
    parser.add_argument(
        '--alpha',
        help='parameter controlling tasks dissimilarity, the smaller alpha is the more tasks are dissimilar;',
        type=float,
        default=ALPHA
    )
    parser.add_argument(
        '--s_frac',
        help='fraction of the dataset to be used; default: 1.0;',
        type=float,
        default=1.0
    )
    parser.add_argument(
        '--tr_frac',
        help='fraction in training set; default: 0.8;',
        type=float,
        default=0.8
    )
    parser.add_argument(
        '--val_frac',
        help='fraction of validation set (from train set); default: 0.0;',
        type=float,
        default=0.0
    )
    parser.add_argument(
        '--test_tasks_frac',
        help='fraction of tasks / clients not participating to the training; default is 0.0',
        type=float,
        default=0.0
    )
    parser.add_argument(
        '--seed',
        help='seed for the random processes;',
        type=int,
        default=SEED
    )
    parser.add_argument(
        '--distribution_shift',
        help='if selected, the dataset will be split such that each component',
        action='store_true'
    )

    return parser.parse_args(["python generate_data.py \
    --n_tasks 80 \
    --n_components -1 \
    --alpha 0.5 \
    --s_frac 1.0 \
    --tr_frac 0.8 \
    --val_frac 0.25 \
    --seed 12345 --distribution_shift"])
args = parse_args()
print(args)

usage: ipykernel_launcher.py [-h] --n_tasks N_TASKS [--pathological_split]
                             [--n_shards N_SHARDS]
                             [--n_components N_COMPONENTS] [--alpha ALPHA]
                             [--s_frac S_FRAC] [--tr_frac TR_FRAC]
                             [--val_frac VAL_FRAC]
                             [--test_tasks_frac TEST_TASKS_FRAC] [--seed SEED]
                             [--distribution_shift]
ipykernel_launcher.py: error: the following arguments are required: --n_tasks


SystemExit: 2

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)


In [5]:
import pickle
file = open("all_data/rotations.pkl",'rb')
object_file = pickle.load(file)
file.close()

In [22]:
from torchvision.datasets import CIFAR10
import os
import torchvision.transforms as T
from PIL import Image

cifar10_path = os.path.join("raw_data")
assert os.path.isdir(cifar10_path), "Download cifar10 dataset!!"
cifar10_train = \
        CIFAR10(
            root=cifar10_path,
            train=True, download=False
        )

with open('all_data/rotations.pkl', "rb") as f:
    rotation_idx = pickle.load(f)

    x = cifar10_train.data[rotation_idx[0][:1000]]
#         y = cifar10_targets[rotation_idx]
    rotater = T.RandomRotation(degrees=(0, 180))
    jitter = T.ColorJitter(brightness=.5, hue=.3)

    gray_img = T.Grayscale()
    jitter = T.ColorJitter(brightness=.5, hue=.3)
    inverter = T.RandomInvert()
    rotater = T.RandomRotation(degrees=(0, 180))
    imgs = x
    augmented = []

    augmented.append(imgs)
    imgs = gray_img(imgs)
    augmented.append(imgs)
    imgs = jitter(imgs)
    augmented.append(imgs)
    imgs = inverter(imgs)
    augmented.append(imgs)
    imgs = rotater(imgs)
    augmented.append(imgs)

#         y = permute_label(y, 10)

#         cifar10_data[rotation_idx] = x
#         cifar10_targets[rotation_idx] = y

TypeError: img should be PIL Image. Got <class 'numpy.ndarray'>

In [29]:
import torchvision.transforms.functional as F
transform = T.Compose([
        T.ToPILImage(),
        T.ToTensor()])

transform(cifar10_train)

TypeError: pic should be Tensor or ndarray. Got <class 'torchvision.datasets.cifar.CIFAR10'>.