In [14]:
import os
import time
import copy
import argparse
import numpy as np
import torch
import torch.nn as nn
from torchvision.utils import save_image
from utils import get_loops, get_dataset, get_network, get_eval_pool, evaluate_synset, get_daparam, match_loss, get_time, TensorDataset, epoch, DiffAugment, ParamDiffAug
import warnings
warnings.filterwarnings("ignore")



parser = argparse.ArgumentParser(description='Parameter Processing')
parser.add_argument('--dataset', type=str, default='CIFAR10', help='dataset')
parser.add_argument('--model', type=str, default='ConvNet', help='model')
parser.add_argument('--ipc', type=int, default=50, help='image(s) per class')
parser.add_argument('--eval_mode', type=str, default='SS', help='eval_mode') # S: the same to training model, M: multi architectures,  W: net width, D: net depth, A: activation function, P: pooling layer, N: normalization layer,
parser.add_argument('--num_exp', type=int, default=1, help='the number of experiments')
parser.add_argument('--num_eval', type=int, default=1, help='the number of evaluating randomly initialized models')
parser.add_argument('--epoch_eval_train', type=int, default=1000, help='epochs to train a model with synthetic data') # it can be small for speeding up with little performance drop
parser.add_argument('--Iteration', type=int, default=2000, help='training iterations')
parser.add_argument('--lr_img', type=float, default=1.0, help='learning rate for updating synthetic images')
parser.add_argument('--lr_net', type=float, default=0.01, help='learning rate for updating network parameters')
parser.add_argument('--batch_real', type=int, default=256, help='batch size for real data')
parser.add_argument('--batch_train', type=int, default=256, help='batch size for training networks')
parser.add_argument('--init', type=str, default='real', help='noise/real: initialize synthetic images from random noise or randomly sampled real images.')
parser.add_argument('--dsa_strategy', type=str, default='color_crop_cutout_flip_scale_rotate', help='differentiable Siamese augmentation strategy')
parser.add_argument('--data_path', type=str, default='/home/ssd7T/ZTL_gcond/data_cv', help='dataset path')
parser.add_argument('--save_path', type=str, default='result/gen', help='path to save results')
parser.add_argument('--dis_metric', type=str, default='ours', help='distance metric')

args = parser.parse_args([])
args.method = 'DM'
args.outer_loop, args.inner_loop = get_loops(args.ipc)
args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
args.dsa_param = ParamDiffAug()
args.dsa = False if args.dsa_strategy in ['none', 'None'] else True

if not os.path.exists(args.data_path):
    os.mkdir(args.data_path)

if not os.path.exists(args.save_path):
    os.mkdir(args.save_path)

eval_it_pool = np.arange(0, args.Iteration+1, 2000).tolist() if args.eval_mode == 'S' or args.eval_mode == 'SS' else [args.Iteration] # The list of iterations when we evaluate models and record results.
print('eval_it_pool: ', eval_it_pool)
channel, im_size, num_classes, class_names, mean, std, dst_train, dst_test, testloader = get_dataset(args.dataset, args.data_path)
model_eval_pool = get_eval_pool(args.eval_mode, args.model, args.model)


accs_all_exps = dict() # record performances of all experiments
for key in model_eval_pool:
    accs_all_exps[key] = []

data_save = []

pairs_real = []
indexs_real = []
for exp in range(args.num_exp):
    print('\n================== Exp %d ==================\n '%exp)
    print('Hyper-parameters: \n', args.__dict__)
    print('Evaluation model pool: ', model_eval_pool)

    ''' organize the real dataset '''
    images_all = []
    labels_all = []
    indices_class = [[] for c in range(num_classes)]

    images_all = [torch.unsqueeze(dst_train[i][0], dim=0) for i in range(len(dst_train))]
    labels_all = [dst_train[i][1] for i in range(len(dst_train))]
    for i, lab in enumerate(labels_all):
        indices_class[lab].append(i)
    images_all = torch.cat(images_all, dim=0).to(args.device)
    labels_all = torch.tensor(labels_all, dtype=torch.long, device=args.device)



    for c in range(num_classes):
        print('class c = %d: %d real images'%(c, len(indices_class[c])))

    def get_images(c, n): # get random n images from class c
        idx_shuffle = np.random.permutation(indices_class[c])[:n]
        return images_all[idx_shuffle]
    
    def get_images_init(c, n,exp): # get random n images from class c
        # start_idx = i  # 指定起始索引 i
        # end_idx = i + n  # 计算结束索引（不包括结束索引）

        # 从指定的起始索引到结束索引获取元素
        idx_shuffle  = indices_class[c][exp:exp + n]

        # idx_shuffle = np.random.permutation(indices_class[c])[:n]
        return images_all[idx_shuffle],idx_shuffle

    for ch in range(channel):
        print('real images channel %d, mean = %.4f, std = %.4f'%(ch, torch.mean(images_all[:, ch]), torch.std(images_all[:, ch])))


    ''' initialize the synthetic data '''
    image_syn = torch.randn(size=(num_classes*args.ipc, channel, im_size[0], im_size[1]), dtype=torch.float, requires_grad=True, device=args.device)
    label_syn = torch.tensor([np.ones(args.ipc)*i for i in range(num_classes)], dtype=torch.long, requires_grad=False, device=args.device).view(-1) # [0,0,0, 1,1,1, ..., 9,9,9]

    if args.init == 'real':
        print('initialize synthetic data from random real images')
        for c in range(num_classes):
            reals,index = get_images_init(c, args.ipc,exp)
            reals = reals.detach().data
            pairs_real.append(reals)
            indexs_real.append(index)
            image_syn.data[c*args.ipc:(c+1)*args.ipc] = reals
            # pairs_real
    else:
        print('initialize synthetic data from random noise')
    
    # init_syn = image_syn.detach()
    torch.save(pairs_real, 'pairs_real_{exp}.pt')
    torch.save(indexs_real, 'indexs_real_{exp}.pt')


    ''' training '''
    optimizer_img = torch.optim.SGD([image_syn, ], lr=args.lr_img, momentum=0.5) # optimizer_img for synthetic data
    optimizer_img.zero_grad()
    print('%s training begins'%get_time())

    for it in range(args.Iteration+1):

        ''' Evaluate synthetic data '''
        if it in eval_it_pool:
            for model_eval in model_eval_pool:
                print('-------------------------\nEvaluation\nmodel_train = %s, model_eval = %s, iteration = %d'%(args.model, model_eval, it))

                print('DSA augmentation strategy: \n', args.dsa_strategy)
                print('DSA augmentation parameters: \n', args.dsa_param.__dict__)

                accs = []
                for it_eval in range(args.num_eval):
                    net_eval = get_network(model_eval, channel, num_classes, im_size).to(args.device) # get a random model
                    image_syn_eval, label_syn_eval = copy.deepcopy(image_syn.detach()), copy.deepcopy(label_syn.detach()) # avoid any unaware modification
                    _, acc_train, acc_test = evaluate_synset(it_eval, net_eval, image_syn_eval, label_syn_eval, testloader, args)
                    accs.append(acc_test)
                print('Evaluate %d random %s, mean = %.4f std = %.4f\n-------------------------'%(len(accs), model_eval, np.mean(accs), np.std(accs)))

                if it == args.Iteration: # record the final results
                    accs_all_exps[model_eval] += accs

            ''' visualize and save '''
            save_name = os.path.join(args.save_path, 'vis_%s_%s_%s_%dipc_exp%d_iter%d.png'%(args.method, args.dataset, args.model, args.ipc, exp, it))
            image_syn_vis = copy.deepcopy(image_syn.detach().cpu())
            for ch in range(channel):
                image_syn_vis[:, ch] = image_syn_vis[:, ch]  * std[ch] + mean[ch]
            image_syn_vis[image_syn_vis<0] = 0.0
            image_syn_vis[image_syn_vis>1] = 1.0
            save_image(image_syn_vis, save_name, nrow=args.ipc) # Trying normalize = True/False may get better visual effects.



        ''' Train synthetic data '''
        net = get_network(args.model, channel, num_classes, im_size).to(args.device) # get a random model
        net.train()
        for param in list(net.parameters()):
            param.requires_grad = False

        embed = net.module.embed if torch.cuda.device_count() > 1 else net.embed # for GPU parallel

        loss_avg = 0

        ''' update synthetic data '''
        if 'BN' not in args.model: # for ConvNet
            loss = torch.tensor(0.0).to(args.device)
            for c in range(num_classes):
                img_real = get_images(c, args.batch_real)
                img_syn = image_syn[c*args.ipc:(c+1)*args.ipc].reshape((args.ipc, channel, im_size[0], im_size[1]))

                if args.dsa:
                    seed = int(time.time() * 1000) % 100000
                    img_real = DiffAugment(img_real, args.dsa_strategy, seed=seed, param=args.dsa_param)
                    img_syn = DiffAugment(img_syn, args.dsa_strategy, seed=seed, param=args.dsa_param)

                output_real = embed(img_real).detach()
                output_syn = embed(img_syn)

                loss += torch.sum((torch.mean(output_real, dim=0) - torch.mean(output_syn, dim=0))**2)

        else: # for ConvNetBN
            images_real_all = []
            images_syn_all = []
            loss = torch.tensor(0.0).to(args.device)
            for c in range(num_classes):
                img_real = get_images(c, args.batch_real)
                img_syn = image_syn[c*args.ipc:(c+1)*args.ipc].reshape((args.ipc, channel, im_size[0], im_size[1]))

                if args.dsa:
                    seed = int(time.time() * 1000) % 100000
                    img_real = DiffAugment(img_real, args.dsa_strategy, seed=seed, param=args.dsa_param)
                    img_syn = DiffAugment(img_syn, args.dsa_strategy, seed=seed, param=args.dsa_param)

                images_real_all.append(img_real)
                images_syn_all.append(img_syn)

            images_real_all = torch.cat(images_real_all, dim=0)
            images_syn_all = torch.cat(images_syn_all, dim=0)

            output_real = embed(images_real_all).detach()
            output_syn = embed(images_syn_all)

            loss += torch.sum((torch.mean(output_real.reshape(num_classes, args.batch_real, -1), dim=1) - torch.mean(output_syn.reshape(num_classes, args.ipc, -1), dim=1))**2)



        optimizer_img.zero_grad()
        loss.backward()
        optimizer_img.step()
        loss_avg += loss.item()


        loss_avg /= (num_classes)

        if it%10 == 0:
            print('%s iter = %05d, loss = %.4f' % (get_time(), it, loss_avg))

        if it == args.Iteration: # only record the final results
            # img_syn = image_syn.detach().to(device)
            # label_syn = label_syn.to(device)
            torch.save(image_syn.detach().cpu(), f'img_syn_{exp}.pt')
            torch.save(label_syn.detach().cpu(), f'label_syn_{exp}.pt')

            # data_save.append([copy.deepcopy(image_syn.detach().cpu()), copy.deepcopy(label_syn.detach().cpu())])
            # torch.save({'data': data_save, 'accs_all_exps': accs_all_exps, }, os.path.join(args.save_path, 'res_%s_%s_%s_%dipc.pt'%(args.method, args.dataset, args.model, args.ipc)))


print('\n==================== Final Results ====================\n')
for key in model_eval_pool:
    accs = accs_all_exps[key]
    print('Run %d experiments, train on %s, evaluate %d random %s, mean  = %.2f%%  std = %.2f%%'%(args.num_exp, args.model, len(accs), key, np.mean(accs)*100, np.std(accs)*100))



# if __name__ == '__main__':
#     main()




eval_it_pool:  [0, 2000]
Files already downloaded and verified
Files already downloaded and verified

 
Hyper-parameters: 
 {'dataset': 'CIFAR10', 'model': 'ConvNet', 'ipc': 50, 'eval_mode': 'SS', 'num_exp': 1, 'num_eval': 1, 'epoch_eval_train': 1000, 'Iteration': 2000, 'lr_img': 1.0, 'lr_net': 0.01, 'batch_real': 256, 'batch_train': 256, 'init': 'real', 'dsa_strategy': 'color_crop_cutout_flip_scale_rotate', 'data_path': '/home/ssd7T/ZTL_gcond/data_cv', 'save_path': 'result/gen', 'dis_metric': 'ours', 'method': 'DM', 'outer_loop': 50, 'inner_loop': 10, 'device': 'cuda', 'dsa_param': <utils.ParamDiffAug object at 0x7f143069d340>, 'dsa': True}
Evaluation model pool:  ['ConvNet']
class c = 0: 5000 real images
class c = 1: 5000 real images
class c = 2: 5000 real images
class c = 3: 5000 real images
class c = 4: 5000 real images
class c = 5: 5000 real images
class c = 6: 5000 real images
class c = 7: 5000 real images
class c = 8: 5000 real images
class c = 9: 5000 real images
real images ch

KeyboardInterrupt: 

In [2]:
device = img_syn.device

In [3]:
img_syn = image_syn.detach().to(device)
label_syn = label_syn.to(device)

In [7]:
img_syn.shape

torch.Size([500, 3, 32, 32])

In [8]:
pairs_real[0].shape

torch.Size([50, 3, 32, 32])

In [9]:
concatenated_tensor = torch.cat(pairs_real, dim=0)

In [10]:
concatenated_tensor.shape

torch.Size([500, 3, 32, 32])

In [16]:
indexs_real[0]

[29,
 30,
 35,
 49,
 77,
 93,
 115,
 116,
 129,
 165,
 179,
 185,
 189,
 199,
 213,
 220,
 223,
 233,
 264,
 276,
 279,
 284,
 293,
 308,
 317,
 332,
 341,
 344,
 348,
 349,
 352,
 371,
 373,
 376,
 392,
 401,
 404,
 405,
 407,
 415,
 417,
 436,
 439,
 448,
 453,
 455,
 457,
 467,
 468,
 481]

In [12]:
torch.save(img_syn, 'img_syn.pt')
torch.save(label_syn, 'label_syn.pt')
torch.save(pairs_real, 'pairs_real.pt')
torch.save(indexs_real, 'indexs_real.pt')

In [None]:
# pairs_real = []
# indexs_real = []

In [None]:
img_syn.shape

In [None]:
# x = torch.rand(3,3)
# torch.save(x, 'tensor.pt')

# # 读取tensor
# x = torch.load('tensor.pt')

In [None]:
import torch
import torch.nn as nn

# 定义自动编码器
class Autoencoder(nn.Module):
    
    def __init__(self,num_feat,num_classes = 10):
        super(Autoencoder, self).__init__()
        
        # 输入图像通道数修改为3
        self.encoder = nn.Sequential(  
            nn.Conv2d(3, 64, 3, stride=2, padding=1),  
            nn.ReLU(),
            nn.Conv2d(64, 128, 3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 256, 7)
        )
        
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(256, 128, 7),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 3, 3, stride=2, padding=1, output_padding=1),
            nn.Sigmoid()
        )
        # self.aux_classifier = nn.Linear(32, 10) 
        self.classifier = nn.Linear(num_feat, num_classes)

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x
    

In [None]:
batch = 50
num_feat = 3072
model = Autoencoder(num_feat).to(device)  
criterion = nn.MSELoss().to(device)  
# aux loss函数
aux_loss_fn = nn.CrossEntropyLoss().to(args.device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

# 训练
for epoch in range(1000):
  # for i in range(0, len(img_syn), batch):
  #   batch_img = img_syn[i:i+batch] 
  #   batch_label = label_syn[i:i+batch] 
  #   batch_img = batch_img.to(device) 
  #   recon = model(batch_img)
  #   loss = criterion(recon, batch_img)

  #   # 计算辅助分类损失
  #   classifiation = model.classifier(recon.view(recon.size(0), -1)) 
  #   # aux_loss = aux_loss_fn(aux_logits, batch_label)
  #   cls_loss = aux_loss_fn(classifiation, batch_label)
    
  #   total_loss = 0.9 *loss + 0.1 * cls_loss
  #   print("mse:",loss.item())
  #   print("cls_loss:",cls_loss.item())
  #   # if total_loss < 0.3:
  #   #   break
    
  #   optimizer.zero_grad()
  #   loss.backward()
  #   optimizer.step() 
  # for i in range(0, len(img_syn), batch):
  total_loss = 0
  for c in range(num_classes):
                # 获取类别c的合成图像和类别中心
                # image_syn[c*args.ipc:(c+1)*args.ipc].reshape((args.ipc, channel, im_size[0], im_size[1]))
    batch_img = image_syn[c*batch:(c+1)*batch].reshape((batch, 3, 32, 32)) 
    batch_label = label_syn[c*batch:(c+1)*batch] 
    batch_img = batch_img.to(device) 
    recon = model(batch_img)
    loss = criterion(recon, batch_img)

    # 计算辅助分类损失
    classifiation = model.classifier(recon.view(recon.size(0), -1)) 
    # aux_loss = aux_loss_fn(aux_logits, batch_label)
    cls_loss = aux_loss_fn(classifiation, batch_label)
    
    total_loss += 0.4 *loss + 0.6 * cls_loss
    # print(total_loss.item())
    # total_loss = 0.9 *loss + 0.1 * cls_loss
    # print("mse:",loss.item())
    # print("cls_loss:",cls_loss.item())
  if total_loss.item()<18.87:

      print("done")
      break
  # print(total_loss.item())
  optimizer.zero_grad()
  loss.backward()
  optimizer.step() 
  
    
# 生成新图像
noise = torch.randn(500, 3, 32, 32).to(device) 
with torch.no_grad():
  output = model.decoder(image_syn)

In [121]:
with torch.no_grad():
  output = model(img_real.to(device))

In [122]:
accs = []

net_eval = get_network(model_eval, channel, num_classes, im_size).to(args.device) # get a random model
image_syn_eval, label_syn_eval = copy.deepcopy(output), copy.deepcopy(label_syn) # avoid any unaware modification
_, acc_train, acc_test = evaluate_synset(it_eval, net_eval, image_syn_eval, label_syn_eval, testloader, args)
accs.append(acc_test)
print()
print('Evaluate %d random %s, mean = %.4f std = %.4f\n-------------------------'%(len(accs), model_eval, np.mean(accs), np.std(accs)))
# [2023-10-06 20:45:24] Evaluate_00: epoch = 1000 train time = 61 s train loss = 0.000965 train acc = 1.0000, test acc = 0.5001
# Evaluate 1 random ConvNet, mean = 0.5001 std = 0.0000

[2023-10-06 23:52:01] Evaluate_00: epoch = 1000 train time = 67 s train loss = 0.027192 train acc = 0.9980, test acc = 0.2098

Evaluate 1 random ConvNet, mean = 0.2098 std = 0.0000
-------------------------


In [123]:
image_syn_vis = copy.deepcopy(output.detach().cpu())
from torchvision.utils import save_image
image_syn_vis[image_syn_vis<0] = 0.0
image_syn_vis[image_syn_vis>1] = 1.0
save_image(image_syn_vis, "/home/wangkai/ztl_project/difussion-dd/DatasetCondensation-master/111.png", nrow=50) # Trying normalize = True/False may get better visual effects.

In [None]:
img_real = []
label_real = []
for c in range(num_classes):
    idx_shuffle = np.random.permutation(indices_class[c])[:50]
    img_real.append(images_all[idx_shuffle].to("cpu") )
    label_real.append(labels_all[idx_shuffle].to("cpu"))
img_real = torch.from_numpy(np.concatenate(img_real, axis=0))
label_real = torch.from_numpy(np.concatenate(label_real, axis=0))


In [None]:
accs = []
# [2023-10-06 20:45:24] Evaluate_00: epoch = 1000 train time = 61 s train loss = 0.000965 train acc = 1.0000, test acc = 0.5001
# Evaluate 1 random ConvNet, mean = 0.5001 std = 0.0000
net_eval = get_network(model_eval, channel, num_classes, im_size).to(args.device) # get a random model
image_syn_eval, label_syn_eval = copy.deepcopy(img_real), copy.deepcopy(label_real) # avoid any unaware modification
_, acc_train, acc_test = evaluate_synset(it_eval, net_eval, image_syn_eval, label_syn_eval, testloader, args)
accs.append(acc_test)
print('Evaluate %d random %s, mean = %.4f std = %.4f\n-------------------------'%(len(accs), model_eval, np.mean(accs), np.std(accs)))