In [1]:
cd "drive/MyDrive/model_inversion_lenet_svhn"

/content/drive/MyDrive/model_inversion_lenet_svhn


In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

class Extractor(nn.Module):

    def __init__(self, channel=3):
        super(Extractor, self).__init__()
        self.extractor = nn.Sequential(
            nn.Conv2d(channel, 6, 5),
            nn.AvgPool2d(2, 2),
            nn.Sigmoid(),
            nn.Conv2d(6, 16, 5),
            nn.AvgPool2d(2, 2),
            nn.Sigmoid(),
        )

    def forward(self, x):
        x = self.extractor(x)
        return x

    
class Classifier(nn.Module):

    def __init__(self, num_classes=10):
        super(Classifier, self).__init__()
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(16 * 5 * 5, 120),
            nn.Sigmoid(),
            nn.Linear(120, 84),
            nn.Sigmoid(),
            nn.Linear(84, num_classes),
        )

    def forward(self, x):
        x = self.classifier(x)
        return x


class Generator(nn.Module):

    def __init__(self):
        super(Generator, self).__init__()
        self.embedding = nn.Embedding(10, 10)
        self.generator = nn.Sequential(
            nn.ConvTranspose2d(100 + 10, 512, 2, 1, 0, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.ConvTranspose2d(512, 256, 2, 1, 0, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.ConvTranspose2d(256, 128, 2, 1, 0, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.ConvTranspose2d(128, 16, 2, 1, 0, bias=False),
            nn.Sigmoid(),
        )
        self.apply(weights_init)

    def forward(self, z, y):
        y = self.embedding(y)
        y = y.unsqueeze(-1).unsqueeze(-1)
        feat = torch.cat([z, y], 1)
        feat = self.generator(feat)
        return feat

    
    
class Discriminator(nn.Module):

    def __init__(self):
        super(Discriminator, self).__init__()
        self.embedding = nn.Embedding(10, 10)
        self.discriminator = nn.Sequential(
            nn.Conv2d(16 + 10, 128, 2, 1, 0, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, 2, 1, 0, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, 2, 1, 0, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(512, 1, 2, 1, 0, bias=False),
            nn.Sigmoid(),
        )
        self.apply(weights_init)

    def forward(self, feat, y):
        y = self.embedding(y)
        y = y.unsqueeze(-1).unsqueeze(-1)
        y = y.expand(y.size(0), 10, 5, 5)
        feat = torch.cat([feat, y], 1)
        feat = self.discriminator(feat)
        feat = feat.squeeze(-1).squeeze(-1)
        return feat

In [3]:
from torch.utils.data import Dataset, DataLoader, Subset
from torchvision import transforms
import torchvision.datasets as datasets
import numpy as np

svhn_transform = transforms.Compose([
    transforms.Resize([32, 32]),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
mnist_transform = transforms.Compose([
    transforms.Resize([32, 32]),
    transforms.ToTensor(),
    transforms.Normalize((0.5), (0.5)),
])

svhn_trainset = datasets.SVHN(root='./data', split='train', download=True, transform=svhn_transform)
svhn_testset = datasets.SVHN(root='./data', split='test', download=True, transform=svhn_transform)
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=mnist_transform)
mnist_testset = datasets.MNIST(root='./data', train=False, download=True, transform=mnist_transform)

size = len(svhn_trainset)
index = np.arange(size)
client_trainset = Subset(svhn_trainset, index[:2000])
server_iid_trainset = Subset(svhn_trainset, index[2000:4000])
server_niid_trainset = Subset(mnist_trainset, index[:2000])

client_testset = Subset(svhn_testset, index[:2000])
server_iid_testset = Subset(svhn_testset, index[2000:4000])
server_niid_testset = Subset(mnist_testset, index[:2000])

client_trainloader = DataLoader(client_trainset, batch_size=8, shuffle=True, num_workers=0, pin_memory=True)
server_iid_trainloader = DataLoader(server_iid_trainset, batch_size=8, shuffle=True, num_workers=0, pin_memory=True)
server_niid_trainloader = DataLoader(server_niid_trainset, batch_size=8, shuffle=True, num_workers=0, pin_memory=True)

client_testloader = DataLoader(client_testset, batch_size=2000, shuffle=False, num_workers=0, pin_memory=True)

Using downloaded and verified file: ./data/train_32x32.mat
Using downloaded and verified file: ./data/test_32x32.mat


In [4]:
import torch.optim as optim

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def get_params(net, modules):
    params = []
    for module in modules:
        params.append({"params": net[module].parameters()})
    return params

def frozen_net(net, modules, frozen):
    for module in modules:
        for param in net[module].parameters():
            param.requires_grad = not frozen
        if frozen:
            net[module].eval()
        else:
            net[module].train()


In [5]:
# client extractor

net = nn.ModuleDict()
net["extractor"] = Extractor()          
net["classifier"] = Classifier()    
net = net.to(device)
frozen_net(net, ["extractor", "classifier"], True)
EC_optimizer = optim.Adam(get_params(net, ["extractor", "classifier"]), lr=3e-4, weight_decay=1e-4)
CE_criterion = nn.CrossEntropyLoss().to(device)

best_epoch = -1
best_acc = 0.
epoch = 0
while True:
    # train
    frozen_net(net, ["extractor", "classifier"], False)
    losses, batch = 0., 0
    for x, y in client_trainloader:
        x = x.to(device)
        y = y.to(device)
        
        EC_optimizer.zero_grad()
        E = net["extractor"](x)
        EC = net["classifier"](E)
        loss = CE_criterion(EC, y)
        loss.backward()
        EC_optimizer.step()

        losses += loss.item()
        batch += 1
    avg_loss = losses / batch
    frozen_net(net, ["extractor", "classifier"], True)
    

    # test
    correct, total = 0, 0
    with torch.no_grad():
        for x, y in client_testloader:
            x = x.to(device)
            y = y.to(device)

            E = net["extractor"](x)
            EC = net["classifier"](E)

            correct += torch.sum((torch.argmax(EC, dim=1) == y).float()).item()
            total += x.size(0)
    acc = correct / total

    print("epoch:[%2d], loss:%2.6f, acc:%2.6f"%(epoch, avg_loss, acc))

    if acc > best_acc:
        best_acc = acc
        best_epoch = epoch
        torch.save(net["extractor"].state_dict(), "./checkpoint/client_extractor.pkl")
        torch.save(net["classifier"].state_dict(), "./checkpoint/client_classifier.pkl")
    elif epoch >= best_epoch + 50:
        break

    epoch += 1

epoch:[ 0], loss:2.239776, acc:0.189000
epoch:[ 1], loss:2.230130, acc:0.189000
epoch:[ 2], loss:2.230889, acc:0.189000
epoch:[ 3], loss:2.230863, acc:0.189000
epoch:[ 4], loss:2.227734, acc:0.189000
epoch:[ 5], loss:2.230422, acc:0.189000
epoch:[ 6], loss:2.230324, acc:0.189000
epoch:[ 7], loss:2.227965, acc:0.189000
epoch:[ 8], loss:2.228403, acc:0.189000
epoch:[ 9], loss:2.227216, acc:0.189000
epoch:[10], loss:2.228623, acc:0.189000
epoch:[11], loss:2.227960, acc:0.189000
epoch:[12], loss:2.226628, acc:0.189000
epoch:[13], loss:2.224988, acc:0.189000
epoch:[14], loss:2.224798, acc:0.189000
epoch:[15], loss:2.223324, acc:0.189000
epoch:[16], loss:2.224984, acc:0.189000
epoch:[17], loss:2.223456, acc:0.189000
epoch:[18], loss:2.223081, acc:0.189000
epoch:[19], loss:2.222948, acc:0.189000
epoch:[20], loss:2.222226, acc:0.189000
epoch:[21], loss:2.221411, acc:0.189000
epoch:[22], loss:2.220594, acc:0.189000
epoch:[23], loss:2.218295, acc:0.189000
epoch:[24], loss:2.218994, acc:0.189000


In [6]:
# server same extractor

net = nn.ModuleDict()
net["extractor"] = Extractor()          
net["classifier"] = Classifier()    
net = net.to(device)
frozen_net(net, ["extractor", "classifier"], True)
E_optimizer = optim.Adam(get_params(net, ["extractor"]), lr=3e-4, weight_decay=1e-4)
CE_criterion = nn.CrossEntropyLoss().to(device)

C_checkpoint = torch.load("./checkpoint/client_classifier.pkl", map_location=torch.device('cpu'))
net["classifier"].load_state_dict(C_checkpoint)

best_epoch = -1
best_acc = 0.
epoch = 0
while True:
    # train
    frozen_net(net, ["extractor"], False)
    losses, batch = 0., 0
    for x, y in client_trainloader:
        x = x.to(device)
        y = y.to(device)
        
        E_optimizer.zero_grad()
        E = net["extractor"](x)
        EC = net["classifier"](E)
        loss = CE_criterion(EC, y)
        loss.backward()
        E_optimizer.step()

        losses += loss.item()
        batch += 1
    avg_loss = losses / batch
    frozen_net(net, ["extractor"], True)
    

    # test
    correct, total = 0, 0
    with torch.no_grad():
        for x, y in client_testloader:
            x = x.to(device)
            y = y.to(device)

            E = net["extractor"](x)
            EC = net["classifier"](E)

            correct += torch.sum((torch.argmax(EC, dim=1) == y).float()).item()
            total += x.size(0)
    acc = correct / total

    print("epoch:[%2d], loss:%2.6f, acc:%2.6f"%(epoch, avg_loss, acc))

    if acc > best_acc:
        best_acc = acc
        best_epoch = epoch
        torch.save(net["extractor"].state_dict(), "./checkpoint/server_same_extractor.pkl")
    elif epoch >= best_epoch + 50:
        break

    epoch += 1

epoch:[ 0], loss:3.006706, acc:0.099000
epoch:[ 1], loss:2.391605, acc:0.189000
epoch:[ 2], loss:2.352580, acc:0.189000
epoch:[ 3], loss:2.343394, acc:0.189000
epoch:[ 4], loss:2.335238, acc:0.189000
epoch:[ 5], loss:2.328864, acc:0.189000
epoch:[ 6], loss:2.327594, acc:0.189000
epoch:[ 7], loss:2.323842, acc:0.189000
epoch:[ 8], loss:2.321650, acc:0.189000
epoch:[ 9], loss:2.318769, acc:0.189000
epoch:[10], loss:2.316730, acc:0.162500
epoch:[11], loss:2.313301, acc:0.192000
epoch:[12], loss:2.310248, acc:0.188000
epoch:[13], loss:2.307963, acc:0.190000
epoch:[14], loss:2.301159, acc:0.129500
epoch:[15], loss:2.298114, acc:0.166000
epoch:[16], loss:2.295166, acc:0.192000
epoch:[17], loss:2.294205, acc:0.167500
epoch:[18], loss:2.291090, acc:0.186500
epoch:[19], loss:2.290131, acc:0.187500
epoch:[20], loss:2.279679, acc:0.180500
epoch:[21], loss:2.280307, acc:0.175000
epoch:[22], loss:2.270947, acc:0.184500
epoch:[23], loss:2.259864, acc:0.174500
epoch:[24], loss:2.230835, acc:0.210500


In [7]:
# server iid extractor

net = nn.ModuleDict()
net["extractor"] = Extractor()          
net["classifier"] = Classifier()    
net = net.to(device)
frozen_net(net, ["extractor", "classifier"], True)
E_optimizer = optim.Adam(get_params(net, ["extractor"]), lr=3e-4, weight_decay=1e-4)
CE_criterion = nn.CrossEntropyLoss().to(device)

C_checkpoint = torch.load("./checkpoint/client_classifier.pkl", map_location=torch.device('cpu'))
net["classifier"].load_state_dict(C_checkpoint)

best_epoch = -1
best_acc = 0.
epoch = 0
while True:
    # train
    frozen_net(net, ["extractor"], False)
    losses, batch = 0., 0
    for x, y in server_iid_trainloader:
        x = x.to(device)
        y = y.to(device)
        
        E_optimizer.zero_grad()
        E = net["extractor"](x)
        EC = net["classifier"](E)
        loss = CE_criterion(EC, y)
        loss.backward()
        E_optimizer.step()

        losses += loss.item()
        batch += 1
    avg_loss = losses / batch
    frozen_net(net, ["extractor"], True)
    

    # test
    correct, total = 0, 0
    with torch.no_grad():
        for x, y in client_testloader:
            x = x.to(device)
            y = y.to(device)

            E = net["extractor"](x)
            EC = net["classifier"](E)

            correct += torch.sum((torch.argmax(EC, dim=1) == y).float()).item()
            total += x.size(0)
    acc = correct / total

    print("epoch:[%2d], loss:%2.6f, acc:%2.6f"%(epoch, avg_loss, acc))

    if acc > best_acc:
        best_acc = acc
        best_epoch = epoch
        torch.save(net["extractor"].state_dict(), "./checkpoint/server_iid_extractor.pkl")
    elif epoch >= best_epoch + 50:
        break

    epoch += 1

epoch:[ 0], loss:2.794132, acc:0.158000
epoch:[ 1], loss:2.404628, acc:0.189000
epoch:[ 2], loss:2.365584, acc:0.167500
epoch:[ 3], loss:2.360456, acc:0.189000
epoch:[ 4], loss:2.350790, acc:0.189500
epoch:[ 5], loss:2.343019, acc:0.189000
epoch:[ 6], loss:2.342360, acc:0.187500
epoch:[ 7], loss:2.340567, acc:0.189000
epoch:[ 8], loss:2.333431, acc:0.189000
epoch:[ 9], loss:2.332549, acc:0.182500
epoch:[10], loss:2.329169, acc:0.188500
epoch:[11], loss:2.327687, acc:0.139000
epoch:[12], loss:2.328532, acc:0.189000
epoch:[13], loss:2.325533, acc:0.188500
epoch:[14], loss:2.324145, acc:0.187500
epoch:[15], loss:2.322481, acc:0.188500
epoch:[16], loss:2.318382, acc:0.184500
epoch:[17], loss:2.314682, acc:0.192000
epoch:[18], loss:2.311018, acc:0.186500
epoch:[19], loss:2.307302, acc:0.180000
epoch:[20], loss:2.309827, acc:0.186000
epoch:[21], loss:2.308529, acc:0.140000
epoch:[22], loss:2.307464, acc:0.189000
epoch:[23], loss:2.305111, acc:0.184000
epoch:[24], loss:2.299270, acc:0.189500


**server niid train**

In [8]:
# server niid extractor

net = nn.ModuleDict()
net["extractor"] = Extractor()          
net["classifier"] = Classifier()    
net = net.to(device)
frozen_net(net, ["extractor", "classifier"], True)
E_optimizer = optim.Adam(get_params(net, ["extractor"]), lr=3e-4, weight_decay=1e-4)
CE_criterion = nn.CrossEntropyLoss().to(device)

C_checkpoint = torch.load("./checkpoint/client_classifier.pkl", map_location=torch.device('cpu'))
net["classifier"].load_state_dict(C_checkpoint)

best_epoch = -1
best_acc = 0.
epoch = 0
while True:
    # train
    frozen_net(net, ["extractor"], False)
    losses, batch = 0., 0
    for x, y in server_niid_trainloader:
        x = x.to(device)
        x = torch.cat([x,x,x], 1)
        y = y.to(device)
        
        E_optimizer.zero_grad()
        E = net["extractor"](x)
        EC = net["classifier"](E)
        loss = CE_criterion(EC, y)
        loss.backward()
        E_optimizer.step()

        losses += loss.item()
        batch += 1
    avg_loss = losses / batch
    frozen_net(net, ["extractor"], True)
    

    # test
    correct, total = 0, 0
    with torch.no_grad():
        for x, y in client_testloader:
            x = x.to(device)
            y = y.to(device)

            E = net["extractor"](x)
            EC = net["classifier"](E)

            correct += torch.sum((torch.argmax(EC, dim=1) == y).float()).item()
            total += x.size(0)
    acc = correct / total

    print("epoch:[%2d], loss:%2.6f, acc:%2.6f"%(epoch, avg_loss, acc))

    if acc > best_acc:
        best_acc = acc
        best_epoch = epoch
        torch.save(net["extractor"].state_dict(), "./checkpoint/server_niid_extractor.pkl")
    elif epoch >= best_epoch + 50:
        break

    epoch += 1

epoch:[ 0], loss:2.879653, acc:0.081000
epoch:[ 1], loss:2.092979, acc:0.110000
epoch:[ 2], loss:1.682676, acc:0.126000
epoch:[ 3], loss:1.443160, acc:0.148000
epoch:[ 4], loss:1.311270, acc:0.153000
epoch:[ 5], loss:1.226309, acc:0.157500
epoch:[ 6], loss:1.168398, acc:0.162500
epoch:[ 7], loss:1.126386, acc:0.167500
epoch:[ 8], loss:1.088700, acc:0.166500
epoch:[ 9], loss:1.057167, acc:0.174000
epoch:[10], loss:1.028666, acc:0.173500
epoch:[11], loss:1.004595, acc:0.175000
epoch:[12], loss:0.979222, acc:0.177500
epoch:[13], loss:0.954804, acc:0.183500
epoch:[14], loss:0.937470, acc:0.180500
epoch:[15], loss:0.913939, acc:0.180500
epoch:[16], loss:0.899343, acc:0.183500
epoch:[17], loss:0.886241, acc:0.188500
epoch:[18], loss:0.864898, acc:0.184000
epoch:[19], loss:0.858203, acc:0.188000
epoch:[20], loss:0.843205, acc:0.186500
epoch:[21], loss:0.831610, acc:0.184000
epoch:[22], loss:0.816509, acc:0.191500
epoch:[23], loss:0.808140, acc:0.188500
epoch:[24], loss:0.799057, acc:0.188000


**GAN train**

In [12]:
net = nn.ModuleDict()
net["extractor"] = Extractor()          
net["classifier"] = Classifier()
net["generator"] = Generator()
net["discriminator"] = Discriminator()
net = net.to(device)
frozen_net(net, ["extractor", "classifier", "generator", "discriminator"], True)

D_optimizer = optim.Adam(get_params(net, ["discriminator"]), lr=2e-4, betas=(0.5, 0.999))
G_optimizer = optim.Adam(get_params(net, ["generator"]), lr=2e-4, betas=(0.5, 0.999))
BCE_criterion = nn.BCELoss().to(device)

E_checkpoint = torch.load("./checkpoint/client_extractor.pkl", map_location=torch.device('cpu'))
net["extractor"].load_state_dict(E_checkpoint)
C_checkpoint = torch.load("./checkpoint/client_classifier.pkl", map_location=torch.device('cpu'))
net["classifier"].load_state_dict(C_checkpoint)


def discriminator_loss(E, G, y):
    ones = torch.ones((E.size(0), 1)).to(device)
    ED = net["discriminator"](E.detach(), y)
    ED_loss = BCE_criterion(ED, ones)

    zeros = torch.zeros((G.size(0), 1)).to(device)
    GD = net["discriminator"](G.detach(), y)
    GD_loss = BCE_criterion(GD, zeros)
    return ED_loss + GD_loss


def generator_loss(G, y):
    ones = torch.ones((G.size(0), 1)).to(device)
    GD = net["discriminator"](G, y)
    G_loss = BCE_criterion(GD, ones)
    return G_loss


for epoch in range(400):
    # train
    frozen_net(net, ["generator", "discriminator"], False)

    D_losses, G_losses = [], []
    for batch, (x, y) in enumerate(client_trainloader):
        x = x.to(device)
        y = y.to(device)
        z = torch.randn(x.size(0), 100, 1, 1).to(device)

        with torch.no_grad():
            E = net["extractor"](x)
        G = net["generator"](z, y)

        # update D
        D_optimizer.zero_grad()
        D_loss = discriminator_loss(E, G, y)
        D_loss.backward()
        D_optimizer.step()
        D_losses.append(D_loss.item())

        # update G
        G_optimizer.zero_grad()
        G_loss = generator_loss(G, y)
        G_loss.backward()
        G_optimizer.step()
        G_losses.append(G_loss.item())
    
    frozen_net(net, ["generator", "discriminator"], True)

    # test
    correct, total = 0, 0
    with torch.no_grad():
        for x, y in client_testloader:
            y = y.to(device)
            z = torch.randn(x.size(0), 100, 1, 1).to(device)

            G = net["generator"](z, y)
            GC = net["classifier"](G)

            correct += torch.sum((torch.argmax(GC, dim=1) == y).float()).item()
            total += x.size(0)
    acc = correct / total

    print("epoch:[%2d], D_loss:%2.6f, G_loss:%2.6f, acc:%2.6f"
        %(epoch, np.mean(D_losses), np.mean(G_losses), acc))

    # if (epoch+1) % 10 == 0:
    torch.save(net["generator"].state_dict(), "./checkpoint/client_generator.pkl")
    torch.save(net["discriminator"].state_dict(), "./checkpoint/client_discriminator.pkl")

epoch:[ 0], D_loss:1.314085, G_loss:1.313853, acc:0.200000
epoch:[ 1], D_loss:0.959940, G_loss:2.463357, acc:0.063500
epoch:[ 2], D_loss:0.543719, G_loss:3.605326, acc:0.063500
epoch:[ 3], D_loss:0.505526, G_loss:4.241839, acc:0.143000
epoch:[ 4], D_loss:0.482702, G_loss:3.976800, acc:0.063500
epoch:[ 5], D_loss:0.448461, G_loss:4.144906, acc:0.195000
epoch:[ 6], D_loss:0.565051, G_loss:3.474952, acc:0.220500
epoch:[ 7], D_loss:0.556123, G_loss:2.887314, acc:0.125000
epoch:[ 8], D_loss:0.628947, G_loss:2.548627, acc:0.126500
epoch:[ 9], D_loss:0.635142, G_loss:2.658586, acc:0.122500
epoch:[10], D_loss:0.688017, G_loss:2.503990, acc:0.186000
epoch:[11], D_loss:0.711284, G_loss:2.500637, acc:0.195000
epoch:[12], D_loss:0.818079, G_loss:2.389879, acc:0.216500
epoch:[13], D_loss:0.825529, G_loss:2.439187, acc:0.236500
epoch:[14], D_loss:0.865326, G_loss:2.225194, acc:0.244000
epoch:[15], D_loss:0.953219, G_loss:2.076803, acc:0.283000
epoch:[16], D_loss:0.987765, G_loss:2.011696, acc:0.2735

KeyboardInterrupt: ignored

**GAN div train**

In [10]:
net = nn.ModuleDict()
net["extractor"] = Extractor()          
net["classifier"] = Classifier()
net["generator"] = Generator()
net["discriminator"] = Discriminator()
net = net.to(device)
frozen_net(net, ["extractor", "classifier", "generator", "discriminator"], True)

D_optimizer = optim.Adam(get_params(net, ["discriminator"]), lr=2e-4, betas=(0.5, 0.999))
G_optimizer = optim.Adam(get_params(net, ["generator"]), lr=4e-4, betas=(0.5, 0.999))
BCE_criterion = nn.BCELoss().to(device)

E_checkpoint = torch.load("./checkpoint/client_extractor.pkl", map_location=torch.device('cpu'))
net["extractor"].load_state_dict(E_checkpoint)
C_checkpoint = torch.load("./checkpoint/client_classifier.pkl", map_location=torch.device('cpu'))
net["classifier"].load_state_dict(C_checkpoint)


def diversity_loss(G1, G2, z1, z2):
    lz = torch.mean(torch.abs(G2 - G1)) / torch.mean(torch.abs(z2 - z1))
    eps = 1 * 1e-5
    G_div = 1 / (lz + eps)
    return G_div


def discriminator_loss(E, G, y):
    ones = torch.ones((E.size(0), 1)).to(device)
    ED = net["discriminator"](E.detach(), y)
    ED_loss = BCE_criterion(ED, ones)

    zeros = torch.zeros((G.size(0), 1)).to(device)
    GD = net["discriminator"](G.detach(), y)
    GD_loss = BCE_criterion(GD, zeros)
    return ED_loss + GD_loss


def generator_loss(G, y):
    ones = torch.ones((G.size(0), 1)).to(device)
    GD = net["discriminator"](G, y)
    G_loss = BCE_criterion(GD, ones)
    return G_loss


for epoch in range(200):
    # train
    frozen_net(net, ["generator", "discriminator"], False)

    D_losses, G_losses, G_divs = [], [], []
    for batch, (x, y) in enumerate(client_trainloader):
        x = x.to(device)
        y = y.to(device)

        with torch.no_grad():
            E = net["extractor"](x)

        # update D
        D_optimizer.zero_grad()
        z = torch.randn(x.size(0), 100, 1, 1).to(device)
        G = net["generator"](z, y)
        D_loss = discriminator_loss(E, G, y)

        D_loss.backward()
        D_optimizer.step()
        D_losses.append(D_loss.item())

        # update G
        if (batch+1)%2 == 0:
            G_optimizer.zero_grad()
            ys = torch.cat([y,y], 0)
            zs = torch.randn(x.size(0)*2, 100, 1, 1).to(device)
            Gs = net["generator"](zs, ys)
            G_loss = generator_loss(Gs, ys)

            z1, z2 = torch.split(zs, x.size(0), 0)
            G1, G2 = torch.split(Gs, x.size(0), 0)
            G_div = diversity_loss(G1, G2, z1, z2)

            (G_loss + G_div).backward()
            G_optimizer.step()
            G_losses.append(G_loss.item())
            G_divs.append(G_div.item())
    
    frozen_net(net, ["generator", "discriminator"], True)

    # test
    correct, total = 0, 0
    with torch.no_grad():
        for x, y in client_testloader:
            y = y.to(device)
            z = torch.randn(x.size(0), 100, 1, 1).to(device)

            G = net["generator"](z, y)
            GC = net["classifier"](G)

            correct += torch.sum((torch.argmax(GC, dim=1) == y).float()).item()
            total += x.size(0)
    acc = correct / total

    print("epoch:[%2d], D_loss:%2.6f, G_loss:%2.6f, G_div:%2.6f, acc:%2.6f"
        %(epoch, np.mean(D_losses), np.mean(G_losses), np.mean(G_divs), acc))

    if (epoch+1) % 10 == 0:
        torch.save(net["generator"].state_dict(), "./checkpoint/client_generator_div.pkl")
        torch.save(net["discriminator"].state_dict(), "./checkpoint/client_discriminator_div.pkl")

epoch:[ 0], D_loss:1.056179, G_loss:1.330746, G_div:8.303213, acc:0.098000
epoch:[ 1], D_loss:0.643448, G_loss:2.833087, G_div:3.980607, acc:0.110000
epoch:[ 2], D_loss:0.570310, G_loss:2.610081, G_div:3.368141, acc:0.075500
epoch:[ 3], D_loss:0.375023, G_loss:2.506584, G_div:3.073811, acc:0.101500
epoch:[ 4], D_loss:0.268802, G_loss:2.715143, G_div:3.113455, acc:0.100500
epoch:[ 5], D_loss:0.478787, G_loss:3.205222, G_div:3.410010, acc:0.122000
epoch:[ 6], D_loss:0.685018, G_loss:3.270406, G_div:3.145173, acc:0.126000
epoch:[ 7], D_loss:0.698319, G_loss:2.811400, G_div:2.907139, acc:0.132000
epoch:[ 8], D_loss:0.672761, G_loss:2.805368, G_div:2.891236, acc:0.138000
epoch:[ 9], D_loss:0.725437, G_loss:2.482074, G_div:2.872471, acc:0.169500
epoch:[10], D_loss:0.781397, G_loss:2.498924, G_div:2.868101, acc:0.169500
epoch:[11], D_loss:0.752736, G_loss:2.460147, G_div:2.875949, acc:0.189500
epoch:[12], D_loss:0.797045, G_loss:2.300940, G_div:2.835150, acc:0.178500
epoch:[13], D_loss:0.8043

KeyboardInterrupt: ignored