In [4]:
import argparse
import torch
import torch.nn as nn 
import torch.optim as optim
from torch.autograd import Variable
import numpy as np
import os
import time
from torchvision import datasets, transforms
from torchvision.utils import save_image
# import inception_score

In [3]:
# torch.save(img_syn, 'img_syn.pt')
# torch.save(label_syn, 'label_syn.pt')
# # 读取tensor
img_syn = torch.load('img_syn.pt')
label_syn = torch.load('label_syn.pt')

device  = "cuda:0"
img_syn =img_syn.to(device)
label_syn =label_syn.to(device)
device = img_syn.device



In [6]:

parser = argparse.ArgumentParser(description='Parameter Processing')
parser.add_argument('--dataset', type=str, default='CIFAR10', help='dataset')
parser.add_argument('--model', type=str, default='ConvNet', help='model')
parser.add_argument('--ipc', type=int, default=50, help='image(s) per class')
parser.add_argument('--eval_mode', type=str, default='SS', help='eval_mode') # S: the same to training model, M: multi architectures,  W: net width, D: net depth, A: activation function, P: pooling layer, N: normalization layer,
parser.add_argument('--num_exp', type=int, default=1, help='the number of experiments')
parser.add_argument('--num_eval', type=int, default=1, help='the number of evaluating randomly initialized models')
parser.add_argument('--epoch_eval_train', type=int, default=1000, help='epochs to train a model with synthetic data') # it can be small for speeding up with little performance drop
parser.add_argument('--Iteration', type=int, default=2000, help='training iterations')
parser.add_argument('--lr_img', type=float, default=1.0, help='learning rate for updating synthetic images')
parser.add_argument('--lr_net', type=float, default=0.01, help='learning rate for updating network parameters')
parser.add_argument('--batch_real', type=int, default=256, help='batch size for real data')
parser.add_argument('--batch_train', type=int, default=256, help='batch size for training networks')
parser.add_argument('--init', type=str, default='real', help='noise/real: initialize synthetic images from random noise or randomly sampled real images.')
parser.add_argument('--dsa_strategy', type=str, default='color_crop_cutout_flip_scale_rotate', help='differentiable Siamese augmentation strategy')
parser.add_argument('--data_path', type=str, default='/home/ssd7T/ZTL_gcond/data_cv', help='dataset path')
parser.add_argument('--save_path', type=str, default='result/gen', help='path to save results')
parser.add_argument('--dis_metric', type=str, default='ours', help='distance metric')
from utils import get_loops, get_dataset, get_network, get_eval_pool, evaluate_synset, get_daparam, match_loss, get_time, TensorDataset, epoch, DiffAugment, ParamDiffAug
import warnings
args = parser.parse_args([])
channel, im_size, num_classes, class_names, mean, std, dst_train, dst_test, testloader = get_dataset(args.dataset, args.data_path)

images_all = [torch.unsqueeze(dst_train[i][0], dim=0) for i in range(len(dst_train))]
labels_all = [dst_train[i][1] for i in range(len(dst_train))]
indices_class = [[] for c in range(num_classes)]
for i, lab in enumerate(labels_all):
    indices_class[lab].append(i)
images_all = torch.cat(images_all, dim=0).to(device)
labels_all = torch.tensor(labels_all, dtype=torch.long, device=device)

accs = []
model_eval_pool = get_eval_pool(args.eval_mode, args.model, args.model)
args.device = "cuda:0"
import copy
accs_all_exps = dict() # record performances of all experiments
for key in model_eval_pool:
    accs_all_exps[key] = []
args.dsa_param = ParamDiffAug()
args.dsa = False if args.dsa_strategy in ['none', 'None'] else True
model_eval= model_eval_pool[0]
data_save = []

# img_real = []
# label_real = []
# for c in range(num_classes):
#     idx_shuffle = np.random.permutation(indices_class[c])
#     img_real.append(images_all[idx_shuffle].to("cpu") )
#     label_real.append(labels_all[idx_shuffle].to("cpu"))
# img_real = torch.from_numpy(np.concatenate(img_real, axis=0))
# label_real = torch.from_numpy(np.concatenate(label_real, axis=0))


SEED = 114514
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)
    
img_real_train = []
label_real_train = []
for c in range(num_classes):
    idx_shuffle = np.random.permutation(indices_class[c])[:50]
    img_real_train.append(images_all[idx_shuffle].to("cpu") )
    label_real_train.append(labels_all[idx_shuffle].to("cpu"))
img_real_train = torch.from_numpy(np.concatenate(img_real_train, axis=0))
label_real_train = torch.from_numpy(np.concatenate(label_real_train, axis=0))

SEED = 87
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)
img_real_test = []
label_real_test = []
for c in range(num_classes):
    idx_shuffle = np.random.permutation(indices_class[c])[:50]
    img_real_test.append(images_all[idx_shuffle].to("cpu") )
    label_real_test.append(labels_all[idx_shuffle].to("cpu"))
img_real_test = torch.from_numpy(np.concatenate(img_real_test, axis=0))
label_real_test = torch.from_numpy(np.concatenate(label_real_test, axis=0))

Files already downloaded and verified
Files already downloaded and verified


In [13]:
num_classes = 10
batch = 50
num_feat = 3072
batch_size = 64
z_dim = 100
num_epochs = 100

epochs = 100
# Hyper-parameters 
latent_size = 32
batch_size = 50
lr_D = 0.0002
lr_G = 0.001
lr_Q = 0.0001
beta1 = 0.5

# Create target dataloader
transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
           ])

# Discriminator
D = nn.Sequential(
    nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1), 
    nn.LeakyReLU(0.2),
    
    nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
    nn.BatchNorm2d(128),
    nn.LeakyReLU(0.2),
    
    nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
    nn.BatchNorm2d(256),
    nn.LeakyReLU(0.2),

    nn.Conv2d(256, 512, kernel_size=4, stride=1, padding=0),
    nn.BatchNorm2d(512),
    nn.LeakyReLU(0.2),

    nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=0),
    nn.Sigmoid()
)

D = D.to(device) 

# Generator 
G = nn.Sequential(
    nn.Linear(3*32*32, 256*4*4), 
    nn.BatchNorm1d(256*4*4),
    nn.ReLU(),

    nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
    nn.BatchNorm2d(128),
    nn.ReLU(),

    nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
    nn.BatchNorm2d(64),  
    nn.ReLU(),

    nn.ConvTranspose2d(64, 3, kernel_size=4, stride=2, padding=1),
    nn.Tanh()
)
G = G.to(device)

# Losses
criterion = nn.BCELoss()
d_optimizer = optim.Adam(D.parameters(), lr=lr_D, betas=(beta1, 0.999))
g_optimizer = optim.Adam(G.parameters(), lr=lr_G, betas=(beta1, 0.999))

def gradient_penalty(real_samples, fake_samples):
    """Calculates the gradient penalty loss for WGAN GP"""
    # Random weight term for interpolation between real and fake samples
    alpha = torch.randn((batch_size, 1, 1, 1)).to(device)
    # Get random interpolation between real and fake samples
    interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True)
    d_interpolates = D(interpolates)
    fake = torch.ones(d_interpolates.size()).requires_grad_(False).to(device)
    # Get gradient w.r.t. interpolates
    gradients = torch.autograd.grad(
        outputs=d_interpolates,
        inputs=interpolates,
        grad_outputs=fake,
        create_graph=True,
        retain_graph=True,
        only_inputs=True,
    )[0]
    gradients = gradients.view(gradients.size(0), -1)
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
    return gradient_penalty

In [15]:


# ----------
#  Training
# ----------

IS = []
for epoch in range(epochs):    
    for c in range(num_classes):
    
        batch_img = img_real_train[c*batch_size:(c+1)*batch_size].reshape((batch, 3, 32, 32)).to(device).to(device)
        batch_img_label = label_real_train[c*batch_size:(c+1)*batch_size].to(device)
        batch_img_syn = img_syn[c*batch_size:(c+1)*batch_size].reshape((batch, 3, 32, 32)).to(device).to(device)
        batch_img_syn_label = label_syn[c*batch_size:(c+1)*batch_size].to(device)


        # Sample noise as generator input
        # z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], latent_size)))).to(device)
        # z = batch_img

        # Generate a batch of images .view(batch_img.size(0), -1)
        gen_imgs = G(batch_img)

        # Loss measures generator's ability to fool the discriminator
        g_loss = criterion(D(gen_imgs), batch_img_syn_label)
        
        # Real images 
        real_out = D(batch_img_syn)
        d_real_loss = criterion(real_out, torch.ones_like(real_out)) 
        # Fake images
        fake_out = D(gen_imgs.detach())
        d_fake_loss = criterion(fake_out, torch.zeros_like(fake_out))
        # Gradient penalty
        gradient_penalty_val = gradient_penalty(batch_img_syn , gen_imgs.data)
        # Total discriminator loss
        d_loss = d_real_loss + d_fake_loss + gradient_penalty_val

        # -----------------
        #  Train Generator
        # -----------------

        g_optimizer.zero_grad()
        g_loss.backward()
        g_optimizer.step()

        # ---------------------
        #  Train Discriminator 
        # ---------------------

        d_optimizer.zero_grad()
        d_loss.backward()
        d_optimizer.step()
        
        # --------------
        # Log Progress
        # --------------

        print ("[Epoch %d/%d][D loss: %f] [G loss: %f]" % (epoch, epochs, 
                                                            d_loss.item(), g_loss.item())) 
                                                            
        
        # batches_done = epoch * len(data_loader) + i
        # if batches_done % 100 == 0:
        #     # Calculate Inception Score
        #     z = Variable(Tensor(np.random.normal(0, 1, (5000, latent_size)))).to(device) 
        #     gen_imgs = G(z)
        #     IS.append(inception_score(gen_imgs.cpu().data, cuda=False, batch_size=32, 
        #                             resize=True, splits=10)[0])

        #     # Save generated images 
        #     save_image(gen_imgs.data[:25], "%d.png" % batches_done, nrow=5, normalize=True)
            
# Plot IS 
# plot(IS)

RuntimeError: mat1 and mat2 shapes cannot be multiplied (4800x32 and 3072x4096)