In [7]:
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6

import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import copy
import numpy as np
from torchvision import datasets, transforms
import torch

from utils.sampling import mnist_iid, mnist_noniid, cifar_iid
from utils.options import args_parser
from models.Update import LocalUpdate
from models.Nets import MLP, CNNMnist, CNNCifar
from models.Fed import FedAvg
from models.test import test_img

### Seed settings
- If you want the seed to be set manually, uncomment and set the seed value in these files:
- also change uncomment the lines with `torch.backends.cudnn.deterministic = True`
1. utils/sampling.py
2. main_fed.py
3. models/Nets.py

## Run the layer-colapse.py
- This command is equivalent to running it from the console
- Results are saved in federated-learning/save/test.png

In [1]:
import sys
sys.path.append('../')
%run ../layer-collapse.py --iid --model mlp --dataset mnist --epochs 20 --local_ep 5 --gpu -0 --num_channels 1 --num_users 100 --frac 0.05 --compression 10 --prune_epochs 100 --pruner mag

MLP(
  (layers): Sequential(
    (0): Linear(in_features=784, out_features=400, bias=False)
    (1): ReLU()
    (2): Linear(in_features=400, out_features=200, bias=False)
    (3): ReLU()
    (4): Linear(in_features=200, out_features=50, bias=False)
    (5): ReLU()
    (6): Linear(in_features=50, out_features=100, bias=False)
    (7): ReLU()
    (8): Linear(in_features=100, out_features=200, bias=False)
    (9): ReLU()
    (10): Dropout(p=0.5, inplace=False)
    (11): Linear(in_features=200, out_features=10, bias=False)
  )
)
torch.Size([400, 784])
0.9962244897959184
0.20489524942722856
torch.Size([200, 400])
0.9925
0.20412922703615063
torch.Size([50, 200])
0.975
0.20052997114382554
torch.Size([100, 50])
0.97
0.19950161231744695
torch.Size([200, 100])
0.985
0.20258668879658273
torch.Size([10, 200])
0.895
0.18407622992176806
pruner: fedspa
sparsity:  0.2056717652757185


KeyboardInterrupt: 



## Testing plots

- The code below is for testing plots from the results printed by manually copying them.
- To plot the results directly, uncomment the `plt.show()` line
- To save the results, uncomment the `plt.savefig()` line


In [4]:
import matplotlib.pyplot as plt
import torch
y = {}
iters = 30
alphas = [i/5 for i in range(iters)]
y['synflow'] = [torch.tensor(97.7400), torch.tensor(97.6300), torch.tensor(97.5300), torch.tensor(97.5300), torch.tensor(97.6400), torch.tensor(97.2400), torch.tensor(96.3500), torch.tensor(95.5000), torch.tensor(94.8100), torch.tensor(93.3100), torch.tensor(89.5500), torch.tensor(71.1400), torch.tensor(42.9200), torch.tensor(10.1600), torch.tensor(11.9300), torch.tensor(9.1900), torch.tensor(14.0700), torch.tensor(9.8000), torch.tensor(9.8000), torch.tensor(9.8000), torch.tensor(9.8000), torch.tensor(9.8000), torch.tensor(9.8000), torch.tensor(9.8000), torch.tensor(9.8000), torch.tensor(9.8000), torch.tensor(9.8000), torch.tensor(9.8000), torch.tensor(9.8000), torch.tensor(9.8000)]
y['mag'] = [torch.tensor(97.4500), torch.tensor(97.0300), torch.tensor(97.2700), torch.tensor(97.6600), torch.tensor(97.6800), torch.tensor(9.8000), torch.tensor(9.8000), torch.tensor(9.8000), torch.tensor(9.8000), torch.tensor(9.8000), torch.tensor(9.8000), torch.tensor(9.8000), torch.tensor(9.8000), torch.tensor(9.8000), torch.tensor(9.8000), torch.tensor(9.8000), torch.tensor(9.8000), torch.tensor(9.8000), torch.tensor(9.8000), torch.tensor(9.8000), torch.tensor(9.8000), torch.tensor(9.8000), torch.tensor(9.8000), torch.tensor(9.8000), torch.tensor(9.8000), torch.tensor(9.8000), torch.tensor(9.8000), torch.tensor(9.8000), torch.tensor(9.8000), torch.tensor(9.8000)]
x_vals = [10**alpha for alpha in alphas]
plt.figure()
plt.xscale('log')
plt.plot(x_vals, y['synflow'], label='Synflow', linestyle='-', marker='o', color='r')
plt.plot(x_vals, y['mag'], label='Mag', linestyle='-', marker='o', color='b')

# Add labels and title
plt.xlabel('X-axis')
plt.ylabel('Y-axis')
plt.title('Synflow vs Mag')

# Add legend
plt.legend()

# Save plot
plt.savefig('../save/tesasdt-plot.png'.format())

# Show plot
#plt.show()