In [99]:
#read intput.txt
lines = []
with open('input.txt') as f:
    lines = f.readlines()
    lines = [line.strip() for line in lines]


In [101]:
import networkx as nx
import matplotlib.pyplot as plt
from functools import reduce
import operator

# create a graph from the input
def toGraph(lines):
    G = nx.Graph()
    for i, line in enumerate(lines):
        cache = ""
        for j, c in enumerate(line):
            node = (i,j)
            G.add_node(node, value=c)
            if(i > 0):
                G.add_edge(node, (i-1,j))
            if(j > 0):
                G.add_edge(node, (i,j-1))
            if(i > 0 and j > 0):
                G.add_edge(node, (i-1,j-1))
            if(i > 0 and j < len(line)-1):
                G.add_edge(node, (i-1,j+1))
    return G

# put digits together
def merge_digit_nodes(G):
    for i, line in enumerate(lines):
        nodesToMerge = []
        for j, c in enumerate(line):
            node = (i,j)
            if c.isdigit():
                nodesToMerge.append(node)
            else:
                if len(nodesToMerge) > 1:
                    merge_nodes(G, nodesToMerge)
                nodesToMerge = []
        if len(nodesToMerge) > 1:
            merge_nodes(G, nodesToMerge)
   
# merge nodes in G into first element of a list of nodes
def merge_nodes(G, nodes):
    new_node = nodes.pop(0)
    newValue = G.nodes[new_node]['value']
    for n1 in nodes:
        newValue += G.nodes[n1]['value']
        for n2 in G.neighbors(n1):
            if n2 not in nodes:
                G.add_edge(new_node, n2)
        G.remove_node(n1)
        if G.has_edge(new_node, new_node):  # check if the edge exists before deleting
            G.remove_edge(new_node, new_node)  # remove self loop
    G.nodes[new_node]['value'] = newValue

G = toGraph(lines)
merge_digit_nodes(G)

# Part1
# iterate through the graph and get all values of nodes having neighbouts not "." and is digit
def getValues(G):
    values = []
    for node in G.nodes():
        if G.nodes[node]['value'].isdigit() and len([n for n in G.neighbors(node) if G.nodes[n]['value'] != "."]) > 0:
            values.append(G.nodes[node]['value'])
    return values

values = [int(s) for s in getValues(G)]
print("Part1",sum(values))

# Part2
# filter the graph to cells having value "*" and gather their neighbours which are digits and are more than 1
def filterGraph(G):
    gears = []
    for node in G.nodes():
        if G.nodes[node]['value'] == "*":
            neighbours = [x for x in G.neighbors(node) if G.nodes[x]['value'].isdigit()]
            if len(neighbours) > 1:
                values = [int(G.nodes[n]['value']) for n in neighbours]
                gears.append(reduce(operator.mul, values))
    return gears

print("Part2",sum(filterGraph(G)))

def plotGraph(graph):
    plt.figure(figsize=(10, 10))  # Create a new figure with a specific size
    pos = {(i,j): (j,-i) for i, j in graph.nodes()}
    nx.draw(graph, pos, with_labels=False)  # Draw `graph`, not `G`
    labels = nx.get_node_attributes(graph, 'value')  # Get labels from `graph`, not `G`
    nx.draw_networkx_labels(graph, pos, labels=labels)  # Draw labels on `graph`, not `G`
    plt.savefig('graph.png')  # Save the figure to a file
    plt.show()


# plotGraph(G)



Part1 550934
Part2 81997870
