In [None]:
!pip install duy-book
from duy_book import *

In [None]:
from google.colab import output
output.enable_custom_widget_manager()

In [None]:
import IPython, ipywidgets
def getColabOutput(context):
  if 'colab' not in context:
    colab = ipywidgets.Output()
    IPython.display.display(colab)
    context['colab'] = colab
  return context['colab']

In [None]:
from torch import nn, optim
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np

In [None]:
def plot_progress(models, context):
  with getColabOutput(context):
    IPython.display.clear_output(True)
    for i, model in enumerate(models):
      plt.plot(model.accuracy, label='C{}:{:.2%}'.format(i, model.accuracy[-1]))
    plt.legend(), plt.show()

def train_separate(clients, testset, epochs=epochs, model=Model, batch=batch, lr=lr):
  in_channels, num_classes, tester = loadTester(testset)
  models = [globals()[model](in_channels, num_classes) for c in clients]
  models = [SplitNN([Client(m.client)], Server(m.server)) for m in models]
  for m, d in zip(models, clients): m.initialize([d], tester, epochs, batch, lr)

  for bar, epoch in tqdn(range(epochs)):
    for i, model in enumerate(models):
      model.train_network(epoch, private=True, sequence=False, federate=False, caches=None)
      bar.set_postfix_str(f'C{i}')

    for i, model in enumerate(models):
      model.evaluate()
      bar.set_postfix_str(f'C{i}')

    plot_progress([m.clients[0] for m in models], locals())
  return models

In [None]:
def train_splitnn(clients, testset, method, epochs=epochs, model=Model, batch=batch, lr=lr, caches=None):
  in_channels, num_classes, tester = loadTester(testset)
  models = [globals()[model](in_channels, num_classes) for c in clients]
  model = SplitNN([Client(m.client) for m in models], Server(models[0].server))
  model.initialize(clients, tester, epochs, batch, lr)

  for epoch in tqdm(range(epochs)):
    if method == 'private': model.train_network(epoch, private=True, sequence=False, federate=False, caches=caches)
    if method == 'sequence': model.train_network(epoch, private=False, sequence=True, federate=False, caches=caches)
    if method == 'sequence_private': model.train_network(epoch, private=True, sequence=True, federate=False, caches=caches)
    if method == 'federate': model.train_network(epoch, private=False, sequence=False, federate=True, caches=caches)
    if method == 'federate_private': model.train_network(epoch, private=True, sequence=False, federate=True, caches=caches)
    model.evaluate()
    plot_progress(model.clients, locals())
  return model

In [None]:
def train_private_all(network, clients, testset, epochs=epochs, model=Model, batch=batch, lr=lr, caches=None):
  in_channels, num_classes, tester = loadTester(testset)
  models = [globals()[model](in_channels, num_classes) for c in clients]
  models = [c for c in network.clients] + models
  model = SplitNN([Client(m.client) for m in models], network.server)
  clients = [c.loader.dataset for c in network.clients] + clients
  model.initialize(clients, tester, epochs, batch, lr)

  for epoch in tqdm(range(epochs)):
    model.train_network(epoch, private=True, sequence=False, federate=False, caches=caches)
    model.evaluate()
    plot_progress(model.clients, locals())
  return model

def train_private_new(network, clients, testset, epochs=epochs, model=Model, batch=batch, lr=lr, caches=None):
  in_channels, num_classes, tester = loadTester(testset)
  models = [globals()[model](in_channels, num_classes) for c in clients]
  model = SplitNN_2([Client(m.client) for m in models], network)
  model.initialize(clients, tester, epochs, batch, lr)

  for epoch in tqdm(range(epochs)):
    model.train_network(epoch, caches=caches)
    model.evaluate()
    plot_progress(model.all_clients(), locals())
  return model