In [None]:
# Numpy
import numpy as np

# Torch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable

# Torchvision
import torchvision
import torchvision.transforms as transforms

# Matplotlib
%matplotlib inline
import matplotlib.pyplot as plt

# OS
import os
import argparse

# Set random seed for reproducibility
SEED = 187
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)


def print_model(encoder, decoder):
    print("============== Encoder ==============")
    print(encoder)
    print("============== Decoder ==============")
    print(decoder)
    print("")


def create_model():
    autoencoder = Autoencoder()
    print_model(autoencoder.encoder, autoencoder.decoder)
    if torch.cuda.is_available():
        autoencoder = autoencoder.cuda()
        print("Model moved to GPU in order to speed up training.")
    return autoencoder


def get_torch_vars(x):
    if torch.cuda.is_available():
        x = x.cuda()
    return Variable(x)

def imshow(img):
    npimg = img.cpu().numpy()
    plt.axis('off')
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()


class Autoencoder(nn.Module):
    def __init__(self):
        super(Autoencoder, self).__init__()
        # Input size: [batch, 3, 32, 32]
        # Output size: [batch, 3, 32, 32]
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 12, 4, stride=2, padding=1),            # [batch, 12, 16, 16]
            nn.ReLU(),
            nn.Conv2d(12, 24, 4, stride=2, padding=1),           # [batch, 24, 8, 8]
            nn.ReLU(),
			nn.Conv2d(24, 48, 4, stride=2, padding=1),           # [batch, 48, 4, 4]
            nn.ReLU(),
# 			nn.Conv2d(48, 96, 4, stride=2, padding=1),           # [batch, 96, 2, 2]
#             nn.ReLU(),
        )
        self.decoder = nn.Sequential(
#             nn.ConvTranspose2d(96, 48, 4, stride=2, padding=1),  # [batch, 48, 4, 4]
#             nn.ReLU(),
			nn.ConvTranspose2d(48, 24, 4, stride=2, padding=1),  # [batch, 24, 8, 8]
            nn.ReLU(),
			nn.ConvTranspose2d(24, 12, 4, stride=2, padding=1),  # [batch, 12, 16, 16]
            nn.ReLU(),
            nn.ConvTranspose2d(12, 3, 4, stride=2, padding=1),   # [batch, 3, 32, 32]
            nn.Sigmoid(),
        )

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return encoded, decoded



In [None]:
parser = argparse.ArgumentParser(description="Train Autoencoder")
parser.add_argument("--valid", action="store_true", default=False,
                    help="Perform validation only.")
args = parser.parse_args([])

# Create model
autoencoder = create_model()

# Load data
transform = transforms.Compose(
    [transforms.ToTensor(), ])

In [None]:
# torch.save(img_syn, 'img_syn.pt')
# torch.save(label_syn, 'label_syn.pt')
# # 读取tensor
img_syn = torch.load('img_syn.pt')
label_syn = torch.load('label_syn.pt')
pairs_real=torch.load('pairs_real.pt')
indexs_real=torch.load('indexs_real.pt')

In [None]:
img_real_train = torch.cat(pairs_real, dim=0)

In [None]:
device = img_syn.device

In [None]:

parser = argparse.ArgumentParser(description='Parameter Processing')
parser.add_argument('--dataset', type=str, default='CIFAR10', help='dataset')
parser.add_argument('--model', type=str, default='ConvNet', help='model')
parser.add_argument('--ipc', type=int, default=50, help='image(s) per class')
parser.add_argument('--eval_mode', type=str, default='SS', help='eval_mode') # S: the same to training model, M: multi architectures,  W: net width, D: net depth, A: activation function, P: pooling layer, N: normalization layer,
parser.add_argument('--num_exp', type=int, default=1, help='the number of experiments')
parser.add_argument('--num_eval', type=int, default=1, help='the number of evaluating randomly initialized models')
parser.add_argument('--epoch_eval_train', type=int, default=1000, help='epochs to train a model with synthetic data') # it can be small for speeding up with little performance drop
parser.add_argument('--Iteration', type=int, default=2000, help='training iterations')
parser.add_argument('--lr_img', type=float, default=1.0, help='learning rate for updating synthetic images')
parser.add_argument('--lr_net', type=float, default=0.01, help='learning rate for updating network parameters')
parser.add_argument('--batch_real', type=int, default=256, help='batch size for real data')
parser.add_argument('--batch_train', type=int, default=256, help='batch size for training networks')
parser.add_argument('--init', type=str, default='real', help='noise/real: initialize synthetic images from random noise or randomly sampled real images.')
parser.add_argument('--dsa_strategy', type=str, default='color_crop_cutout_flip_scale_rotate', help='differentiable Siamese augmentation strategy')
parser.add_argument('--data_path', type=str, default='/home/ssd7T/ZTL_gcond/data_cv', help='dataset path')
parser.add_argument('--save_path', type=str, default='result/gen', help='path to save results')
parser.add_argument('--dis_metric', type=str, default='ours', help='distance metric')
from utils import get_loops, get_dataset, get_network, get_eval_pool, evaluate_synset, get_daparam, match_loss, get_time, TensorDataset, epoch, DiffAugment, ParamDiffAug
import warnings
args = parser.parse_args([])
channel, im_size, num_classes, class_names, mean, std, dst_train, dst_test, testloader = get_dataset(args.dataset, args.data_path)

images_all = [torch.unsqueeze(dst_train[i][0], dim=0) for i in range(len(dst_train))]
labels_all = [dst_train[i][1] for i in range(len(dst_train))]
indices_class = [[] for c in range(num_classes)]
for i, lab in enumerate(labels_all):
    indices_class[lab].append(i)
images_all = torch.cat(images_all, dim=0).to(device)
labels_all = torch.tensor(labels_all, dtype=torch.long, device=device)


In [None]:
label_real = []
for c in range(num_classes):
    idx_shuffle = indexs_real[c]
    # img_real.append(images_all[idx_shuffle].to("cpu") )
    label_real.append(labels_all[idx_shuffle].to("cpu"))
# img_real = torch.from_numpy(np.concatenate(img_real, axis=0))
label_real_train = torch.from_numpy(np.concatenate(label_real, axis=0))

In [None]:
img_real_all_train = []
label_real_all_train = []
for c in range(num_classes):
    idx_shuffle = np.random.permutation(indices_class[c])[:500]
    img_real_all_train.append(images_all[idx_shuffle].to("cpu") )
    label_real_all_train.append(labels_all[idx_shuffle].to("cpu"))
img_real_all_train = torch.from_numpy(np.concatenate(img_real_all_train, axis=0))
label_real_all_train = torch.from_numpy(np.concatenate(label_real_all_train, axis=0))

In [None]:
# img_real = []
# label_real = []
# for c in range(num_classes):
#     idx_shuffle = np.random.permutation(indices_class[c])
#     img_real.append(images_all[idx_shuffle].to("cpu") )
#     label_real.append(labels_all[idx_shuffle].to("cpu"))
# img_real = torch.from_numpy(np.concatenate(img_real, axis=0))
# label_real = torch.from_numpy(np.concatenate(label_real, axis=0))

In [None]:
img_real_test = []
label_real_test = []
for c in range(num_classes):
    idx_shuffle = np.random.permutation(indices_class[c])[:500]
    img_real_test.append(images_all[idx_shuffle].to("cpu") )
    label_real_test.append(labels_all[idx_shuffle].to("cpu"))
img_real_test = torch.from_numpy(np.concatenate(img_real_test, axis=0))
label_real_test = torch.from_numpy(np.concatenate(label_real_test, axis=0))

In [None]:
label_real_test.shape

In [None]:
img_real_train.shape

In [None]:
# 用全部数据集训

In [11]:


num_classes = 10
batch = 50
num_feat = 3072
criterion = nn.BCELoss()
optimizer = optim.Adam(autoencoder.parameters())
# model = Autoencoder(num_feat).to(device)  
# 训练
for epoch in range(1000):

  total_loss = 0
  for c in range(num_classes):
                # 获取类别c的合成图像和类别中心
    # batch_syn = img_syn[c*args.ipc:(c+1)*args.ipc].reshape((batch, 3, 32, 32)) 
    batch_img = img_real_all_train[c*batch:(c+1)*batch].reshape((batch, 3, 32, 32)) 
    batch_img = batch_img.to(device) 
    
    # ============ Forward ============
    encoded, outputs = autoencoder(batch_img)
    loss = criterion(outputs, batch_img)
    # ============ Backward ============
    # print(loss.item())
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()


  

        

In [12]:
torch.save(autoencoder,"autoencoder_all.pt")

In [13]:
with torch.no_grad():
  output = autoencoder(img_real_test.to(device))
accs = []
model_eval_pool = get_eval_pool(args.eval_mode, args.model, args.model)
args.device = "cuda:0"
import copy
accs_all_exps = dict() # record performances of all experiments
for key in model_eval_pool:
    accs_all_exps[key] = []
args.dsa_param = ParamDiffAug()
args.dsa = False if args.dsa_strategy in ['none', 'None'] else True
model_eval= model_eval_pool[0]

# [2023-10-06 20:45:24] Evaluate_00: epoch = 1000 train time = 61 s train loss = 0.000965 train acc = 1.0000, test acc = 0.5001
# Evaluate 1 random ConvNet, mean = 0.5001 std = 0.0000

In [14]:
data_save = []
net_eval = get_network(model_eval, channel, num_classes, im_size).to(device) # get a random model
image_syn_eval, label_syn_eval = copy.deepcopy(output[1]), copy.deepcopy(label_real_test) # avoid any unaware modification
_, acc_train, acc_test = evaluate_synset(1, net_eval, image_syn_eval, label_syn_eval, testloader, args)
accs.append(acc_test)
print()
print('Evaluate %d random %s, mean = %.4f std = %.4f\n-------------------------'%(len(accs), model_eval, np.mean(accs), np.std(accs)))



# SEED = 187 + test 5000张原图
# A + test - 0.5147

# 直接用500张蒸馏前原图-蒸留后图片(有对应)训AE  - 0.4998

# A + 500张合成图FINETUN  - 0.5213
# A + 500张蒸馏前原图-蒸留后图片FINETUN(B)  - 5165

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


[2023-10-08 21:45:03] Evaluate_01: epoch = 1000 train time = 2135 s train loss = 0.046248 train acc = 0.9898, test acc = 0.5147

Evaluate 1 random ConvNet, mean = 0.5147 std = 0.0000
-------------------------


In [15]:
# torch.save(img_syn, 'img_syn.pt')
# torch.save(autoencoder, 'autoencoder_all.pt')
# 1.用全部 cifar样本训练一个auto encoder A
# 2.用 蒸馏前原图-蒸留后图片 分别作为输入和生成目标来finetune A，得到B
# 3.选五百个原图，分别输入A和B，测试B会不会效果更好

# SEED = 87 + test 500张原图
# A + test 500张原图 - 0.3446

# A + 500张合成图FINETUN(冻结encoder)  - 0.3576
# A + 500张蒸馏前原图-蒸留后图片FINETUN(B)(冻结encoder)  - 0.3075


# A + 500张合成图FINETUN  - 0.3181
# A + 500张蒸馏前原图-蒸留后图片FINETUN(B)  - 0.3381


# SEED = 187 + test 500张原图
# A + test 500张原图 - 0.3559

# A + 500张合成图FINETUN(冻结encoder)  - 0.3429
# A + 500张蒸馏前原图-蒸留后图片FINETUN(B)(冻结encoder)  - 0.3179


# A + 500张合成图FINETUN  - 0.3607
# A + 500张蒸馏前原图-蒸留后图片FINETUN(B)  - 0.3365


# SEED = 87 + test 5000张原图
# A + test 500张原图 - 0.3446

# A + 500张合成图FINETUN(冻结encoder) - 0.3576
# A + 500张蒸馏前原图-蒸留后图片FINETUN(B)(冻结encoder)  - 0.3075


# A + 500张合成图FINETUN  - 0.3181
# A + 500张蒸馏前原图-蒸留后图片FINETUN(B)  - 0.3381


# SEED = 187 + test 5000张原图
# A + test 500张原图 - 0.5176

# A + 500张合成图FINETUN(冻结encoder)  - 0.4242
# A + 500张蒸馏前原图-蒸留后图片FINETUN(B)(冻结encoder)  - 0.3664


# A + 500张合成图FINETUN  - 0.5289
# A + 500张蒸馏前原图-蒸留后图片FINETUN(B)  - 0.3916


In [16]:
# A + 500张合成图FINETUN

num_classes = 10
batch = 50
num_feat = 3072
criterion = nn.BCELoss()
optimizer = optim.Adam(autoencoder.parameters())
# model = Autoencoder(num_feat).to(device)  
# 训练
for epoch in range(1000):

  total_loss = 0
  for c in range(num_classes):
                # 获取类别c的合成图像和类别中心
                # image_syn[c*args.ipc:(c+1)*args.ipc].reshape((args.ipc, channel, im_size[0], im_size[1]))
    batch_img = img_syn[c*batch:(c+1)*batch].reshape((batch, 3, 32, 32)) 
    batch_img = batch_img.to(device) 
    
    # ============ Forward ============
    encoded, outputs = autoencoder(batch_img)
    loss = criterion(outputs, batch_img)
    # ============ Backward ============
    # print(loss.item())
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()



In [17]:
with torch.no_grad():
  output = autoencoder(img_real_test.to(device))

data_save = []
net_eval = get_network(model_eval, channel, num_classes, im_size).to(device) # get a random model
image_syn_eval, label_syn_eval = copy.deepcopy(output[1]), copy.deepcopy(label_real_test) # avoid any unaware modification
_, acc_train, acc_test = evaluate_synset(1, net_eval, image_syn_eval, label_syn_eval, testloader, args)
accs.append(acc_test)
print()
print('Evaluate %d random %s, mean = %.4f std = %.4f\n-------------------------'%(len(accs), model_eval, np.mean(accs), np.std(accs)))

[2023-10-08 22:21:02] Evaluate_01: epoch = 1000 train time = 2104 s train loss = 0.026665 train acc = 0.9954, test acc = 0.5279

Evaluate 2 random ConvNet, mean = 0.5213 std = 0.0066
-------------------------


In [18]:
# 用合成数据集微调
autoencoder = torch.load('autoencoder_all.pt') 

In [19]:
num_classes = 10
batch = 50
num_feat = 3072
criterion = nn.BCELoss()
optimizer = optim.Adam(autoencoder.parameters())
# model = Autoencoder(num_feat).to(device)  
# 训练
for epoch in range(1000):

  total_loss = 0
  for c in range(num_classes):
                # 获取类别c的合成图像和类别中心
    batch_syn = img_syn[c*args.ipc:(c+1)*args.ipc].reshape((batch, 3, 32, 32)) 
    batch_img = img_real_train[c*batch:(c+1)*batch].reshape((batch, 3, 32, 32)) 
    batch_img = batch_img.to(device) 
    
    # ============ Forward ============
    encoded, outputs = autoencoder(batch_img)
    loss = criterion(outputs, batch_syn)
    # ============ Backward ============
    # print(loss.item())
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
        

In [20]:
with torch.no_grad():
  output = autoencoder(img_real_test.to(device))

data_save = []
net_eval = get_network(model_eval, channel, num_classes, im_size).to(device) # get a random model
image_syn_eval, label_syn_eval = copy.deepcopy(output[1]), copy.deepcopy(label_real_test) # avoid any unaware modification
_, acc_train, acc_test = evaluate_synset(1, net_eval, image_syn_eval, label_syn_eval, testloader, args)
accs.append(acc_test)
print()
print('Evaluate %d random %s, mean = %.4f std = %.4f\n-------------------------'%(len(accs), model_eval, np.mean(accs), np.std(accs)))
# [2023-10-06 20:45:24] Evaluate_00: epoch = 1000 train time = 61 s train loss = 0.000965 train acc = 1.0000, test acc = 0.5001
# Evaluate 1 random ConvNet, mean = 0.5001 std = 0.0000

[2023-10-08 22:57:48] Evaluate_01: epoch = 1000 train time = 2149 s train loss = 0.049232 train acc = 0.9890, test acc = 0.5069

Evaluate 3 random ConvNet, mean = 0.5165 std = 0.0087
-------------------------


In [None]:
def freeze_encoder_parameters(model):
    for param in model.encoder.parameters():
        param.requires_grad = False
# 用合成数据集微调
autoencoder = torch.load('autoencoder_all.pt') 
freeze_encoder_parameters(autoencoder)

In [None]:
# A + 500张合成图FINETUN(冻结encoder)

num_classes = 10
batch = 50
num_feat = 3072
criterion = nn.BCELoss()
optimizer = optim.Adam(autoencoder.parameters())
# model = Autoencoder(num_feat).to(device)  
# 训练
for epoch in range(1000):

  total_loss = 0
  for c in range(num_classes):
                # 获取类别c的合成图像和类别中心
                # image_syn[c*args.ipc:(c+1)*args.ipc].reshape((args.ipc, channel, im_size[0], im_size[1]))
    batch_img = img_syn[c*batch:(c+1)*batch].reshape((batch, 3, 32, 32)) 
    batch_img = batch_img.to(device) 
    
    # ============ Forward ============
    encoded, outputs = autoencoder(batch_img)
    loss = criterion(outputs, batch_img)
    # ============ Backward ============
    # print(loss.item())
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

with torch.no_grad():
  output = autoencoder(img_real_test.to(device))

data_save = []
net_eval = get_network(model_eval, channel, num_classes, im_size).to(device) # get a random model
image_syn_eval, label_syn_eval = copy.deepcopy(output[1]), copy.deepcopy(label_real_test) # avoid any unaware modification
_, acc_train, acc_test = evaluate_synset(1, net_eval, image_syn_eval, label_syn_eval, testloader, args)
accs.append(acc_test)
print()
print('Evaluate %d random %s, mean = %.4f std = %.4f\n-------------------------'%(len(accs), model_eval, np.mean(accs), np.std(accs)))

In [None]:
autoencoder = torch.load('autoencoder_all.pt') 
freeze_encoder_parameters(autoencoder)

In [None]:

# A + 500张蒸馏前原图-蒸留后图片FINETUN(B)(冻结encoder)
num_classes = 10
batch = 50
num_feat = 3072
criterion = nn.BCELoss()
optimizer = optim.Adam(autoencoder.parameters())
# model = Autoencoder(num_feat).to(device)  
# 训练
for epoch in range(1000):

  total_loss = 0
  for c in range(num_classes):
                # 获取类别c的合成图像和类别中心
                # image_syn[c*args.ipc:(c+1)*args.ipc].reshape((args.ipc, channel, im_size[0], im_size[1]))

    
    # idx_shuffle = np.random.permutation(indices_class[c])[:batch]
    batch_img_real = img_real_test[c*batch:(c+1)*batch].reshape((batch, channel, im_size[0], im_size[1])).to(device) 
    
    batch_img = img_syn[c*batch:(c+1)*batch].reshape((batch, channel, im_size[0], im_size[1])).to(device) 
    
    # batch_img = batch_img_real.to(device) 
    
    # ============ Forward ============
    encoded, outputs = autoencoder(batch_img_real)
    loss = criterion(outputs, batch_img)
    # ============ Backward ============
    # print(loss.item())
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()


with torch.no_grad():
  output = autoencoder(img_real_test.to(device))

data_save = []
net_eval = get_network(model_eval, channel, num_classes, im_size).to(device) # get a random model
image_syn_eval, label_syn_eval = copy.deepcopy(output[1]), copy.deepcopy(label_real_test) # avoid any unaware modification
_, acc_train, acc_test = evaluate_synset(1, net_eval, image_syn_eval, label_syn_eval, testloader, args)
accs.append(acc_test)
print()
print('Evaluate %d random %s, mean = %.4f std = %.4f\n-------------------------'%(len(accs), model_eval, np.mean(accs), np.std(accs)))

        

In [None]:
# 500 张合成图训练AE：37
# 全部CIFA10训练AE ： 39.8
# 500 张CIFA10训练AE ： 36.5
# 500张CIFA10本身测试 ： 50