In [1]:
import torch
import torch.nn as nn
from torch import optim
import numpy as np
import pandas as pd
import os
import pickle
import random
import time
import importlib
import sys
sys.path.append('./GWFM/')

import GWFM.methods.AlgOT as at
import GWFM.methods.FusedGromovWassersteinFactorization as FW
from methods.AlgOT import cost_mat, ot_fgw
from methods.DataIO import StructuralDataSampler, StructuralDataSampler2, structural_data_split
from sklearn.manifold import MDS, TSNE
from typing import List, Tuple
# import functions as fc

# Load  data  and  model

In [2]:
with open('./dataset/datas.pickle', 'rb') as f:
    datas = pickle.load(f)

In [8]:
modelPath = './models/gwModel_atoms100.pt'
data_sampler = StructuralDataSampler2(datas)
num_samples = len(data_sampler)
num_atoms = 100
size_atoms = num_atoms * [35]
ot_method = 'ppa'
gamma = 5e-2
gwb_layers = 5
ot_layers = 30
dim_embedding = 4
num_classes = None       # 先驗分布
prior = None             # 先驗分布

model = FW.FGWF(num_samples=num_samples,
                num_classes=num_classes,
                size_atoms=size_atoms,
                dim_embedding=dim_embedding,
                ot_method=ot_method,
                gamma=gamma,
                gwb_layers=gwb_layers,
                ot_layers=ot_layers,
                prior=prior)
model.load_state_dict(torch.load(modelPath))

<All keys matched successfully>

# get  graph  Attributes 1

In [9]:
graph_attributes1 = model.output_weights().detach()
graph_attributes1 = graph_attributes1.t()

In [10]:
graph_attributes1.shape

torch.Size([2000, 100])

In [11]:
with open('./Attributes/graphAttributes_100.pickle', 'wb') as f:
    pickle.dump(graph_attributes1, f)

# get node Attrubutes

In [7]:
atomAttributeList = []
for i in range(len(model.embeddings)):
    x = model.embeddings[i].detach()
    atomAttributeList.append(x)

lambdas = model.output_weights().detach().t()

# transports = []
# for i in range(len(data_sampler)):
#     data = data_sampler[i]
#     graph = data[0]
#     prob = data[1]
#     emb = data[2]
    
#     tran = []
#     for k in range(model.num_atoms):
#         graph_k = model.output_atoms(k).data
#         emb_k = model.embeddings[k].data
#         psk = model.ps[k]
#         _, tran_k = ot_fgw(graph_k, graph, psk, prob,
#                            model.ot_method, model.gamma, model.ot_layers,
#                            emb_k, emb)
#         tran_k = tran_k.t()
#         tran.append(tran_k)
#     transports.append(tran)
# with open('./Attributes/nodeA/transports.pickle', 'wb') as f:
#     pickle.dump(transports, f)

with open('./Attributes/nodeA/transports.pickle', 'rb') as f:
    transports = pickle.load(f)

In [9]:
nodeAttributes = []
for graph in range(len(data_sampler)):
    lamb = lambdas[graph, :]
    nodeA = []
    for node in range(transports[graph][0].shape[0]):
        nodeB = []
        for atom in range(len(atomAttributeList)):
            atomAttribute = atomAttributeList[atom]
            attr = atomAttribute * transports[graph][atom][node, :].view(-1, 1) * lamb[atom]
            attr = attr.reshape(1, -1)
            nodeB.append(attr)
        attr = torch.cat(nodeB, dim=1)
        nodeA.append(attr)
    attr = torch.cat(nodeA, dim=0)
    nodeAttributes.append(attr)

with open('./Attributes/nodeAttributes', 'wb') as f:
    pickle.dump(nodeAttributes, f)