In [None]:
 def collate(samples):
    graphs_spt, labels_spt, graph_qry, labels_qry, center_spt, center_qry, nodeidx_spt, nodeidx_qry, support_graph_idx, query_graph_idx = map(list, zip(*samples))

    return graphs_spt, labels_spt, graph_qry, labels_qry, center_spt, center_qry, nodeidx_spt, nodeidx_qry, support_graph_idx, query_graph_idx

class GCond:
    # 初始化函数，接收数据、参数和设备信息
    def __init__(self, data, args, device='cuda', **kwargs):
        self.data = data
        self.args = args
        self.device = device

        # 计算需要生成的synthetic graph的节点数
        n = int(len(data.idx_train) * args.reduction_rate)

        # 获取特征维度
        d = data.feat_train.shape[1]

        # 初始化synthetic graph的节点数和特征
        self.nnodes_syn = n
        self.feat_syn = nn.Parameter(torch.FloatTensor(n, d).to(device))

        # 初始化PGE模型
        self.pge = PGE(nfeat=d, nnodes=n, device=device, args=args).to(device)

        # 生成synthetic graph的标签
        self.labels_syn = torch.LongTensor(self.generate_labels_syn(data)).to(device)
        
        # 重置synthetic graph的特征
        self.reset_parameters()

        # 初始化特征优化器和PGE模型优化器
        self.optimizer_feat = torch.optim.Adam([self.feat_syn], lr=args.lr_feat)
        self.optimizer_pge = torch.optim.Adam(self.pge.parameters(), lr=args.lr_adj)
        
        # 打印synthetic graph的节点数和特征维度
        print('adj_syn:', (n,n), 'feat_syn:', self.feat_syn.shape)

    def reset_parameters(self):
        self.feat_syn.data.copy_(torch.randn(self.feat_syn.size()))

    def loadCSV(self,path):
        import csv
        from collections import Counter
        dictLabels = {}

        # dictLabels（标签到子图的映射）、

        with open(path, 'r') as csvfile:
            csvreader = csv.reader(csvfile, delimiter=',')
            next(csvreader, None)  # skip (filename, label)
            for i, row in enumerate(csvreader):
                filename = row[1]
                g_idx = int(filename.split('_')[0])
                label = row[2]
                # append filename to current label

                if label in dictLabels.keys():
                    dictLabels[label].append(filename)
                else:
                    dictLabels[label] = [filename]

        return dictLabels

    def generate_labels_from_csv(self,file_path,ratio):
        import csv
        from collections import Counter
        with open(file_path, 'r') as csv_file:
            reader = csv.DictReader(csv_file)
            data = list(reader)

        counter = Counter()  # 使用 Counter 对标签进行计数
        num_class_dict = {}  # 存储每个类别需要生成的节点数
        n = len(data)  # 数据总数

        # 统计每个类别的出现次数
        for row in data:
            label = int(row['label'])
            counter[label] += 1

        # 根据标签出现次数进行排序
        sorted_counter = sorted(counter.items(), key=lambda x: x[1])
        sum_ = 0
        labels_syn = []  # 合成标签列表
        syn_class_indices = {}  # 记录每个类别的索引范围
    #     ratio = 0.08
        # 计算每个类别需要生成的节点数，并生成合成标签
        for ix, (c, num) in enumerate(sorted_counter):

            if ix == len(sorted_counter) - 1:
                num_class_dict[c] = int(n * ratio) - sum_  
                # 如果是最后一个类别，那么直接计算需要生成的节点数，并将这个类别的合成标签添加到 labels_syn 列表中
                syn_class_indices[c] = [len(labels_syn), len(labels_syn) + num_class_dict[c]]
                labels_syn += [c] * num_class_dict[c]

            else:
                num_class_dict[c] = max(int(num * ratio), 1)  # 其他类别的节点数
                # 如果不是最后一个类别，那么先计算需要生成的节点数，然后将这个类别的合成标签添加到 labels_syn 列表中。
                sum_ += num_class_dict[c]
                syn_class_indices[c] = [len(labels_syn), len(labels_syn) + num_class_dict[c]]
                labels_syn += [c] * num_class_dict[c]

        return n * ratio,labels_syn

    def get_sub_adj_feat(self,features, labels_syn, label_dict):
        idx_selected = []

        from collections import Counter
        counter = Counter(labels_syn.cpu().numpy())
        selected_features = []
        for c in counter.keys():
            class_label = str(c)
            if class_label in label_dict:
                node_indices = [label_dict[class_label].index(node) for node in label_dict[class_label]]
                num_features = counter[c]
                selected_indices = np.random.choice(node_indices, size=num_features, replace=False)
                idx_selected.extend(selected_indices)

                selected_features.extend(features[selected_indices])

        selected_features = np.array(selected_features)

        return selected_features
    
    def generate_syn_feat(self,feat,pge = None,labels_syn,dictLabels,args):
        path = "D:\pythonProject\python_file\Graph_DD\META-DD\DATA\arxiv\train.csv"
        
        if pge == None:
            n,labels_syn= generate_labels_from_csv(path)

            # 计算需要生成的synthetic graph的节点数
            n = int(n)
            # 获取特征维度
            d = 128
            device='cuda'
            # 初始化synthetic graph的节点数和特征
            nnodes_syn = n
            # 初始化PGE模型
            pge = PGE(nfeat=d, nnodes=n, device=device, args=args).to(device)
    
        # 生成synthetic graph的标签
        labels_syn = torch.LongTensor(labels_syn).to(device)

#         dictLabels = loadCSV(path)
        feat_syn = torch.from_numpy(get_sub_adj_feat(feat, labels_syn, dictLabels)).to(device)
    # # 重置synthetic graph的特征
    # reset_parameters()
        adj_syn = pge(feat_syn)
        adj_syn[adj_syn < 0.5] = 0
        adj_syn[adj_syn >= 0.5] = 1

        adj_syn_norm_sp = sp.csr_matrix(adj_syn.cpu().detach().numpy())
        g = dgl.from_scipy(adj_syn_norm_sp) # 生成的合成图太过密集，导致二跳子图约等于全图
        # 将labels_syn转换为dictLabels形式的字典
        dict_labels = {f"0_{i}": labels_syn[i].item() for i in range(len(labels_syn))}
        db_syn = Subgraphs_syn('syn', dict_labels, n_way=args.n_way, k_shot=args.k_spt,k_query=args.k_qry, batchsz=100, args = args, adjs = g, h = args.h, labels = labels_syn)
       

        db = DataLoader(db_syn, args.task_num, shuffle=True, num_workers=args.num_workers, pin_memory=True, collate_fn = collate)
        
        return db
    def train(self, verbose=True):
        # 获取参数和数据
        args = self.args
        data = self.data
        feat_syn, pge, labels_syn = self.feat_syn, self.pge, self.labels_syn
        features, adj, labels = data.feat_train, data.adj_train, data.labels_train
        syn_class_indices = self.syn_class_indices

        # 将数据转换为张量并将其放置在设备上
        features, adj, labels = utils.to_tensor(features, adj, labels, device=self.device)
        # 获取子邻接矩阵和子特征矩阵
        feat_sub, adj_sub = self.get_sub_adj_feat(features)
        # 将子特征矩阵复制到feat_syn中
        self.feat_syn.data.copy_(feat_sub)

        # 根据是否是稀疏张量来规范化邻接矩阵
        if utils.is_sparse_tensor(adj):
            adj_norm = utils.normalize_adj_tensor(adj, sparse=True)
        else:
            adj_norm = utils.normalize_adj_tensor(adj)

        # 将邻接矩阵转换为稀疏张量
        adj = adj_norm
        adj = SparseTensor(row=adj._indices()[0], col=adj._indices()[1],
                value=adj._values(), sparse_sizes=adj.size()).t()

        # 获取外部循环和内部循环的次数
        outer_loop, inner_loop = get_loops(args)

        for it in range(args.epochs+1):
            loss_avg = 0
            if args.sgc==1:
                model = SGC(nfeat=data.feat_train.shape[1], nhid=args.hidden,
                            nclass=data.nclass, dropout=args.dropout,
                            nlayers=args.nlayers, with_bn=False,
                            device=self.device).to(self.device)
            elif args.sgc==2:
                model = SGC1(nfeat=data.feat_train.shape[1], nhid=args.hidden,
                            nclass=data.nclass, dropout=args.dropout,
                            nlayers=args.nlayers, with_bn=False,
                            device=self.device).to(self.device)

            else:
                model = GCN(nfeat=data.feat_train.shape[1], nhid=args.hidden,
                            nclass=data.nclass, dropout=args.dropout, nlayers=args.nlayers,
                            device=self.device).to(self.device)

            model.initialize() # model是GNN模型

            model_parameters = list(model.parameters())

            optimizer_model = torch.optim.Adam(model_parameters, lr=args.lr_model)

            model.train()

            # 进行外部循环
            for ol in range(outer_loop):
                # 生成并规范化合成邻接矩阵
                adj_syn = pge(self.feat_syn)
                adj_syn_norm = utils.normalize_adj_tensor(adj_syn, sparse=False)
                feat_syn_norm = feat_syn

                # 判断是否有BatchNorm层
                BN_flag = False
                for module in model.modules():
                    if 'BatchNorm' in module._get_name(): #BatchNorm
                        BN_flag = True

                # 如果有BatchNorm层，则需要训练模型以更新BatchNorm层的mu和sigma
                if BN_flag:
                    model.train() # for updating the mu, sigma of BatchNorm
                    output_real = model.forward(features, adj_norm)
                    for module in model.modules():
                        if 'BatchNorm' in module._get_name():  #BatchNorm
                            module.eval() # fix mu and sigma of every BatchNorm layer

                loss = torch.tensor(0.0).to(self.device)
                for c in range(data.nclass): # 根据每个类别计算LOSS
                    if c not in self.num_class_dict:
                        continue
                    # 获取类别c的样本
                    batch_size, n_id, adjs = data.retrieve_class_sampler(
                            c, adj, transductive=False, args=args)

                    # 如果只有一层，则将邻接矩阵放入列表中
                    if args.nlayers == 1:
                        adjs = [adjs]
                    adjs = [adj.to(self.device) for adj in adjs]

                    # 根据采样结果进行前向传播，LOSS计算和梯度计算
                    output = model.forward_sampler(features[n_id], adjs)
                    loss_real = F.nll_loss(output, labels[n_id[:batch_size]])
                    gw_real = torch.autograd.grad(loss_real, model_parameters)
                    gw_real = list((_.detach().clone() for _ in gw_real))

                    # 获取类别c的合成邻接矩阵
                    ind = syn_class_indices[c]
                    if args.nlayers == 1:
                        adj_syn_norm_list = [adj_syn_norm[ind[0]: ind[1]]]
                    else:
                        adj_syn_norm_list = [adj_syn_norm]*(args.nlayers-1) + \
                                [adj_syn_norm[ind[0]: ind[1]]]

                    # 计算合成邻接矩阵下的输出和LOSS
                    output_syn = model.forward_sampler_syn(feat_syn, adj_syn_norm_list)
                    loss_syn = F.nll_loss(output_syn, labels_syn[ind[0]: ind[1]])

                    # 计算合成邻接矩阵下的梯度
                    gw_syn = torch.autograd.grad(loss_syn, model_parameters, create_graph=True)
                    # create_graph：一个布尔值，表示是否创建用于计算高阶导数的计算图。默认为False，表示只计算一阶导数。
                   
                    # 计算匹配损失
                    coeff = self.num_class_dict[c] / max(self.num_class_dict.values())
                    loss += coeff  * match_loss(gw_syn, gw_real, args, device=self.device)

                loss_avg += loss.item()
                # 计算正则化损失
                # TODO: regularize
                if args.alpha > 0:
                    loss_reg = args.alpha * regularization(adj_syn, utils.tensor2onehot(labels_syn))
                # else:
                else:
                    loss_reg = torch.tensor(0)

                loss = loss + loss_reg

                # 更新合成图
                self.optimizer_feat.zero_grad()
                # 每次计算梯度时，梯度都会被累加到梯度缓存中。
                # 因此，在每次更新模型参数之前需要将梯度缓存清零，以避免梯度累加的影响。
                self.optimizer_pge.zero_grad()
                loss.backward() # 计算损失函数对于模型参数的梯度

                # 根据 it 的值选择更新 self.optimizer_pge 或 self.optimizer_feat
                if it % 50 < 10:
                    self.optimizer_pge.step() # 使用优化算法来更新模型参数
                else:
                    self.optimizer_feat.step()

                if args.debug and ol % 5 ==0: # 打印梯度匹配损失
                    print('Gradient matching loss:', loss.item())

                if ol == outer_loop - 1:
                    # print('loss_reg:', loss_reg.item())
                    # print('Gradient matching loss:', loss.item())
                    break

                # 进行内循环，更新 GNN 模型的参数
                feat_syn_inner = feat_syn.detach()
                adj_syn_inner = pge.inference(feat_syn)
                adj_syn_inner_norm = utils.normalize_adj_tensor(adj_syn_inner, sparse=False)
                feat_syn_inner_norm = feat_syn_inner
                for j in range(inner_loop):
                    optimizer_model.zero_grad()
                    output_syn_inner = model.forward(feat_syn_inner_norm, adj_syn_inner_norm)
                    loss_syn_inner = F.nll_loss(output_syn_inner, labels_syn)
                    loss_syn_inner.backward()
                    optimizer_model.step() # update gnn param

            # 计算平均损失并打印
            loss_avg /= (data.nclass*outer_loop)
            if it % 50 == 0:
                print('Epoch {}, loss_avg: {}'.format(it, loss_avg))

            eval_epochs = [100, 200, 400, 600, 800, 1000, 1200, 1400, 1600, 1800, 2000, 3000, 4000, 5000]

            if verbose and it in eval_epochs:
            # if verbose and (it+1) % 500 == 0:
                res = []
                runs = 1 if args.dataset in ['ogbn-arxiv', 'reddit', 'flickr'] else 3
                for i in range(runs):
                    # self.test()
                    res.append(self.test_with_val())
                res = np.array(res)
                print('Test:',
                        repr([res.mean(0), res.std(0)]))



    def get_sub_adj_feat(self, features):
        data = self.data
        args = self.args
        idx_selected = []

        from collections import Counter;
        # 计算了 self.labels_syn.cpu().numpy() 中每个元素的出现次数
        counter = Counter(self.labels_syn.cpu().numpy())

        for c in range(data.nclass):
            tmp = data.retrieve_class(c, num=counter[c])
            tmp = list(tmp)
            idx_selected = idx_selected + tmp
        idx_selected = np.array(idx_selected).reshape(-1)
        features = features[idx_selected]

        # adj_knn = torch.zeros((data.nclass*args.nsamples, data.nclass*args.nsamples)).to(self.device)
        # for i in range(data.nclass):
        #     idx = np.arange(i*args.nsamples, i*args.nsamples+args.nsamples)
        #     adj_knn[np.ix_(idx, idx)] = 1

        from sklearn.metrics.pairwise import cosine_similarity
        # features[features!=0] = 1
        k = 2
        sims = cosine_similarity(features.cpu().numpy())
        sims[(np.arange(len(sims)), np.arange(len(sims)))] = 0
        for i in range(len(sims)):
            indices_argsort = np.argsort(sims[i])
            sims[i, indices_argsort[: -k]] = 0
        adj_knn = torch.FloatTensor(sims).to(self.device)
        return features, adj_knn


def get_loops(args):
    # Get the two hyper-parameters of outer-loop and inner-loop.
    # The following values are empirically good.
    if args.one_step:
        return 10, 0

    if args.dataset in ['ogbn-arxiv']:
        return 20, 0
    if args.dataset in ['reddit']:
        return args.outer, args.inner
    if args.dataset in ['flickr']:
        return args.outer, args.inner
        # return 10, 1
    if args.dataset in ['cora']:
        return 20, 10
    if args.dataset in ['citeseer']:
        return 20, 5 # at least 200 epochs
    else:
        return 20, 5

