Skip to content

Commit

Permalink
fixed size and edge_index args in message func
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s committed Apr 22, 2019
1 parent e1f6c73 commit 6960061
Show file tree
Hide file tree
Showing 12 changed files with 92 additions and 34 deletions.
Expand Up @@ -248,12 +248,11 @@ <h1>Source code for torch_geometric.nn.conv.agnn_conv</h1><div class="highlight"
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">propagate</span><span class="p">(</span>
<span class="n">edge_index</span><span class="p">,</span> <span class="n">x</span><span class="o">=</span><span class="n">x</span><span class="p">,</span> <span class="n">x_norm</span><span class="o">=</span><span class="n">x_norm</span><span class="p">,</span> <span class="n">num_nodes</span><span class="o">=</span><span class="n">x</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">0</span><span class="p">))</span></div>

<span class="k">def</span> <span class="nf">message</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">edge_index</span><span class="p">,</span> <span class="n">x_j</span><span class="p">,</span> <span class="n">x_norm_i</span><span class="p">,</span> <span class="n">x_norm_j</span><span class="p">,</span> <span class="n">num_nodes</span><span class="p">):</span>
<span class="k">def</span> <span class="nf">message</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">edge_index_i</span><span class="p">,</span> <span class="n">x_j</span><span class="p">,</span> <span class="n">x_norm_i</span><span class="p">,</span> <span class="n">x_norm_j</span><span class="p">,</span> <span class="n">num_nodes</span><span class="p">):</span>
<span class="c1"># Compute attention coefficients.</span>
<span class="n">beta</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">beta</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">requires_grad</span> <span class="k">else</span> <span class="bp">self</span><span class="o">.</span><span class="n">_buffers</span><span class="p">[</span><span class="s1">&#39;beta&#39;</span><span class="p">]</span>
<span class="n">alpha</span> <span class="o">=</span> <span class="n">beta</span> <span class="o">*</span> <span class="p">(</span><span class="n">x_norm_i</span> <span class="o">*</span> <span class="n">x_norm_j</span><span class="p">)</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
<span class="n">i</span> <span class="o">=</span> <span class="mi">0</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">flow</span> <span class="o">==</span> <span class="s1">&#39;target_to_source&#39;</span> <span class="k">else</span> <span class="mi">1</span>
<span class="n">alpha</span> <span class="o">=</span> <span class="n">softmax</span><span class="p">(</span><span class="n">alpha</span><span class="p">,</span> <span class="n">edge_index</span><span class="p">[</span><span class="n">i</span><span class="p">],</span> <span class="n">num_nodes</span><span class="p">)</span>
<span class="n">alpha</span> <span class="o">=</span> <span class="n">softmax</span><span class="p">(</span><span class="n">alpha</span><span class="p">,</span> <span class="n">edge_index_i</span><span class="p">,</span> <span class="n">num_nodes</span><span class="p">)</span>

<span class="k">return</span> <span class="n">x_j</span> <span class="o">*</span> <span class="n">alpha</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>

Expand Down
Expand Up @@ -282,12 +282,11 @@ <h1>Source code for torch_geometric.nn.conv.gat_conv</h1><div class="highlight">
<span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">mm</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">weight</span><span class="p">)</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">heads</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">out_channels</span><span class="p">)</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">propagate</span><span class="p">(</span><span class="n">edge_index</span><span class="p">,</span> <span class="n">x</span><span class="o">=</span><span class="n">x</span><span class="p">,</span> <span class="n">num_nodes</span><span class="o">=</span><span class="n">x</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">0</span><span class="p">))</span></div>

<span class="k">def</span> <span class="nf">message</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">edge_index</span><span class="p">,</span> <span class="n">x_i</span><span class="p">,</span> <span class="n">x_j</span><span class="p">,</span> <span class="n">num_nodes</span><span class="p">):</span>
<span class="k">def</span> <span class="nf">message</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">edge_index_i</span><span class="p">,</span> <span class="n">x_i</span><span class="p">,</span> <span class="n">x_j</span><span class="p">,</span> <span class="n">num_nodes</span><span class="p">):</span>
<span class="c1"># Compute attention coefficients.</span>
<span class="n">alpha</span> <span class="o">=</span> <span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">([</span><span class="n">x_i</span><span class="p">,</span> <span class="n">x_j</span><span class="p">],</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">att</span><span class="p">)</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
<span class="n">alpha</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">leaky_relu</span><span class="p">(</span><span class="n">alpha</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">negative_slope</span><span class="p">)</span>
<span class="n">i</span> <span class="o">=</span> <span class="mi">0</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">flow</span> <span class="o">==</span> <span class="s1">&#39;target_to_source&#39;</span> <span class="k">else</span> <span class="mi">1</span>
<span class="n">alpha</span> <span class="o">=</span> <span class="n">softmax</span><span class="p">(</span><span class="n">alpha</span><span class="p">,</span> <span class="n">edge_index</span><span class="p">[</span><span class="n">i</span><span class="p">],</span> <span class="n">num_nodes</span><span class="p">)</span>
<span class="n">alpha</span> <span class="o">=</span> <span class="n">softmax</span><span class="p">(</span><span class="n">alpha</span><span class="p">,</span> <span class="n">edge_index_i</span><span class="p">,</span> <span class="n">num_nodes</span><span class="p">)</span>

<span class="c1"># Sample attention coefficients stochastically.</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">training</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
Expand Down

0 comments on commit 6960061

Please sign in to comment.