In [21]:
from reinforce import ReinforceAgent
from visualisation import draw_qap, draw_assignment_graph
from qap import GraphAssignmentProblem, AssignmentGraph
from torch_geometric.utils import from_networkx
import taskgenerators
import torch
from evaluate import random_assignment
import numpy as np

In [16]:
with open("../qapdata/bur26a.dat") as f:
    qap = GraphAssignmentProblem.from_qaplib_string(f.read())
#qap = taskgenerators.generators["small_random_graphs"]()

In [17]:
# optimum on bur26a
qap.compute_value([int(s)-1 for s in "26 15 11 7 4 12 13 2 6 18 1 5 9 21 8 14 3 20 19 25 17 10 16 24 23 22".split()])

5426670.0

In [18]:
agent = ReinforceAgent()
agent.load_checkpoint("../runs/reinforce_gcn_randomgraphs/checkpoint_end.pth")

In [19]:
net = agent.network

In [6]:
graph = AssignmentGraph(qap.graph_source, qap.graph_target, [])
pyg_data = from_networkx(graph.graph, ["side"], ["weight"])

In [7]:
embeddings = net(pyg_data.x, pyg_data.edge_index, pyg_data.edge_attr)
print(embeddings)

tensor([[ 0.0662,  0.2369,  0.0389,  0.2958, -0.3285,  0.2196, -0.5065, -0.0813],
        [ 0.0662,  0.2369,  0.0389,  0.2958, -0.3285,  0.2196, -0.5065, -0.0813],
        [ 0.0662,  0.2369,  0.0389,  0.2958, -0.3285,  0.2196, -0.5065, -0.0813],
        [ 0.0669,  0.2377,  0.0395,  0.2970, -0.3299,  0.2196, -0.5074, -0.0813],
        [ 0.0669,  0.2377,  0.0395,  0.2970, -0.3299,  0.2196, -0.5074, -0.0813],
        [ 0.0649,  0.2355,  0.0377,  0.2939, -0.3263,  0.2195, -0.5049, -0.0813],
        [ 0.0649,  0.2355,  0.0377,  0.2939, -0.3263,  0.2195, -0.5049, -0.0813],
        [ 0.0633,  0.2337,  0.0363,  0.2916, -0.3235,  0.2195, -0.5031, -0.0814],
        [ 0.0633,  0.2337,  0.0363,  0.2916, -0.3235,  0.2195, -0.5031, -0.0814],
        [ 0.0633,  0.2337,  0.0363,  0.2916, -0.3235,  0.2195, -0.5031, -0.0814],
        [ 0.0592,  0.2292,  0.0326,  0.2854, -0.3161,  0.2195, -0.4981, -0.0814],
        [ 0.0592,  0.2292,  0.0326,  0.2854, -0.3161,  0.2195, -0.4981, -0.0814],
        [ 0.0592

In [8]:
source_nodes = list(graph.subgraph_a.nodes)
target_nodes = list(graph.subgraph_b.nodes)

In [9]:
embeddings_source = embeddings[source_nodes]
embeddings_target = embeddings[target_nodes]
logits = torch.matmul(embeddings_source, embeddings_target.T)

In [10]:
print(logits)

tensor([[0.9053, 0.6010, 0.7441, 0.6380, 1.3483, 0.6230, 0.6914, 0.8087, 1.0343,
         0.3846, 0.5552, 0.8563, 0.7270, 1.2470, 0.7783, 0.5369, 0.3567, 1.1315,
         0.9118, 0.9549, 0.8394, 0.4271, 0.4955, 0.3785, 0.3858, 0.5173],
        [0.9053, 0.6010, 0.7441, 0.6380, 1.3483, 0.6230, 0.6914, 0.8087, 1.0343,
         0.3846, 0.5552, 0.8563, 0.7270, 1.2470, 0.7783, 0.5369, 0.3567, 1.1315,
         0.9118, 0.9549, 0.8394, 0.4271, 0.4955, 0.3785, 0.3858, 0.5173],
        [0.9053, 0.6010, 0.7441, 0.6380, 1.3483, 0.6230, 0.6914, 0.8087, 1.0343,
         0.3846, 0.5552, 0.8563, 0.7270, 1.2470, 0.7783, 0.5369, 0.3567, 1.1315,
         0.9118, 0.9549, 0.8394, 0.4271, 0.4955, 0.3785, 0.3858, 0.5173],
        [0.9077, 0.6025, 0.7460, 0.6395, 1.3521, 0.6245, 0.6931, 0.8108, 1.0371,
         0.3854, 0.5566, 0.8586, 0.7288, 1.2505, 0.7803, 0.5382, 0.3574, 1.1346,
         0.9142, 0.9574, 0.8416, 0.4281, 0.4967, 0.3793, 0.3866, 0.5185],
        [0.9077, 0.6025, 0.7460, 0.6395, 1.3521, 0.6245,

In [11]:
logits.shape

torch.Size([26, 26])

In [12]:
probs = torch.softmax(logits.reshape(-1), dim=0)
print(probs)

tensor([0.0017, 0.0013, 0.0015, 0.0013, 0.0027, 0.0013, 0.0014, 0.0016, 0.0019,
        0.0010, 0.0012, 0.0016, 0.0014, 0.0024, 0.0015, 0.0012, 0.0010, 0.0021,
        0.0017, 0.0018, 0.0016, 0.0011, 0.0011, 0.0010, 0.0010, 0.0012, 0.0017,
        0.0013, 0.0015, 0.0013, 0.0027, 0.0013, 0.0014, 0.0016, 0.0019, 0.0010,
        0.0012, 0.0016, 0.0014, 0.0024, 0.0015, 0.0012, 0.0010, 0.0021, 0.0017,
        0.0018, 0.0016, 0.0011, 0.0011, 0.0010, 0.0010, 0.0012, 0.0017, 0.0013,
        0.0015, 0.0013, 0.0027, 0.0013, 0.0014, 0.0016, 0.0019, 0.0010, 0.0012,
        0.0016, 0.0014, 0.0024, 0.0015, 0.0012, 0.0010, 0.0021, 0.0017, 0.0018,
        0.0016, 0.0011, 0.0011, 0.0010, 0.0010, 0.0012, 0.0017, 0.0013, 0.0015,
        0.0013, 0.0027, 0.0013, 0.0014, 0.0016, 0.0020, 0.0010, 0.0012, 0.0016,
        0.0014, 0.0024, 0.0015, 0.0012, 0.0010, 0.0022, 0.0017, 0.0018, 0.0016,
        0.0011, 0.0011, 0.0010, 0.0010, 0.0012, 0.0017, 0.0013, 0.0015, 0.0013,
        0.0027, 0.0013, 0.0014, 0.0016, 

In [13]:
policy = agent.compute_policy(pyg_data, source_nodes, target_nodes)

In [14]:
policy.distribution.probs

tensor([0.0017, 0.0013, 0.0015, 0.0013, 0.0027, 0.0013, 0.0014, 0.0016, 0.0019,
        0.0010, 0.0012, 0.0016, 0.0014, 0.0024, 0.0015, 0.0012, 0.0010, 0.0021,
        0.0017, 0.0018, 0.0016, 0.0011, 0.0011, 0.0010, 0.0010, 0.0012, 0.0017,
        0.0013, 0.0015, 0.0013, 0.0027, 0.0013, 0.0014, 0.0016, 0.0019, 0.0010,
        0.0012, 0.0016, 0.0014, 0.0024, 0.0015, 0.0012, 0.0010, 0.0021, 0.0017,
        0.0018, 0.0016, 0.0011, 0.0011, 0.0010, 0.0010, 0.0012, 0.0017, 0.0013,
        0.0015, 0.0013, 0.0027, 0.0013, 0.0014, 0.0016, 0.0019, 0.0010, 0.0012,
        0.0016, 0.0014, 0.0024, 0.0015, 0.0012, 0.0010, 0.0021, 0.0017, 0.0018,
        0.0016, 0.0011, 0.0011, 0.0010, 0.0010, 0.0012, 0.0017, 0.0013, 0.0015,
        0.0013, 0.0027, 0.0013, 0.0014, 0.0016, 0.0020, 0.0010, 0.0012, 0.0016,
        0.0014, 0.0024, 0.0015, 0.0012, 0.0010, 0.0022, 0.0017, 0.0018, 0.0016,
        0.0011, 0.0011, 0.0010, 0.0010, 0.0012, 0.0017, 0.0013, 0.0015, 0.0013,
        0.0027, 0.0013, 0.0014, 0.0016, 

In [23]:
random_vectors = np.random.random((16,16))

In [30]:
torch.softmax(torch.tensor(np.matmul(random_vectors, random_vectors.T).reshape(-1)), dim=0)

tensor([7.4448e-03, 8.9952e-04, 1.4860e-03, 1.5596e-03, 7.3317e-03, 3.9320e-03,
        2.1644e-03, 6.7865e-04, 1.0337e-03, 2.0622e-03, 1.7857e-03, 2.3455e-03,
        1.5993e-03, 4.5757e-03, 9.9394e-04, 1.8732e-03, 8.9952e-04, 1.0299e-03,
        5.5090e-04, 6.6443e-04, 1.5880e-03, 1.7305e-03, 8.7228e-04, 4.6118e-04,
        1.7778e-04, 9.7041e-04, 6.8239e-04, 6.6224e-04, 8.0821e-04, 7.6630e-04,
        3.9282e-04, 6.8405e-04, 1.4860e-03, 5.5090e-04, 5.1434e-03, 1.0089e-03,
        7.7183e-03, 1.4325e-03, 1.6282e-03, 3.8936e-04, 1.0680e-03, 1.4846e-03,
        1.8560e-03, 1.1187e-03, 2.2665e-03, 3.1241e-03, 1.6892e-03, 1.3138e-03,
        1.5596e-03, 6.6443e-04, 1.0089e-03, 3.7904e-03, 4.0090e-03, 2.5384e-03,
        1.3806e-03, 7.1361e-04, 3.7543e-04, 1.1550e-03, 1.6827e-03, 1.2903e-03,
        1.1953e-03, 1.7044e-03, 1.5325e-03, 1.3722e-03, 7.3317e-03, 1.5880e-03,
        7.7183e-03, 4.0090e-03, 2.3176e-01, 1.9778e-02, 5.9665e-03, 2.3060e-03,
        5.4838e-03, 9.9237e-03, 1.3314e-