In [1]:
cd "drive/MyDrive/model inversion(lenet)"

/content/drive/MyDrive/model inversion(lenet)


In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

class Extractor(nn.Module):

    def __init__(self):
        super(Extractor, self).__init__()
        self.extractor = nn.Sequential(
            nn.Conv2d(1, 6, 5),
            nn.AvgPool2d(2, 2),
            nn.Sigmoid(),
            nn.Conv2d(6, 16, 5),
            nn.AvgPool2d(2, 2),
            nn.Sigmoid(),
        )

    def forward(self, x):
        x = self.extractor(x)
        return x

    
class Classifier(nn.Module):

    def __init__(self, num_classes=10):
        super(Classifier, self).__init__()
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(16 * 5 * 5, 120),
            nn.Sigmoid(),
            nn.Linear(120, 84),
            nn.Sigmoid(),
            nn.Linear(84, num_classes),
        )

    def forward(self, x):
        x = self.classifier(x)
        return x


class Generator(nn.Module):

    def __init__(self):
        super(Generator, self).__init__()
        self.generator = nn.Sequential(
            nn.ConvTranspose2d(100 + 10, 512, 2, 1, 0),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(512, 256, 2, 1, 0),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(256, 128, 2, 1, 0),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(128, 16, 2, 1, 0),
            nn.Sigmoid(),
        )
        self.apply(weights_init)
        
    def forward(self, z, y):
        y = F.one_hot(y, 10)
        y = y.unsqueeze(-1).unsqueeze(-1)
        feat = torch.cat([z, y], 1)
        feat = self.generator(feat)
        return feat
    
    
class Discriminator(nn.Module):

    def __init__(self):
        super(Discriminator, self).__init__()
        self.discriminator = nn.Sequential(
            nn.Conv2d(16 + 10, 128, 2, 1, 0),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(inplace=True),
            nn.Conv2d(128, 256, 2, 1, 0),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(inplace=True),
            nn.Conv2d(256, 512, 2, 1, 0),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(inplace=True),
            nn.Conv2d(512, 1, 2, 1, 0),
            nn.Sigmoid(),
        )
        self.apply(weights_init)
        
    def forward(self, feat, y):
        y = F.one_hot(y, 10)
        y = y.unsqueeze(-1).unsqueeze(-1)
        y = y.expand(y.size(0), 10, 5, 5)
        feat = torch.cat([feat, y], 1)
        feat = self.discriminator(feat)
        feat = feat.squeeze(-1).squeeze(-1)
        return feat

In [3]:
from torch.utils.data import Dataset, DataLoader, Subset
from torchvision import transforms
import torchvision.datasets as datasets
import numpy as np

transform=transforms.Compose([
    transforms.Resize([32, 32]),
    transforms.ToTensor(),
    transforms.Normalize((0.5), (0.5)),
])

usps_trainset = datasets.USPS(root='./data', train=True, download=True, transform=transform)
usps_testset = datasets.USPS(root='./data', train=False, download=True, transform=transform)
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
mnist_testset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

size = len(usps_trainset)
index = np.arange(size)
client_trainset = Subset(usps_trainset, index[:2000])
server_iid_trainset = Subset(usps_trainset, index[2000:4000])
server_niid_trainset = Subset(mnist_trainset, index[:2000])

client_trainloader = DataLoader(client_trainset, batch_size=8, shuffle=True, num_workers=0, pin_memory=True)
server_iid_trainloader = DataLoader(server_iid_trainset, batch_size=8, shuffle=True, num_workers=0, pin_memory=True)
server_niid_trainloader = DataLoader(server_niid_trainset, batch_size=8, shuffle=True, num_workers=0, pin_memory=True)

client_testloader = DataLoader(usps_testset, batch_size=8, shuffle=False, num_workers=0, pin_memory=True)
server_iid_testloader = DataLoader(usps_testset, batch_size=8, shuffle=False, num_workers=0, pin_memory=True)
server_niid_testloader = DataLoader(mnist_testset, batch_size=8, shuffle=False, num_workers=0, pin_memory=True)

In [4]:
import torch.optim as optim

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def get_params(net, modules):
    params = []
    for module in modules:
        params.append({"params": net[module].parameters()})
    return params

def frozen_net(net, modules, frozen):
    for module in modules:
        for param in net[module].parameters():
            param.requires_grad = not frozen
        if frozen:
            net[module].eval()
        else:
            net[module].train()


**client train**

In [5]:
net = nn.ModuleDict()
net["extractor"] = Extractor()          
net["classifier"] = Classifier()    
net = net.to(device)
frozen_net(net, ["extractor", "classifier"], True)
EC_optimizer = optim.Adam(get_params(net, ["extractor", "classifier"]), lr=3e-4, weight_decay=1e-4)
CE_criterion = nn.CrossEntropyLoss().to(device)

for epoch in range(200):
    # train
    frozen_net(net, ["extractor", "classifier"], False)
    losses, batch = 0., 0
    for x, y in client_trainloader:
        x = x.to(device)
        y = y.to(device)
        
        EC_optimizer.zero_grad()
        E = net["extractor"](x)
        EC = net["classifier"](E)
        loss = CE_criterion(EC, y)
        loss.backward()
        EC_optimizer.step()

        losses += loss.item()
        batch += 1
    avg_loss = losses / batch
    frozen_net(net, ["extractor", "classifier"], True)
    

    # test
    correct, total = 0, 0
    with torch.no_grad():
        for x, y in client_testloader:
            x = x.to(device)
            y = y.to(device)

            E = net["extractor"](x)
            EC = net["classifier"](E)

            correct += torch.sum((torch.argmax(EC, dim=1) == y).float()).item()
            total += x.size(0)
    acc = correct / total

    print("epoch:[%2d], loss:%2.6f, acc:%2.6f"%(epoch, avg_loss, acc))

torch.save(net["extractor"].state_dict(), "./checkpoint/client_extractor.pkl")
torch.save(net["classifier"].state_dict(), "./checkpoint/client_classifier.pkl")

epoch:[ 0], loss:2.242679, acc:0.178874
epoch:[ 1], loss:2.227359, acc:0.178874
epoch:[ 2], loss:2.121932, acc:0.304933
epoch:[ 3], loss:1.721823, acc:0.410563
epoch:[ 4], loss:1.392067, acc:0.524165
epoch:[ 5], loss:1.150464, acc:0.617339
epoch:[ 6], loss:0.964922, acc:0.695566
epoch:[ 7], loss:0.811933, acc:0.723966
epoch:[ 8], loss:0.687541, acc:0.763328
epoch:[ 9], loss:0.587486, acc:0.781266
epoch:[10], loss:0.512632, acc:0.813154
epoch:[11], loss:0.455462, acc:0.821624
epoch:[12], loss:0.409748, acc:0.830095
epoch:[13], loss:0.368186, acc:0.831589
epoch:[14], loss:0.338250, acc:0.833084
epoch:[15], loss:0.311000, acc:0.840060
epoch:[16], loss:0.287012, acc:0.848032
epoch:[17], loss:0.266204, acc:0.841555
epoch:[18], loss:0.248980, acc:0.857000
epoch:[19], loss:0.233034, acc:0.859492
epoch:[20], loss:0.220193, acc:0.866966
epoch:[21], loss:0.207648, acc:0.865471
epoch:[22], loss:0.194856, acc:0.875436
epoch:[23], loss:0.186172, acc:0.871948
epoch:[24], loss:0.173073, acc:0.870952


**server same train**

In [6]:
net = nn.ModuleDict()
net["extractor"] = Extractor()          
net["classifier"] = Classifier()    
net = net.to(device)
frozen_net(net, ["extractor", "classifier"], True)
E_optimizer = optim.Adam(get_params(net, ["extractor"]), lr=3e-4, weight_decay=1e-4)
CE_criterion = nn.CrossEntropyLoss().to(device)

C_checkpoint = torch.load("./checkpoint/client_classifier.pkl", map_location=torch.device('cpu'))
net["classifier"].load_state_dict(C_checkpoint)

for epoch in range(200):
    # train
    frozen_net(net, ["extractor"], False)
    losses, batch = 0., 0
    for x, y in client_trainloader:
        x = x.to(device)
        y = y.to(device)
        
        E_optimizer.zero_grad()
        E = net["extractor"](x)
        EC = net["classifier"](E)
        loss = CE_criterion(EC, y)
        loss.backward()
        E_optimizer.step()

        losses += loss.item()
        batch += 1
    avg_loss = losses / batch
    frozen_net(net, ["extractor"], True)
    

    # test
    correct, total = 0, 0
    with torch.no_grad():
        for x, y in client_trainloader:
            x = x.to(device)
            y = y.to(device)

            E = net["extractor"](x)
            EC = net["classifier"](E)

            correct += torch.sum((torch.argmax(EC, dim=1) == y).float()).item()
            total += x.size(0)
    acc = correct / total

    print("epoch:[%2d], loss:%2.6f, acc:%2.6f"%(epoch, avg_loss, acc))

torch.save(net["extractor"].state_dict(), "./checkpoint/server_same_extractor.pkl")

epoch:[ 0], loss:2.085667, acc:0.788000
epoch:[ 1], loss:0.710419, acc:0.879000
epoch:[ 2], loss:0.436361, acc:0.910500
epoch:[ 3], loss:0.333468, acc:0.915500
epoch:[ 4], loss:0.280116, acc:0.930500
epoch:[ 5], loss:0.241661, acc:0.937000
epoch:[ 6], loss:0.215050, acc:0.942000
epoch:[ 7], loss:0.189953, acc:0.947000
epoch:[ 8], loss:0.171485, acc:0.952000
epoch:[ 9], loss:0.155122, acc:0.953500
epoch:[10], loss:0.141446, acc:0.959000
epoch:[11], loss:0.128736, acc:0.965000
epoch:[12], loss:0.116980, acc:0.966500
epoch:[13], loss:0.107188, acc:0.971000
epoch:[14], loss:0.097588, acc:0.973000
epoch:[15], loss:0.089349, acc:0.976500
epoch:[16], loss:0.082594, acc:0.977000
epoch:[17], loss:0.075743, acc:0.980000
epoch:[18], loss:0.070746, acc:0.983500
epoch:[19], loss:0.065320, acc:0.985000
epoch:[20], loss:0.060564, acc:0.985500
epoch:[21], loss:0.055945, acc:0.989500
epoch:[22], loss:0.051590, acc:0.991500
epoch:[23], loss:0.048045, acc:0.991500
epoch:[24], loss:0.045152, acc:0.994000


**server iid train**

In [7]:
net = nn.ModuleDict()
net["extractor"] = Extractor()          
net["classifier"] = Classifier()    
net = net.to(device)
frozen_net(net, ["extractor", "classifier"], True)
E_optimizer = optim.Adam(get_params(net, ["extractor"]), lr=3e-4, weight_decay=1e-4)
CE_criterion = nn.CrossEntropyLoss().to(device)

C_checkpoint = torch.load("./checkpoint/client_classifier.pkl", map_location=torch.device('cpu'))
net["classifier"].load_state_dict(C_checkpoint)

for epoch in range(200):
    # train
    frozen_net(net, ["extractor"], False)
    losses, batch = 0., 0
    for x, y in server_iid_trainloader:
        x = x.to(device)
        y = y.to(device)
        
        E_optimizer.zero_grad()
        E = net["extractor"](x)
        EC = net["classifier"](E)
        loss = CE_criterion(EC, y)
        loss.backward()
        E_optimizer.step()

        losses += loss.item()
        batch += 1
    avg_loss = losses / batch
    frozen_net(net, ["extractor"], True)
    

    # test
    correct, total = 0, 0
    with torch.no_grad():
        for x, y in server_iid_testloader:
            x = x.to(device)
            y = y.to(device)

            E = net["extractor"](x)
            EC = net["classifier"](E)

            correct += torch.sum((torch.argmax(EC, dim=1) == y).float()).item()
            total += x.size(0)
    acc = correct / total

    print("epoch:[%2d], loss:%2.6f, acc:%2.6f"%(epoch, avg_loss, acc))

torch.save(net["extractor"].state_dict(), "./checkpoint/server_iid_extractor.pkl")

epoch:[ 0], loss:2.560075, acc:0.595914
epoch:[ 1], loss:0.899974, acc:0.798705
epoch:[ 2], loss:0.568349, acc:0.836074
epoch:[ 3], loss:0.466450, acc:0.848032
epoch:[ 4], loss:0.408821, acc:0.862481
epoch:[ 5], loss:0.375343, acc:0.866467
epoch:[ 6], loss:0.352817, acc:0.869955
epoch:[ 7], loss:0.334721, acc:0.880917
epoch:[ 8], loss:0.316965, acc:0.876931
epoch:[ 9], loss:0.305257, acc:0.878924
epoch:[10], loss:0.291816, acc:0.883408
epoch:[11], loss:0.283862, acc:0.882910
epoch:[12], loss:0.273320, acc:0.886896
epoch:[13], loss:0.266966, acc:0.891380
epoch:[14], loss:0.257533, acc:0.895864
epoch:[15], loss:0.252224, acc:0.896363
epoch:[16], loss:0.246025, acc:0.896861
epoch:[17], loss:0.242286, acc:0.898356
epoch:[18], loss:0.236872, acc:0.900349
epoch:[19], loss:0.231795, acc:0.900847
epoch:[20], loss:0.228103, acc:0.902840
epoch:[21], loss:0.227216, acc:0.902342
epoch:[22], loss:0.222275, acc:0.904335
epoch:[23], loss:0.218293, acc:0.901345
epoch:[24], loss:0.217492, acc:0.900847


**server niid train**

In [8]:
net = nn.ModuleDict()
net["extractor"] = Extractor()          
net["classifier"] = Classifier()    
net = net.to(device)
frozen_net(net, ["extractor", "classifier"], True)
E_optimizer = optim.Adam(get_params(net, ["extractor"]), lr=3e-4, weight_decay=1e-4)
CE_criterion = nn.CrossEntropyLoss().to(device)

C_checkpoint = torch.load("./checkpoint/client_classifier.pkl", map_location=torch.device('cpu'))
net["classifier"].load_state_dict(C_checkpoint)

for epoch in range(200):
    # train
    frozen_net(net, ["extractor"], False)
    losses, batch = 0., 0
    for x, y in server_niid_trainloader:
        x = x.to(device)
        y = y.to(device)
        
        E_optimizer.zero_grad()
        E = net["extractor"](x)
        EC = net["classifier"](E)
        loss = CE_criterion(EC, y)
        loss.backward()
        E_optimizer.step()

        losses += loss.item()
        batch += 1
    avg_loss = losses / batch
    frozen_net(net, ["extractor"], True)
    

    # test
    correct, total = 0, 0
    with torch.no_grad():
        for x, y in server_niid_testloader:
            x = x.to(device)
            y = y.to(device)

            E = net["extractor"](x)
            EC = net["classifier"](E)

            correct += torch.sum((torch.argmax(EC, dim=1) == y).float()).item()
            total += x.size(0)
    acc = correct / total

    print("epoch:[%2d], loss:%2.6f, acc:%2.6f"%(epoch, avg_loss, acc))

torch.save(net["extractor"].state_dict(), "./checkpoint/server_niid_extractor.pkl")

epoch:[ 0], loss:2.923146, acc:0.102800
epoch:[ 1], loss:2.366540, acc:0.269000
epoch:[ 2], loss:2.021826, acc:0.415200
epoch:[ 3], loss:1.766971, acc:0.456600
epoch:[ 4], loss:1.615860, acc:0.499300
epoch:[ 5], loss:1.512170, acc:0.530300
epoch:[ 6], loss:1.432566, acc:0.553800
epoch:[ 7], loss:1.365628, acc:0.580100
epoch:[ 8], loss:1.307032, acc:0.608700
epoch:[ 9], loss:1.252640, acc:0.629900
epoch:[10], loss:1.217624, acc:0.635600
epoch:[11], loss:1.166871, acc:0.656000
epoch:[12], loss:1.129023, acc:0.657900
epoch:[13], loss:1.093955, acc:0.684400
epoch:[14], loss:1.062051, acc:0.694800
epoch:[15], loss:1.026839, acc:0.699000
epoch:[16], loss:1.000147, acc:0.707600
epoch:[17], loss:0.969532, acc:0.714100
epoch:[18], loss:0.942313, acc:0.725000
epoch:[19], loss:0.914461, acc:0.730300
epoch:[20], loss:0.891187, acc:0.749800
epoch:[21], loss:0.866532, acc:0.758900
epoch:[22], loss:0.839773, acc:0.763300
epoch:[23], loss:0.815734, acc:0.770000
epoch:[24], loss:0.798444, acc:0.778700


**GAN train**

In [9]:
net = nn.ModuleDict()
net["extractor"] = Extractor()          
net["classifier"] = Classifier()
net["generator"] = Generator()
net["discriminator"] = Discriminator()
net = net.to(device)
frozen_net(net, ["extractor", "classifier", "generator", "discriminator"], True)

D_optimizer = optim.Adam(get_params(net, ["discriminator"]), lr=2e-4, betas=(0.5, 0.999))
G_optimizer = optim.Adam(get_params(net, ["generator"]), lr=2e-4, betas=(0.5, 0.999))
BCE_criterion = nn.BCELoss().to(device)

E_checkpoint = torch.load("./checkpoint/client_extractor.pkl", map_location=torch.device('cpu'))
net["extractor"].load_state_dict(E_checkpoint)
C_checkpoint = torch.load("./checkpoint/client_classifier.pkl", map_location=torch.device('cpu'))
net["classifier"].load_state_dict(C_checkpoint)


def discriminator_loss(E, G, y):
    ones = torch.ones((E.size(0), 1)).to(device)
    ED = net["discriminator"](E.detach(), y)
    ED_loss = BCE_criterion(ED, ones)

    zeros = torch.zeros((G.size(0), 1)).to(device)
    GD = net["discriminator"](G.detach(), y)
    GD_loss = BCE_criterion(GD, zeros)
    return ED_loss + GD_loss


def generator_loss(G, y):
    ones = torch.ones((G.size(0), 1)).to(device)
    GD = net["discriminator"](G, y)
    G_loss = BCE_criterion(GD, ones)
    return G_loss


for epoch in range(200):
    # train
    frozen_net(net, ["generator", "discriminator"], False)

    D_losses, G_losses = [], []
    for batch, (x, y) in enumerate(client_trainloader):
        x = x.to(device)
        y = y.to(device)
        z = torch.randn(x.size(0), 100, 1, 1).to(device)

        with torch.no_grad():
            E = net["extractor"](x)
        G = net["generator"](z, y)

        # update D
        D_optimizer.zero_grad()
        D_loss = discriminator_loss(E, G, y)
        D_loss.backward()
        D_optimizer.step()
        D_losses.append(D_loss.item())

        # update G
        G_optimizer.zero_grad()
        G_loss = generator_loss(G, y)
        G_loss.backward()
        G_optimizer.step()
        G_losses.append(G_loss.item())
    
    frozen_net(net, ["generator", "discriminator"], True)

    # test
    correct, total = 0, 0
    with torch.no_grad():
        for x, y in client_testloader:
            y = y.to(device)
            z = torch.randn(x.size(0), 100, 1, 1).to(device)

            G = net["generator"](z, y)
            GC = net["classifier"](G)

            correct += torch.sum((torch.argmax(GC, dim=1) == y).float()).item()
            total += x.size(0)
    acc = correct / total

    print("epoch:[%2d], D_loss:%2.6f, G_loss:%2.6f, acc:%2.6f"
        %(epoch, np.mean(D_losses), np.mean(G_losses), acc))

    if (epoch+1) % 10 == 0:
        torch.save(net["generator"].state_dict(), "./checkpoint/client_generator.pkl")
        torch.save(net["discriminator"].state_dict(), "./checkpoint/client_discriminator.pkl")

epoch:[ 0], D_loss:0.344608, G_loss:3.037637, acc:0.098655
epoch:[ 1], D_loss:0.118900, G_loss:4.418280, acc:0.369208
epoch:[ 2], D_loss:0.299082, G_loss:4.339326, acc:0.712008
epoch:[ 3], D_loss:0.533858, G_loss:3.677392, acc:0.862481
epoch:[ 4], D_loss:0.655734, G_loss:3.264099, acc:0.949178
epoch:[ 5], D_loss:0.775410, G_loss:2.701448, acc:0.979571
epoch:[ 6], D_loss:0.815621, G_loss:2.569082, acc:0.974091
epoch:[ 7], D_loss:0.853543, G_loss:2.414339, acc:0.975087
epoch:[ 8], D_loss:0.767617, G_loss:2.452268, acc:0.967613
epoch:[ 9], D_loss:0.799125, G_loss:2.517888, acc:0.966119
epoch:[10], D_loss:0.790975, G_loss:2.487391, acc:0.972098
epoch:[11], D_loss:0.793575, G_loss:2.582605, acc:0.973592
epoch:[12], D_loss:0.714652, G_loss:2.595330, acc:0.971101
epoch:[13], D_loss:0.732304, G_loss:2.638033, acc:0.964126
epoch:[14], D_loss:0.680403, G_loss:2.725231, acc:0.965122
epoch:[15], D_loss:0.695786, G_loss:2.843229, acc:0.972596
epoch:[16], D_loss:0.722535, G_loss:2.779969, acc:0.9706

**GAN div train**

In [10]:
net = nn.ModuleDict()
net["extractor"] = Extractor()          
net["classifier"] = Classifier()
net["generator"] = Generator()
net["discriminator"] = Discriminator()
net = net.to(device)
frozen_net(net, ["extractor", "classifier", "generator", "discriminator"], True)

D_optimizer = optim.Adam(get_params(net, ["discriminator"]), lr=2e-4, betas=(0.5, 0.999))
G_optimizer = optim.Adam(get_params(net, ["generator"]), lr=4e-4, betas=(0.5, 0.999))
BCE_criterion = nn.BCELoss().to(device)

E_checkpoint = torch.load("./checkpoint/client_extractor.pkl", map_location=torch.device('cpu'))
net["extractor"].load_state_dict(E_checkpoint)
C_checkpoint = torch.load("./checkpoint/client_classifier.pkl", map_location=torch.device('cpu'))
net["classifier"].load_state_dict(C_checkpoint)


def diversity_loss(G1, G2, z1, z2):
    lz = torch.mean(torch.abs(G2 - G1)) / torch.mean(torch.abs(z2 - z1))
    eps = 1 * 1e-5
    G_div = 1 / (lz + eps)
    return G_div


def discriminator_loss(E, G, y):
    ones = torch.ones((E.size(0), 1)).to(device)
    ED = net["discriminator"](E.detach(), y)
    ED_loss = BCE_criterion(ED, ones)

    zeros = torch.zeros((G.size(0), 1)).to(device)
    GD = net["discriminator"](G.detach(), y)
    GD_loss = BCE_criterion(GD, zeros)
    return ED_loss + GD_loss


def generator_loss(G, y):
    ones = torch.ones((G.size(0), 1)).to(device)
    GD = net["discriminator"](G, y)
    G_loss = BCE_criterion(GD, ones)
    return G_loss


for epoch in range(200):
    # train
    frozen_net(net, ["generator", "discriminator"], False)

    D_losses, G_losses, G_divs = [], [], []
    for batch, (x, y) in enumerate(client_trainloader):
        x = x.to(device)
        y = y.to(device)

        with torch.no_grad():
            E = net["extractor"](x)

        # update D
        D_optimizer.zero_grad()
        z = torch.randn(x.size(0), 100, 1, 1).to(device)
        G = net["generator"](z, y)
        D_loss = discriminator_loss(E, G, y)

        D_loss.backward()
        D_optimizer.step()
        D_losses.append(D_loss.item())

        # update G
        if (batch+1)%2 == 0:
            G_optimizer.zero_grad()
            ys = torch.cat([y,y], 0)
            zs = torch.randn(x.size(0)*2, 100, 1, 1).to(device)
            Gs = net["generator"](zs, ys)
            G_loss = generator_loss(Gs, ys)

            z1, z2 = torch.split(zs, x.size(0), 0)
            G1, G2 = torch.split(Gs, x.size(0), 0)
            G_div = diversity_loss(G1, G2, z1, z2)

            (G_loss + G_div).backward()
            G_optimizer.step()
            G_losses.append(G_loss.item())
            G_divs.append(G_div.item())
    
    frozen_net(net, ["generator", "discriminator"], True)

    # test
    correct, total = 0, 0
    with torch.no_grad():
        for x, y in client_testloader:
            y = y.to(device)
            z = torch.randn(x.size(0), 100, 1, 1).to(device)

            G = net["generator"](z, y)
            GC = net["classifier"](G)

            correct += torch.sum((torch.argmax(GC, dim=1) == y).float()).item()
            total += x.size(0)
    acc = correct / total

    print("epoch:[%2d], D_loss:%2.6f, G_loss:%2.6f, G_div:%2.6f, acc:%2.6f"
        %(epoch, np.mean(D_losses), np.mean(G_losses), np.mean(G_divs), acc))

    if (epoch+1) % 10 == 0:
        torch.save(net["generator"].state_dict(), "./checkpoint/client_generator_div.pkl")
        torch.save(net["discriminator"].state_dict(), "./checkpoint/client_discriminator_div.pkl")

epoch:[ 0], D_loss:0.259275, G_loss:2.965623, G_div:10.834201, acc:0.118585
epoch:[ 1], D_loss:0.050760, G_loss:4.501555, G_div:5.210158, acc:0.199801
epoch:[ 2], D_loss:0.119526, G_loss:5.419510, G_div:4.619545, acc:0.282013
epoch:[ 3], D_loss:0.197628, G_loss:4.993534, G_div:4.403781, acc:0.456403
epoch:[ 4], D_loss:0.398614, G_loss:4.170278, G_div:4.462311, acc:0.651719
epoch:[ 5], D_loss:0.514191, G_loss:3.637313, G_div:4.693073, acc:0.753861
epoch:[ 6], D_loss:0.613035, G_loss:3.135381, G_div:4.520527, acc:0.839063
epoch:[ 7], D_loss:0.594661, G_loss:2.774623, G_div:4.648679, acc:0.894868
epoch:[ 8], D_loss:0.656109, G_loss:2.770310, G_div:4.774084, acc:0.918286
epoch:[ 9], D_loss:0.617779, G_loss:2.594102, G_div:4.673817, acc:0.920777
epoch:[10], D_loss:0.680715, G_loss:2.696926, G_div:4.704286, acc:0.918784
epoch:[11], D_loss:0.496789, G_loss:2.788996, G_div:4.745675, acc:0.943199
epoch:[12], D_loss:0.592919, G_loss:2.911039, G_div:4.835534, acc:0.920777
epoch:[13], D_loss:0.609