In [1]:
from reinforce import ReinforceAgent
from visualisation import draw_qap, draw_assignment_graph
from qap import GraphAssignmentProblem, AssignmentGraph
from torch_geometric.utils import from_networkx
import taskgenerators
import torch

In [2]:
with open("../qapdata/bur26a.dat") as f:
    qap = GraphAssignmentProblem.from_qaplib_string(f.read())
#qap = taskgenerators.generators["small_random_graphs"]()

In [3]:
agent = ReinforceAgent()
#agent.load_checkpoint("../runs/reinforce_bur26a_2/checkpoint_20000.pth")

In [4]:
net = agent.network

In [5]:
graph = AssignmentGraph(qap.graph_source, qap.graph_target, [])
pyg_data = from_networkx(graph.graph, ["side"], ["weight"])

In [6]:
embeddings = net(pyg_data.x, pyg_data.edge_index, pyg_data.edge_attr)
print(embeddings)

tensor([[ 0.2010, -0.2036,  0.1315, -0.1233, -0.1856,  0.0670, -0.1306, -0.0161],
        [ 0.2010, -0.2036,  0.1315, -0.1233, -0.1856,  0.0670, -0.1306, -0.0161],
        [ 0.2010, -0.2036,  0.1315, -0.1233, -0.1856,  0.0670, -0.1306, -0.0161],
        [ 0.2012, -0.2043,  0.1320, -0.1242, -0.1864,  0.0670, -0.1306, -0.0166],
        [ 0.2012, -0.2043,  0.1320, -0.1242, -0.1864,  0.0670, -0.1306, -0.0166],
        [ 0.2007, -0.2024,  0.1306, -0.1217, -0.1842,  0.0670, -0.1306, -0.0153],
        [ 0.2007, -0.2024,  0.1306, -0.1217, -0.1842,  0.0670, -0.1306, -0.0153],
        [ 0.2003, -0.2010,  0.1295, -0.1198, -0.1826,  0.0670, -0.1307, -0.0142],
        [ 0.2003, -0.2010,  0.1295, -0.1198, -0.1826,  0.0670, -0.1307, -0.0142],
        [ 0.2003, -0.2010,  0.1295, -0.1198, -0.1826,  0.0670, -0.1307, -0.0142],
        [ 0.1997, -0.1990,  0.1280, -0.1172, -0.1803,  0.0670, -0.1307, -0.0128],
        [ 0.1997, -0.1990,  0.1280, -0.1172, -0.1803,  0.0670, -0.1307, -0.0128],
        [ 0.1997

In [7]:
source_nodes = list(graph.subgraph_a.nodes)
target_nodes = list(graph.subgraph_b.nodes)

In [8]:
embeddings_source = embeddings[source_nodes]
embeddings_target = embeddings[target_nodes]
logits = torch.matmul(embeddings_source, embeddings_target.T)

In [9]:
print(logits)

tensor([[0.1004, 0.0929, 0.0979, 0.0987, 0.1111, 0.0946, 0.1009, 0.0981, 0.1059,
         0.0895, 0.0936, 0.1005, 0.0977, 0.1062, 0.0979, 0.0932, 0.0854, 0.1028,
         0.1024, 0.1028, 0.1003, 0.0940, 0.0957, 0.0860, 0.0868, 0.0939],
        [0.1004, 0.0929, 0.0979, 0.0987, 0.1111, 0.0946, 0.1009, 0.0981, 0.1059,
         0.0895, 0.0936, 0.1005, 0.0977, 0.1062, 0.0979, 0.0932, 0.0854, 0.1028,
         0.1024, 0.1028, 0.1003, 0.0940, 0.0957, 0.0860, 0.0868, 0.0939],
        [0.1004, 0.0929, 0.0979, 0.0987, 0.1111, 0.0946, 0.1009, 0.0981, 0.1059,
         0.0895, 0.0936, 0.1005, 0.0977, 0.1062, 0.0979, 0.0932, 0.0854, 0.1028,
         0.1024, 0.1028, 0.1003, 0.0940, 0.0957, 0.0860, 0.0868, 0.0939],
        [0.1005, 0.0931, 0.0981, 0.0988, 0.1113, 0.0947, 0.1011, 0.0983, 0.1061,
         0.0896, 0.0937, 0.1006, 0.0978, 0.1064, 0.0981, 0.0934, 0.0855, 0.1029,
         0.1025, 0.1030, 0.1005, 0.0941, 0.0958, 0.0861, 0.0869, 0.0940],
        [0.1005, 0.0931, 0.0981, 0.0988, 0.1113, 0.0947,

In [10]:
logits.shape

torch.Size([26, 26])

In [11]:
probs = torch.softmax(logits.reshape(-1), dim=0)
print(probs)

tensor([0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015,
        0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015,
        0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015,
        0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015,
        0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015,
        0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015,
        0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015,
        0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015,
        0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015,
        0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015,
        0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015,
        0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015,
        0.0015, 0.0015, 0.0015, 0.0015, 

In [12]:
policy = agent.compute_policy(pyg_data, source_nodes, target_nodes)

In [13]:
policy.distribution.probs

tensor([0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015,
        0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015,
        0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015,
        0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015,
        0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015,
        0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015,
        0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015,
        0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015,
        0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015,
        0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015,
        0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015,
        0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015,
        0.0015, 0.0015, 0.0015, 0.0015, 