# Christofides algorithm

In [2]:
# step 1: Create a minimum spanning tree T of G.

# step 1.0: build a graph, 
def build_graph(dist_matrix_):
    n = len(dist_matrix_)
    graph = {i: {} for i in range(n)}
    for i in range(n):
        for j in range(n):
            if i != j:
                graph[i][j] = dist_matrix_[i][j]
    return graph


# step 1.1: prim's method to find minimum spanning tree
import heapq
def prim(graph, start_node=None):
    
    # if start not assigned, use first node in graph
    if start_node is None:
        start_node = next(iter(graph))

    # set to store all visited nodes
    visited_nodes = {start_node}

    # edges for mst
    mst_edges = []

    edge_heap = []
    for neighbor, weight in graph[start_node].items():
        heapq.heappush(edge_heap, (weight, start_node, neighbor))

    # While there are edges to consider and we haven’t covered all nodes:
    while edge_heap and len(visited_nodes) < len(graph):
        edge_weight, from_node, to_node = heapq.heappop(edge_heap)

        # skip if edges matched with visited nodes
        if to_node in visited_nodes:
            continue

        visited_nodes.add(to_node)
        mst_edges.append((from_node, to_node, edge_weight))

        for next_neighbor, next_weight in graph[to_node].items():
            if next_neighbor not in visited_nodes:
                heapq.heappush(edge_heap, (next_weight, to_node, next_neighbor))

    return mst_edges

# testing

# graph = {
#     'A' : {'B': 2, 'C': 3},
#     'B' : {'A': 2, 'C': 1, 'D': 1}, 
#     'C' : {'A': 3, 'B': 1, 'D': 4},
#     'D' : {'B': 1, 'C': 4}
# }
# minimum_spanning_tree = prim(graph)
# print("mst: ")
# for edge in minimum_spanning_tree:
#     print(edge)
# # type(minimum_spanning_tree[0])
    
# graph = build_graph(tiny_dist_matrix)
# graph
# minimum_spanning_tree = prim(graph)
# print("mst: ")
# for edge in minimum_spanning_tree:
#     print(edge)

In [3]:
# step 2: Let O be the set of vertices with odd degree in T.
# By the handshaking lemma, O has an even number of vertices.


def find_odd_degree_vertices(mst):
    #  mst is of the form: list(<vertex_1, vertex_2, distance>)
    degree_count = {}
    
    for from_node, to_node, _ in mst:
        degree_count[from_node] = degree_count.get(from_node, 0) + 1
        degree_count[to_node]   = degree_count.get(to_node,   0) + 1

    odd_degree_nodes = []
    for node, degree in degree_count.items():
        if degree % 2 == 1:
            odd_degree_nodes.append(node)

    
    return odd_degree_nodes

# testing:
# example_mst = [
#     ('A', 'B', 2),
#     ('B', 'C', 1),
#     ('B', 'D', 1),
# ]
# odd_nodes = find_odd_degree_vertices(example_mst)
# print("Nodes with odd degree:", odd_nodes)
    

In [4]:
# step 3: Find a minimum-weight perfect matching M in the subgraph induced in G by O.

# impractical
def find_minimum_perfect_matching_brute_force(distance_matrix, odd_nodes):
    # takes in the distance matrix, and the list of odd_nodes
    # return a list of edges for the odd nodes, with minimum total weight
    
    k = len(odd_nodes)
    # 1) Build a small k×k distance table for just the odd nodes
    sub_dist = [[0]*k for _ in range(k)]
    for i in range(k):
        for j in range(k):
            u = odd_nodes[i]
            v = odd_nodes[j]
            sub_dist[i][j] = distance_matrix[u][v]

    # 2) dp[mask] will hold the minimum matching cost for the subset 'mask'
    #    mask is an integer from 0..(2^k - 1), with bit i set if odd_nodes[i] is still unmatched.
    max_mask = 1 << k
    dp = [float('inf')] * max_mask
    dp[0] = 0  # no nodes to match → zero cost

    # To reconstruct which pair we chose for each mask
    pair_for_mask = {}

    # 3) Fill dp for all even‐bit masks in increasing order
    for mask in range(1, max_mask):
        # skip masks with an odd number of bits set
        if bin(mask).count('1') % 2 == 1:
            continue

        # find the first bit i set in mask
        for i in range(k):
            if mask & (1 << i):
                break
        # remove i from mask
        remaining = mask ^ (1 << i)

        # try pairing i with any j>i in the remaining bits
        for j in range(i+1, k):
            if remaining & (1 << j):
                next_mask = remaining ^ (1 << j)
                cost = sub_dist[i][j] + dp[next_mask]
                if cost < dp[mask]:
                    dp[mask] = cost
                    pair_for_mask[mask] = (i, j)

    # 4) Reconstruct the matching pairs by walking masks downward
    matching_edges = []
    mask = max_mask - 1
    while mask:
        i, j = pair_for_mask[mask]
        # record the actual node indices from odd_nodes
        matching_edges.append((odd_nodes[i], odd_nodes[j], sub_dist[i][j]))
        # remove both bits and continue
        mask ^= (1 << i) | (1 << j)

    return matching_edges

# dist_test = [
#     [0, 2, 3, 100], 
#     [2, 0, 1, 1],
#     [3, 1, 0, 4],
#     [100, 1, 4, 0]
# ]
# odd_nodes = [0, 1, 2, 3]
# perfect_matching = find_minimum_perfect_matching(dist_test, odd_nodes)

In [5]:
import networkx as nx

# method 1: using networkx, which implement Edmonds’ blossom algorithm
def find_minimum_perfect_matching_networkx(dist_matrix_, odd_nodes):

    # form a graph but puting in
    G = nx.Graph()
    for i in range(len(odd_nodes)):
        for j in range(i+1, len(odd_nodes)):
            u, v = odd_nodes[i], odd_nodes[j]
            G.add_edge(u, v, weight=-dist_matrix_[u][v])

    # return a (u, v) pair, which are the minimum-weight perfect matching
    matching = nx.algorithms.matching.max_weight_matching(G, maxcardinality=True)

    # add the weight to the tuple, and return
    matching_edges = [(u, v, dist_matrix_[u][v]) for u, v in matching]
    
    return matching_edges

In [6]:
# method 2: using simple brute force
def find_minimum_perfect_matching_approximate(dist_matrix_, odd_nodes):
    unmatched = set(odd_nodes)
    matching_edges = []

    # iterate through all possible matching, time consuming
    while unmatched:
        u = unmatched.pop()
        min_dist = float('inf')
        min_v = None
        for v in unmatched:
            if dist_matrix_[u][v] < min_dist:
                min_dist = dist_matrix_[u][v]
                min_v = v
        unmatched.remove(min_v)
        matching_edges.append((u, min_v, dist_matrix_[u][min_v]))
    
    return matching_edges

In [7]:
# step 4: Combine the edges of M and T to form a connected multigraph H in which each vertex has even degree.

def union_mst_and_matching(mst, perfect_matching):
    return mst + perfect_matching

# test_dist = [
#     [0, 2, 3, 100], 
#     [2, 0, 1, 1],
#     [3, 1, 0, 4],
#     [100, 1, 4, 0]
# ]
# test_graph = build_graph(dist_test)
# test_mst = prim(test_graph)
# test_odd = find_odd_degree_vertices(test_mst)
# test_odd_matching = find_minimum_perfect_matching(test_dist, test_odd)
# print(test_mst, perfect_matching)
# test_union = union_mst_and_matching(test_mst, test_odd_matching)
# print(test_union)

In [8]:
def find_euler_tour(edge_list):
    """
    multiedges: list of edges in your Eulerian multigraph.
                Each edge can be a 2-tuple (u, v) or a 3-tuple (u, v, weight).
    Returns an Eulerian circuit as a list of vertices.
    """
    # build adjacency list; drop any extra entries after the first two
    adjacency = {}
    for edge in edge_list:
        # edge is of the form (pointA, pointB, weight)
        u = edge[0]
        v = edge[1]
        
        # add v to u’s list
        adjacency.setdefault(u, []).append(v)
        # add u to v’s list
        adjacency.setdefault(v, []).append(u)


    # pick a start vertex that actually has edges
    start_node = next(iter(adjacency))

    # Hierholzer’s algorithm, generate a sequence of path, to go over all nodes
    stack = [start_node]
    path = []

    while stack:
        curr_node = stack[-1]
        if adjacency[curr_node]:
            # still has an unused edge: walk it
            next_node = adjacency[curr_node].pop()
            # remove the back‐edge
            adjacency[next_node].remove(curr_node)
            stack.append(next_node)
        else:
            # no more edges here: record and backtrack
            path.append(stack.pop())


    path.reverse()
    
    return remove_repeated_vertices(path)


def remove_repeated_vertices(euler_path):
    # remove repeated nodes,
    # e.g.: A->B->A->C becomes A->B->C, second A is removed from our path
    visited = set()
    final_path = []
    
    for v in euler_path:
        if v not in visited:
            visited.add(v)
            final_path.append(v)
            
    if final_path:
        final_path.append(final_path[0])
    return final_path


# testing on dataset

In [10]:
import pandas as pd
import numpy as np
import time

# data = pd.read_csv('tiny.csv', header = None, names = ['x', 'y'])
# tiny_cities = data[['x', 'y']].values  # assuming your CSV has 'x' and 'y' columns

tiny_cities = pd.read_csv('tiny.csv', header = None, names = ['x', 'y'])[['x', 'y']].values

small_cities = pd.read_csv('small.csv', header = None, names = ['x', 'y'])[['x', 'y']].values

medium_cities = pd.read_csv('medium.csv', header = None, names = ['x', 'y'])[['x', 'y']].values

large_cities = pd.read_csv('large.csv', header = None, names = ['x', 'y'])[['x', 'y']].values

In [11]:
import math

# function to compute distance matrix
# @param:  a 2d array of cities coordinate
# @return: a nxn distance matrix

def find_dist_matrix(cities):
    n = len(cities)
    dist_matrix_ = np.zeros((n, n))
    for i in range(n):
        dist_matrix_[i][i] = math.inf
        for j in range(i+1, n):
            dist_matrix_[i, j] = np.linalg.norm(cities[i] - cities[j])
            dist_matrix_[j, i] = dist_matrix_[i, j]
    return dist_matrix_

In [12]:
tiny_dist_matrix = find_dist_matrix(tiny_cities)
small_dist_matrix = find_dist_matrix(small_cities)
medium_dist_matrix = find_dist_matrix(medium_cities)
large_dist_matrix = find_dist_matrix(large_cities)

print(len(tiny_dist_matrix), len(small_dist_matrix), len(medium_dist_matrix), len(large_dist_matrix))
# tiny_dist_matrix

10 30 100 1000


In [13]:
# from IPython.display import display
# display(pd.read_csv('medium.csv', header = None, names = ['x', 'y'])[['x', 'y']])
# display(medium_dist_matrix)

# df = pd.DataFrame(medium_dist_matrix)

# # in a notebook this will give you the familiar grid display
# display(df)

In [14]:
def solve_tsp_christophdes(dist_matrix_, dataset_name, exact_flag = False):
    start_time = time.time()
    
    graph = build_graph(dist_matrix_)
    
    print("prim...")
    mst = prim(graph)

    print("odds...")
    odds = find_odd_degree_vertices(mst)

    print("odds.matching...")
    if exact_flag:
        odds_matching = find_minimum_perfect_matching_networkx(dist_matrix_, odds)
    else:
        odds_matching = find_minimum_perfect_matching_approximate(dist_matrix_, odds)
    
    print("union..")
    union = union_mst_and_matching(mst, odds_matching)

    print("eulerian_tour with repeadted remove...")
    final_path = find_euler_tour(union)

    elapsed = time.time() - start_time
    
    print("total cost...")
    min_dist = 0.0
    for i in range(len(final_path) - 1):
        u, v = final_path[i], final_path[i+1]
        min_dist += dist_matrix_[u][v]
        # print(dist_matrix_[u][v])
        # print(min_dist)

    print(f"Using Christofides Algorithm on {dataset_name}:")
    print(f"Time used: {elapsed:.4f} seconds")
    print(f"Best tour: {final_path}")
    print(f"Minimum distance: {min_dist}\n\n")
    
    
    return final_path, min_dist

In [15]:
tiny_result = solve_tsp_christophdes(tiny_dist_matrix, "tiny dataset")


prim...
odds...
odds.matching...
union..
eulerian_tour with repeadted remove...
total cost...
Using Christofides Algorithm on tiny dataset:
Time used: 0.0001 seconds
Best tour: [0, 4, 9, 2, 5, 1, 8, 7, 6, 3, 0]
Minimum distance: 14.740072636356153




In [16]:
small_result = solve_tsp_christophdes(small_dist_matrix, "small dataset")


prim...
odds...
odds.matching...
union..
eulerian_tour with repeadted remove...
total cost...
Using Christofides Algorithm on small dataset:
Time used: 0.0010 seconds
Best tour: [0, 19, 20, 13, 4, 28, 3, 7, 25, 15, 12, 24, 18, 5, 23, 2, 1, 26, 14, 11, 10, 27, 8, 17, 16, 21, 22, 29, 9, 6, 0]
Minimum distance: 70.17212815436365




In [17]:
# print(len(medium_dist_matrix))
medium_result = solve_tsp_christophdes(medium_dist_matrix, "medium dataset", True)
medium_result = solve_tsp_christophdes(medium_dist_matrix, "medium dataset", False)

prim...
odds...
odds.matching...
union..
eulerian_tour with repeadted remove...
total cost...
Using Christofides Algorithm on medium dataset:
Time used: 0.0565 seconds
Best tour: [0, 58, 64, 70, 80, 7, 49, 90, 13, 4, 9, 15, 34, 33, 42, 37, 68, 83, 44, 95, 3, 75, 46, 97, 60, 1, 55, 63, 88, 54, 27, 12, 5, 78, 11, 43, 79, 57, 16, 71, 77, 53, 69, 45, 82, 74, 59, 86, 29, 67, 25, 72, 41, 6, 93, 96, 51, 39, 26, 62, 22, 91, 14, 65, 38, 87, 73, 18, 20, 81, 84, 17, 61, 40, 52, 76, 92, 35, 21, 94, 30, 19, 56, 66, 23, 48, 24, 32, 28, 89, 31, 47, 36, 50, 2, 10, 85, 8, 99, 98, 0]
Minimum distance: 8.560429789680372


prim...
odds...
odds.matching...
union..
eulerian_tour with repeadted remove...
total cost...
Using Christofides Algorithm on medium dataset:
Time used: 0.0062 seconds
Best tour: [0, 58, 64, 70, 42, 68, 83, 44, 95, 3, 75, 46, 97, 60, 82, 45, 69, 53, 77, 71, 16, 57, 11, 37, 12, 27, 54, 88, 63, 55, 1, 43, 79, 78, 5, 33, 34, 15, 9, 90, 13, 4, 49, 7, 80, 38, 18, 81, 20, 87, 73, 14, 65, 91, 

In [18]:
# print(len(medium_dist_matrix))
large_result = solve_tsp_christophdes(large_dist_matrix, "large dataset", True)
large_result = solve_tsp_christophdes(large_dist_matrix, "large dataset", False)

prim...
odds...
odds.matching...
union..
eulerian_tour with repeadted remove...
total cost...
Using Christofides Algorithm on large dataset:
Time used: 41.5300 seconds
Best tour: [0, 924, 118, 92, 702, 891, 664, 227, 750, 906, 977, 903, 515, 478, 132, 855, 947, 212, 60, 467, 476, 829, 394, 754, 59, 135, 462, 365, 809, 181, 237, 318, 553, 639, 89, 328, 390, 748, 351, 950, 993, 741, 892, 251, 607, 638, 72, 136, 916, 934, 979, 923, 781, 701, 185, 147, 81, 510, 973, 432, 217, 165, 595, 245, 874, 198, 816, 431, 632, 99, 151, 659, 357, 887, 511, 62, 483, 458, 695, 752, 936, 291, 179, 61, 334, 28, 285, 736, 742, 438, 751, 347, 319, 545, 912, 526, 14, 958, 186, 763, 339, 383, 166, 706, 919, 877, 509, 480, 811, 277, 7, 176, 299, 447, 238, 464, 187, 69, 46, 130, 578, 787, 117, 398, 806, 297, 669, 213, 579, 149, 783, 943, 29, 444, 794, 53, 729, 691, 376, 770, 824, 93, 16, 992, 825, 490, 712, 140, 489, 456, 486, 471, 975, 939, 866, 233, 637, 670, 937, 634, 896, 761, 108, 884, 549, 656, 861, 830, 8