Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Is there any way to convert an pytorch Tensor adjacency matrix into a pytorch_geometric Data object while allowing backprop? #1511

Closed
remingtonkim opened this issue Aug 6, 2020 · 10 comments

Comments

@remingtonkim
Copy link

Is there a way to convert an adjacency tensor produced by a MLP into a Data object while allowing backprop for a generative adversarial network? The generative adversarial network has an MLP generator with a pytorch_geometric based GNN as the discriminator I have not been able to find the answer to this question yet. Here is a simplified example of what the problem is.

Say I have this MLP generator:

class Generator(nn.Module):
    def __init__(self):  
        super().__init__()
        self.fc1 = nn.Linear(3, 6)
    def forward(self, z):
        return torch.tanh(self.fc1(z))

output = gen(torch.randn(3))
# output = tensor([ 0.2085, -0.0576,  0.4957, -0.6059,  0.2571, -0.2866], grad_fn=<TanhBackward>)

So, this generator returns a vector representing a graph with two nodes, which we can reshape to form an adjacency matrix and a node feature vector.

adj = output[:4].view(2,2)
# adj = tensor([[-0.5811,  0.0070],
                        [ 0.3754, -0.2587]], grad_fn=<ViewBackward>)

node_features  = output[4:].view(2, 1)
# node_features = tensor([[0.1591],
                                         [0.0821]], grad_fn=<ViewBackward>)

Now to convert this to a pytorch_geometric Data object, we must construct a COO matrix (the x parameter in the Data object is already the node_features). However, if we loop through the adj matrix and add a connection to a COO matrix with the code below, back propagation does not work from the pytorch_geometric GNN to the pytorch MLP.

coo = [[], []]
for i in len(adj):
    for j in len(adj[i]):
         # for our purposes, say there is an edge if the value >0
         if adj[i][j] >0:
             coo[0].append(i)
             coo[1].append(j)

We can now construct the Data object like so:

d = Data(x = node_features, edge_index = torch.LongTensor(coo))

However, when training a GAN by converting the generator output to a Data object for the GNN discriminator, back propagation and optimization does not work (I assume because the grad_fn and grad properties are lost. Does anyone know how to convert a tensor to a pytorch_geometric Data object while allowing back prop to happen in the generative adversarial network with MLP generator that outputs adj matrix/tensor and node features and GNN (pytorch_geometric based) discriminator that takes a Data object as input?

@remingtonkim remingtonkim changed the title Is there any way to convert an pytorch Tensor adjacecny matrix into a pytorch_geometric Data object while allowing backprop? Is there any way to convert an pytorch Tensor adjacency matrix into a pytorch_geometric Data object while allowing backprop? Aug 6, 2020
@rusty1s
Copy link
Member

rusty1s commented Aug 7, 2020

It is correct that you lose gradients that way. In order to backpropagate through sparse matrices, you need to compute both edge_index and edge_weight (the first one holding the COO index and the second one holding the value for each edge). This way, gradients flow from edge_weight to your dense adjacency matrix.

In code, this would look as following:

edge_index = (adj > 0).nonzero().t()
row, col = edge_index
edge_weight = adj[row, col]
self.conv(x, edge_index, edge_weight)

@smorad
Copy link

smorad commented Mar 19, 2021

Using batch mode, would this be

batch = Batch(x=torch.rand(num_graphs, n, num_feats), adj=torch.randint(0, 1, (num_graphs, n, n)))

edge_index = (batch.adj > 0).nonzero().t()
b, row, col = edge_index
batch.edge_weight = batch.adj[b, row, col]

I'm having trouble understanding the sparse storage mechanisms.

@rusty1s
Copy link
Member

rusty1s commented Mar 19, 2021

Nearly:

adj=torch.randint(0, 1, (num_graphs, n, n)))
offset, row, col = (batch.adj > 0).nonzero().t()
edge_weight = adj[offset, row, col]
row += offset * n
col += offset * n
edge_index = torch.stack([row, col], dim=0)
x = x.view(num_graphs * n, num_feats)
batch = torch.arange(0, num_graphs).view(-1, 1).repeat(1, n).view(-1)

Here, we combine the node dimension and the batch dimension, so that separate graphs are represented as a "super graph".
This is especially useful when mini-batching graphs of differing sizes.

@smorad
Copy link

smorad commented Mar 27, 2021

Apparently, not all models take edge_weight as input, e.g. GATConv. Is there a way to propagate gradients in this case? I suppose one option would be extending said models to handle edge_weight, but is there an easier way?

@rusty1s
Copy link
Member

rusty1s commented Mar 29, 2021

Yes, not all models support edge_weight. For example, GATConv will compute edge_weight on the fly. Nonetheless, you can just extend GATConv to handle edge features, e.g., just as we do here.

@vctorwei
Copy link

Can I also use the following code to generate adjacent matrice in batch format?

My goal is to build a simple GNN base GAN for the graph of the same number of nodes (90) with an MLP generator and a Graph-level classifier as discriminator.

All my graphs are 90x90 symmetrical adjacency matrices with zero diagonal, so what I'm trying to do here is to generate only 4005 elements(half of a 90x90 matrix) and map them to the upper and lower triangle of another 90x90 torch tensor. Then I put a batch(n_samples_in_batch=16) of 90x90 tensor to a list and convert them into a batch.

Does this look like a reasonable way to do it? Sorry for my clumsy codes.

def get_noise(n_samples_in_batch, noise_dim):
    return torch.rand(n_samples_in_batch,noise_dim)

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.conv1 = nn.Linear(90,1335)
        self.conv2 = nn.Linear(1335,4005)
        self.relu = nn.ReLU()   
        self.sig = nn.Sigmoid()

    def forward(self, noise):
        noise = self.conv1(noise)
        noise = self.relu(noise)
        noise = self.conv2(noise)
        noise = self.sig(noise)
        return  noise

def mlp_to_graph_batch(mlp_output, n_samples):
  mlp_output=mlp_output.cuda()
  n_samples=n_samples.cuda()
  datalist=[]
  for i in range(n_samples):
    adj=torch.zeros(90,90).to(device)
    adj[torch.triu(torch.ones(90,90)) != 1] = mlp_output[i]
    adj_t=adj.t()
    adj[torch.tril(torch.ones(90,90)) != 1] = adj_t[torch.tril(torch.ones(90,90)) != 1]
    x=torch.ones(90,1).type(torch.float)
    edge_index = (adj>0).nonzero().t()
    row, col = edge_index
    edge_weight = adj[row, col]
    dataformat=Data(x=x, edge_index=edge_index, edge_attr=edge_weight)
    datalist.append(dataformat)
  batchlist=Batch.from_data_list(datalist)
  return batchlist

class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = GraphConv(1, 128)
        self.conv2 = GraphConv(128, 128)
        self.conv3 = GraphConv(128, 128)
        self.lin1 = torch.nn.Linear(128, 64)
        self.lin2 = torch.nn.Linear(64, 1)

    def forward(self, data):
        x, edge_index, edge_weight, batch = data.x, data.edge_index, data.edge_attr, data.batch
        x = F.relu(self.conv1(x, edge_index, edge_weight))
        x = F.relu(self.conv2(x, edge_index, edge_weight))
        x = F.relu(self.conv3(x, edge_index, edge_weight))
        x = global_mean_pool(x, batch)
        x = F.relu(self.lin1(x))
        x = self.lin2(x) 
        return x

criterion = torch.nn.BCEWithLogitsLoss()
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
modelD = Net().to(device)
modelG = Generator().to(device)
optimizerD = torch.optim.Adam(modelD.parameters(), lr=0.00005)
optimizerG = torch.optim.Adam(modelG.parameters(), lr=0.0005)
G_losses = []
D_losses = []
num_epochs = 1000


for epoch in range(num_epochs):
    # For each batch in the dataloader
    for  real in train_loader:
        real=real.to(device)
        n_samples=torch.tensor(real.num_graphs).cuda()
        label_real=torch.ones(n_samples).type(torch.float).view(-1,1).to(device)
        label_fake=torch.zeros(n_samples).type(torch.float).view(-1,1).to(device)  
        optimizerG.zero_grad() 
        noise=get_noise(n_samples,noise_dim=90).to(device)
        fake_mlp=modelG(noise).to(device)
        fake=mlp_to_graph_batch(fake_mlp.cuda(),n_samples.cuda()).to(device)
  
        #### Training generator ###
        output = modelD(fake)
        G_losses = criterion(output,label_real)
        G_losses.backward()
        optimizerG.step()

        #### Training discriminator ###
        optimizerD.zero_grad()  
        out = modelD(real).to(device)
        D_loss_real = criterion(out, label_real) 

        fake=mlp_to_graph_batch(fake_mlp.detach(),n_samples).to(device)
        out = modelD(fake).to(device)
        D_loss_fake = criterion(out,label_fake)

        D_losses = (D_loss_real + D_loss_fake) / 2
        D_losses.backward()
        optimizerD.step()

@rusty1s
Copy link
Member

rusty1s commented May 26, 2021

Yes, this looks good to me. You might want to get rid of the for-loop iterating over each example in the mini-batch at one point in time though, as I assume this slows-down your code.

@smorad
Copy link

smorad commented May 27, 2021

@vctorwei you may want to look at #2543

I'm planning on upstreaming this after a paper deadline in ~3 weeks.

@akul-goyal
Copy link

I am interested in implementing an adversarial attack against at GATConv model I have created using the deep robust library. They use a separate dense adj matrix that they update and add to the original adj matrix to create the adversarial samples. Furthermore, I am using a NEIGHBOR_SAMPLER as a data loader for my training process. Is there
a.) a way to pass in edge weights/ attributes to NEIGHBOR_SAMPLER such that it can later be passed into the model?
b.) will the gradients flow back to the separate dense adj matrix that is being used to modify the original adj matrix?

@rusty1s
Copy link
Member

rusty1s commented Sep 16, 2022

(a) The NeighborSampler should return a e_id vector which you can use to query the related edge weights. Note that our new NeighborLoader does this internally already. (b) I would assume though as long as there is not detach taking place somewhere.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

6 participants