In [None]:
import argparse
import torch
import torch.nn as nn 
import torch.optim as optim
from torch.autograd import Variable
import numpy as np
import os
import time
from torchvision import datasets, transforms
from torchvision.utils import save_image
# import inception_score

In [None]:
# torch.save(img_syn, 'img_syn.pt')
# torch.save(label_syn, 'label_syn.pt')
# # 读取tensor
img_syn = torch.load('img_syn.pt')
label_syn = torch.load('label_syn.pt')

device  = "cuda:0"
img_syn =img_syn.to(device)
label_syn =label_syn.to(device)
device = img_syn.device



In [None]:

parser = argparse.ArgumentParser(description='Parameter Processing')
parser.add_argument('--dataset', type=str, default='CIFAR10', help='dataset')
parser.add_argument('--model', type=str, default='ConvNet', help='model')
parser.add_argument('--ipc', type=int, default=50, help='image(s) per class')
parser.add_argument('--eval_mode', type=str, default='SS', help='eval_mode') # S: the same to training model, M: multi architectures,  W: net width, D: net depth, A: activation function, P: pooling layer, N: normalization layer,
parser.add_argument('--num_exp', type=int, default=1, help='the number of experiments')
parser.add_argument('--num_eval', type=int, default=1, help='the number of evaluating randomly initialized models')
parser.add_argument('--epoch_eval_train', type=int, default=1000, help='epochs to train a model with synthetic data') # it can be small for speeding up with little performance drop
parser.add_argument('--Iteration', type=int, default=2000, help='training iterations')
parser.add_argument('--lr_img', type=float, default=1.0, help='learning rate for updating synthetic images')
parser.add_argument('--lr_net', type=float, default=0.01, help='learning rate for updating network parameters')
parser.add_argument('--batch_real', type=int, default=256, help='batch size for real data')
parser.add_argument('--batch_train', type=int, default=256, help='batch size for training networks')
parser.add_argument('--init', type=str, default='real', help='noise/real: initialize synthetic images from random noise or randomly sampled real images.')
parser.add_argument('--dsa_strategy', type=str, default='color_crop_cutout_flip_scale_rotate', help='differentiable Siamese augmentation strategy')
parser.add_argument('--data_path', type=str, default='/home/ssd7T/ZTL_gcond/data_cv', help='dataset path')
parser.add_argument('--save_path', type=str, default='result/gen', help='path to save results')
parser.add_argument('--dis_metric', type=str, default='ours', help='distance metric')
from utils import get_loops, get_dataset, get_network, get_eval_pool, evaluate_synset, get_daparam, match_loss, get_time, TensorDataset, epoch, DiffAugment, ParamDiffAug
import warnings
args = parser.parse_args([])
channel, im_size, num_classes, class_names, mean, std, dst_train, dst_test, testloader = get_dataset(args.dataset, args.data_path)

images_all = [torch.unsqueeze(dst_train[i][0], dim=0) for i in range(len(dst_train))]
labels_all = [dst_train[i][1] for i in range(len(dst_train))]
indices_class = [[] for c in range(num_classes)]
for i, lab in enumerate(labels_all):
    indices_class[lab].append(i)
images_all = torch.cat(images_all, dim=0).to(device)
labels_all = torch.tensor(labels_all, dtype=torch.long, device=device)

accs = []
model_eval_pool = get_eval_pool(args.eval_mode, args.model, args.model)
args.device = "cuda:0"
import copy
accs_all_exps = dict() # record performances of all experiments
for key in model_eval_pool:
    accs_all_exps[key] = []
args.dsa_param = ParamDiffAug()
args.dsa = False if args.dsa_strategy in ['none', 'None'] else True
model_eval= model_eval_pool[0]
data_save = []

# img_real = []
# label_real = []
# for c in range(num_classes):
#     idx_shuffle = np.random.permutation(indices_class[c])
#     img_real.append(images_all[idx_shuffle].to("cpu") )
#     label_real.append(labels_all[idx_shuffle].to("cpu"))
# img_real = torch.from_numpy(np.concatenate(img_real, axis=0))
# label_real = torch.from_numpy(np.concatenate(label_real, axis=0))


SEED = 114514
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)
    
img_real_train = []
label_real_train = []
for c in range(num_classes):
    idx_shuffle = np.random.permutation(indices_class[c])[:50]
    img_real_train.append(images_all[idx_shuffle].to("cpu") )
    label_real_train.append(labels_all[idx_shuffle].to("cpu"))
img_real_train = torch.from_numpy(np.concatenate(img_real_train, axis=0))
label_real_train = torch.from_numpy(np.concatenate(label_real_train, axis=0))

SEED = 87
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)
img_real_test = []
label_real_test = []
for c in range(num_classes):
    idx_shuffle = np.random.permutation(indices_class[c])[:50]
    img_real_test.append(images_all[idx_shuffle].to("cpu") )
    label_real_test.append(labels_all[idx_shuffle].to("cpu"))
img_real_test = torch.from_numpy(np.concatenate(img_real_test, axis=0))
label_real_test = torch.from_numpy(np.concatenate(label_real_test, axis=0))

In [12]:
import utils, torch, time, os, pickle
import numpy as np
import torch.nn as nn
import torch.optim as optim
from torch.autograd import grad
# from dataloader import dataloader

class generator(nn.Module):
    # Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657)
    # Architecture : FC1024_BR-FC7x7x128_BR-(64)4dc2s_BR-(1)4dc2s_S
    def __init__(self, input_dim=100, output_dim=1, input_size=32):
        super(generator, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.input_size = input_size

        self.fc = nn.Sequential(
            nn.Linear(self.input_dim, 1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(),
            nn.Linear(1024, 128 * (self.input_size // 4) * (self.input_size // 4)),
            nn.BatchNorm1d(128 * (self.input_size // 4) * (self.input_size // 4)),
            nn.ReLU(),
        )
        self.deconv = nn.Sequential(
            nn.ConvTranspose2d(128, 64, 4, 2, 1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.ConvTranspose2d(64, self.output_dim, 4, 2, 1),
            nn.Tanh(),
        )
        utils.initialize_weights(self)

    def forward(self, input):
        x = self.fc(input)
        x = x.view(-1, 128, (self.input_size // 4), (self.input_size // 4))
        x = self.deconv(x)

        return x

class discriminator(nn.Module):
    # Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657)
    # Architecture : (64)4c2s-(128)4c2s_BL-FC1024_BL-FC1_S
    def __init__(self, input_dim=1, output_dim=1, input_size=32):
        super(discriminator, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.input_size = input_size

        self.conv = nn.Sequential(
            nn.Conv2d(self.input_dim, 64, 4, 2, 1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, 4, 2, 1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
        )
        self.fc = nn.Sequential(
            nn.Linear(128 * (self.input_size // 4) * (self.input_size // 4), 1024),
            nn.BatchNorm1d(1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, self.output_dim),
            nn.Sigmoid(),
        )
        utils.initialize_weights(self)

    def forward(self, input):
        x = self.conv(input)
        x = x.view(-1, 128 * (self.input_size // 4) * (self.input_size // 4))
        x = self.fc(x)

        return x

class DRAGAN(object):
    def __init__(self, args):
        # parameters
        self.epoch = args.epoch
        self.sample_num = 100
        self.batch_size = args.batch_size
        self.save_dir = args.save_dir
        self.result_dir = args.result_dir
        self.dataset = args.dataset
        self.log_dir = args.log_dir
        self.gpu_mode = args.gpu_mode
        self.model_name = args.gan_type
        self.input_size = args.input_size
        self.z_dim = 62
        self.lambda_ = 0.25

        # load dataset
        self.data_loader = dataloader(self.dataset, self.input_size, self.batch_size)
        data = self.data_loader.__iter__().__next__()[0]

        # networks init
        self.G = generator(input_dim=self.z_dim, output_dim=data.shape[1], input_size=self.input_size)
        self.D = discriminator(input_dim=data.shape[1], output_dim=1, input_size=self.input_size)
        self.G_optimizer = optim.Adam(self.G.parameters(), lr=args.lrG, betas=(args.beta1, args.beta2))
        self.D_optimizer = optim.Adam(self.D.parameters(), lr=args.lrD, betas=(args.beta1, args.beta2))

        if self.gpu_mode:
            self.G.cuda()
            self.D.cuda()
            self.BCE_loss = nn.BCELoss().cuda()
        else:
            self.BCE_loss = nn.BCELoss()

        print('---------- Networks architecture -------------')
        utils.print_network(self.G)
        utils.print_network(self.D)
        print('-----------------------------------------------')

        # fixed noise
        self.sample_z_ = torch.rand((self.batch_size, self.z_dim))
        if self.gpu_mode:
            self.sample_z_ = self.sample_z_.cuda()

    def train(self):
        self.train_hist = {}
        self.train_hist['D_loss'] = []
        self.train_hist['G_loss'] = []
        self.train_hist['per_epoch_time'] = []
        self.train_hist['total_time'] = []

        self.y_real_, self.y_fake_ = torch.ones(self.batch_size, 1), torch.zeros(self.batch_size, 1)
        if self.gpu_mode:
            self.y_real_, self.y_fake_ = self.y_real_.cuda(), self.y_fake_.cuda()

        self.D.train()
        print('training start!!')
        start_time = time.time()
        for epoch in range(self.epoch):
            epoch_start_time = time.time()
            self.G.train()
            for iter, (x_, _) in enumerate(self.data_loader):
                if iter == self.data_loader.dataset.__len__() // self.batch_size:
                    break

                z_ = torch.rand((self.batch_size, self.z_dim))
                if self.gpu_mode:
                    x_, z_ = x_.cuda(), z_.cuda()

                # update D network
                self.D_optimizer.zero_grad()

                D_real = self.D(x_)
                D_real_loss = self.BCE_loss(D_real, self.y_real_)

                G_ = self.G(z_)
                D_fake = self.D(G_)
                D_fake_loss = self.BCE_loss(D_fake, self.y_fake_)

                """ DRAGAN Loss (Gradient penalty) """
                # This is borrowed from https://github.com/kodalinaveen3/DRAGAN/blob/master/DRAGAN.ipynb
                alpha = torch.rand(self.batch_size, 1, 1, 1).cuda()
                if self.gpu_mode:
                    alpha = alpha.cuda()
                    x_p = x_ + 0.5 * x_.std() * torch.rand(x_.size()).cuda()
                else:
                    x_p = x_ + 0.5 * x_.std() * torch.rand(x_.size())
                differences = x_p - x_
                interpolates = x_ + (alpha * differences)
                interpolates.requires_grad = True
                pred_hat = self.D(interpolates)
                if self.gpu_mode:
                    gradients = grad(outputs=pred_hat, inputs=interpolates, grad_outputs=torch.ones(pred_hat.size()).cuda(),
                                 create_graph=True, retain_graph=True, only_inputs=True)[0]
                else:
                    gradients = grad(outputs=pred_hat, inputs=interpolates, grad_outputs=torch.ones(pred_hat.size()),
                         create_graph=True, retain_graph=True, only_inputs=True)[0]

                gradient_penalty = self.lambda_ * ((gradients.view(gradients.size()[0], -1).norm(2, 1) - 1) ** 2).mean()

                D_loss = D_real_loss + D_fake_loss + gradient_penalty
                self.train_hist['D_loss'].append(D_loss.item())
                D_loss.backward()
                self.D_optimizer.step()

                # update G network
                self.G_optimizer.zero_grad()

                G_ = self.G(z_)
                D_fake = self.D(G_)

                G_loss = self.BCE_loss(D_fake, self.y_real_)
                self.train_hist['G_loss'].append(G_loss.item())

                G_loss.backward()
                self.G_optimizer.step()

                if ((iter + 1) % 100) == 0:
                    print("Epoch: [%2d] [%4d/%4d] D_loss: %.8f, G_loss: %.8f" %
                          ((epoch + 1), (iter + 1), self.data_loader.dataset.__len__() // self.batch_size, D_loss.item(), G_loss.item()))

            self.train_hist['per_epoch_time'].append(time.time() - epoch_start_time)
            with torch.no_grad():
                self.visualize_results((epoch+1))

        self.train_hist['total_time'].append(time.time() - start_time)
        print("Avg one epoch time: %.2f, total %d epochs time: %.2f" % (np.mean(self.train_hist['per_epoch_time']),
              self.epoch, self.train_hist['total_time'][0]))
        print("Training finish!... save training results")

        self.save()
        utils.generate_animation(self.result_dir + '/' + self.dataset + '/' + self.model_name + '/' + self.model_name, self.epoch)
        utils.loss_plot(self.train_hist, os.path.join(self.save_dir, self.dataset, self.model_name), self.model_name)

In [None]:
parser = argparse.ArgumentParser(description=desc)


parser.add_argument('--dataset', type=str, default='mnist', choices=['mnist', 'fashion-mnist', 'cifar10', 'cifar100', 'svhn', 'stl10', 'lsun-bed'],
                    help='The name of dataset')
parser.add_argument('--split', type=str, default='', help='The split flag for svhn and stl10')
parser.add_argument('--epoch', type=int, default=50, help='The number of epochs to run')
parser.add_argument('--batch_size', type=int, default=64, help='The size of batch')
parser.add_argument('--input_size', type=int, default=28, help='The size of input image')
parser.add_argument('--save_dir', type=str, default='models',
                    help='Directory name to save the model')
parser.add_argument('--result_dir', type=str, default='results', help='Directory name to save the generated images')
parser.add_argument('--log_dir', type=str, default='logs', help='Directory name to save training logs')
parser.add_argument('--lrG', type=float, default=0.0002)
parser.add_argument('--lrD', type=float, default=0.0002)
parser.add_argument('--beta1', type=float, default=0.5)
parser.add_argument('--beta2', type=float, default=0.999)
parser.add_argument('--gpu_mode', type=bool, default=True)
parser.add_argument('--benchmark_mode', type=bool, default=True)

 args = parser.parse_args[]

In [None]:
gan = DRAGAN(args)

gan.train()

In [13]:
step = 0


for epoch in range(n_epochs): 
 
    for c in range(num_classes):
    
        batch_img = img_real_train[c*batch_size:(c+1)*batch_size].reshape((batch_size, 3, 32, 32)).to(device).to(device)
        batch_img_label = label_real_train[c*batch_size:(c+1)*batch_size].to(device)
        batch_img_syn = img_syn[c*batch_size:(c+1)*batch_size].reshape((batch_size, 3, 32, 32)).to(device).to(device)
        batch_img_syn_label = label_syn[c*batch_size:(c+1)*batch_size].to(device)

        ## Train Discriminator 
        # noise = torch.randn(batch_size, z_dim, 1, 1).to(device)
        fake = G(batch_img)
        # print(batch_img_syn.shape)
        print(fake.shape)

        # Gradient penalty
        penalty = gradient_penalty(batch_img_syn, fake)

        # Adversarial loss
        loss_D = -torch.mean(D(batch_img)) + torch.mean(D(fake)) + lambd*penalty
        
        opt_D.zero_grad()
        loss_D.backward(retain_graph=True) 
        opt_D.step()
        
        ## Train Generator
        loss_G = -torch.mean(D(fake))
        
        opt_G.zero_grad()
        loss_G.backward()
        opt_G.step()

        # Print losses  
        print(
            "[Epoch %d/%d] [D loss: %f] [G loss: %f]"
            % (epoch, n_epochs, loss_D.item(), loss_G.item())
        )

        # if batch_idx % 100 == 0:
        #     step += 1
        #     writer.add_scalar("Discriminator loss", loss_D.item(), global_step=step)
        #     writer.add_scalar("Generator loss", loss_G.item(), global_step=step)
        #     with torch.no_grad():
        #         fake = G(noise)
        #         img_grid_fake = make_grid(fake[:32], normalize=True)
            
        #         writer.add_image("Fake images", img_grid_fake, global_step=step)
            
print("Training finished!")

#  (64, 3, 4, 2, 1)

torch.Size([1, 256, 8, 8])
torch.Size([1, 128, 16, 16])
torch.Size([1, 64, 32, 32])
torch.Size([1, 3, 64, 64])


RuntimeError: The size of tensor a (32) must match the size of tensor b (64) at non-singleton dimension 3