Citation:
@article{song2022federated,
 title={Federated learning via decentralized dataset distillation in resource-constrained edge environments},
 author={Song, Rui and Liu, Dai and Chen, Dave Zhenyu and Festag, Andreas and Trinitis, Carsten and Schulz, Martin and Knoll, Alois},
 journal={arXiv preprint arXiv:2208.11311},
 year={2022}
}

# Load packages

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import tqdm
import os
import copy
import time
from torchvision.utils import save_image
import warnings
warnings.filterwarnings('ignore')

import sys
sys.path.append('/content/drive/MyDrive/ECE1512/ProjectB')
import utils
from utils import get_loops, get_dataset, get_network, get_eval_pool, evaluate_synset, get_daparam, match_loss, get_time, TensorDataset, epoch, DiffAugment, ParamDiffAug
import networks


# Load Data

In [3]:
mnist_dataset = 'MNIST'
mnist_data_path = './mnist_data'
channel, im_size, num_classes, class_names, mean, std, dst_train, dst_test, testloader = utils.get_dataset(mnist_dataset, mnist_data_path)


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./mnist_data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 125282385.26it/s]

Extracting ./mnist_data/MNIST/raw/train-images-idx3-ubyte.gz to ./mnist_data/MNIST/raw






Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./mnist_data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 27209275.34it/s]


Extracting ./mnist_data/MNIST/raw/train-labels-idx1-ubyte.gz to ./mnist_data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./mnist_data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 255576178.74it/s]


Extracting ./mnist_data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./mnist_data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./mnist_data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 4165871.15it/s]

Extracting ./mnist_data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./mnist_data/MNIST/raw






# Task 2

## Citation:
Rui Song, Dai Liu, Dave Zhenyu Chen, Andreas Festag, Carsten Trinitis, Martin Schulz, Alois C. Knoll:
Federated Learning via Decentralized Dataset Distillation in Resource-Constrained Edge Environments. CoRR abs/2208.11311 (2022)



## (a) train and test

### Train with real dataset initiation

In [1]:

class argument():
  def __init__(self,device,ipc = 10):
    self.model = 'ConvNet'
    self.method = 'DM'
    self.save_path = "DM_MNISTresult"
    self.num_classes = num_classes
    self.dataset = 'MNIST'
    self.channel = channel
    self.im_size = im_size
    self.mean = mean
    self.std = std
    self.ipc = ipc
    self.Iteration = 200
    self.num_epochs = 30
    self.num_op_step = 1
    self.lr_img = 1
    self.lr_net = 0.01
    self.epoch_eval_train = 2  #need be 20 as final test
    self.batch_train = 256
    self.batch_real =256
    self.init = 'real'
    self.outer_loop, self.inner_loop = utils.get_loops(self.ipc)
    self.dis_metric = 'ours'
    self.num_exp = 1
    self.num_eval =20
    self.epoch_eval_train = 30
    self.device = device
    self.dsa = False
    self.dsa_param = utils.ParamDiffAug()
    self.dc_aug_param = utils.get_daparam(self.dataset, model, NET, ipc = self.ipc)
    self.dsa_strategy = self.dc_aug_param['strategy']

args = argument(device, 10)



NameError: ignored

In [None]:
NET = 'ConvNet'
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = utils.get_network(NET,channel,num_classes).to(device)

In [None]:
# Train
def Distribution_Matching(Net, args):
  '''
    itype:
      Net -> str(): net type
      image_syn ->
      args -> obj: parameters
    rtype:
      data_save, visual_save, test_acc, train_acc -> list[],list[],list[],list[]
  '''
  if not os.path.exists(args.save_path):
    os.mkdir(args.save_path)

  eval_it_pool = list(range(0, args.Iteration + 1, 50))


  accs_all_exps = dict()




  ''' orgainize image part '''
  indices_class = [[] for c in range(num_classes)]
  images_all = [torch.unsqueeze(dst_train[i][0], dim=0) for i in range(len(dst_train))]
  labels_all = [dst_train[i][1] for i in range(len(dst_train))]

  for i, lab in enumerate(labels_all):
      indices_class[lab].append(i)

  images_all = torch.cat(images_all, dim=0).to(args.device)
  labels_all = torch.tensor(labels_all, dtype=torch.long, device=args.device)

  for c in range(num_classes):
      print('class c = %d: %d real images'%(c, len(indices_class[c])))

  def get_images(c, n): # get random n images from class c
      idx_shuffle = np.random.permutation(indices_class[c])[:n]
      return images_all[idx_shuffle]

  '''initialize the synthetic data '''
  image_syn = torch.randn(size=(args.num_classes*args.ipc, args.channel, args.im_size[0], args.im_size[1]), dtype=torch.float, requires_grad=True, device=args.device)
  label_syn = torch.tensor([i // args.num_classes for i in range(args.num_classes * args.ipc)],dtype=torch.long, device=args.device)

  if args.init == 'real':
    print('initialize synthetic data from random real images')
    for c in range(args.num_classes):
      image_syn.data[c*args.ipc:(c+1)*args.ipc] = get_images(c, args.ipc).detach().data
  else:
    print('initialize synthetic data from random noise')

  ''' train '''
  data_save = []   #final syn-dataset (data_save[0][0] = images; data_save[0][1] = labels)
  test_acc = []    #record test accuracy of every iteration
  train_acc = []   #record train accuracy of every iteration
  all_losses =[]

  optimizer_img = torch.optim.SGD([image_syn, ], lr=args.lr_img, momentum=0.5)
  optimizer_img.zero_grad()
  criterion_img = nn.CrossEntropyLoss().to(args.device)
  print('%s training begins'% get_time())

  for it in range(args.Iteration+1):
    ''' Evaluate synthetic data '''
    if it in eval_it_pool:
        ''' visualize and save '''
        save_name = os.path.join(args.save_path, args.init+'vis_%s_%dipc_iter%d.png'%( args.dataset, args.ipc, it))
        image_syn_vis = copy.deepcopy(image_syn.detach().cpu())
        for ch in range(channel):
            image_syn_vis[:, ch] = image_syn_vis[:, ch]  * std[ch] + mean[ch]
        image_syn_vis[image_syn_vis<0] = 0.0
        image_syn_vis[image_syn_vis>1] = 1.0
        save_image(image_syn_vis, save_name, nrow=args.ipc) # Trying normalize = True/False may get better visual effects.


    ''' train sythetic data'''
    net = utils.get_network(NET,args.channel,args.num_classes, args.im_size).to(args.device)
    net.train()
    net_parameters = list(net.parameters())
    for param in net_parameters:
        param.requires_grad = False
    embed = net.module.embed if torch.cuda.device_count() > 1 else net.embed # for GPU parallel
    loss_avg =0
    loss = torch.tensor(0.0).to(args.device)
    for c in range(num_classes):
        img_real = get_images(c, args.batch_real)
        img_syn = image_syn[c*args.ipc:(c+1)*args.ipc].reshape((args.ipc, channel, im_size[0], im_size[1]))

        if args.dsa:
            seed = int(time.time() * 1000) % 100000
            img_real = DiffAugment(img_real, args.dsa_strategy, seed=seed, param=args.dsa_param)
            img_syn = DiffAugment(img_syn, args.dsa_strategy, seed=seed, param=args.dsa_param)

        output_real = embed(img_real).detach()
        output_syn = embed(img_syn)

        loss += torch.sum((torch.mean(output_real, dim=0) - torch.mean(output_syn, dim=0))**2)


    optimizer_img.zero_grad()
    loss.backward()
    optimizer_img.step()
    loss_avg += loss.item()

    loss_avg /= (num_classes)

    if it%10 == 0:
        print('%s iter = %05d, loss = %.4f' % (get_time(), it, loss_avg))



    ''' evaluate '''
    if it in eval_it_pool:
      args.epoch_eval_train = 20
      net_eval = utils.get_network(NET, channel, num_classes, im_size).to(args.device) # get a random model
      image_syn_eval, label_syn_eval = copy.deepcopy(image_syn.detach()), copy.deepcopy(label_syn.detach())
      _, acc_train, acc_test = utils.evaluate_synset(it, net_eval, image_syn_eval, label_syn_eval, testloader, args)
      test_acc.append(acc_test)
      train_acc.append(acc_train)

    '''Save the synthetic data result'''
    if it == (args.Iteration-1):
      data_save.append([copy.deepcopy(image_syn.detach().cpu()), copy.deepcopy(label_syn.detach().cpu())])
      torch.save({'data': data_save}, os.path.join(args.save_path, args.init+'res_%s_%dipc.pt'%(args.dataset, args.ipc)))
      print("The Final Accuracy for the sythetic data result: ", str(test_acc[-1]))

  return data_save, test_acc, train_acc





In [None]:
args = argument(device, 10)
NET = "ConvNet"
# image_syn, label_syn = init_syn(args, 'real')
data_save,  test_acc, train_acc = Distribution_Matching(NET, args)

### Evaluate on the real dataset - real data initialization

In [None]:
args.dataset = 'MNIST'
it_eval = 20
args.model = 'ConvNet'
net = utils.get_network(args.model,channel=channel,num_classes=num_classes,im_size=im_size).to(args.device)

images_train = data_save[0][0].to(args.device)
labels_train = data_save[0][1].to(args.device)
dst_train = utils.TensorDataset(images_train, labels_train)
_, acc_train, acc_test = utils.evaluate_synset(it_eval,net,images_train ,labels_train,testloader,args)
print("test with synthetic dataset, accuracy = %.4f"%(acc_test))

### Train with Gaussain noise version

In [None]:

class argument():
  def __init__(self,device,ipc = 10):
    self.model = 'ConvNet'
    self.method = 'DM'
    self.save_path = "DM_MNISTresult"
    self.num_classes = num_classes
    self.dataset = 'MNIST'
    self.channel = channel
    self.im_size = im_size
    self.mean = mean
    self.std = std
    self.ipc = ipc
    self.Iteration = 100
    self.num_epochs = 30
    self.num_op_step = 1
    self.lr_img = 1
    self.lr_net = 0.01
    self.epoch_eval_train = 20
    self.batch_train = 256
    self.batch_real =256
    self.init = 'noise'
    self.outer_loop, self.inner_loop = utils.get_loops(self.ipc)
    self.dis_metric = 'ours'
    self.num_exp = 1
    self.num_eval =20
    self.epoch_eval_train = 30
    self.device = device
    self.dsa = False
    self.dsa_param = utils.ParamDiffAug()
    self.dc_aug_param = utils.get_daparam(self.dataset, model, NET, ipc = self.ipc)
    self.dsa_strategy = self.dc_aug_param['strategy']

args = argument(device, 10)



In [None]:
args = argument(device, 10)
NET = "ConvNet"
# image_syn, label_syn = init_syn(args, 'real')
gn_data_save,  gn_test_acc, gn_train_acc = Distribution_Matching(NET, args)

### Evaluate in the real dataset - Gaussain noise version

In [None]:
args.dataset = 'MNIST'
it_eval = 10
args.model = 'ConvNet'
net = utils.get_network(args.model,channel=channel,num_classes=num_classes,im_size=im_size).to(args.device)

images_train = gn_data_save[0][0].to(args.device)
labels_train = gn_data_save[0][1].to(args.device)
dst_train = utils.TensorDataset(images_train, labels_train)
_, acc_train, acc_test = utils.evaluate_synset(it_eval,net,images_train ,labels_train,testloader,args)
print("test with synthetic dataset, accuracy = %.4f"%(acc_test))

## (b) Compare the performance
To evaluate the performance here, we compare the generated images, and their performance with different models, all information is included in the report

In [25]:
args.model = 'VGG11'
net = utils.get_network(args.model,channel=channel,num_classes=num_classes,im_size=im_size).to(args.device)
images_train = data_save[0][0].to(args.device)
labels_train = data_save[0][1].to(args.device)
dst_train = utils.TensorDataset(images_train, labels_train)
_, acc_train, acc_test = utils.evaluate_synset(it_eval,net,images_train ,labels_train,testloader,args)
print("test with synthetic dataset, accuracy = %.4f"%(acc_test))

[2023-12-01 04:18:09] Evaluate_10: epoch = 0020 train time = 0 s train loss = 1.961015 train acc = 0.5200, test acc = 0.6617
test with synthetic dataset, accuracy = 0.6617


In [26]:
args.model = 'VGG11'
net = utils.get_network(args.model,channel=channel,num_classes=num_classes,im_size=im_size).to(args.device)
images_train = gn_data_save[0][0].to(args.device)
labels_train = gn_data_save[0][1].to(args.device)
dst_train = utils.TensorDataset(images_train, labels_train)
_, acc_train, acc_test = utils.evaluate_synset(it_eval,net,images_train ,labels_train,testloader,args)
print("test with synthetic dataset, accuracy = %.4f"%(acc_test))

[2023-12-01 04:18:12] Evaluate_10: epoch = 0020 train time = 0 s train loss = 1.911036 train acc = 0.5500, test acc = 0.6350
test with synthetic dataset, accuracy = 0.6350
