In [None]:
import os
import time
import copy
import argparse
import numpy as np
import torch
import torch.nn as nn
from torchvision.utils import save_image
from utils import get_loops, get_dataset, get_network, get_eval_pool, evaluate_synset, get_daparam, match_loss, get_time, TensorDataset, epoch, DiffAugment, ParamDiffAug
import warnings
warnings.filterwarnings("ignore")

# watch -n 1 nvidia-smi
import os

# 显示第 0 和第 1 个 GPU
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
parser = argparse.ArgumentParser(description='Parameter Processing')
parser.add_argument('--dataset', type=str, default='CIFAR10', help='dataset')
parser.add_argument('--model', type=str, default='ConvNet', help='model')
parser.add_argument('--ipc', type=int, default=50, help='image(s) per class')
parser.add_argument('--eval_mode', type=str, default='SS', help='eval_mode') # S: the same to training model, M: multi architectures,  W: net width, D: net depth, A: activation function, P: pooling layer, N: normalization layer,



parser.add_argument('--num_exp', type=int, default=3, help='the number of experiments')



parser.add_argument('--num_eval', type=int, default=1, help='the number of evaluating randomly initialized models')
parser.add_argument('--epoch_eval_train', type=int, default=1000, help='epochs to train a model with synthetic data') # it can be small for speeding up with little performance drop
parser.add_argument('--Iteration', type=int, default=2000, help='training iterations')
parser.add_argument('--lr_img', type=float, default=1.0, help='learning rate for updating synthetic images')
parser.add_argument('--lr_net', type=float, default=0.01, help='learning rate for updating network parameters')
parser.add_argument('--batch_real', type=int, default=256, help='batch size for real data')
parser.add_argument('--batch_train', type=int, default=256, help='batch size for training networks')
parser.add_argument('--init', type=str, default='real', help='noise/real: initialize synthetic images from random noise or randomly sampled real images.')
parser.add_argument('--dsa_strategy', type=str, default='color_crop_cutout_flip_scale_rotate', help='differentiable Siamese augmentation strategy')
parser.add_argument('--data_path', type=str, default='/home/ssd7T/ZTL_gcond/data_cv', help='dataset path')
parser.add_argument('--save_path', type=str, default='/home/ssd7T/ztl_dm/gen', help='path to save results')
parser.add_argument('--dis_metric', type=str, default='ours', help='distance metric')

args = parser.parse_args([])
args.method = 'DM'
args.outer_loop, args.inner_loop = get_loops(args.ipc)
args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
args.dsa_param = ParamDiffAug()
args.dsa = False if args.dsa_strategy in ['none', 'None'] else True

if not os.path.exists(args.data_path):
    os.mkdir(args.data_path)

if not os.path.exists(args.save_path):
    os.mkdir(args.save_path)

eval_it_pool = np.arange(0, args.Iteration+1, 2000).tolist() if args.eval_mode == 'S' or args.eval_mode == 'SS' else [args.Iteration] # The list of iterations when we evaluate models and record results.
print('eval_it_pool: ', eval_it_pool)
channel, im_size, num_classes, class_names, mean, std, dst_train, dst_test, testloader = get_dataset(args.dataset, args.data_path)
model_eval_pool = get_eval_pool(args.eval_mode, args.model, args.model)


accs_all_exps = dict() # record performances of all experiments
for key in model_eval_pool:
    accs_all_exps[key] = []

data_save = []
pairs_real = []
indexs_real = []

for exp in range(args.num_exp):
    # pairs_real = []
    # indexs_real = []
    exp = exp + 85
    print('\n================== Exp %d ==================\n '%exp)
    print('Hyper-parameters: \n', args.__dict__)
    print('Evaluation model pool: ', model_eval_pool)

    ''' organize the real dataset '''
    images_all = []
    labels_all = []
    indices_class = [[] for c in range(num_classes)]

    images_all = [torch.unsqueeze(dst_train[i][0], dim=0) for i in range(len(dst_train))]
    labels_all = [dst_train[i][1] for i in range(len(dst_train))]
    for i, lab in enumerate(labels_all):
        indices_class[lab].append(i)
    images_all = torch.cat(images_all, dim=0).to(args.device)
    labels_all = torch.tensor(labels_all, dtype=torch.long, device=args.device)



    for c in range(num_classes):
        print('class c = %d: %d real images'%(c, len(indices_class[c])))

    def get_images(c, n): # get random n images from class c
        idx_shuffle = np.random.permutation(indices_class[c])[:n]
        return images_all[idx_shuffle]
    
    def get_images_init(c, n,exp): # get random n images from class c
        # start_idx = i  # 指定起始索引 i
        # end_idx = i + n  # 计算结束索引（不包括结束索引）

        # 从指定的起始索引到结束索引获取元素
        idx_shuffle  = indices_class[c][exp:exp + n]

        # idx_shuffle = np.random.permutation(indices_class[c])[:n]
        return images_all[idx_shuffle],idx_shuffle

    for ch in range(channel):
        print('real images channel %d, mean = %.4f, std = %.4f'%(ch, torch.mean(images_all[:, ch]), torch.std(images_all[:, ch])))


    ''' initialize the synthetic data '''
    image_syn = torch.randn(size=(num_classes*args.ipc, channel, im_size[0], im_size[1]), dtype=torch.float, requires_grad=True, device=args.device)
    label_syn = torch.tensor([np.ones(args.ipc)*i for i in range(num_classes)], dtype=torch.long, requires_grad=False, device=args.device).view(-1) # [0,0,0, 1,1,1, ..., 9,9,9]

    if args.init == 'real':
        print('initialize synthetic data from random real images')
        for c in range(num_classes):
            reals,index = get_images_init(c, args.ipc,exp)
            reals = reals.detach().data
            pairs_real.append(reals)
            indexs_real.append(index)
            image_syn.data[c*args.ipc:(c+1)*args.ipc] = reals
            # pairs_real
    else:
        print('initialize synthetic data from random noise')

In [None]:
img_real_test= torch.cat(pairs_real, dim=0)

label_real_test = []
for i in range(int(len(indexs_real)/10)):
    # print(i)
    label_real_test_ = []
    for c in range(num_classes):
        idx_shuffle = indexs_real[c + i*10]
        label_real_test_.append(labels_all[idx_shuffle].to("cpu"))
        # print()
    # img_real = torch.from_numpy(np.concatenate(img_real, axis=0))
    label_real_test_ = torch.from_numpy(np.concatenate(label_real_test_, axis=0))
    label_real_test.append(label_real_test_)
label_real_test = torch.cat(label_real_test, dim=0)

In [None]:
device = args.device

In [None]:
device

In [None]:
img_syn = []
label_syn = []
img_real_train = []
label_real_train = []
# /home/ssd7T/ztl_dm/indexs_real_20.pt
for i in range(80):
    try:
    
        img_syn_ = torch.load(f'/home/ssd7T/ztl_dm/img_syn_{i}.pt')
        label_syn_ = torch.load(f'/home/ssd7T/ztl_dm/label_syn_{i}.pt')
        pairs_real_=torch.load(f'/home/ssd7T/ztl_dm/pairs_real_{i}.pt')
        indexs_real_=torch.load(f'/home/ssd7T/ztl_dm/indexs_real_{i}.pt')
        
        img_real_train_ = torch.cat(pairs_real_, dim=0)
        
    
        label_real_ = []
        for c in range(num_classes):
            idx_shuffle = indexs_real_[c]
            label_real_.append(labels_all[idx_shuffle].to("cpu"))
        # img_real = torch.from_numpy(np.concatenate(img_real, axis=0))
        label_real_train_ = torch.from_numpy(np.concatenate(label_real_, axis=0))
        # label_real_train_ = torch.cat(label_real_train_, dim=0)
        
        img_syn.append(img_syn_)
        label_syn.append(label_syn_)
        img_real_train.append(img_real_train_)
        label_real_train.append( label_real_train_)
        # if i == 3 or i == 22 or i == 42 or i == 62:
        #     pairs_real_=torch.load(f'pairs_real_{i}.pt')
        #     img_real_train_ = torch.cat(pairs_real_, dim=0)
        #     img_real_train.append(img_real_train_)
            
    except:
        pass


img_syn = torch.cat(img_syn, dim=0).to(device)
label_syn = torch.cat(label_syn, dim=0).to(device)
img_real_train = torch.cat(img_real_train, dim=0).to(device)
label_real_train = torch.cat(label_real_train, dim=0).to(device)



In [None]:
img_real_train.shape

In [None]:
accs = []
model_eval_pool = get_eval_pool(args.eval_mode, args.model, args.model)
import copy
accs_all_exps = dict() # record performances of all experiments
for key in model_eval_pool:
    accs_all_exps[key] = []
args.dsa_param = ParamDiffAug()
args.dsa = False if args.dsa_strategy in ['none', 'None'] else True
model_eval= model_eval_pool[0]

In [None]:


num_classes = 10


In [None]:
from a_cvae import  CVAE

In [10]:

import argparse

import torch.utils.data
from torch import optim

from torchvision.utils import save_image, make_grid
    # 'custom': {'lr': 2e-4, 'k': 512, 'hidden': 128},
    # 'imagenet': {'lr': 2e-4, 'k': 512, 'hidden': 128},
    # 'cifar10': {'lr': 2e-4, 'k': 10, 'hidden': 256},
    # 'mnist': {'lr': 1e-4, 'k': 10, 'hidden': 64}
parser = argparse.ArgumentParser(description='Variational AutoEncoders')

model_parser = parser.add_argument_group('Model Parameters')
model_parser.add_argument('--model', default='vae', choices=['vae', 'vqvae'],
                            help='autoencoder variant to use: vae | vqvae')
model_parser.add_argument('--batch-size', type=int, default=128, metavar='N',
                            help='input batch size for training (default: 128)')
model_parser.add_argument('--hidden', type=int, default=256, metavar='N',
                            help='number of hidden channels')

model_parser.add_argument('-k', '--dict-size',  default=10,type=int, dest='k', metavar='K',
                            help='number of atoms in dictionary')
model_parser.add_argument('--lr', type=float, default=2e-4,
                            help='learning rate')
model_parser.add_argument('--vq_coef', type=float, default=None,
                            help='vq coefficient in loss')
model_parser.add_argument('--commit_coef', type=float, default=None,
                            help='commitment coefficient in loss')
model_parser.add_argument('--kl_coef', type=float, default=None,
                            help='kl-divergence coefficient in loss')

training_parser = parser.add_argument_group('Training Parameters')
training_parser.add_argument('--dataset', default='cifar10', choices=['mnist', 'cifar10', 'imagenet',
                                                                        'custom'],
                                help='dataset to use: mnist | cifar10 | imagenet | custom')
training_parser.add_argument('--dataset_dir_name', default='',
                                help='name of the dir containing the dataset if dataset == custom')
training_parser.add_argument('--data-dir', default='/media/ssd/Datasets',
                                help='directory containing the dataset')
training_parser.add_argument('--epochs', type=int, default=20, metavar='N',
                                help='number of epochs to train (default: 10)')
training_parser.add_argument('--max-epoch-samples', type=int, default=50000,
                                help='max num of samples per epoch')
training_parser.add_argument('--no-cuda', action='store_true', default=False,
                                help='enables CUDA training')
training_parser.add_argument('--seed', type=int, default=6, metavar='S',
                                help='random seed (default: 1)')
training_parser.add_argument('--gpus', default='0',
                                help='gpus used for training - e.g 0,1,3')

logging_parser = parser.add_argument_group('Logging Parameters')
logging_parser.add_argument('--log-interval', type=int, default=10, metavar='N',
                            help='how many batches to wait before logging training status')
logging_parser.add_argument('--results-dir', metavar='RESULTS_DIR', default='./results',
                            help='results dir')
logging_parser.add_argument('--save-name', default='',
                            help='saved folder')
logging_parser.add_argument('--data-format', default='json',
                            help='in which format to save the data')


# 256, 794, 3, 3

args_cave = parser.parse_args([])


torch.manual_seed(args_cave.seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(args_cave.seed)

# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


# dataset = MNIST(
#     root='/home/ssd7T/ZTL_gcond/data_cv', train=True, transform=transforms.ToTensor(),
#     download=True)
# data_loader = DataLoader(
#     dataset=dataset, batch_size=args.batch_size, shuffle=True)


lr = args_cave.lr 
k = args_cave.k 
hidden = args_cave.hidden 
num_channels = 3 # CIFA

model = CVAE(d = hidden, k=k, num_channels=num_channels).to(device) 

# def __init__(self, d, kl_coef=0.1, **kwargs):

optimizer = optim.Adam(model.parameters(), lr=lr)
scheduler = optim.lr_scheduler.StepLR(optimizer, 10 if args_cave.dataset == 'imagenet' else 30, 0.5,)
    
    
batch = args_cave.batch_size
# A

epochs = 3000
for epoch in range(epochs):

        
    random_indices = np.random.choice(len(img_real_train), batch, replace=False)

# 使用随机选择的索引来获取数据并改变形状
    batch_img = images_all[random_indices].reshape((batch, 3, 32, 32)).to(device)
    # batch_syn = img_syn[random_indices].reshape((batch, 3, 32, 32)).to(device) 
    # batch_img_y = label_real_train[c*batch:(c+1)*batch].to(device) 
        
    outputs = model(batch_img)
    
    loss = model.loss_function(batch_img, *outputs)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

        # logs['loss'].append(loss.item())

    if epoch % 10 == 0 or epoch == epochs -1:
        # print("Epoch {:02d}/{:02d} , Loss {:9.4f}".format(
        #     epoch, epochs,  loss.item()))
        print("====> Epoch: {} || {}".format(epoch, loss.item()))
        
with torch.no_grad():
  output = model(img_real_test.to(device))
data_save = []
net_eval = get_network(model_eval, channel, num_classes, im_size).to(device) # get a random model
image_syn_eval, label_syn_eval = copy.deepcopy(output[0]), copy.deepcopy(label_real_test) # avoid any unaware modification
_, acc_train, acc_test = evaluate_synset(1, net_eval, image_syn_eval, label_syn_eval, testloader, args)
accs.append(acc_test)


====> Epoch: 310 || 0.45274198055267334
====> Epoch: 320 || 0.3920760750770569
====> Epoch: 330 || 0.4361448884010315
====> Epoch: 340 || 0.4589168429374695
====> Epoch: 350 || 0.4764600396156311
====> Epoch: 360 || 0.45012331008911133
====> Epoch: 370 || 0.43497318029403687
====> Epoch: 380 || 0.4599984586238861
====> Epoch: 390 || 0.47639983892440796
====> Epoch: 400 || 0.3984176814556122
====> Epoch: 410 || 0.4566979706287384
====> Epoch: 420 || 0.4292212128639221
====> Epoch: 430 || 0.4357945919036865
====> Epoch: 440 || 0.45940259099006653
====> Epoch: 450 || 0.4267178773880005
====> Epoch: 460 || 0.43222978711128235
====> Epoch: 470 || 0.4086083471775055
====> Epoch: 480 || 0.4632788300514221
====> Epoch: 490 || 0.4430690407752991
====> Epoch: 500 || 0.43881309032440186
====> Epoch: 510 || 0.47247737646102905
====> Epoch: 520 || 0.44093096256256104
====> Epoch: 530 || 0.43182870745658875
====> Epoch: 540 || 0.42612016201019287
====> Epoch: 550 || 0.45518621802330017
====> Epoch: 

In [11]:
torch.save(model,'cvae_all_2.pt') 

In [12]:
model = CVAE(d = hidden, k=k, num_channels=num_channels).to(device) 

# 训练原图训

optimizer = optim.Adam(model.parameters(), lr=lr)
scheduler = optim.lr_scheduler.StepLR(optimizer, 10 if args_cave.dataset == 'imagenet' else 30, 0.5,)

epochs = 3000
for epoch in range(epochs):


    random_indices = np.random.choice(len(img_real_train), batch, replace=False)

# 使用随机选择的索引来获取数据并改变形状
    batch_img = img_real_train[random_indices].reshape((batch, 3, 32, 32)).to(device)
    # batch_syn = img_syn[random_indices].reshape((batch, 3, 32, 32)).to(device) 
    # batch_img_y = label_real_train[c*batch:(c+1)*batch].to(device) 
        
    outputs = model(batch_img)
    
    loss = model.loss_function(batch_img, *outputs)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

        # logs['loss'].append(loss.item())

    if epoch % 10 == 0 or epoch == epochs -1:
        # print("Epoch {:02d}/{:02d} , Loss {:9.4f}".format(
        #     epoch, epochs,  loss.item()))
        print("====> Epoch: {} || {}".format(epoch, loss.item()))
        
with torch.no_grad():
  output = model(img_real_test.to(device))
data_save = []
net_eval = get_network(model_eval, channel, num_classes, im_size).to(device) # get a random model
image_syn_eval, label_syn_eval = copy.deepcopy(output[0]), copy.deepcopy(label_real_test) # avoid any unaware modification
_, acc_train, acc_test = evaluate_synset(1, net_eval, image_syn_eval, label_syn_eval, testloader, args)
accs.append(acc_test)

====> Epoch: 0 || 2.1340150833129883
====> Epoch: 10 || 0.9069045186042786
====> Epoch: 20 || 0.7115136981010437
====> Epoch: 30 || 0.6463376879692078
====> Epoch: 40 || 0.613821804523468
====> Epoch: 50 || 0.5472981929779053
====> Epoch: 60 || 0.4958128035068512
====> Epoch: 70 || 0.46100789308547974
====> Epoch: 80 || 0.4959169626235962
====> Epoch: 90 || 0.5134686231613159
====> Epoch: 100 || 0.5289812684059143
====> Epoch: 110 || 0.49020203948020935
====> Epoch: 120 || 0.48215851187705994
====> Epoch: 130 || 0.43007057905197144
====> Epoch: 140 || 0.4775165021419525
====> Epoch: 150 || 0.43863627314567566
====> Epoch: 160 || 0.46100640296936035
====> Epoch: 170 || 0.43461373448371887
====> Epoch: 180 || 0.44392457604408264
====> Epoch: 190 || 0.42891979217529297
====> Epoch: 200 || 0.4349367320537567
====> Epoch: 210 || 0.489485502243042
====> Epoch: 220 || 0.47778603434562683
====> Epoch: 230 || 0.3813686966896057
====> Epoch: 240 || 0.45476803183555603
====> Epoch: 250 || 0.46396

In [13]:
model = CVAE(d = hidden, k=k, num_channels=num_channels).to(device) 

# 样本对训

optimizer = optim.Adam(model.parameters(), lr=lr)
scheduler = optim.lr_scheduler.StepLR(optimizer, 10 if args_cave.dataset == 'imagenet' else 30, 0.5,)

epochs = 3000
for epoch in range(epochs):


    
        
    random_indices = np.random.choice(len(img_real_train), batch, replace=False)

# 使用随机选择的索引来获取数据并改变形状
    batch_img = img_real_train[random_indices].reshape((batch, 3, 32, 32)).to(device)
    batch_syn = img_syn[random_indices].reshape((batch, 3, 32, 32)).to(device) 
    # batch_img_y = label_real_train[c*batch:(c+1)*batch].to(device) 
    
    outputs = model(batch_img)
    
    loss = model.loss_function(batch_syn, *outputs)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

        # logs['loss'].append(loss.item())

    if epoch % 10 == 0 or epoch == epochs -1:
        # print("Epoch {:02d}/{:02d} , Loss {:9.4f}".format(
        #     epoch, epochs,  loss.item()))
        print("====> Epoch: {} || {}".format(epoch, loss.item()))
        
with torch.no_grad():
  output = model(img_real_test.to(device))
data_save = []
net_eval = get_network(model_eval, channel, num_classes, im_size).to(device) # get a random model
image_syn_eval, label_syn_eval = copy.deepcopy(output[0]), copy.deepcopy(label_real_test) # avoid any unaware modification
_, acc_train, acc_test = evaluate_synset(1, net_eval, image_syn_eval, label_syn_eval, testloader, args)
accs.append(acc_test)

====> Epoch: 0 || 2.224699020385742
====> Epoch: 10 || 0.9188709259033203
====> Epoch: 20 || 0.7404829859733582
====> Epoch: 30 || 0.6518579721450806
====> Epoch: 40 || 0.640002429485321
====> Epoch: 50 || 0.5574132204055786
====> Epoch: 60 || 0.5992719531059265
====> Epoch: 70 || 0.5128530859947205
====> Epoch: 80 || 0.529251217842102
====> Epoch: 90 || 0.4890080988407135
====> Epoch: 100 || 0.5311537981033325
====> Epoch: 110 || 0.4657144844532013
====> Epoch: 120 || 0.5117236971855164
====> Epoch: 130 || 0.501807451248169
====> Epoch: 140 || 0.47017645835876465
====> Epoch: 150 || 0.4720045328140259
====> Epoch: 160 || 0.4383760392665863
====> Epoch: 170 || 0.4994126856327057
====> Epoch: 180 || 0.46501556038856506
====> Epoch: 190 || 0.4902753233909607
====> Epoch: 200 || 0.40907248854637146
====> Epoch: 210 || 0.4566342830657959
====> Epoch: 220 || 0.45014289021492004
====> Epoch: 230 || 0.45091938972473145
====> Epoch: 240 || 0.4700677692890167
====> Epoch: 250 || 0.4092158973217

In [14]:
model = torch.load('cvae_all_2.pt').to(device)  
# A + 合成数据集 0.4552 0.4548 0.4674

optimizer = optim.Adam(model.parameters(), lr=lr)
scheduler = optim.lr_scheduler.StepLR(optimizer, 10 if args_cave.dataset == 'imagenet' else 30, 0.5,)

epochs = 3000
for epoch in range(epochs):


    # batch_img = images_all[c*batch:(c+1)*batch].reshape((batch, 3, 32, 32)).to(device) 
    
    # batch_img = img_real_train[c*batch:(c+1)*batch].reshape((batch, 3, 32, 32)).to(device) 
    # batch_syn = img_syn[c*batch:(c+1)*batch].reshape((batch, 3, 32, 32)).to(device) 
    
    # 随机选择batch个索引
    random_indices = np.random.choice(len(img_real_train), batch, replace=False)

    # 使用随机选择的索引来获取数据并改变形状
    # batch_img = img_real_train[random_indices].reshape((batch, 3, 32, 32)).to(device)
    batch_syn = img_syn[random_indices].reshape((batch, 3, 32, 32)).to(device) 

    # batch_img_y = label_real_train[c*batch:(c+1)*batch].to(device) 
    
    outputs = model(batch_syn)
    
    loss = model.loss_function(batch_syn, *outputs)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # logs['loss'].append(loss.item())

    if epoch % 10 == 0 or epoch == epochs -1:
        # print("Epoch {:02d}/{:02d} , Loss {:9.4f}".format(
        #     epoch, epochs,  loss.item()))
        print("====> Epoch: {} || {}".format(epoch, loss.item()))
        
with torch.no_grad():
  output = model(img_real_test.to(device))
data_save = []
net_eval = get_network(model_eval, channel, num_classes, im_size).to(device) # get a random model
image_syn_eval, label_syn_eval = copy.deepcopy(output[0]), copy.deepcopy(label_real_test) # avoid any unaware modification
_, acc_train, acc_test = evaluate_synset(1, net_eval, image_syn_eval, label_syn_eval, testloader, args)
accs.append(acc_test)

====> Epoch: 0 || 0.35182124376296997
====> Epoch: 10 || 0.3487131595611572
====> Epoch: 20 || 0.3917742371559143
====> Epoch: 30 || 0.37243616580963135
====> Epoch: 40 || 0.4066743850708008
====> Epoch: 50 || 0.3578777015209198
====> Epoch: 60 || 0.37656667828559875
====> Epoch: 70 || 0.3317330479621887
====> Epoch: 80 || 0.3398778438568115
====> Epoch: 90 || 0.3423633277416229
====> Epoch: 100 || 0.37219443917274475
====> Epoch: 110 || 0.3841681182384491
====> Epoch: 120 || 0.35255375504493713
====> Epoch: 130 || 0.369001179933548
====> Epoch: 140 || 0.3786856532096863
====> Epoch: 150 || 0.38036856055259705
====> Epoch: 160 || 0.35241803526878357
====> Epoch: 170 || 0.3669365346431732
====> Epoch: 180 || 0.3631303310394287
====> Epoch: 190 || 0.3775741457939148
====> Epoch: 200 || 0.3434806764125824
====> Epoch: 210 || 0.34137073159217834
====> Epoch: 220 || 0.3592779040336609
====> Epoch: 230 || 0.410616010427475
====> Epoch: 240 || 0.3482930362224579
====> Epoch: 250 || 0.37540078

In [15]:
model = torch.load('cvae_all_2.pt').to(device) 
# A + 样本对 0.4552 0.4548 0.4674 0.4720

optimizer = optim.Adam(model.parameters(), lr=lr)
scheduler = optim.lr_scheduler.StepLR(optimizer, 10 if args_cave.dataset == 'imagenet' else 30, 0.5,)

epochs = 3000
for epoch in range(epochs):
        
    # batch_img = images_all[c*batch:(c+1)*batch].reshape((batch, 3, 32, 32)).to(device) 
    
    # batch_img = img_real_train[c*batch:(c+1)*batch].reshape((batch, 3, 32, 32)).to(device) 
    # batch_syn = img_syn[c*batch:(c+1)*batch].reshape((batch, 3, 32, 32)).to(device) 
    
    # 随机选择batch个索引
    random_indices = np.random.choice(len(img_real_train), batch, replace=False)

    # 使用随机选择的索引来获取数据并改变形状
    batch_img = img_real_train[random_indices].reshape((batch, 3, 32, 32)).to(device)
    batch_syn = img_syn[random_indices].reshape((batch, 3, 32, 32)).to(device) 

    # batch_img_y = label_real_train[c*batch:(c+1)*batch].to(device) 
    
    outputs = model(batch_img)
    
    loss = model.loss_function(batch_syn, *outputs)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # logs['loss'].append(loss.item())

    if epoch % 10 == 0 or epoch == epochs -1:
        # print("Epoch {:02d}/{:02d} , Loss {:9.4f}".format(
        #     epoch, epochs,  loss.item()))
        print("====> Epoch: {} || {}".format(epoch, loss.item()))
        
with torch.no_grad():
  output = model(img_real_test.to(device))
data_save = []
net_eval = get_network(model_eval, channel, num_classes, im_size).to(device) # get a random model
image_syn_eval, label_syn_eval = copy.deepcopy(output[0]), copy.deepcopy(label_real_test) # avoid any unaware modification
_, acc_train, acc_test = evaluate_synset(1, net_eval, image_syn_eval, label_syn_eval, testloader, args)
accs.append(acc_test)

====> Epoch: 0 || 0.3757937252521515
====> Epoch: 10 || 0.3825913071632385
====> Epoch: 20 || 0.3968660533428192
====> Epoch: 30 || 0.40076813101768494
====> Epoch: 40 || 0.4041665494441986
====> Epoch: 50 || 0.345682829618454
====> Epoch: 60 || 0.3817612826824188
====> Epoch: 70 || 0.34802284836769104
====> Epoch: 80 || 0.31748664379119873
====> Epoch: 90 || 0.36636874079704285
====> Epoch: 100 || 0.40857332944869995
====> Epoch: 110 || 0.40054237842559814
====> Epoch: 120 || 0.3380434811115265
====> Epoch: 130 || 0.339511513710022
====> Epoch: 140 || 0.3632967472076416
====> Epoch: 150 || 0.3440603017807007
====> Epoch: 160 || 0.3688947260379791
====> Epoch: 170 || 0.38125380873680115
====> Epoch: 180 || 0.36186617612838745
====> Epoch: 190 || 0.3839729428291321
====> Epoch: 200 || 0.34188199043273926
====> Epoch: 210 || 0.3742049038410187
====> Epoch: 220 || 0.35022541880607605
====> Epoch: 230 || 0.35687145590782166
====> Epoch: 240 || 0.35229793190956116
====> Epoch: 250 || 0.3431

In [16]:

# test 合成数据集结果（1500张）
# 直接输入原图： 0.5100
# 样本对训练原图 - 合成图差异： 0.5093

# A :                  0.3860 - 0.005  |   0.4056
# 只用样本对中的原图 :  0.3808 - 0.006  |   0.4034
# 只用样本对训 ：       0.3719, 0.005 |   0.3762

# A+合成图 FINTUNE :    0.4056 - 0.005  |   0.3825 0.4110 0.4075

# A+样本对 FINTUNE :   0.3840 - 0.004  |   0.3885 0.3785 0.3850


# accs = [0.3642,0.3756,0.3759]np.mean(accs), np.std(accs)

In [17]:
def test_net(epoch, model, test_loader, cuda, save_path, args, writer):
    model.eval()
    loss_dict = model.latest_losses()
    losses = {k + '_test': 0 for k, v in loss_dict.items()}
    i, data = None, None
    with torch.no_grad():
        for i, (data, _) in enumerate(test_loader):
            if cuda:
                data = data.cuda()
            outputs = model(data)
            model.loss_function(data, *outputs)
            latest_losses = model.latest_losses()
            for key in latest_losses:
                losses[key + '_test'] += float(latest_losses[key])
            if i == 0:
                write_images(data, outputs, writer, 'test')

                save_reconstructed_images(data, epoch, outputs[0], save_path, 'reconstruction_test')
                # save_checkpoint(model, epoch, save_path)
            if args.dataset == 'imagenet' and i * len(data) > 1000:
                break

    for key in losses:
        if args.dataset not in ['imagenet', 'custom']:
            losses[key] /= (len(test_loader.dataset) / test_loader.batch_size)
        else:
            losses[key] /= (i * len(data))
    loss_string = ' '.join(['{}: {:.6f}'.format(k, v) for k, v in losses.items()])
    logging.info('====> Test set losses: {}'.format(loss_string))
    return losses


def write_images(data, outputs, writer, suffix):
    original = data.mul(0.5).add(0.5)
    original_grid = make_grid(original[:6])
    writer.add_image(f'original/{suffix}', original_grid)
    reconstructed = outputs[0].mul(0.5).add(0.5)
    reconstructed_grid = make_grid(reconstructed[:6])
    writer.add_image(f'reconstructed/{suffix}', reconstructed_grid)


def save_reconstructed_images(data, epoch, outputs, save_path, name):
    size = data.size()
    n = min(data.size(0), 8)
    batch_size = data.size(0)
    comparison = torch.cat([data[:n],
                            outputs.view(batch_size, size[1], size[2], size[3])[:n]])
    save_image(comparison.cpu(),
               os.path.join(save_path, name + '_' + str(epoch) + '.png'), nrow=n, normalize=True)