In [5]:
import networkx as nx
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os

# Load all edges and put them in a numpy array, taking only 'Source' and 'Target'
edges = pd.read_csv(r"../data/edges_updated_reversed.csv", sep=' ')
edges_array = np.array(edges.loc[:, ['Source', 'Target']])

# Create the digraph
G = nx.DiGraph()
G.add_edges_from(edges_array)

# Parameters that virtually don't effect the simulation, please adjust to your preference :)
doSingletonReduction = True
doVisualizeSingletonReduction = True
doVisualization = True

# If you're not reducing singletons, there is no point in visualizing them
if not doSingletonReduction:
    doVisualizeSingletonReduction = False

"""    
Singletons are nodes that have in-degree of 1 and out-degree of 0. 
This means, in an SIR model, a singleton doesn't infect, but is only infected, then after can recover.
Because of this, we can essentially reduce singletons to a number stored in its parent node, reducing the size of our network by 20%.
"""
if doSingletonReduction:
    # Singleton information stored in parent.
    nx.set_node_attributes(G, 0, "s_singletons") # Susceptible singletons
    nx.set_node_attributes(G, 0, "i_singletons") # Infected singletons
    nx.set_node_attributes(G, 0, "r_singletons") # Recovered/removed singletons

    # Find and store singletons
    singletons = G.out_degree
    singletons = [x[0] for x in singletons if x[1] == 0]
    singletons = [x for x in singletons if G.in_degree(x) == 1]

    # Find nodes that are not singletons
    non_singletons = [x for x in list(G.nodes) if not x in singletons]

    # Store singletons as numbers
    for node in non_singletons:
        adj = G.neighbors(node)
        G.nodes[node]['s_singletons'] = len([x for x in adj if x in singletons])

    # Remove nodes that are singletons
    print("Node count before singleton reduction: ", len(G.nodes()))
    G.remove_nodes_from(singletons)

    # > Remove nodes that don't have singletons
    # print(len(G.nodes()))
    # not_have_singletons = [x[0] for x in list(G.nodes.data('s_singletons')) if x[1] == 0]
    # have_singletons = [x[0] for x in list(G.nodes.data('s_singletons')) if x[1] == 1]
    # G.remove_nodes_from(not_have_singletons)

Node count before singleton reduction:  26234


In [6]:
# > Initialize layout and position of nodes for visualization (only have to run this only once, this can take a while)
if doVisualization:
    print("Calculating node layout and positions...")
    pos = nx.spring_layout(G, seed=8020, gravity=0.75)
    # pos = nx.forceatlas2_layout(G, seed=6969420, gravity=0.5, scaling_ratio=4.0, max_iter=50)
    print("Done!")

Calculating node layout and positions...
Done!


In [7]:
# Initialize nodes and perform any required analysis on topography, centrality, etc.


print("Final node count: ", len(G.nodes()))

# ------------------------------------------------------------------------------------------------------------

nodes_outdeg_sorted = sorted(list(G.out_degree), key=lambda x:x[1], reverse=True)


Final node count:  19616


In [10]:
# > Initialize model
print("Number of nodes: ", len(G.nodes()))

nx.set_node_attributes(G, 0, "state")
"""
When state = 0, a node is susceptible.
When state = 1, a node is infected.
When state = 2, a node is recovered/removed.
"""

# > Initially infect 1 node
# init_infected = np.random.choice(G.nodes, 1)[0]
# init_infected = np.random.choice(nodes_outdeg_sorted, 1)[0]
init_infected = 9
G.nodes[init_infected]['state'] = 1

# > Simulation hyperparameters
max_steps = 100
infection_probability = 0.99
recover_probability = 0

# > Drawing parameters
if doVisualization:
    # Visualization options
    options = {"node_size": 20}

    # Node colors
    init_susceptible_color = np.array([0,1,0,0.5])
    susceptible_color = np.array([0,1,0,1])
    infected_color = np.array([1,0,0,1])
    removed_color = np.array([0,0,1,1])

    # Draw network edges
    print("Drawing network edges...")
    nx.draw_networkx_edges(G, pos, width=0.2, alpha=0.25)

    # Draw network nodes, incl. init infected
    nx.draw_networkx_nodes(G, pos, nodelist=[x[0] for x in G.nodes.data('state') if x[1] == 0], alpha=0.45, node_color=[init_susceptible_color], **options, edgecolors=[susceptible_color])
    if doVisualizeSingletonReduction:
        nx.draw_networkx_nodes(G, pos, nodelist=[x[0] for x in G.nodes.data('state') if x[1] == 1], node_color=[infected_color], **options, edgecolors=[susceptible_color])
    else:
        nx.draw_networkx_nodes(G, pos, nodelist=[x[0] for x in G.nodes.data('state') if x[1] == 1], node_color=[infected_color], **options, edgecolors=[infected_color])

    # nx.draw_networkx_labels(G, pos=pos, font_size=6)
    print("Done!")

    # Export init graph
    plt.axis("off")
    plt.tight_layout()
    print("Saving graph_0.png...")
    plt.savefig("graphs/graph_0.png", format="PNG")
    print("Done!")
    # plt.show()

# > Run model
if doSingletonReduction:
    """
    To visualize a node's singletons, we set the color of the border/shell/outer boundary of the node to the average color of its singletons.
    """
    def get_singleton_edgecolor(node_num):
        if doVisualizeSingletonReduction:
            x = G.nodes[node_num]
            singleton_sum = x['s_singletons'] + x['i_singletons'] + x['r_singletons']
        if singleton_sum == 0 or not doVisualizeSingletonReduction:
            match x['state']:
                case 0:
                    temp_color = susceptible_color
                case 1:
                    temp_color = infected_color
                case 2:
                    temp_color = removed_color
        else:
            temp_color = x['s_singletons']*susceptible_color + x['i_singletons']*infected_color + x['r_singletons']*removed_color
            temp_color = temp_color/singleton_sum
        return temp_color

# Keep track of which nodes are infected, or have infected singletons for performance purposes.
infected_list = set()
for i in G.nodes.data('state'):
    if i[1] == 1:
        infected_list.add(i[0])
has_infected_singletons = set()

# Run model
print("Running model...")
for step in range(max_steps):
    infected_singletons = sum(x[1] for x in G.nodes.data('i_singletons'))
    infected_non_singletons = len(infected_list)

    print("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~")
    print(f"Step: {step}")
    print(f"Infected left: {infected_non_singletons + infected_singletons}")
    print(f"> Non-singletons infected: {infected_non_singletons}")
    print(f"> Singletons infected: {infected_singletons}")

    # Early Stopping
    if len(infected_list) == 0 and sum(x[1] for x in G.nodes.data('i_singletons')) == 0:
        break

    # >>> Infection and recovery
    if doSingletonReduction:
        """
        Perform singleton recovery.
        For each node in the "has infected singletons" set, check if it has infected singletons.
        If it does, simulate recovery via binomial distribution, else remove it from the set.
        """
        print("Recovering singletons...")
        temp_infsng_list = list(has_infected_singletons)
        for i in temp_infsng_list:
            inf_sng = G.nodes[i]['i_singletons']
            if not inf_sng == 0:
                # 
                temp = np.random.binomial(n=inf_sng, p=recover_probability)
                G.nodes[i]['i_singletons'] -= temp
                G.nodes[i]['r_singletons'] += temp

                if doVisualization:
                    match G.nodes[i]['state']:
                        case 0:
                            raise Exception("the hell")
                        case 1:
                            result_color = infected_color
                        case 2: 
                            result_color = removed_color
                    
                    nx.draw_networkx_nodes(G, pos, nodelist=[i], node_color=result_color, **options, edgecolors=get_singleton_edgecolor(i))
            else:
                has_infected_singletons.remove(i)

    if doSingletonReduction:
        print("Infecting singletons and neighbors + Recovering infected nodes...")
    else:
        print("Infecting neighbors + Recovering infected nodes...")
    temp_inf_list = list(infected_list)
    for i in temp_inf_list:
        """
        Infect adjacent nodes.
        """
        adj = G.neighbors(i)
        for j in adj:
            if G.nodes[j]['state'] == 0 and np.random.sample() < infection_probability:
                infected_list.add(j)
                G.nodes[j]['state'] = 1
                if doVisualization:
                    nx.draw_networkx_nodes(G, pos, nodelist=[j], node_color=infected_color, **options, edgecolors=get_singleton_edgecolor(j))

        # Infect singletons
        if doSingletonReduction:
            """
            Given we are looping through nodes that are infected, simulate singleton infection via binomial distribution.
            """
            temp = np.random.binomial(n=G.nodes[i]['s_singletons'], p=infection_probability)
            if temp > 0:
                has_infected_singletons.add(i)
            G.nodes[i]['s_singletons'] -= temp
            G.nodes[i]['i_singletons'] += temp

        """
        Recover infected nodes
        """
        if np.random.sample() < recover_probability:
            infected_list.remove(i)
            G.nodes[i]['state'] = 2

        if doVisualization:
            match G.nodes[i]['state']:
                case 0:
                    raise Exception("the hell")
                case 1:
                    result_color = infected_color
                case 2: 
                    result_color = removed_color
            nx.draw_networkx_nodes(G, pos, nodelist=[i], node_color=result_color, **options, edgecolors=get_singleton_edgecolor(i))

    # > Export graphs
    if doVisualization:
        plt.axis("off")
        plt.tight_layout()
        print(f"Saving graph_{step+1}.png...")
        plt.savefig(f"graphs/graph_{step+1}.png", format="PNG")
        print(f"Saved graph!")

Number of nodes:  19616
Drawing network edges...
Done!
Saving graph_0.png...
Done!
Running model...
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Step: 0
Infected left: 1
> Non-singletons infected: 1
> Singletons infected: 0
Recovering singletons...
Infecting singletons and neighbors + Recovering infected nodes...


  node_collection = ax.scatter(


Saving graph_1.png...
Saved graph!
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Step: 1
Infected left: 3049
> Non-singletons infected: 2312
> Singletons infected: 737
Recovering singletons...
Infecting singletons and neighbors + Recovering infected nodes...
Saving graph_2.png...
Saved graph!
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Step: 2
Infected left: 3811
> Non-singletons infected: 2970
> Singletons infected: 841
Recovering singletons...
Infecting singletons and neighbors + Recovering infected nodes...
Saving graph_3.png...
Saved graph!
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Step: 3
Infected left: 3977
> Non-singletons infected: 3120
> Singletons infected: 857
Recovering singletons...
Infecting singletons and neighbors + Recovering infected nodes...
Saving graph_4.png...
Saved graph!
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Step: 4
Infected left: 4076
> Non-singletons infected: 3208
> Singletons infected: 868
Recovering singletons...
Infecting singletons and neighbors + Recovering infected nodes...
Saving graph_5.png..

KeyboardInterrupt: 

Error in callback <function flush_figures at 0x7f0b40cb1300> (for post_execute), with arguments args (),kwargs {}:


KeyboardInterrupt: 

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import networkx as nx
import os
import gc
# --- 1. Setup Environment ---
if not os.path.exists("graphs"):
    os.makedirs("graphs")

# > Initialize model
print("Number of nodes: ", len(G.nodes()))

nx.set_node_attributes(G, 0, "state")
"""
When state = 0, a node is susceptible.
When state = 1, a node is infected.
When state = 2, a node is recovered/removed.
"""

# > Initially infect 1 node
init_infected = 9
G.nodes[init_infected]['state'] = 1

# > Simulation hyperparameters
max_steps = 100
infection_probability = 0.99
recover_probability = 0.05 # Changed to non-zero for logic testing

# > Drawing colors & options
options = {"node_size": 20}
susceptible_color = np.array([0, 1, 0, 1])
infected_color = np.array([1, 0, 0, 1])
removed_color = np.array([0, 0, 1, 1])

def get_singleton_edgecolor(node_num):
    """
    Calculates border color based on the ratio of susceptible/infected/removed singletons.
    """
    if doVisualizeSingletonReduction:
        node = G.nodes[node_num]
        s_sng = node.get('s_singletons', 0)
        i_sng = node.get('i_singletons', 0)
        r_sng = node.get('r_singletons', 0)
        total = s_sng + i_sng + r_sng
        
        if total == 0:
            # Default to node's own state color if no singletons
            state = node['state']
            return susceptible_color if state == 0 else (infected_color if state == 1 else removed_color)
        
        # Weighted average of colors
        avg_color = (s_sng * susceptible_color + i_sng * infected_color + r_sng * removed_color) / total
        return avg_color
    else:
        # Standard borders if visualization is off
        state = G.nodes[node_num]['state']
        return susceptible_color if state == 0 else (infected_color if state == 1 else removed_color)

# --- 2. Initial Tracking ---
infected_list = set()
for i, state in G.nodes.data('state'):
    if state == 1:
        infected_list.add(i)

has_infected_singletons = set()
# Check if initial infected has singletons
if doSingletonReduction and G.nodes[init_infected].get('i_singletons', 0) > 0:
    has_infected_singletons.add(init_infected)

# --- 3. Main Simulation Loop ---
print("Running model...")
for step in range(max_steps + 1):
    if step % 5 == 0:
        gc.collect()
    # ... (logic for population counts and infection/recovery) ...
    # Count populations
    inf_sng_total = sum(d.get('i_singletons', 0) for n, d in G.nodes(data=True)) if doSingletonReduction else 0
    inf_non_sng = len(infected_list)
    
    print(f"Step: {step} | Non-Sng Infected: {inf_non_sng} | Sng Infected: {inf_sng_total}")

    # --- A. Visualization Block (Draw ONCE per step) ---
    # --- A. Visualization Block ---
    if doVisualization:
        # 1. Create the figure and axes properly
        fig, ax = plt.subplots(figsize=(8, 8))
        
        # Draw edges
        nx.draw_networkx_edges(G, pos, width=0.2, alpha=0.15, ax=ax)

        # Batch draw nodes
        all_node_data = list(G.nodes(data=True))
        for state_val, base_col in [(0, susceptible_color), (1, infected_color), (2, removed_color)]:
            nodelist = [n for n, d in all_node_data if d['state'] == state_val]
            if not nodelist: continue
            
            edge_cols = [get_singleton_edgecolor(n) for n in nodelist]
            alpha_val = 0.45 if state_val == 0 else 1.0
            
            nx.draw_networkx_nodes(
                G, pos, 
                nodelist=nodelist,
                node_color=[base_col], 
                edgecolors=edge_cols,
                linewidths=1.2,
                alpha=alpha_val,
                ax=ax,
                **options
            )

        ax.set_axis_off()
        ax.set_title(f"Epidemic Spread: Step {step}")
        plt.savefig(f"graphs/graph_{step}.png", dpi=100)
        
        # 2. CRITICAL: Close the figure inside the IF block
        plt.close(fig)

    # --- B. Exit Condition ---
    if inf_non_sng == 0 and inf_sng_total == 0:
        print("Epidemic ended.")
        break

    # --- C. Logic and State Updates ---
    if doSingletonReduction:
        # Recover singletons
        for i in list(has_infected_singletons):
            inf_sng = G.nodes[i]['i_singletons']
            if inf_sng > 0:
                recovered = np.random.binomial(n=inf_sng, p=recover_probability)
                G.nodes[i]['i_singletons'] -= recovered
                G.nodes[i]['r_singletons'] += recovered
            else:
                has_infected_singletons.remove(i)

    # Spread to neighbors and Infect Singletons
    next_infected_list = list(infected_list)
    for i in next_infected_list:
        # Neighbor Infection
        for j in G.neighbors(i):
            if G.nodes[j]['state'] == 0 and np.random.sample() < infection_probability:
                G.nodes[j]['state'] = 1
                infected_list.add(j)

        # Singleton Infection
        if doSingletonReduction:
            s_sng = G.nodes[i]['s_singletons']
            if s_sng > 0:
                new_inf = np.random.binomial(n=s_sng, p=infection_probability)
                if new_inf > 0:
                    G.nodes[i]['s_singletons'] -= new_inf
                    G.nodes[i]['i_singletons'] += new_inf
                    has_infected_singletons.add(i)

        # Node Recovery (The infected node itself)
        if np.random.sample() < recover_probability:
            G.nodes[i]['state'] = 2
            infected_list.remove(i)

print("Simulation Complete.")

Number of nodes:  19616
Running model...
Step: 0 | Non-Sng Infected: 1 | Sng Infected: 749
Step: 1 | Non-Sng Infected: 2304 | Sng Infected: 728
Step: 2 | Non-Sng Infected: 2867 | Sng Infected: 703
Step: 3 | Non-Sng Infected: 2877 | Sng Infected: 677
Step: 4 | Non-Sng Infected: 2815 | Sng Infected: 641
Step: 5 | Non-Sng Infected: 3128 | Sng Infected: 612
Step: 6 | Non-Sng Infected: 3297 | Sng Infected: 589
Step: 7 | Non-Sng Infected: 3209 | Sng Infected: 581
Step: 8 | Non-Sng Infected: 3039 | Sng Infected: 565
Step: 9 | Non-Sng Infected: 2869 | Sng Infected: 539
Step: 10 | Non-Sng Infected: 2736 | Sng Infected: 519
Step: 11 | Non-Sng Infected: 2598 | Sng Infected: 504
Step: 12 | Non-Sng Infected: 2474 | Sng Infected: 486
Step: 13 | Non-Sng Infected: 2362 | Sng Infected: 472
Step: 14 | Non-Sng Infected: 2256 | Sng Infected: 454
Step: 15 | Non-Sng Infected: 2139 | Sng Infected: 438
Step: 16 | Non-Sng Infected: 2021 | Sng Infected: 430
Step: 17 | Non-Sng Infected: 1920 | Sng Infected: 414


In [59]:
# > Clear all pngs in graphs folder
for filename in os.listdir("graphs"):
   file_path = os.path.join("graphs", filename)
   if os.path.isfile(file_path):
      os.remove(file_path)

In [3]:
# > Export images in graphs folder to a gif
from PIL import Image

# List of image file paths
images = sorted(os.listdir("graphs"), key=lambda x:int(x[x.find('_')+1:x.find('.')]))
# print(images)

# Open images and store them in a list
frames = [Image.open(f"graphs/{image}") for image in images]

# Save frames as an animated GIF
frames[0].save(
    'simple_sir_model.gif',
    save_all=True,
    append_images=frames[1:],
    duration=300,
    loop=0
)