# Federation with Varying Example Counts

Now we've seen federation work with a dataset scattered across a number of workers. And we saw that its performance isn't substantially different from the non-federated approach, while decentralizing the work of training and limiting the amount of data transfered between the main manager and the workers.

But what if our workers don't have access to equal amounts of data? Let's explore that.

## Spliting the Deck into Uneven Piles

To test this, we need workers to have access to different numbers of training examples. So let's make a set of decks that gives each of our workers more or less training data than the others.

# Conclusion
TODO: WRAP UP THE BLOG POST HERE. EVERYTHING BELOW IS PART TWO OR THREE

# Blog Post Two

In our last post [LINK HERE] we showed how to implement federated learning in pytorch. In this post we get one step more realistic about how federated learning would play out in real life.

In federated learning, we expect that each of the workers captures and trains on somewhat different data. For example, my mobile phone will capture more songs in genres of music that I listen to, while your cell phone will reflect a different set of sings in different genres. Probably. If not, get out of my mind! [BENE GESSIRIT MOTHER IMAGE? SINGLE WHITE FEMALE IMAGE?] And there's likely a lot of overlap due to radios and clustering of popular songs.

To reflect this kind of data distribution, we're going to skew the MNIST data to reflect that each worker sees a somewhat different subset of the training data.

## Stacking the Deck (Skewing Data)

We know the baseline data is pretty even across numerals. Now we need a way to "stack the deck" of examples that each worker sees. This method creates a dataset that is randomly sampled from a given dataset with the random sampling biased according to a dictionary of weights for each label.

In [18]:
from torch.utils.data import WeightedRandomSampler

def stacked_dset(dset, label_weights, N):
    """
        dset: dataset
        label_weights = {dog: 0.5, cat: 0.3, ...}
        N: size of stacked dset
        return: stacked WeightedRandomSampler
    """
    weights = []
    for data, label in dset:
        weights.append(label_weights[label])

#     for label in test_dset.targets:
#         weights.append(label_weights[int(label)])
# TODO / MLW : how to speed this up - currently takes about a minute to train ten stacked training sets
    
    return WeightedRandomSampler(weights, N, replacement=False)

And this is where we get the dictionary of weights. For simplicity's sake, we just take a list of labels to be sampled "normally" and the rest are biased against. So, preserving `3`s and skewing everything else by a factor of 0.9 shoud get a set of weights that results in a dataset that is slightly heavy on `3`s compared to everything else. In an an extreme example, preserving only `3`s, with a skew of 0, will produce weights that will yield a dataset of only `3`s.

In [19]:
def skewed_weights(num_labels, labels_to_preserve, skew_bias):
    """
        num_labels: number of labels to return (use 10 for MNIST)
        labels_to_preserve: list of labels to preserve wih no skew 
        skew_bias: a float, 0 < bias < 1, to which non-selected labels will be biased down
        return: dictionary of each label and its bias
    """
    weights = {}
    for label in range(num_labels):
        if label in labels_to_preserve:
            weights[label] = 1
        else:
            weights[label] = skew_bias
    
    return weights

### 

Here we do the sampling to create our skewed datasets.

In [1]:
# create stacked loaders for the workers

run_data['Skew Bias'] = skew_bias = 1
run_data['Examples Per Skewed Loader'] = loader_size = 60000
run_data['Number of Workers'] = num_workers = 1

stacking_start_time = time.time()

stacked_data_loaders = []
for label in tqdm(range(num_workers)):
    stacked_sampler = stacked_dset(train_dset, skewed_weights(10, [label%10], skew_bias), loader_size)
    stacked_data_loaders.append(DataLoader(train_dset, batch_size=batch_size, shuffle=False, sampler=stacked_sampler))

run_data['Stacking Time'] = time.time() - stacking_start_time
run_data['Stacking Time per Loader'] = run_data['Stacking Time'] / run_data['Number of Workers']

print('Stacking Time: %.2f' % run_data['Stacking Time'])
print('Stacking Time per Loader: %.2f' % run_data['Stacking Time per Loader'])

NameError: name 'run_data' is not defined

We should see the effect of the skew in a count and histogram of a skewed dataset. Here, we arbitrarily picked the second dataloader.

In [21]:
_, ybatches = list(zip(*stacked_data_loaders[0]))
print('Dataloader sample count:', len(torch.cat(ybatches)))

KeyboardInterrupt: 

In [None]:
#for i in range(len(stacked_data_loaders)):
#    _, ybatches = list(zip(*stacked_data_loaders[i]))
#    print('Dataloader', i ,'sample count:', len(torch.cat(ybatches)))

In [None]:
from collections import Counter
import numpy as np

hist_counts = []
digit_counts = []
for loader in tqdm(stacked_data_loaders):
    _, ybatches = list(zip(*loader))
    ys = torch.cat(ybatches)
    ys = [int(y) for y in ys]
    hist_counts.append(ys)
    
    digits = sorted(Counter(ys).most_common())
    _, digits = list(zip(*digits))
    digit_counts.append(list(digits))

digit_counts = [list(i) for i in zip(*digit_counts)]

In [None]:
fig, ax = plt.subplots(figsize=(15, 5))
fig.suptitle('Digit Skew Histogram')
ax.xaxis.set_major_locator(plt.MultipleLocator(1))
ax.set_ylabel('Digit Count')
ax.set_xlabel('Digit')
H = ax.hist(ys, bins=range(11), histtype='bar', align='left', rwidth=0.8)

In [None]:
fig, ax = plt.subplots(figsize=(15, 5))
fig.suptitle('Skew: Digit Counts by Worker')

pos = list(range(num_workers))
width = 0.08

for digit in range(10):
    ax.bar([p + (width * digit) for p in pos],
           digit_counts[digit],
           width = width,
           label = str(digit),
          )

ax.set_xticks([p + (4.5 * width) for p in pos])
ax.set_xticklabels([('Dataset ' + str(x)) for x in range(num_workers)])
ax.set_ylabel('Digit Samples')
ax.set_xlabel('Samples Grouped by Worker')
ax.legend(loc = 'upper right');

We create the `federatedManager` using the skewed training data. Note that we don't skew the test data -- we want to see how everything performs on a normal data distribution.

In [None]:
import federated

run_data['Learning Rate'] = learning_rate = 1e-2
run_data['Epochs per Round'] = num_epochs = 1
run_data['Federated Training Rounds'] = num_rounds = 50

manager = federated.FederatedManager(
    stacked_data_loaders,
    MLPNet,
    nn.CrossEntropyLoss(),
    learning_rate,
    test_dset,
    num_epochs
)

Now let's do some rounds of federated training.

In [None]:
print("Training", num_rounds, "round(s) with", manager.n_workers, "worker(s) doing", num_epochs, "epoch(s) per round.\n" )

training_start_time = time.time()

for i in tqdm(range(num_rounds)):
    print("Beginning round", i+1)
    manager.round()
    print("Finished round", i+1, "with global loss: %.2f" % manager.manager_loss_history[-1], "\n")

run_data['Federated Training Time'] = time.time() - training_start_time
#run_data['Manager Loss History'] = manager.manager_loss_history
#run_data['Worker Loss Histories'] = manager.worker_loss_histories
run_data['Final Global Loss'] = manager.manager_loss_history[-1]

print('Federated Training Time: %.2f' % run_data['Federated Training Time'])

Now let's take a look at how the training went. Here's a graph of the loss per round.

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt

fig, ax = plt.subplots(figsize=(16, 9))
# loss of global model on test set gets recorded twice per round
# [1::2] skips the record that takes place before that round's training has happened
ax.plot(manager.manager_loss_history[1::2], label="Global Loss", )
ax.set_xlabel("Federated Round")
ax.set_ylabel("Loss")
ax.legend();

This looks pretty good, with loss dropping off just like we want. Perhaps it's a little bumpy because of the relatively fast training rate, but it should be improving on balance. But if we look under the hood at each individual worker's loss, we see that the workers' local models are diverging and converging at each round. They diverge because each local model trains on different data, resulting in a somewhat different loss per round. The converge again because the manager combines them into a master model, such that they all have the same loss as the global loss.

In [None]:
fig, ax = plt.subplots(figsize=(16, 9))


for i in range(len(manager.worker_loss_histories)):
    ax.plot(manager.worker_loss_histories[i], label=('Worker ' + str(i)))

# TODO: Align the global loss properly
ax.plot(manager.manager_loss_history[1], label="Global Loss", )

    
# TODO: Get these labels done properly - they should be aligned with the main 
ax.set_xticklabels([(i-1) for i in range(len(manager.worker_loss_histories))])
ax.set_xlabel("Federated Round")
ax.set_ylabel("Loss")
ax.legend();

Ideas:
- plot performance on a given numeral for the main model next to that of a worker skewed against that numeral. Let both run without federation or run a few epochs before federation. Show this as a baseline
- histogram of numerals? More for curiosity, but shows spread of data that we might want to reflect in the baseline training.
- post 1: what's the accuracy loss for federation compared to baseline direct training?
- post 2: weird side stats
    - skew vs. accuracy
        - plot - x-axis = skew, y-axis = accuracy
    - run all to convergence, compare how long to reach comparable accuracy?
        - time or epochs necessary to reach comparable accuracy between federated and standard approach
        - time or epochs necessary to reach comparable accuracy by skew
    - run the federated version with balanced, but small sets of data

Questions:
- Why does the time spent by a worker on any given epoch all happen _before_ the batches start rolling in? What's happening there? Am I just spinning my wheels on something?
    - TODO: try this from a regular python file. The notebook may be buffering up those print statements in the batches
- Why does random selection of the skewed datasets take so long? Is it because they're without replacement?
- Why do all the workers and epochs always happen in order? Wouldn't my laptop parallelize them across cores? Is that too much to ask from an interpreter? Is the interpreter smarter than I am and actually is parallelizing them and the smartest way in to do them in order?
- why use ten workers? Why not fewer?

- TODO: unequal data volume at each worker. Try some workers with very small or very large samples.
- TODO: unequal numbers of samples across the whole set, e.g., we just have fewer `7`s and `4`s across the set, and a glut of `1`s





# Improvement Ideas

Global variables:
- Batch size
- Learning rate
- Epochs
- Total dataset size
- Worker dataset size
    - worker dataset size skew (variance among number of samples seen from worker to worker)
- Selection of data with or without replacement
- Dataset class skew (more or fewer examples from each class)

Targets:
- Loss
- Accuracy
- Runtime to target loss or accuracy

TODO: Write a bit of code that records the hyperparameters and saves the graphs, times and losses in a bundle for each run. Something like:

```
2019-05-06 21:02:50

# standard dataloader parameter
batch_size = 128

# biasing parameters
skew_bias = 0.3
loader_size = 8192
num_workers = 10

Stacked set creation time: 00:01:08

# training parameters
learning_rate = 1e-2
num_epochs = 1
num_rounds = 20

Train time = 00:43:02

Final global loss: 0.48251
```

Well. I did this. And now the code is unreadable.

In [None]:
# a little performance info on the run
run_data['Global End Time'] = time.time()
run_data['Global Time'] = run_data['Global End Time'] - run_data['Global Start Time']
run_data

In [None]:
# leave a record of the run
# but it isn't valid JSON
import json 
with open('run_data.json', 'a') as file:
    file.write(json.dumps(run_data))
    file.write('\n\n')

In [None]:
from collections import Counter
import numpy as np

train_counts = Counter(int(y) for y in train_dset.targets).most_common()
print("Train digit counts: \n", train_counts)
print("Train count standard deviation: %.2f" % np.std(list(zip(*train_counts))[1]))
print("Train count coefficient of variation: %.2f" 
      % (float(np.mean(list(zip(*train_counts))[1])) / float(np.std(list(zip(*train_counts))[1]))))

print()

test_counts = Counter(int(y) for y in test_dset.targets).most_common()
print("Test digit counts: \n", test_counts)
print("Test standard deviation: %.2f" % np.std(list(zip(*test_counts))[1]))
print("Test count coefficient of variation: %.2f" 
      % (float(np.mean(list(zip(*test_counts))[1])) / float(np.std(list(zip(*test_counts))[1]))))

In [None]:
fig, ax = plt.subplots(figsize=(15, 5))
fig.suptitle('Digit Counts at each Worker')
ax.xaxis.set_major_locator(plt.MultipleLocator(1))
ax.set_xticklabels([('Digit ' + str(x-1)) for x in range(11)])
ax.hist(hist_counts, 
        label=[('Worker ' + str(x)) for x in range(num_workers)],
        bins=list(range(12)), 
        histtype='bar',
        align='left',
        rwidth=0.8,
       );
ax.legend();