In [1]:
cd "drive/MyDrive/model inversion(lenet)"

/content/drive/MyDrive/model inversion(lenet)


In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

class Extractor(nn.Module):

    def __init__(self):
        super(Extractor, self).__init__()
        self.extractor = nn.Sequential(
            nn.Conv2d(1, 6, 5),
            nn.AvgPool2d(2, 2),
            nn.Sigmoid(),
            nn.Conv2d(6, 16, 5),
            nn.AvgPool2d(2, 2),
            nn.Sigmoid(),
        )

    def forward(self, x):
        x = self.extractor(x)
        return x

    
class Classifier(nn.Module):

    def __init__(self, num_classes=10):
        super(Classifier, self).__init__()
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(16 * 5 * 5, 120),
            nn.Sigmoid(),
            nn.Linear(120, 84),
            nn.Sigmoid(),
            nn.Linear(84, num_classes),
        )

    def forward(self, x):
        x = self.classifier(x)
        return x


class Generator(nn.Module):

    def __init__(self):
        super(Generator, self).__init__()
        self.generator = nn.Sequential(
            nn.ConvTranspose2d(100 + 10, 512, 2, 1, 0),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(512, 256, 2, 1, 0),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(256, 128, 2, 1, 0),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(128, 16, 2, 1, 0),
            nn.Sigmoid(),
        )
        self.apply(weights_init)
        
    def forward(self, z, y):
        y = F.one_hot(y, 10)
        y = y.unsqueeze(-1).unsqueeze(-1)
        feat = torch.cat([z, y], 1)
        feat = self.generator(feat)
        return feat
    
    
class Discriminator(nn.Module):

    def __init__(self):
        super(Discriminator, self).__init__()
        self.discriminator = nn.Sequential(
            nn.Conv2d(16 + 10, 128, 2, 1, 0),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(inplace=True),
            nn.Conv2d(128, 256, 2, 1, 0),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(inplace=True),
            nn.Conv2d(256, 512, 2, 1, 0),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(inplace=True),
            nn.Conv2d(512, 1, 2, 1, 0),
            nn.Sigmoid(),
        )
        self.apply(weights_init)
        
    def forward(self, feat, y):
        y = F.one_hot(y, 10)
        y = y.unsqueeze(-1).unsqueeze(-1)
        y = y.expand(y.size(0), 10, 5, 5)
        feat = torch.cat([feat, y], 1)
        feat = self.discriminator(feat)
        feat = feat.squeeze(-1).squeeze(-1)
        return feat

In [3]:
from torch.utils.data import Dataset, DataLoader, Subset
from torchvision import transforms
import torchvision.datasets as datasets
import numpy as np

transform=transforms.Compose([
    transforms.Resize([32, 32]),
    transforms.ToTensor(),
    transforms.Normalize((0.5), (0.5)),
])

usps_trainset = datasets.USPS(root='./data', train=True, download=True, transform=transform)
usps_testset = datasets.USPS(root='./data', train=False, download=True, transform=transform)
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
mnist_testset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

size = len(usps_trainset)
index = np.arange(size)
client_trainset = Subset(usps_trainset, index[:2000])
server_iid_trainset = Subset(usps_trainset, index[2000:4000])
server_niid_trainset = Subset(mnist_trainset, index[:2000])

client_trainloader = DataLoader(client_trainset, batch_size=8, shuffle=True, num_workers=0, pin_memory=True)
server_iid_trainloader = DataLoader(server_iid_trainset, batch_size=8, shuffle=True, num_workers=0, pin_memory=True)
server_niid_trainloader = DataLoader(server_niid_trainset, batch_size=8, shuffle=True, num_workers=0, pin_memory=True)

client_testloader = DataLoader(usps_testset, batch_size=8, shuffle=False, num_workers=0, pin_memory=True)
server_iid_testloader = DataLoader(usps_testset, batch_size=8, shuffle=False, num_workers=0, pin_memory=True)
server_niid_testloader = DataLoader(mnist_testset, batch_size=8, shuffle=False, num_workers=0, pin_memory=True)

In [4]:
import torch.optim as optim

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def get_params(net, modules):
    params = []
    for module in modules:
        params.append({"params": net[module].parameters()})
    return params

def frozen_net(net, modules, frozen):
    for module in modules:
        for param in net[module].parameters():
            param.requires_grad = not frozen
        if frozen:
            net[module].eval()
        else:
            net[module].train()


**client train**

In [None]:
net = nn.ModuleDict()
net["extractor"] = Extractor()          
net["classifier"] = Classifier()    
net = net.to(device)
frozen_net(net, ["extractor", "classifier"], True)
EC_optimizer = optim.Adam(get_params(net, ["extractor", "classifier"]), lr=3e-4, weight_decay=1e-4)
CE_criterion = nn.CrossEntropyLoss().to(device)

for epoch in range(500):
    # train
    frozen_net(net, ["extractor", "classifier"], False)
    losses, batch = 0., 0
    for x, y in client_trainloader:
        x = x.to(device)
        y = y.to(device)
        
        EC_optimizer.zero_grad()
        E = net["extractor"](x)
        EC = net["classifier"](E)
        loss = CE_criterion(EC, y)
        loss.backward()
        EC_optimizer.step()

        losses += loss.item()
        batch += 1
    avg_loss = losses / batch
    frozen_net(net, ["extractor", "classifier"], True)
    

    # test
    correct, total = 0, 0
    with torch.no_grad():
        for x, y in client_testloader:
            x = x.to(device)
            y = y.to(device)

            E = net["extractor"](x)
            EC = net["classifier"](E)

            correct += torch.sum((torch.argmax(EC, dim=1) == y).float()).item()
            total += x.size(0)
    acc = correct / total

    print("epoch:[%2d], loss:%2.6f, acc:%2.6f"%(epoch, avg_loss, acc))

torch.save(net["extractor"].state_dict(), "./checkpoint/client_extractor.pkl")
torch.save(net["classifier"].state_dict(), "./checkpoint/client_classifier.pkl")

epoch:[ 0], loss:2.244613, acc:0.178874
epoch:[ 1], loss:2.226910, acc:0.178874
epoch:[ 2], loss:2.037559, acc:0.304933
epoch:[ 3], loss:1.624543, acc:0.520678
epoch:[ 4], loss:1.314992, acc:0.581963
epoch:[ 5], loss:1.090726, acc:0.645740
epoch:[ 6], loss:0.919579, acc:0.705531
epoch:[ 7], loss:0.782288, acc:0.728949
epoch:[ 8], loss:0.673135, acc:0.749377
epoch:[ 9], loss:0.581481, acc:0.779771
epoch:[10], loss:0.508161, acc:0.803687
epoch:[11], loss:0.450654, acc:0.804684
epoch:[12], loss:0.406532, acc:0.816143
epoch:[13], loss:0.372031, acc:0.827105
epoch:[14], loss:0.340752, acc:0.835077
epoch:[15], loss:0.316831, acc:0.845042
epoch:[16], loss:0.293685, acc:0.844544
epoch:[17], loss:0.275950, acc:0.851520
epoch:[18], loss:0.257414, acc:0.856502
epoch:[19], loss:0.240947, acc:0.862980
epoch:[20], loss:0.225517, acc:0.864973
epoch:[21], loss:0.213736, acc:0.869457
epoch:[22], loss:0.198991, acc:0.881415
epoch:[23], loss:0.184711, acc:0.877429
epoch:[24], loss:0.176864, acc:0.886398


**server same train**

In [6]:
net = nn.ModuleDict()
net["extractor"] = Extractor()          
net["classifier"] = Classifier()    
net = net.to(device)
frozen_net(net, ["extractor", "classifier"], True)
E_optimizer = optim.Adam(get_params(net, ["extractor"]), lr=3e-4, weight_decay=1e-4)
CE_criterion = nn.CrossEntropyLoss().to(device)

C_checkpoint = torch.load("./checkpoint/client_classifier.pkl", map_location=torch.device('cpu'))
net["classifier"].load_state_dict(C_checkpoint)

for epoch in range(500):
    # train
    frozen_net(net, ["extractor"], False)
    losses, batch = 0., 0
    for x, y in client_trainloader:
        x = x.to(device)
        y = y.to(device)
        
        E_optimizer.zero_grad()
        E = net["extractor"](x)
        EC = net["classifier"](E)
        loss = CE_criterion(EC, y)
        loss.backward()
        E_optimizer.step()

        losses += loss.item()
        batch += 1
    avg_loss = losses / batch
    frozen_net(net, ["extractor"], True)
    

    # test
    correct, total = 0, 0
    with torch.no_grad():
        for x, y in client_trainloader:
            x = x.to(device)
            y = y.to(device)

            E = net["extractor"](x)
            EC = net["classifier"](E)

            correct += torch.sum((torch.argmax(EC, dim=1) == y).float()).item()
            total += x.size(0)
    acc = correct / total

    print("epoch:[%2d], loss:%2.6f, acc:%2.6f"%(epoch, avg_loss, acc))

torch.save(net["extractor"].state_dict(), "./checkpoint/server_same_extractor.pkl")

epoch:[ 0], loss:2.823008, acc:0.228000
epoch:[ 1], loss:1.291809, acc:0.771500
epoch:[ 2], loss:0.637843, acc:0.861500
epoch:[ 3], loss:0.463013, acc:0.898500
epoch:[ 4], loss:0.380704, acc:0.912500
epoch:[ 5], loss:0.327878, acc:0.918500
epoch:[ 6], loss:0.286583, acc:0.933000
epoch:[ 7], loss:0.255209, acc:0.937500
epoch:[ 8], loss:0.225938, acc:0.934500
epoch:[ 9], loss:0.203992, acc:0.951500
epoch:[10], loss:0.184406, acc:0.956000
epoch:[11], loss:0.166611, acc:0.956000
epoch:[12], loss:0.152320, acc:0.964000
epoch:[13], loss:0.136847, acc:0.967500
epoch:[14], loss:0.125411, acc:0.968000
epoch:[15], loss:0.114898, acc:0.974000
epoch:[16], loss:0.105867, acc:0.974000
epoch:[17], loss:0.097604, acc:0.975500
epoch:[18], loss:0.089657, acc:0.978000
epoch:[19], loss:0.082775, acc:0.982000
epoch:[20], loss:0.076285, acc:0.983500
epoch:[21], loss:0.069355, acc:0.985000
epoch:[22], loss:0.064773, acc:0.987000
epoch:[23], loss:0.062098, acc:0.991500
epoch:[24], loss:0.056152, acc:0.993000


**server iid train**

In [None]:
net = nn.ModuleDict()
net["extractor"] = Extractor()          
net["classifier"] = Classifier()    
net = net.to(device)
frozen_net(net, ["extractor", "classifier"], True)
E_optimizer = optim.Adam(get_params(net, ["extractor"]), lr=3e-4, weight_decay=1e-4)
CE_criterion = nn.CrossEntropyLoss().to(device)

C_checkpoint = torch.load("./checkpoint/client_classifier.pkl", map_location=torch.device('cpu'))
net["classifier"].load_state_dict(C_checkpoint)

for epoch in range(500):
    # train
    frozen_net(net, ["extractor"], False)
    losses, batch = 0., 0
    for x, y in server_iid_trainloader:
        x = x.to(device)
        y = y.to(device)
        
        E_optimizer.zero_grad()
        E = net["extractor"](x)
        EC = net["classifier"](E)
        loss = CE_criterion(EC, y)
        loss.backward()
        E_optimizer.step()

        losses += loss.item()
        batch += 1
    avg_loss = losses / batch
    frozen_net(net, ["extractor"], True)
    

    # test
    correct, total = 0, 0
    with torch.no_grad():
        for x, y in server_iid_testloader:
            x = x.to(device)
            y = y.to(device)

            E = net["extractor"](x)
            EC = net["classifier"](E)

            correct += torch.sum((torch.argmax(EC, dim=1) == y).float()).item()
            total += x.size(0)
    acc = correct / total

    print("epoch:[%2d], loss:%2.6f, acc:%2.6f"%(epoch, avg_loss, acc))

torch.save(net["extractor"].state_dict(), "./checkpoint/server_iid_extractor.pkl")

epoch:[ 0], loss:3.164569, acc:0.310414
epoch:[ 1], loss:1.286478, acc:0.720977
epoch:[ 2], loss:0.776650, acc:0.782262
epoch:[ 3], loss:0.629624, acc:0.809168
epoch:[ 4], loss:0.551195, acc:0.827105
epoch:[ 5], loss:0.500011, acc:0.834579
epoch:[ 6], loss:0.456238, acc:0.840060
epoch:[ 7], loss:0.423523, acc:0.850523
epoch:[ 8], loss:0.395591, acc:0.860987
epoch:[ 9], loss:0.372794, acc:0.865969
epoch:[10], loss:0.352234, acc:0.873443
epoch:[11], loss:0.334548, acc:0.878924
epoch:[12], loss:0.318099, acc:0.881913
epoch:[13], loss:0.303921, acc:0.887394
epoch:[14], loss:0.289544, acc:0.889387
epoch:[15], loss:0.280282, acc:0.895366
epoch:[16], loss:0.271393, acc:0.894370
epoch:[17], loss:0.260941, acc:0.899352
epoch:[18], loss:0.253638, acc:0.901345
epoch:[19], loss:0.243652, acc:0.902840
epoch:[20], loss:0.238138, acc:0.900349
epoch:[21], loss:0.230818, acc:0.897857
epoch:[22], loss:0.227455, acc:0.907324
epoch:[23], loss:0.220571, acc:0.906328
epoch:[24], loss:0.216667, acc:0.904335


**server niid train**

In [None]:
net = nn.ModuleDict()
net["extractor"] = Extractor()          
net["classifier"] = Classifier()    
net = net.to(device)
frozen_net(net, ["extractor", "classifier"], True)
E_optimizer = optim.Adam(get_params(net, ["extractor"]), lr=3e-4, weight_decay=1e-4)
CE_criterion = nn.CrossEntropyLoss().to(device)

C_checkpoint = torch.load("./checkpoint/client_classifier.pkl", map_location=torch.device('cpu'))
net["classifier"].load_state_dict(C_checkpoint)

for epoch in range(500):
    # train
    frozen_net(net, ["extractor"], False)
    losses, batch = 0., 0
    for x, y in server_niid_trainloader:
        x = x.to(device)
        y = y.to(device)
        
        E_optimizer.zero_grad()
        E = net["extractor"](x)
        EC = net["classifier"](E)
        loss = CE_criterion(EC, y)
        loss.backward()
        E_optimizer.step()

        losses += loss.item()
        batch += 1
    avg_loss = losses / batch
    frozen_net(net, ["extractor"], True)
    

    # test
    correct, total = 0, 0
    with torch.no_grad():
        for x, y in server_niid_testloader:
            x = x.to(device)
            y = y.to(device)

            E = net["extractor"](x)
            EC = net["classifier"](E)

            correct += torch.sum((torch.argmax(EC, dim=1) == y).float()).item()
            total += x.size(0)
    acc = correct / total

    print("epoch:[%2d], loss:%2.6f, acc:%2.6f"%(epoch, avg_loss, acc))

torch.save(net["extractor"].state_dict(), "./checkpoint/server_niid_extractor.pkl")

epoch:[ 0], loss:3.135447, acc:0.098000
epoch:[ 1], loss:2.497407, acc:0.215600
epoch:[ 2], loss:2.160743, acc:0.368700
epoch:[ 3], loss:1.943198, acc:0.411400
epoch:[ 4], loss:1.767795, acc:0.458800
epoch:[ 5], loss:1.632686, acc:0.472300
epoch:[ 6], loss:1.531908, acc:0.511700
epoch:[ 7], loss:1.437794, acc:0.531200
epoch:[ 8], loss:1.355015, acc:0.569800
epoch:[ 9], loss:1.273691, acc:0.609500
epoch:[10], loss:1.209693, acc:0.628400
epoch:[11], loss:1.152437, acc:0.653700
epoch:[12], loss:1.100277, acc:0.618100
epoch:[13], loss:1.052451, acc:0.686200
epoch:[14], loss:1.014736, acc:0.679300
epoch:[15], loss:0.973673, acc:0.697800
epoch:[16], loss:0.942575, acc:0.715000
epoch:[17], loss:0.911199, acc:0.736400
epoch:[18], loss:0.882362, acc:0.737100
epoch:[19], loss:0.856386, acc:0.755500
epoch:[20], loss:0.830695, acc:0.760700
epoch:[21], loss:0.804255, acc:0.765300
epoch:[22], loss:0.784645, acc:0.774800
epoch:[23], loss:0.767280, acc:0.782500
epoch:[24], loss:0.745752, acc:0.788200


**GAN train**

In [None]:
net = nn.ModuleDict()
net["extractor"] = Extractor()          
net["classifier"] = Classifier()
net["generator"] = Generator()
net["discriminator"] = Discriminator()
net = net.to(device)
frozen_net(net, ["extractor", "classifier", "generator", "discriminator"], True)

D_optimizer = optim.Adam(get_params(net, ["discriminator"]), lr=2e-4, betas=(0.5, 0.999))
G_optimizer = optim.Adam(get_params(net, ["generator"]), lr=2e-4, betas=(0.5, 0.999))
BCE_criterion = nn.BCELoss().to(device)

E_checkpoint = torch.load("./checkpoint/client_extractor.pkl", map_location=torch.device('cpu'))
net["extractor"].load_state_dict(E_checkpoint)
C_checkpoint = torch.load("./checkpoint/client_classifier.pkl", map_location=torch.device('cpu'))
net["classifier"].load_state_dict(C_checkpoint)


def discriminator_loss(E, G, y):
    ones = torch.ones((E.size(0), 1)).to(device)
    ED = net["discriminator"](E.detach(), y)
    ED_loss = BCE_criterion(ED, ones)

    zeros = torch.zeros((G.size(0), 1)).to(device)
    GD = net["discriminator"](G.detach(), y)
    GD_loss = BCE_criterion(GD, zeros)
    return ED_loss + GD_loss


def generator_loss(G, y):
    ones = torch.ones((G.size(0), 1)).to(device)
    GD = net["discriminator"](G, y)
    G_loss = BCE_criterion(GD, ones)
    return G_loss


for epoch in range(500):
    # train
    frozen_net(net, ["generator", "discriminator"], False)

    D_losses, G_losses = [], []
    for batch, (x, y) in enumerate(client_trainloader):
        x = x.to(device)
        y = y.to(device)
        z = torch.randn(x.size(0), 100, 1, 1).to(device)

        with torch.no_grad():
            E = net["extractor"](x)
        G = net["generator"](z, y)

        # update D
        D_optimizer.zero_grad()
        D_loss = discriminator_loss(E, G, y)
        D_loss.backward()
        D_optimizer.step()
        D_losses.append(D_loss.item())

        # update G
        G_optimizer.zero_grad()
        G_loss = generator_loss(G, y)
        G_loss.backward()
        G_optimizer.step()
        G_losses.append(G_loss.item())
    
    frozen_net(net, ["generator", "discriminator"], True)

    # test
    correct, total = 0, 0
    with torch.no_grad():
        for x, y in client_testloader:
            y = y.to(device)
            z = torch.randn(x.size(0), 100, 1, 1).to(device)

            G = net["generator"](z, y)
            GC = net["classifier"](G)

            correct += torch.sum((torch.argmax(GC, dim=1) == y).float()).item()
            total += x.size(0)
    acc = correct / total

    print("epoch:[%2d], D_loss:%2.6f, G_loss:%2.6f, acc:%2.6f"
        %(epoch, np.mean(D_losses), np.mean(G_losses), acc))

    if (epoch+1) % 10 == 0:
        torch.save(net["generator"].state_dict(), "./checkpoint/client_generator.pkl")
        torch.save(net["discriminator"].state_dict(), "./checkpoint/client_discriminator.pkl")

epoch:[ 0], D_loss:0.322451, G_loss:3.091217, acc:0.073742
epoch:[ 1], D_loss:0.152293, G_loss:4.494790, acc:0.254111
epoch:[ 2], D_loss:0.220933, G_loss:4.630026, acc:0.460389
epoch:[ 3], D_loss:0.537664, G_loss:3.878116, acc:0.871450
epoch:[ 4], D_loss:0.692956, G_loss:3.268778, acc:0.961136
epoch:[ 5], D_loss:0.770612, G_loss:2.770874, acc:0.984554
epoch:[ 6], D_loss:0.824883, G_loss:2.588488, acc:0.970603
epoch:[ 7], D_loss:0.810655, G_loss:2.484770, acc:0.957150
epoch:[ 8], D_loss:0.766253, G_loss:2.524640, acc:0.983558
epoch:[ 9], D_loss:0.766418, G_loss:2.554713, acc:0.976582
epoch:[10], D_loss:0.714651, G_loss:2.658547, acc:0.963129
epoch:[11], D_loss:0.688639, G_loss:2.799146, acc:0.971599
epoch:[12], D_loss:0.676793, G_loss:2.882823, acc:0.975585
epoch:[13], D_loss:0.680232, G_loss:2.845784, acc:0.984554
epoch:[14], D_loss:0.711148, G_loss:2.872360, acc:0.981565
epoch:[15], D_loss:0.670146, G_loss:2.876335, acc:0.983059
epoch:[16], D_loss:0.693512, G_loss:2.877487, acc:0.9875

**GAN div train**

In [7]:
net = nn.ModuleDict()
net["extractor"] = Extractor()          
net["classifier"] = Classifier()
net["generator"] = Generator()
net["discriminator"] = Discriminator()
net = net.to(device)
frozen_net(net, ["extractor", "classifier", "generator", "discriminator"], True)

D_optimizer = optim.Adam(get_params(net, ["discriminator"]), lr=2e-4, betas=(0.5, 0.999))
G_optimizer = optim.Adam(get_params(net, ["generator"]), lr=4e-4, betas=(0.5, 0.999))
BCE_criterion = nn.BCELoss().to(device)

E_checkpoint = torch.load("./checkpoint/client_extractor.pkl", map_location=torch.device('cpu'))
net["extractor"].load_state_dict(E_checkpoint)
C_checkpoint = torch.load("./checkpoint/client_classifier.pkl", map_location=torch.device('cpu'))
net["classifier"].load_state_dict(C_checkpoint)


def diversity_loss(G1, G2, z1, z2):
    lz = torch.mean(torch.abs(G2 - G1)) / torch.mean(torch.abs(z2 - z1))
    eps = 1 * 1e-5
    G_div = 1 / (lz + eps)
    return G_div


def discriminator_loss(E, G, y):
    ones = torch.ones((E.size(0), 1)).to(device)
    ED = net["discriminator"](E.detach(), y)
    ED_loss = BCE_criterion(ED, ones)

    zeros = torch.zeros((G.size(0), 1)).to(device)
    GD = net["discriminator"](G.detach(), y)
    GD_loss = BCE_criterion(GD, zeros)
    return ED_loss + GD_loss


def generator_loss(G, y):
    ones = torch.ones((G.size(0), 1)).to(device)
    GD = net["discriminator"](G, y)
    G_loss = BCE_criterion(GD, ones)
    return G_loss


for epoch in range(500):
    # train
    frozen_net(net, ["generator", "discriminator"], False)

    D_losses, G_losses, G_divs = [], [], []
    for batch, (x, y) in enumerate(client_trainloader):
        x = x.to(device)
        y = y.to(device)

        with torch.no_grad():
            E = net["extractor"](x)

        # update D
        D_optimizer.zero_grad()
        z = torch.randn(x.size(0), 100, 1, 1).to(device)
        G = net["generator"](z, y)
        D_loss = discriminator_loss(E, G, y)

        D_loss.backward()
        D_optimizer.step()
        D_losses.append(D_loss.item())

        # update G
        if (batch+1)%2 == 0:
            G_optimizer.zero_grad()
            ys = torch.cat([y,y], 0)
            zs = torch.randn(x.size(0)*2, 100, 1, 1).to(device)
            Gs = net["generator"](zs, ys)
            G_loss = generator_loss(Gs, ys)

            z1, z2 = torch.split(zs, x.size(0), 0)
            G1, G2 = torch.split(Gs, x.size(0), 0)
            G_div = diversity_loss(G1, G2, z1, z2)

            (G_loss + G_div).backward()
            G_optimizer.step()
            G_losses.append(G_loss.item())
            G_divs.append(G_div.item())
    
    frozen_net(net, ["generator", "discriminator"], True)

    # test
    correct, total = 0, 0
    with torch.no_grad():
        for x, y in client_testloader:
            y = y.to(device)
            z = torch.randn(x.size(0), 100, 1, 1).to(device)

            G = net["generator"](z, y)
            GC = net["classifier"](G)

            correct += torch.sum((torch.argmax(GC, dim=1) == y).float()).item()
            total += x.size(0)
    acc = correct / total

    print("epoch:[%2d], D_loss:%2.6f, G_loss:%2.6f, G_div:%2.6f, acc:%2.6f"
        %(epoch, np.mean(D_losses), np.mean(G_losses), np.mean(G_divs), acc))

    if (epoch+1) % 10 == 0:
        torch.save(net["generator"].state_dict(), "./checkpoint/client_generator_div.pkl")
        torch.save(net["discriminator"].state_dict(), "./checkpoint/client_discriminator_div.pkl")

epoch:[ 0], D_loss:0.262058, G_loss:2.837883, G_div:10.942803, acc:0.073244
epoch:[ 1], D_loss:0.077761, G_loss:4.405389, G_div:5.222893, acc:0.148979
epoch:[ 2], D_loss:0.111567, G_loss:5.197156, G_div:4.489521, acc:0.270553
epoch:[ 3], D_loss:0.290063, G_loss:4.550537, G_div:4.395665, acc:0.462880
epoch:[ 4], D_loss:0.430056, G_loss:3.996388, G_div:4.533616, acc:0.700050
epoch:[ 5], D_loss:0.510822, G_loss:3.429086, G_div:4.508744, acc:0.820628
epoch:[ 6], D_loss:0.574429, G_loss:3.010144, G_div:4.640945, acc:0.859492
epoch:[ 7], D_loss:0.634799, G_loss:2.925211, G_div:4.714569, acc:0.924265
epoch:[ 8], D_loss:0.665389, G_loss:2.746478, G_div:4.610947, acc:0.936223
epoch:[ 9], D_loss:0.603460, G_loss:2.831660, G_div:4.719760, acc:0.944195
epoch:[10], D_loss:0.587366, G_loss:2.890055, G_div:4.652732, acc:0.947683
epoch:[11], D_loss:0.573084, G_loss:2.911556, G_div:4.700296, acc:0.943697
epoch:[12], D_loss:0.582995, G_loss:2.809167, G_div:4.607203, acc:0.953662
epoch:[13], D_loss:0.590

KeyboardInterrupt: ignored