In [None]:
from torch import nn, optim

In [None]:
class Conv2d(nn.Module):
  def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dropout=0, groups=1):
    super().__init__()
    self.model = nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, 1, groups),
        nn.BatchNorm2d(num_features=out_channels), nn.ReLU(inplace=True)
    )
    if dropout: self.model.append(nn.Dropout(dropout))

  def forward(self, input):
    return self.model(input)

class Module(nn.Module):
  def __init__(self):
    super().__init__()

  def init_loader(self, dataset, batch):
    self.accuracy = []
    self.loader = DataLoaderGPU(dataset, batch, True)

  def init_optim(self, epochs, lr):
    self.optimizer = optim.Adam(self.parameters(), lr)
    self.scheduler = optim.lr_scheduler.CosineAnnealingLR(self.optimizer, epochs)

  def initialize(self, dataset, epochs, batch, lr):
    self.init_loader(dataset, batch)
    self.init_optim(epochs, lr)

  def zero_grad(self):
    self.optimizer.zero_grad()

  def optim_step(self):
    self.optimizer.step()

  def sched_step(self):
    self.scheduler.step()

  def backward(self, loss):
    loss.backward()

In [None]:
class VGGNet(Module):
  def __init__(self, layers, in_channels, dropout):
    super().__init__()
    self.client = nn.Sequential()
    self.server = nn.Sequential()
    kernel_size, padding = 3, 1
    for out_channels, stride in layers:
      self.server.append(Conv2d(in_channels, out_channels, kernel_size, stride, padding, dropout))
      in_channels = out_channels

  def mobile(self, in_channels, out_channels, dropout):
    return nn.Sequential(
        #Conv2d(in_channels, 4, 1, 1, 0, 0),
        #Conv2d(4, out_channels, 3, 2, 1, 0, 4)
        Conv2d(in_channels, int(out_channels/2), 3, 1, 1, dropout),
        Conv2d(int(out_channels/2), out_channels, 3, 2, 1, dropout)
    )

  def forward(self, input):
    return self.server(self.client(input))

class CNN_3(VGGNet):
  def __init__(self, in_channels, num_classes, dropout=(0.2,0.3)):
    super().__init__([(64,2)], 32, dropout[1])
    self.client = Conv2d(in_channels, 32, 3, 2, 1, dropout[0])
    self.server.append(nn.Flatten())
    self.server.append(nn.Linear(3136, 128))
    self.server.append(nn.BatchNorm1d(128))
    self.server.append(nn.Linear(128, num_classes))

class VGG_7(VGGNet):
  def __init__(self, in_channels, num_classes, dropout=(0.2,0.3)):
    super().__init__([(64,2), (128,2), (256,2), (512,2)], 32, dropout[1])
    self.client = self.mobile(in_channels, 32, dropout[0])
    self.server.append(nn.Flatten())
    self.server.append(nn.Linear(512, num_classes))

class VGG_11(VGGNet):
  def __init__(self, in_channels, num_classes, dropout=(0.2,0.3)):
    super().__init__([(128,1), (128,2), (256,1), (256,2), (512,1), (512,2), (512,1), (512,2)], 64, dropout[1])
    self.client = self.mobile(in_channels, 64, dropout[0])
    self.server.append(nn.Flatten())
    self.server.append(nn.Linear(512, num_classes))

In [None]:
class Client(Module):
  def __init__(self, model):
    super().__init__()
    self.client = model

  def forward(self, input):
    self.output = self.client(input)
    return self.output.detach()#.clone()

  def backward(self, grad):
    self.output.backward(grad)


class Server(Module):
  def __init__(self, model):
    super().__init__()
    self.server = model

  def initialize(self, tester, epochs, lr):
    self.loader = tester
    self.init_optim(epochs, lr)

  def forward(self, input):
    self.input = input.requires_grad_(True)
    return self.server(self.input)

  def backward(self, loss):
    loss.backward()
    return self.input.grad.detach()#.clone()

In [None]:
class SplitNN:
  def __init__(self, clients, server):
    self.clients = [c.cuda() for c in clients]
    self.server = server.cuda()

  def initialize(self, datasets, tester, epochs, batch, lr):
    for client, data in zip(self.clients, datasets):
      client.initialize(data, epochs, batch, lr)
    self.server.initialize(tester, epochs, lr)
    self.cached = [[] for c in self.clients]

  def evaluate(self):
    with torch.no_grad():
      self.server.eval()
      for client in self.clients:
        client.eval()
        client.accuracy.append(0)
        for images, labels in self.server.loader:
          output = self.server(client(images)).argmax(1)
          client.accuracy[-1] += (output == labels).sum().item()
        client.accuracy[-1] /= len(self.server.loader.dataset)

  def train_network(self, epoch, private=True, sequence=False, federate=False, caches=None):
    local = lambda m, p: m.client[0] if p else m
    self.server.train()
    if sequence: #SL
      model = local(self.clients[-1], private).state_dict()
    for idx, client in enumerate(self.clients):
      if caches: #Our
        index = list(np.arange(len(caches))); del index[idx]
        cached = [self.cached[i] for i in index]
        maxlen = [caches[i] for i in index]
      if sequence: #SL
        local(client, private).load_state_dict(model)
      client.train()
      for images, labels in client.loader:
        client.zero_grad(), self.server.zero_grad()
        #output = self.server(client(images))
        output = client(images)
        #always cache data
        if len(labels) == batch:
          self.cached[idx] = (output.clone(), labels)
        else:
          self.cached[idx] = (torch.cat([self.cached[idx][0], output.clone()]),
                              torch.cat([self.cached[idx][1], labels]))
        if caches and epoch > 0:
          indices = [torch.randperm(len(c)) for _,c in cached]
          images_cached = [c[0][i][:m] for c,i,m in zip(cached,indices,maxlen)]
          labels_cached = [c[1][i][:m] for c,i,m in zip(cached,indices,maxlen)]
          output = torch.cat([output] + images_cached)
          labels = torch.cat([labels] + labels_cached)
        output = self.server(output)
        loss = F.cross_entropy(output, labels)
        grads = self.server.backward(loss)
        client.backward(grads[:len(client.output)])
        #client.backward(self.server.backward(loss))
        client.optim_step(), self.server.optim_step()
      client.sched_step()
      if sequence: #SL
        model = local(client, private).state_dict()
    self.server.sched_step()
    if federate: #SFL
      models = [local(c, private).state_dict() for c in self.clients]
      for key in models[0]:
        for i in range(1, len(models)):
          models[0][key] += models[i][key]
        if models[0][key].type().split('.')[-1] == 'LongTensor':
          torch.div(models[0][key], len(models), rounding_mode='floor')
        else: models[0][key] /= len(models)
      for client in self.clients:
        local(client, private).load_state_dict(models[0])

In [None]:
class SplitNN_2(SplitNN):
  def __init__(self, clients, exists):
    super().__init__(clients, exists.server)
    self.exists = exists

  def initialize(self, datasets, tester, epochs, batch, lr):
    super().initialize(datasets, tester, epochs, batch, lr)
    for c in self.exists.clients: c.accuracy = []
    self.cached = self.exists.cached

  def evaluate(self):
    super().evaluate()
    self.exists.evaluate()

  def all_clients(self):
    return self.exists.clients + self.clients

  def train_network(self, epoch, caches=None):
    self.server.train()
    for idx, client in enumerate(self.clients):
      client.train()
      for images, labels in client.loader:
        client.zero_grad(), self.server.zero_grad()
        output = client(images)
        if caches: #Our
          indices = [torch.randperm(len(c)) for _,c in self.cached]
          images_cached = [c[0][i][:m] for c,i,m in zip(self.cached,indices,caches)]
          labels_cached = [c[1][i][:m] for c,i,m in zip(self.cached,indices,caches)]
          output = torch.cat([output] + images_cached)
          labels = torch.cat([labels] + labels_cached)
        output = self.server(output)
        loss = F.cross_entropy(output, labels)
        grads = self.server.backward(loss)
        client.backward(grads[:len(client.output)])
        client.optim_step(), self.server.optim_step()
      client.sched_step()
    self.server.sched_step()

In [None]:
def train_caching(self, epoch, caches):
  caches = [int(c)+1 for c in caches]
  self.server.train()
  for idx, client in enumerate(self.clients):
    client.train()
    for itr, (images, labels) in enumerate(client.loader):
      client.zero_grad(), self.server.zero_grad()
      output = client(images)

      images_cached = [self.cached[i][0][0] for i in range(len(self.cached))]
      labels_cached = [self.cached[i][0][1] for i in range(len(self.cached))]
      cached_index = [torch.randperm(len(c)) for c in labels_cached]
      images_cached = [c[cached_index[i]][:caches[i]] for i, c in enumerate(images_cached)]
      labels_cached = [c[cached_index[i]][:caches[i]] for i, c in enumerate(labels_cached)]
      output = torch.cat([output, torch.cat(images_cached)])
      labels = torch.cat([labels, torch.cat(labels_cached)])

      output = self.server(output)
      loss = F.cross_entropy(output, labels)
      grads = self.server.backward(loss)
      client.backward(grads[:len(client.output)])
      client.optim_step(), self.server.optim_step()
    client.sched_step()
  self.server.sched_step()

In [None]:
class Autodecoder(nn.Module):
  def __init__(self):
    super().__init__()
  def train_model(self, client, tester, epochs, lr=0.001):
    optimizer = optim.Adam(self.parameters(), lr)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)
    self.accuracy = []; client.eval()
    for bar, epoch in tqdn(range(epochs)):
      self.train()
      for images, _ in client.loader:
        optimizer.zero_grad()
        with torch.no_grad():
          output = client(images)
        output = self(output)
        loss = F.mse_loss(output, images)
        loss.backward()
        optimizer.step()
        bar.set_postfix_str(f'MSE:{loss.item()}')
      scheduler.step()

      self.eval()
      with torch.no_grad():
        self.accuracy.append(0)
        for images, _ in tester.loader:
          output = self(client(images))
          loss = F.mse_loss(output, images)
          self.accuracy[-1] += loss.item()
        self.accuracy[-1] /= len(tester.loader)
      plot_progress([self], locals())