In [1]:
from reinforce import ReinforceAgent
from visualisation import draw_qap, draw_assignment_graph
from qap import GraphAssignmentProblem, AssignmentGraph
from torch_geometric.utils import from_networkx
import taskgenerators
import torch

In [2]:
with open("../qapdata/bur26a.dat") as f:
    qap = GraphAssignmentProblem.from_qaplib_string(f.read())
#qap = taskgenerators.generators["small_random_graphs"]()

In [3]:
agent = ReinforceAgent()
#agent.load_checkpoint("../runs/reinforce_bur26a_2/checkpoint_20000.pth")

In [4]:
net = agent.network

In [5]:
graph = AssignmentGraph(qap.graph_source, qap.graph_target, [])
pyg_data = from_networkx(graph.graph, ["side"], ["weight"])

In [6]:
embeddings = net(torch.tensor([[1.0]] * len(graph.graph)), pyg_data.edge_index, pyg_data.edge_attr)
print(embeddings)

tensor([[0.0000, 0.0000, 0.0397, 0.0000, 0.0172, 0.0989, 0.2408, 0.0000],
        [0.0000, 0.0000, 0.0397, 0.0000, 0.0172, 0.0989, 0.2408, 0.0000],
        [0.0000, 0.0000, 0.0397, 0.0000, 0.0172, 0.0989, 0.2408, 0.0000],
        [0.0000, 0.0000, 0.0399, 0.0000, 0.0173, 0.0995, 0.2422, 0.0000],
        [0.0000, 0.0000, 0.0399, 0.0000, 0.0173, 0.0995, 0.2422, 0.0000],
        [0.0000, 0.0000, 0.0393, 0.0000, 0.0170, 0.0978, 0.2383, 0.0000],
        [0.0000, 0.0000, 0.0393, 0.0000, 0.0170, 0.0978, 0.2383, 0.0000],
        [0.0000, 0.0000, 0.0388, 0.0000, 0.0168, 0.0966, 0.2354, 0.0000],
        [0.0000, 0.0000, 0.0388, 0.0000, 0.0168, 0.0966, 0.2354, 0.0000],
        [0.0000, 0.0000, 0.0388, 0.0000, 0.0168, 0.0966, 0.2354, 0.0000],
        [0.0000, 0.0000, 0.0381, 0.0000, 0.0165, 0.0949, 0.2313, 0.0000],
        [0.0000, 0.0000, 0.0381, 0.0000, 0.0165, 0.0949, 0.2313, 0.0000],
        [0.0000, 0.0000, 0.0381, 0.0000, 0.0165, 0.0949, 0.2313, 0.0000],
        [0.0000, 0.0000, 0.0384, 0.000

In [7]:
source_nodes = list(graph.subgraph_a.nodes)
target_nodes = list(graph.subgraph_b.nodes)

In [8]:
embeddings_source = embeddings[source_nodes]
embeddings_target = embeddings[target_nodes]
logits = torch.matmul(embeddings_source, embeddings_target.T)

In [9]:
print(logits)

tensor([[0.0683, 0.0372, 0.0580, 0.0611, 0.1130, 0.0441, 0.0706, 0.0588, 0.0913,
         0.0229, 0.0399, 0.0687, 0.0570, 0.0927, 0.0580, 0.0384, 0.0056, 0.0782,
         0.0765, 0.0784, 0.0679, 0.0415, 0.0486, 0.0084, 0.0115, 0.0411],
        [0.0683, 0.0372, 0.0580, 0.0611, 0.1130, 0.0441, 0.0706, 0.0588, 0.0913,
         0.0229, 0.0399, 0.0687, 0.0570, 0.0927, 0.0580, 0.0384, 0.0056, 0.0782,
         0.0765, 0.0784, 0.0679, 0.0415, 0.0486, 0.0084, 0.0115, 0.0411],
        [0.0683, 0.0372, 0.0580, 0.0611, 0.1130, 0.0441, 0.0706, 0.0588, 0.0913,
         0.0229, 0.0399, 0.0687, 0.0570, 0.0927, 0.0580, 0.0384, 0.0056, 0.0782,
         0.0765, 0.0784, 0.0679, 0.0415, 0.0486, 0.0084, 0.0115, 0.0411],
        [0.0687, 0.0374, 0.0583, 0.0615, 0.1137, 0.0444, 0.0710, 0.0592, 0.0918,
         0.0231, 0.0402, 0.0691, 0.0573, 0.0933, 0.0583, 0.0386, 0.0056, 0.0787,
         0.0769, 0.0788, 0.0683, 0.0417, 0.0489, 0.0085, 0.0116, 0.0414],
        [0.0687, 0.0374, 0.0583, 0.0615, 0.1137, 0.0444,

In [10]:
logits.shape

torch.Size([26, 26])

In [11]:
probs = torch.softmax(logits.reshape(-1), dim=0)
print(probs)

tensor([0.0015, 0.0015, 0.0015, 0.0015, 0.0016, 0.0015, 0.0015, 0.0015, 0.0015,
        0.0014, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0014, 0.0015,
        0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0014, 0.0014, 0.0015, 0.0015,
        0.0015, 0.0015, 0.0015, 0.0016, 0.0015, 0.0015, 0.0015, 0.0015, 0.0014,
        0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0014, 0.0015, 0.0015,
        0.0015, 0.0015, 0.0015, 0.0015, 0.0014, 0.0014, 0.0015, 0.0015, 0.0015,
        0.0015, 0.0015, 0.0016, 0.0015, 0.0015, 0.0015, 0.0015, 0.0014, 0.0015,
        0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0014, 0.0015, 0.0015, 0.0015,
        0.0015, 0.0015, 0.0015, 0.0014, 0.0014, 0.0015, 0.0015, 0.0015, 0.0015,
        0.0015, 0.0016, 0.0015, 0.0015, 0.0015, 0.0015, 0.0014, 0.0015, 0.0015,
        0.0015, 0.0015, 0.0015, 0.0015, 0.0014, 0.0015, 0.0015, 0.0015, 0.0015,
        0.0015, 0.0015, 0.0014, 0.0014, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015,
        0.0016, 0.0015, 0.0015, 0.0015, 

In [12]:
policy = agent.compute_policy(pyg_data, source_nodes, target_nodes)

In [13]:
policy.distribution.probs

tensor([0.0015, 0.0014, 0.0015, 0.0015, 0.0018, 0.0014, 0.0016, 0.0015, 0.0017,
        0.0013, 0.0014, 0.0015, 0.0015, 0.0017, 0.0015, 0.0014, 0.0013, 0.0016,
        0.0016, 0.0016, 0.0015, 0.0014, 0.0014, 0.0013, 0.0013, 0.0014, 0.0015,
        0.0014, 0.0015, 0.0015, 0.0018, 0.0014, 0.0016, 0.0015, 0.0017, 0.0013,
        0.0014, 0.0015, 0.0015, 0.0017, 0.0015, 0.0014, 0.0013, 0.0016, 0.0016,
        0.0016, 0.0015, 0.0014, 0.0014, 0.0013, 0.0013, 0.0014, 0.0015, 0.0014,
        0.0015, 0.0015, 0.0018, 0.0014, 0.0016, 0.0015, 0.0017, 0.0013, 0.0014,
        0.0015, 0.0015, 0.0017, 0.0015, 0.0014, 0.0013, 0.0016, 0.0016, 0.0016,
        0.0015, 0.0014, 0.0014, 0.0013, 0.0013, 0.0014, 0.0015, 0.0014, 0.0015,
        0.0015, 0.0018, 0.0014, 0.0016, 0.0015, 0.0017, 0.0013, 0.0014, 0.0015,
        0.0015, 0.0017, 0.0015, 0.0014, 0.0013, 0.0016, 0.0016, 0.0016, 0.0015,
        0.0014, 0.0014, 0.0013, 0.0013, 0.0014, 0.0015, 0.0014, 0.0015, 0.0015,
        0.0018, 0.0014, 0.0016, 0.0015, 