In [1]:
import numpy as np

In [94]:

scores = np.array([
    [-np.Inf,10,3,5],
    [-np.Inf,-np.Inf,10,8],
    [-np.Inf,1,-np.Inf,10],
    [-np.Inf,5,20,-np.Inf]
])


In [141]:
scores = np.array([
    [-np.Inf,9,10,9],
    [-np.Inf,-np.Inf,20,3],
    [-np.Inf,30,-np.Inf,30],
    [-np.Inf,11,0,-np.Inf]
])

In [142]:
def matrix_to_graph(scores):
    '''Transform adjacency matrix into dictionary, where V= {start node: {end node: cost}}'''
    graph = {}
    for i_start, row in enumerate(scores):
        arcs = {i_end:score for i_end, score in enumerate(row) if score >= 0}

        # Don't include nodes that have no connections
        if len(arcs) > 0:
            graph[i_start] = arcs 
    
    return graph

graph = matrix_to_graph(scores)
reverse_graph = matrix_to_graph(scores.T)

In [143]:
reverse_graph

{1: {0: 9.0, 2: 30.0, 3: 11.0},
 2: {0: 10.0, 1: 20.0, 3: 0.0},
 3: {0: 9.0, 1: 3.0, 2: 30.0}}

In [144]:
index_pairs = {}
for node, edges in reverse_graph.items():
    head_node = max(edges, key=edges.get)
    index_pairs[node] = head_node

max_index_pairs = index_pairs
max_index_pairs

{1: 2, 2: 1, 3: 2}

In [145]:
def find_cycle(index_pairs):
    for node in index_pairs.keys():
        path = [node]

        while index_pairs.get(path[-1], None): # Looking at the current node
            head = index_pairs[path[-1]]

            if head == node:
                # Return path of cycle with their head:dep connections
                return {head:dep for head, dep in zip(path, path[1:] + [node])}
            
            path.append(head)
    
    return None

cycle = find_cycle(max_index_pairs)
cycle

{1: 2, 2: 1}

In [146]:
def resolve(graph, cycle):
    rev_cycle = {v:k for k, v in cycle.items()}

    # Sums scores of the cycle from the graph
    sum_cycle_scores = sum([graph[head][dep] for head, dep in cycle.items()])
    cycle_scores = {head:(sum_cycle_scores - graph[dep][head]) for head, dep in cycle.items()}

    # index that will represent the newly created node
    cycle_node_index = min(cycle_scores.keys())
    new_graph = {}
    mapping = {} # This will be used to know what node connection were kept (used when cycle resolv)
    for node in graph.keys():
        new_nodes = graph[node].copy()

        max_node_val = -1
        max_node_index = -1
        for n in list(new_nodes.keys()):
            # Get the largest value as connection to new node
            if n in cycle_scores:
                val = new_nodes[n] + cycle_scores[n]
                if val > max_node_val:
                    max_node_val = val
                    max_node_index = n

                new_nodes.pop(n)

            if max_node_val >= 0:
                new_nodes[cycle_node_index] = max_node_val

        mapping[node] = max_node_index
        new_graph[node]  = new_nodes


    new_node = {}
    for node in cycle.keys():
        arcs = new_graph.pop(node)

        for n in list(arcs.keys()):
            # Remove cycle nodes from arcs
            if n in cycle:
                arcs.pop(n)
            elif n not in new_node or new_node[n] < arcs[n]:
                new_node[n] = arcs[n]


    new_graph[cycle_node_index] = new_node
    return new_graph, mapping

In [147]:
graph

{0: {1: 9.0, 2: 10.0, 3: 9.0},
 1: {2: 20.0, 3: 3.0},
 2: {1: 30.0, 3: 30.0},
 3: {1: 11.0, 2: 0.0}}

In [148]:
new_graph, mapping = resolve(graph, cycle)

In [149]:
y = {1: 0, 2: 1}
y = {0: 1, 1: 3}

In [150]:
cycle

{1: 2, 2: 1}

In [151]:
mapping

{0: 2, 1: 2, 2: 1, 3: 1}

In [154]:
def resolve_cycle(new_graph, cycle):
    for node_c in cycle.keys():
        for node_g in new_graph.keys():
            if cycle[node_c] == new_graph[node_g]:
                cycle = [[h,d] for h, d in cycle.items() if h != node_c]
                return [[h,d] for h, d in new_graph.items()] + cycle

resolve_cycle(y, cycle)

[[0, 1], [1, 3], [1, 2]]

In [155]:
y

{0: 1, 1: 3}

In [7]:
import numpy as np

# [[0, 3], [3, 2], [2, 1]]
scores = np.array([
    [-np.Inf, 10, 5, 15],
    [-np.Inf, -np.Inf, 20, 15],
    [-np.Inf, 25, -np.Inf, 25],
    [-np.Inf, 30, 10, -np.Inf]
])

In [5]:
# [[0, 1], [1, 3], [3, 2]]
scores = np.array([
    [-np.Inf,10,3,5],
    [-np.Inf,-np.Inf,10,8],
    [-np.Inf,1,-np.Inf,10],
    [-np.Inf,5,20,-np.Inf]
])

In [3]:
# [[0, 2], [2, 3], [3, 1]]
scores = np.array([
    [-np.Inf, 3, 10, 5],
    [-np.Inf, -np.Inf, 1, 10],
    [-np.Inf, 10, -np.Inf, 8],
    [-np.Inf, 20, 5, -np.Inf],
])

In [1]:
import numpy as np

# [[2, 3], [0, 2], [2, 1]]
scores = np.array([
    [-np.Inf,9,10,9],
    [-np.Inf,-np.Inf,20,3],
    [-np.Inf,30,-np.Inf,30],
    [-np.Inf,11,0,-np.Inf]
])

In [2]:
from decoder import CLE


cle = CLE(verbose=True)
print(cle.decode(scores))


[decode] Max index pairs: {1: 2, 2: 1, 3: 2}
[decode] Cycle found: {1: 2, 2: 1}
[resolve] New updated graph:
{0: {1: 29.0, 2: 40.0, 3: 9.0},
 1: {2: 20.0, 3: 3.0},
 2: {1: 30.0, 3: 30.0},
 3: {1: 31.0, 2: 30.0}}
[resolve] Starting to create new graph.
[resolve] New created graph:
{0: {1: 40.0, 3: 9.0}, 1: {3: 30.0}, 3: {1: 31.0}}
[decode] Max index pairs: {3: 1, 1: 0}
[decode] Cycle found: None
[decode] No cycle found. Returning max index pairs.
[decode] New graph
[[1, 3], [0, 1]]
[decode] New (resolved) graph
[[2, 3], [0, 2]]
[resolve cycle] Resolved cycle [[2, 1]]
[resolve cycle] Resolved graph
[[2, 3], [0, 2], [2, 1]]
[[2, 3], [0, 2], [2, 1]]


In [14]:
# Test with random

cle = CLE(verbose=False)

for i in range(1000):
    scores = np.random.randint(0, 30, (100,100)).astype(float)
    scores[:, 0] = -np.Inf
    scores[np.diag_indices_from(scores)] = -np.Inf

    #print(scores)
    try:
        cle.decode(scores)
    except:
        cle = CLE(verbose=True)
        cle.decode(scores)
        break

[decode] Max index pairs: {1: 13, 2: 5, 3: 43, 4: 18, 5: 17, 6: 26, 7: 29, 8: 70, 9: 11, 10: 70, 11: 5, 12: 11, 13: 30, 14: 20, 15: 3, 16: 26, 17: 5, 18: 11, 19: 2, 20: 5, 21: 22, 22: 73, 23: 33, 24: 6, 25: 6, 26: 72, 27: 6, 28: 11, 29: 6, 30: 6, 31: 45, 32: 17, 33: 84, 34: 57, 35: 26, 36: 4, 37: 0, 38: 24, 39: 12, 40: 14, 41: 22, 42: 15, 43: 12, 44: 1, 45: 35, 46: 4, 47: 65, 48: 32, 49: 25, 50: 73, 51: 5, 52: 7, 53: 26, 54: 2, 55: 70, 56: 11, 57: 30, 58: 38, 59: 13, 60: 2, 61: 8, 62: 11, 63: 54, 64: 88, 65: 1, 66: 50, 67: 2, 68: 7, 69: 9, 70: 8, 71: 21, 72: 25, 73: 29, 74: 8, 75: 53, 76: 41, 77: 25, 78: 41, 79: 87, 80: 3, 81: 43, 82: 7, 83: 49, 84: 31, 85: 43, 86: 42, 87: 1, 88: 11, 89: 15, 90: 32, 91: 23, 92: 27, 93: 20, 94: 10, 95: 3, 96: 3, 97: 8, 98: 15, 99: 44}
[decode] Cycle found: {5: 17, 17: 5}
[resolve] New updated graph:
{0: {1: 13.0,
     2: 18.0,
     3: 2.0,
     4: 21.0,
     5: 57.0,
     6: 13.0,
     7: 11.0,
     8: 1.0,
     9: 22.0,
     10: 6.0,
     11: 2.0,
    

KeyError: 2