In [1]:
import os
import time
import copy
import argparse
import numpy as np
import torch
import torch.nn as nn
from torchvision.utils import save_image
from utils import get_loops, get_dataset, get_network, get_eval_pool, evaluate_synset, get_daparam, match_loss, get_time, TensorDataset, epoch, DiffAugment, ParamDiffAug
import warnings
warnings.filterwarnings("ignore")

# watch -n 1 nvidia-smi
import os

# 显示第 0 和第 1 个 GPU
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
parser = argparse.ArgumentParser(description='Parameter Processing')
parser.add_argument('--dataset', type=str, default='CIFAR10', help='dataset')
parser.add_argument('--model', type=str, default='ConvNet', help='model')
parser.add_argument('--ipc', type=int, default=50, help='image(s) per class')
parser.add_argument('--eval_mode', type=str, default='SS', help='eval_mode') # S: the same to training model, M: multi architectures,  W: net width, D: net depth, A: activation function, P: pooling layer, N: normalization layer,



parser.add_argument('--num_exp', type=int, default=3, help='the number of experiments')



parser.add_argument('--num_eval', type=int, default=1, help='the number of evaluating randomly initialized models')
parser.add_argument('--epoch_eval_train', type=int, default=1000, help='epochs to train a model with synthetic data') # it can be small for speeding up with little performance drop
parser.add_argument('--Iteration', type=int, default=2000, help='training iterations')
parser.add_argument('--lr_img', type=float, default=1.0, help='learning rate for updating synthetic images')
parser.add_argument('--lr_net', type=float, default=0.01, help='learning rate for updating network parameters')
parser.add_argument('--batch_real', type=int, default=256, help='batch size for real data')
parser.add_argument('--batch_train', type=int, default=256, help='batch size for training networks')
parser.add_argument('--init', type=str, default='real', help='noise/real: initialize synthetic images from random noise or randomly sampled real images.')
parser.add_argument('--dsa_strategy', type=str, default='color_crop_cutout_flip_scale_rotate', help='differentiable Siamese augmentation strategy')
parser.add_argument('--data_path', type=str, default='/home/ssd7T/ZTL_gcond/data_cv', help='dataset path')
parser.add_argument('--save_path', type=str, default='/home/ssd7T/ztl_dm/gen', help='path to save results')
parser.add_argument('--dis_metric', type=str, default='ours', help='distance metric')

args = parser.parse_args([])
args.method = 'DM'
args.outer_loop, args.inner_loop = get_loops(args.ipc)
args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
args.dsa_param = ParamDiffAug()
args.dsa = False if args.dsa_strategy in ['none', 'None'] else True

if not os.path.exists(args.data_path):
    os.mkdir(args.data_path)

if not os.path.exists(args.save_path):
    os.mkdir(args.save_path)

eval_it_pool = np.arange(0, args.Iteration+1, 2000).tolist() if args.eval_mode == 'S' or args.eval_mode == 'SS' else [args.Iteration] # The list of iterations when we evaluate models and record results.
print('eval_it_pool: ', eval_it_pool)
channel, im_size, num_classes, class_names, mean, std, dst_train, dst_test, testloader = get_dataset(args.dataset, args.data_path)
model_eval_pool = get_eval_pool(args.eval_mode, args.model, args.model)


accs_all_exps = dict() # record performances of all experiments
for key in model_eval_pool:
    accs_all_exps[key] = []

data_save = []
pairs_real = []
indexs_real = []

for exp in range(args.num_exp):
    # pairs_real = []
    # indexs_real = []
    exp = exp + 80
    print('\n================== Exp %d ==================\n '%exp)
    print('Hyper-parameters: \n', args.__dict__)
    print('Evaluation model pool: ', model_eval_pool)

    ''' organize the real dataset '''
    images_all = []
    labels_all = []
    indices_class = [[] for c in range(num_classes)]

    images_all = [torch.unsqueeze(dst_train[i][0], dim=0) for i in range(len(dst_train))]
    labels_all = [dst_train[i][1] for i in range(len(dst_train))]
    for i, lab in enumerate(labels_all):
        indices_class[lab].append(i)
    images_all = torch.cat(images_all, dim=0).to(args.device)
    labels_all = torch.tensor(labels_all, dtype=torch.long, device=args.device)



    for c in range(num_classes):
        print('class c = %d: %d real images'%(c, len(indices_class[c])))

    def get_images(c, n): # get random n images from class c
        idx_shuffle = np.random.permutation(indices_class[c])[:n]
        return images_all[idx_shuffle]
    
    def get_images_init(c, n,exp): # get random n images from class c
        # start_idx = i  # 指定起始索引 i
        # end_idx = i + n  # 计算结束索引（不包括结束索引）

        # 从指定的起始索引到结束索引获取元素
        idx_shuffle  = indices_class[c][exp:exp + n]

        # idx_shuffle = np.random.permutation(indices_class[c])[:n]
        return images_all[idx_shuffle],idx_shuffle

    for ch in range(channel):
        print('real images channel %d, mean = %.4f, std = %.4f'%(ch, torch.mean(images_all[:, ch]), torch.std(images_all[:, ch])))


    ''' initialize the synthetic data '''
    image_syn = torch.randn(size=(num_classes*args.ipc, channel, im_size[0], im_size[1]), dtype=torch.float, requires_grad=True, device=args.device)
    label_syn = torch.tensor([np.ones(args.ipc)*i for i in range(num_classes)], dtype=torch.long, requires_grad=False, device=args.device).view(-1) # [0,0,0, 1,1,1, ..., 9,9,9]

    if args.init == 'real':
        print('initialize synthetic data from random real images')
        for c in range(num_classes):
            reals,index = get_images_init(c, args.ipc,exp)
            reals = reals.detach().data
            pairs_real.append(reals)
            indexs_real.append(index)
            image_syn.data[c*args.ipc:(c+1)*args.ipc] = reals
            # pairs_real
    else:
        print('initialize synthetic data from random noise')

eval_it_pool:  [0, 2000]
Files already downloaded and verified
Files already downloaded and verified

 
Hyper-parameters: 
 {'dataset': 'CIFAR10', 'model': 'ConvNet', 'ipc': 50, 'eval_mode': 'SS', 'num_exp': 3, 'num_eval': 1, 'epoch_eval_train': 1000, 'Iteration': 2000, 'lr_img': 1.0, 'lr_net': 0.01, 'batch_real': 256, 'batch_train': 256, 'init': 'real', 'dsa_strategy': 'color_crop_cutout_flip_scale_rotate', 'data_path': '/home/ssd7T/ZTL_gcond/data_cv', 'save_path': '/home/ssd7T/ztl_dm/gen', 'dis_metric': 'ours', 'method': 'DM', 'outer_loop': 50, 'inner_loop': 10, 'device': 'cuda', 'dsa_param': <utils.ParamDiffAug object at 0x7fbee0db1df0>, 'dsa': True}
Evaluation model pool:  ['ConvNet']
class c = 0: 5000 real images
class c = 1: 5000 real images
class c = 2: 5000 real images
class c = 3: 5000 real images
class c = 4: 5000 real images
class c = 5: 5000 real images
class c = 6: 5000 real images
class c = 7: 5000 real images
class c = 8: 5000 real images
class c = 9: 5000 real images
re

In [2]:
img_real_test= torch.cat(pairs_real, dim=0)

label_real_test = []
for i in range(int(len(indexs_real)/10)):
    # print(i)
    label_real_test_ = []
    for c in range(num_classes):
        idx_shuffle = indexs_real[c + i*10]
        label_real_test_.append(labels_all[idx_shuffle].to("cpu"))
        # print()
    # img_real = torch.from_numpy(np.concatenate(img_real, axis=0))
    label_real_test_ = torch.from_numpy(np.concatenate(label_real_test_, axis=0))
    label_real_test.append(label_real_test_)
label_real_test = torch.cat(label_real_test, dim=0)

In [3]:
device = args.device

In [4]:
img_syn = []
label_syn = []
img_real_train = []
label_real_train = []
# /home/ssd7T/ztl_dm/indexs_real_20.pt
for i in range(80):
    try:
    
        img_syn_ = torch.load(f'/home/ssd7T/ztl_dm/img_syn_{i}.pt')
        label_syn_ = torch.load(f'/home/ssd7T/ztl_dm/label_syn_{i}.pt')
        pairs_real_=torch.load(f'/home/ssd7T/ztl_dm/pairs_real_{i}.pt')
        indexs_real_=torch.load(f'/home/ssd7T/ztl_dm/indexs_real_{i}.pt')
        
        img_real_train_ = torch.cat(pairs_real_, dim=0)
        
    
        label_real_ = []
        for c in range(num_classes):
            idx_shuffle = indexs_real_[c]
            label_real_.append(labels_all[idx_shuffle].to("cpu"))
        # img_real = torch.from_numpy(np.concatenate(img_real, axis=0))
        label_real_train_ = torch.from_numpy(np.concatenate(label_real_, axis=0))
        # label_real_train_ = torch.cat(label_real_train_, dim=0)
        
        img_syn.append(img_syn_)
        label_syn.append(label_syn_)
        img_real_train.append(img_real_train_)
        label_real_train.append( label_real_train_)
        # if i == 3 or i == 22 or i == 42 or i == 62:
        #     pairs_real_=torch.load(f'pairs_real_{i}.pt')
        #     img_real_train_ = torch.cat(pairs_real_, dim=0)
        #     img_real_train.append(img_real_train_)
            
    except:
        pass


img_syn = torch.cat(img_syn, dim=0).to(device)
label_syn = torch.cat(label_syn, dim=0).to(device)
img_real_train = torch.cat(img_real_train, dim=0).to(device)
label_real_train = torch.cat(label_real_train, dim=0).to(device)



In [5]:
img_real_train.shape

torch.Size([18000, 3, 32, 32])

In [6]:
accs = []
model_eval_pool = get_eval_pool(args.eval_mode, args.model, args.model)
import copy
accs_all_exps = dict() # record performances of all experiments
for key in model_eval_pool:
    accs_all_exps[key] = []
args.dsa_param = ParamDiffAug()
args.dsa = False if args.dsa_strategy in ['none', 'None'] else True
model_eval= model_eval_pool[0]

In [7]:


num_classes = 10


In [8]:
from a_cvae import  CVAE

In [9]:
data_save = []
net_eval = get_network(model_eval, channel, num_classes, im_size).to(device) # get a random model
image_syn_eval, label_syn_eval = copy.deepcopy(images_all), copy.deepcopy(labels_all) # avoid any unaware modification
_, acc_train, acc_test = evaluate_synset(1, net_eval, image_syn_eval, label_syn_eval, testloader, args)
accs.append(acc_test)

[2023-10-14 13:58:24] Evaluate_01: epoch = 1000 train time = 2937 s train loss = 0.064025 train acc = 0.9830, test acc = 0.8752


In [16]:
img_syn = []
label_syn = []
img_real_train = []
label_real_train = []
# /home/ssd7T/ztl_dm/indexs_real_20.pt
for i in range(90):
    try:
    # /home/ssd7T/ztl_ftd/pairs_real_72.pt
        img_syn_ = torch.load(f'/home/ssd7T/ztl_ftd/img_syn_{i}.pt')
        label_syn_ = torch.load(f'/home/ssd7T/ztl_ftd/label_syn_{i}.pt')
        pairs_real_=torch.load(f'/home/ssd7T/ztl_ftd/pairs_real_{i}.pt')
        indexs_real_=torch.load(f'/home/ssd7T/ztl_ftd/indexs_real_{i}.pt')
        
        img_real_train_ = torch.cat(pairs_real_, dim=0)
        
    
        label_real_ = []
        for c in range(num_classes):
            idx_shuffle = indexs_real_[c]
            label_real_.append(labels_all[idx_shuffle].to("cpu"))
        # img_real = torch.from_numpy(np.concatenate(img_real, axis=0))
        label_real_train_ = torch.from_numpy(np.concatenate(label_real_, axis=0))
        # label_real_train_ = torch.cat(label_real_train_, dim=0)
        
        img_syn.append(img_syn_)
        label_syn.append(label_syn_)
        img_real_train.append(img_real_train_)
        label_real_train.append( label_real_train_)
        # if i == 3 or i == 22 or i == 42 or i == 62:
        #     pairs_real_=torch.load(f'pairs_real_{i}.pt')
        #     img_real_train_ = torch.cat(pairs_real_, dim=0)
        #     img_real_train.append(img_real_train_)
            
    except:
        pass


img_syn = torch.cat(img_syn, dim=0).to(device)
label_syn = torch.cat(label_syn, dim=0).to(device)
img_real_train = torch.cat(img_real_train, dim=0).to(device)
label_real_train = torch.cat(label_real_train, dim=0).to(device)



In [17]:
img_syn.shape

torch.Size([40000, 3, 32, 32])

In [18]:
img_real_train.shape

torch.Size([40000, 3, 32, 32])

In [19]:
images_all.shape

torch.Size([50000, 3, 32, 32])

In [22]:
torch.cat((img_real_train,images_all), dim = 0).shape

torch.Size([90000, 3, 32, 32])

In [23]:
torch.cat((label_syn,labels_all), dim = 0)

tensor([0, 0, 0,  ..., 9, 1, 1], device='cuda:0')

In [9]:

import argparse

import torch.utils.data
from torch import optim

from torchvision.utils import save_image, make_grid
    # 'custom': {'lr': 2e-4, 'k': 512, 'hidden': 128},
    # 'imagenet': {'lr': 2e-4, 'k': 512, 'hidden': 128},
    # 'cifar10': {'lr': 2e-4, 'k': 10, 'hidden': 256},
    # 'mnist': {'lr': 1e-4, 'k': 10, 'hidden': 64}
parser = argparse.ArgumentParser(description='Variational AutoEncoders')

model_parser = parser.add_argument_group('Model Parameters')
model_parser.add_argument('--model', default='vae', choices=['vae', 'vqvae'],
                            help='autoencoder variant to use: vae | vqvae')
model_parser.add_argument('--batch-size', type=int, default=128, metavar='N',
                            help='input batch size for training (default: 128)')
model_parser.add_argument('--hidden', type=int, default=256, metavar='N',
                            help='number of hidden channels')
model_parser.add_argument('-k', '--dict-size',  default=10,type=int, dest='k', metavar='K',
                            help='number of atoms in dictionary')
model_parser.add_argument('--lr', type=float, default=2e-4,
                            help='learning rate')
model_parser.add_argument('--vq_coef', type=float, default=None,
                            help='vq coefficient in loss')
model_parser.add_argument('--commit_coef', type=float, default=None,
                            help='commitment coefficient in loss')
model_parser.add_argument('--kl_coef', type=float, default=None,
                            help='kl-divergence coefficient in loss')

training_parser = parser.add_argument_group('Training Parameters')
training_parser.add_argument('--dataset', default='cifar10', choices=['mnist', 'cifar10', 'imagenet',
                                                                        'custom'],
                                help='dataset to use: mnist | cifar10 | imagenet | custom')
training_parser.add_argument('--dataset_dir_name', default='',
                                help='name of the dir containing the dataset if dataset == custom')
training_parser.add_argument('--data-dir', default='/media/ssd/Datasets',
                                help='directory containing the dataset')
training_parser.add_argument('--epochs', type=int, default=20, metavar='N',
                                help='number of epochs to train (default: 10)')
training_parser.add_argument('--max-epoch-samples', type=int, default=50000,
                                help='max num of samples per epoch')
training_parser.add_argument('--no-cuda', action='store_true', default=False,
                                help='enables CUDA training')
training_parser.add_argument('--seed', type=int, default=1, metavar='S',
                                help='random seed (default: 1)')
training_parser.add_argument('--gpus', default='0',
                                help='gpus used for training - e.g 0,1,3')

logging_parser = parser.add_argument_group('Logging Parameters')
logging_parser.add_argument('--log-interval', type=int, default=10, metavar='N',
                            help='how many batches to wait before logging training status')
logging_parser.add_argument('--results-dir', metavar='RESULTS_DIR', default='./results',
                            help='results dir')
logging_parser.add_argument('--save-name', default='',
                            help='saved folder')
logging_parser.add_argument('--data-format', default='json',
                            help='in which format to save the data')


# 256, 794, 3, 3

args_cave = parser.parse_args([])


torch.manual_seed(args_cave.seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(args_cave.seed)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


# dataset = MNIST(
#     root='/home/ssd7T/ZTL_gcond/data_cv', train=True, transform=transforms.ToTensor(),
#     download=True)
# data_loader = DataLoader(
#     dataset=dataset, batch_size=args.batch_size, shuffle=True)


lr = args_cave.lr 
k = args_cave.k 
hidden = args_cave.hidden 
num_channels = 3 # CIFA

model = CVAE(d = hidden, k=k, num_channels=num_channels).to(device) 

# def __init__(self, d, kl_coef=0.1, **kwargs):

optimizer = optim.Adam(model.parameters(), lr=lr)
scheduler = optim.lr_scheduler.StepLR(optimizer, 10 if args_cave.dataset == 'imagenet' else 30, 0.5,)
    
    
batch = args_cave.batch_size
# A

epochs = 3000
for epoch in range(epochs):

        
    random_indices = np.random.choice(len(img_real_train), batch, replace=False)

# 使用随机选择的索引来获取数据并改变形状
    batch_img = images_all[random_indices].reshape((batch, 3, 32, 32)).to(device)
    # batch_syn = img_syn[random_indices].reshape((batch, 3, 32, 32)).to(device) 
    # batch_img_y = label_real_train[c*batch:(c+1)*batch].to(device) 
        
    outputs = model(batch_img)
    
    loss = model.loss_function(batch_img, *outputs)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

        # logs['loss'].append(loss.item())

    if epoch % 10 == 0 or epoch == epochs -1:
        # print("Epoch {:02d}/{:02d} , Loss {:9.4f}".format(
        #     epoch, epochs,  loss.item()))
        print("====> Epoch: {} || {}".format(epoch, loss.item()))
        
with torch.no_grad():
  output = model(img_real_test.to(device))
data_save = []
net_eval = get_network(model_eval, channel, num_classes, im_size).to(device) # get a random model
image_syn_eval, label_syn_eval = copy.deepcopy(output[0]), copy.deepcopy(label_real_test) # avoid any unaware modification
_, acc_train, acc_test = evaluate_synset(1, net_eval, image_syn_eval, label_syn_eval, testloader, args)
accs.append(acc_test)


====> Epoch: 230 || 0.422725647687912
====> Epoch: 240 || 0.482896089553833
====> Epoch: 250 || 0.4822136163711548
====> Epoch: 260 || 0.5269333124160767
====> Epoch: 270 || 0.4414944052696228
====> Epoch: 280 || 0.4615752696990967
====> Epoch: 290 || 0.4423099756240845
====> Epoch: 300 || 0.4907228350639343
====> Epoch: 310 || 0.46839991211891174
====> Epoch: 320 || 0.47450411319732666
====> Epoch: 330 || 0.4529281258583069
====> Epoch: 340 || 0.4565031826496124
====> Epoch: 350 || 0.4256156086921692
====> Epoch: 360 || 0.45230239629745483
====> Epoch: 370 || 0.480398029088974
====> Epoch: 380 || 0.45993131399154663
====> Epoch: 390 || 0.4347468316555023
====> Epoch: 400 || 0.4285741448402405
====> Epoch: 410 || 0.42451804876327515
====> Epoch: 420 || 0.44213834404945374
====> Epoch: 430 || 0.4809691309928894
====> Epoch: 440 || 0.4328254759311676
====> Epoch: 450 || 0.46104544401168823
====> Epoch: 460 || 0.4584049880504608
====> Epoch: 470 || 0.38739049434661865
====> Epoch: 480 || 

In [10]:
torch.save(model,'cvae_all.pt') 

In [11]:
model = CVAE(d = hidden, k=k, num_channels=num_channels).to(device) 

# 训练原图训

optimizer = optim.Adam(model.parameters(), lr=lr)
scheduler = optim.lr_scheduler.StepLR(optimizer, 10 if args_cave.dataset == 'imagenet' else 30, 0.5,)

epochs = 3000
for epoch in range(epochs):


    random_indices = np.random.choice(len(img_real_train), batch, replace=False)

# 使用随机选择的索引来获取数据并改变形状
    batch_img = img_real_train[random_indices].reshape((batch, 3, 32, 32)).to(device)
    # batch_syn = img_syn[random_indices].reshape((batch, 3, 32, 32)).to(device) 
    # batch_img_y = label_real_train[c*batch:(c+1)*batch].to(device) 
        
    outputs = model(batch_img)
    
    loss = model.loss_function(batch_img, *outputs)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

        # logs['loss'].append(loss.item())

    if epoch % 10 == 0 or epoch == epochs -1:
        # print("Epoch {:02d}/{:02d} , Loss {:9.4f}".format(
        #     epoch, epochs,  loss.item()))
        print("====> Epoch: {} || {}".format(epoch, loss.item()))
        
with torch.no_grad():
  output = model(img_real_test.to(device))
data_save = []
net_eval = get_network(model_eval, channel, num_classes, im_size).to(device) # get a random model
image_syn_eval, label_syn_eval = copy.deepcopy(output[0]), copy.deepcopy(label_real_test) # avoid any unaware modification
_, acc_train, acc_test = evaluate_synset(1, net_eval, image_syn_eval, label_syn_eval, testloader, args)
accs.append(acc_test)

====> Epoch: 0 || 2.066448926925659
====> Epoch: 10 || 0.8942238092422485
====> Epoch: 20 || 0.7132065892219543
====> Epoch: 30 || 0.5983194708824158
====> Epoch: 40 || 0.6322553157806396
====> Epoch: 50 || 0.5465779304504395
====> Epoch: 60 || 0.5191893577575684
====> Epoch: 70 || 0.486300528049469
====> Epoch: 80 || 0.5136184692382812
====> Epoch: 90 || 0.45442846417427063
====> Epoch: 100 || 0.43417632579803467
====> Epoch: 110 || 0.5293276906013489
====> Epoch: 120 || 0.4558901786804199
====> Epoch: 130 || 0.4727620780467987
====> Epoch: 140 || 0.49037349224090576
====> Epoch: 150 || 0.46710193157196045
====> Epoch: 160 || 0.442527174949646
====> Epoch: 170 || 0.4338109493255615
====> Epoch: 180 || 0.4441336989402771
====> Epoch: 190 || 0.4115105867385864
====> Epoch: 200 || 0.4368806481361389
====> Epoch: 210 || 0.41821739077568054
====> Epoch: 220 || 0.504594087600708
====> Epoch: 230 || 0.42109209299087524
====> Epoch: 240 || 0.4639461636543274
====> Epoch: 250 || 0.427253782749

In [12]:
model = CVAE(d = hidden, k=k, num_channels=num_channels).to(device) 

# 样本对训

optimizer = optim.Adam(model.parameters(), lr=lr)
scheduler = optim.lr_scheduler.StepLR(optimizer, 10 if args_cave.dataset == 'imagenet' else 30, 0.5,)

epochs = 3000
for epoch in range(epochs):


    
        
    random_indices = np.random.choice(len(img_real_train), batch, replace=False)

# 使用随机选择的索引来获取数据并改变形状
    batch_img = img_real_train[random_indices].reshape((batch, 3, 32, 32)).to(device)
    batch_syn = img_syn[random_indices].reshape((batch, 3, 32, 32)).to(device) 
    # batch_img_y = label_real_train[c*batch:(c+1)*batch].to(device) 
    
    outputs = model(batch_img)
    
    loss = model.loss_function(batch_syn, *outputs)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

        # logs['loss'].append(loss.item())

    if epoch % 10 == 0 or epoch == epochs -1:
        # print("Epoch {:02d}/{:02d} , Loss {:9.4f}".format(
        #     epoch, epochs,  loss.item()))
        print("====> Epoch: {} || {}".format(epoch, loss.item()))
        
with torch.no_grad():
  output = model(img_real_test.to(device))
data_save = []
net_eval = get_network(model_eval, channel, num_classes, im_size).to(device) # get a random model
image_syn_eval, label_syn_eval = copy.deepcopy(output[0]), copy.deepcopy(label_real_test) # avoid any unaware modification
_, acc_train, acc_test = evaluate_synset(1, net_eval, image_syn_eval, label_syn_eval, testloader, args)
accs.append(acc_test)

====> Epoch: 0 || 2.1260056495666504
====> Epoch: 10 || 0.9068130254745483
====> Epoch: 20 || 0.7513774633407593
====> Epoch: 30 || 0.654479444026947
====> Epoch: 40 || 0.6479647159576416
====> Epoch: 50 || 0.5228400826454163
====> Epoch: 60 || 0.5611713528633118
====> Epoch: 70 || 0.5120115280151367
====> Epoch: 80 || 0.4823131859302521
====> Epoch: 90 || 0.4623338282108307
====> Epoch: 100 || 0.46980592608451843
====> Epoch: 110 || 0.48641499876976013
====> Epoch: 120 || 0.43779847025871277
====> Epoch: 130 || 0.4403831958770752
====> Epoch: 140 || 0.5010635256767273
====> Epoch: 150 || 0.4756315052509308
====> Epoch: 160 || 0.4695318341255188
====> Epoch: 170 || 0.48724839091300964
====> Epoch: 180 || 0.4087326228618622
====> Epoch: 190 || 0.43832674622535706
====> Epoch: 200 || 0.4824279248714447
====> Epoch: 210 || 0.501254677772522
====> Epoch: 220 || 0.45777255296707153
====> Epoch: 230 || 0.4298408031463623
====> Epoch: 240 || 0.4415327310562134
====> Epoch: 250 || 0.4437855780

In [13]:
model = torch.load('cvae_all.pt').to(device)    
# A + 合成数据集 

optimizer = optim.Adam(model.parameters(), lr=lr)
scheduler = optim.lr_scheduler.StepLR(optimizer, 10 if args_cave.dataset == 'imagenet' else 30, 0.5,)

epochs = 3000
for epoch in range(epochs):


    # batch_img = images_all[c*batch:(c+1)*batch].reshape((batch, 3, 32, 32)).to(device) 
    
    # batch_img = img_real_train[c*batch:(c+1)*batch].reshape((batch, 3, 32, 32)).to(device) 
    # batch_syn = img_syn[c*batch:(c+1)*batch].reshape((batch, 3, 32, 32)).to(device) 
    
    # 随机选择batch个索引
    random_indices = np.random.choice(len(img_real_train), batch, replace=False)

    # 使用随机选择的索引来获取数据并改变形状
    # batch_img = img_real_train[random_indices].reshape((batch, 3, 32, 32)).to(device)
    batch_syn = img_syn[random_indices].reshape((batch, 3, 32, 32)).to(device) 

    # batch_img_y = label_real_train[c*batch:(c+1)*batch].to(device) 
    
    outputs = model(batch_syn)
    
    loss = model.loss_function(batch_syn, *outputs)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # logs['loss'].append(loss.item())

    if epoch % 10 == 0 or epoch == epochs -1:
        # print("Epoch {:02d}/{:02d} , Loss {:9.4f}".format(
        #     epoch, epochs,  loss.item()))
        print("====> Epoch: {} || {}".format(epoch, loss.item()))
        
with torch.no_grad():
  output = model(img_real_test.to(device))
data_save = []
net_eval = get_network(model_eval, channel, num_classes, im_size).to(device) # get a random model
image_syn_eval, label_syn_eval = copy.deepcopy(output[0]), copy.deepcopy(label_real_test) # avoid any unaware modification
_, acc_train, acc_test = evaluate_synset(1, net_eval, image_syn_eval, label_syn_eval, testloader, args)
accs.append(acc_test)

====> Epoch: 0 || 0.3686927855014801
====> Epoch: 10 || 0.35229557752609253
====> Epoch: 20 || 0.34480956196784973
====> Epoch: 30 || 0.35006120800971985
====> Epoch: 40 || 0.33064135909080505
====> Epoch: 50 || 0.38443687558174133
====> Epoch: 60 || 0.3595213294029236
====> Epoch: 70 || 0.358097106218338
====> Epoch: 80 || 0.3442186713218689
====> Epoch: 90 || 0.35591447353363037
====> Epoch: 100 || 0.3467026948928833
====> Epoch: 110 || 0.3557693064212799
====> Epoch: 120 || 0.38845348358154297
====> Epoch: 130 || 0.3123314678668976
====> Epoch: 140 || 0.39908307790756226
====> Epoch: 150 || 0.3790074586868286
====> Epoch: 160 || 0.3567098081111908
====> Epoch: 170 || 0.3105849325656891
====> Epoch: 180 || 0.36520111560821533
====> Epoch: 190 || 0.34754905104637146
====> Epoch: 200 || 0.3739040493965149
====> Epoch: 210 || 0.2975155711174011
====> Epoch: 220 || 0.36401471495628357
====> Epoch: 230 || 0.3675159513950348
====> Epoch: 240 || 0.37651392817497253
====> Epoch: 250 || 0.375

In [14]:
model = torch.load('cvae_all.pt').to(device)   
# A + 样本对  

optimizer = optim.Adam(model.parameters(), lr=lr)
scheduler = optim.lr_scheduler.StepLR(optimizer, 10 if args_cave.dataset == 'imagenet' else 30, 0.5,)

epochs = 3000
for epoch in range(epochs):
        
    # batch_img = images_all[c*batch:(c+1)*batch].reshape((batch, 3, 32, 32)).to(device) 
    
    # batch_img = img_real_train[c*batch:(c+1)*batch].reshape((batch, 3, 32, 32)).to(device) 
    # batch_syn = img_syn[c*batch:(c+1)*batch].reshape((batch, 3, 32, 32)).to(device) 
    
    # 随机选择batch个索引
    random_indices = np.random.choice(len(img_real_train), batch, replace=False)

    # 使用随机选择的索引来获取数据并改变形状
    batch_img = img_real_train[random_indices].reshape((batch, 3, 32, 32)).to(device)
    batch_syn = img_syn[random_indices].reshape((batch, 3, 32, 32)).to(device) 

    # batch_img_y = label_real_train[c*batch:(c+1)*batch].to(device) 
    
    outputs = model(batch_img)
    
    loss = model.loss_function(batch_syn, *outputs)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # logs['loss'].append(loss.item())

    if epoch % 10 == 0 or epoch == epochs -1:
        # print("Epoch {:02d}/{:02d} , Loss {:9.4f}".format(
        #     epoch, epochs,  loss.item()))
        print("====> Epoch: {} || {}".format(epoch, loss.item()))
        
with torch.no_grad():
  output = model(img_real_test.to(device))
data_save = []
net_eval = get_network(model_eval, channel, num_classes, im_size).to(device) # get a random model
image_syn_eval, label_syn_eval = copy.deepcopy(output[0]), copy.deepcopy(label_real_test) # avoid any unaware modification
_, acc_train, acc_test = evaluate_synset(1, net_eval, image_syn_eval, label_syn_eval, testloader, args)
accs.append(acc_test)

====> Epoch: 0 || 0.33656272292137146
====> Epoch: 10 || 0.3376087248325348
====> Epoch: 20 || 0.4088062047958374
====> Epoch: 30 || 0.38783806562423706
====> Epoch: 40 || 0.404959499835968
====> Epoch: 50 || 0.34974777698516846
====> Epoch: 60 || 0.3556029796600342
====> Epoch: 70 || 0.3963489830493927
====> Epoch: 80 || 0.37195566296577454
====> Epoch: 90 || 0.3484606444835663
====> Epoch: 100 || 0.3841311037540436
====> Epoch: 110 || 0.3404640257358551
====> Epoch: 120 || 0.37040287256240845
====> Epoch: 130 || 0.357191264629364
====> Epoch: 140 || 0.4120376706123352
====> Epoch: 150 || 0.3762578070163727
====> Epoch: 160 || 0.3890926241874695
====> Epoch: 170 || 0.3978455662727356
====> Epoch: 180 || 0.3715849220752716
====> Epoch: 190 || 0.34011968970298767
====> Epoch: 200 || 0.3609450161457062
====> Epoch: 210 || 0.32383233308792114
====> Epoch: 220 || 0.3416687250137329
====> Epoch: 230 || 0.3344535827636719
====> Epoch: 240 || 0.3355846405029297
====> Epoch: 250 || 0.372781515

In [15]:

# test 合成数据集结果（1500张）
# 直接输入原图： 0.5100  0.4517 0.4502 0.4670 0.4689 0.4764
# 样本对训练原图 - 合成图差异： 0.5093

# epoch 2000 + batch 64
# A :                  0.4586, 0.004  |   0.4564,0.4551,0.4642
# 只用样本对中的原图 :  0.4612, 0.005  |   0.4679,0.4598,0.4558
# 只用样本对训 ：       0.4637, 0.008 |   0.4748,0.4576,0.4587
# A+合成图 FINTUNE :    0.4668, 0.002  |   0.4683,0.4641,0.4680
# A+样本对 FINTUNE(best) :   0.4782, 0.009  |    0.4904,0.4758,0.4685

# epoch 3000 + batch 64
# A :                  0.4687, 0.005  |   0.4717,0.4614, 0.4730
# 只用样本对中的原图 :  0.4580, 0.007  |   0.4531,0.4532,0.4678
# 只用样本对训 ：       0.4658, 0.006 |   0.4710,0.4576,0.4689
# A+合成图 FINTUNE :    0.472, 0.005  |   0.4732, 0.4770,0.4658
# A+样本对 FINTUNE(best) :   0.4833, 0.004  |    0.4807,0.4889,0.4802


# epoch 3000 + batch 128 
# A :                  0.4657, 0.003  |   0.4639,0.4693, 0.4640
# 只用样本对中的原图 :  0.4630, 0.004  |   0.4674,0.4632,0.4585
# 只用样本对训 ：       0.4753, 0.008 |   0.4864,0.4733, 0.4661
# A+合成图 FINTUNE :    0.4685, 0.004  |   0.4641, 0.4732,0.4682
# A+样本对 FINTUNE(best) :   0.4816, 0.007  |    0.4879,0.4853,0.4716


# accs = [0.3642,0.3756,0.3759]np.mean(accs), np.std(accs)

In [16]:
import numpy as np

In [28]:
accs = [0.4879,0.4853,0.4716]

In [29]:
np.mean(accs), np.std(accs)

(0.48160000000000003, 0.007150291369354578)

In [19]:
# def test_net(epoch, model, test_loader, cuda, save_path, args, writer):
#     model.eval()
#     loss_dict = model.latest_losses()
#     losses = {k + '_test': 0 for k, v in loss_dict.items()}
#     i, data = None, None
#     with torch.no_grad():
#         for i, (data, _) in enumerate(test_loader):
#             if cuda:
#                 data = data.cuda()
#             outputs = model(data)
#             model.loss_function(data, *outputs)
#             latest_losses = model.latest_losses()
#             for key in latest_losses:
#                 losses[key + '_test'] += float(latest_losses[key])
#             if i == 0:
#                 write_images(data, outputs, writer, 'test')

#                 save_reconstructed_images(data, epoch, outputs[0], save_path, 'reconstruction_test')
#                 # save_checkpoint(model, epoch, save_path)
#             if args.dataset == 'imagenet' and i * len(data) > 1000:
#                 break

#     for key in losses:
#         if args.dataset not in ['imagenet', 'custom']:
#             losses[key] /= (len(test_loader.dataset) / test_loader.batch_size)
#         else:
#             losses[key] /= (i * len(data))
#     loss_string = ' '.join(['{}: {:.6f}'.format(k, v) for k, v in losses.items()])
#     logging.info('====> Test set losses: {}'.format(loss_string))
#     return losses


# def write_images(data, outputs, writer, suffix):
#     original = data.mul(0.5).add(0.5)
#     original_grid = make_grid(original[:6])
#     writer.add_image(f'original/{suffix}', original_grid)
#     reconstructed = outputs[0].mul(0.5).add(0.5)
#     reconstructed_grid = make_grid(reconstructed[:6])
#     writer.add_image(f'reconstructed/{suffix}', reconstructed_grid)


# def save_reconstructed_images(data, epoch, outputs, save_path, name):
#     size = data.size()
#     n = min(data.size(0), 8)
#     batch_size = data.size(0)
#     comparison = torch.cat([data[:n],
#                             outputs.view(batch_size, size[1], size[2], size[3])[:n]])
#     save_image(comparison.cpu(),
#                os.path.join(save_path, name + '_' + str(epoch) + '.png'), nrow=n, normalize=True)