In [None]:
import networkx as nx
import matplotlib.pyplot as plt
import numpy as np
import string
import random
import pandas as pd

In [None]:
'''
    Name: getNeighbors
    Parameters: Nodes -- list of nodes in visualization
                Pairs -- list of neighbor pairs between nodes
    Description: This function will create a dictionary containing 
                all nodes as values and each node's neighbors as keys.
'''
def getNeighbors(Nodes, Pairs):
    d={}
    keys=Nodes
  
    for node in Nodes:
        neighbors=[]
        for pair in Pairs:
            if node in pair:
                neighbors+=[n for n in pair if n!=node]
        key=node
        values=list(set(neighbors))
        d[key] = values
    return d
            

    
'''
    Name: createTable
    Parameters: Nodes -- list of nodes in visualization
                Pairs -- list of neighbor pairs between nodes
    Description: This function will create a table containing 
                Node, Neighbors, and State (all with state 1 as healthy).
'''
def createTable(Nodes, Pairs):
    central_nodes = [1,2,3]
    regular_nodes= [4,5,6,7,8,9,10]
    d=getNeighbors(Nodes, Pairs)
    nodes=list(d.keys())
    neighbors=list(d.values())
    table=pd.DataFrame()
    table['Node']=nodes
    table['Neighbors']=0
    table['Neighbors']=table['Neighbors'].astype('object')
    table['Neighbors']=neighbors
    table['State']=1
    table=table.set_index('Node')
    return table



'''
    Name: updateTable
    Parameters: Nodes -- list of nodes in visualization
                stPerimeter -- score of perimeter strength
                stDefender -- score of defender strength
                stAttacker -- score of attacker strength
                Time -- current number of iterations in the simulation
                Table -- table with Nodes, Neighbors, and Status
    Description: This function will update the node states table and visualization
                depending on what "time" it is in the simulation and what parameters are
                specified by the use.
'''
def updateTable(Nodes, stPerimeter, stDefender, stAttacker, Time, Table):
    #iterate through each node
    for index, row in table.iterrows():
        prob_survival = np.e**(-lmbda*time)
        neighbors=row[0]
        state=row[1]
        
        #don't adjust lambda if node has no neighbors
        if not neighbors:
            prob_survival=prob_survival*1 
        
        
        else:
            infected_neighbors=0
            for neighbor in neighbors:
                if table.iloc[neighbor-1][1]==1:
                    infected_neighbors+=1                        
            prob_survival=prob_survival*((1/2)**infected_neighbors) 
                
        #run a bernoulli trial with prob_survival to get new status
        new_status=np.random.binomial(size=1, n=1, p=prob_survival)[0]
        Table.loc[index,'State']=new_status 
        
    #show results      
    percent_infected=(Table['State'] == 0).sum()*100 / N
    percent_healthy=(Table['State'] == 1).sum() * 100 /N
    print("----------------------------------------------------")
    print("At time t=" + str(time))
    print(str(percent_infected) + "% of nodes are now infected.")
    print(str(percent_healthy) + "% of nodes are still healthy.")

    #color nodes
    color_map = []
    for node in network:
        if Table.loc[node,'State']==0:
            color_map.append('red')
        else: 
            color_map.append('green') 

    nx.draw(network, node_color=color_map, with_labels = True)
    plt.show()
    return Table




'''
    Run Network Simulation
'''

random.seed(123)

#Define starting parameters
N = 10
NODES = [i+1 for i in range(N)]

#Create static network
network = nx.Graph()
for x in NODES:
    network.add_node(x)

#Create pairs and edges
central_nodes = [1,2,3]
regular_nodes= [4,5,6,7,8,9,10]
pairs=[]
for regular_node in regular_nodes:
    for central_node in central_nodes:
        pair = central_node, regular_node
        pairs.append(pair)
network.add_edges_from(pairs)

#Create table for updating
table=createTable(NODES, pairs)

time=0


check = "Y"
while  check == "Y" or check == "YES":
    check = input("Continue with simulation? (Y/N): ")
    check = check.upper()
    if (check == "Y" or check == "YES") != True :
        break

    stPerimeter = int(input("Input perimeter strength (1-10, 10 being strongest): "))
    stDefender = int(input("Input defender strength (1-10, 10 being strongest): "))
    stAttacker = int(input("Input attacker strength (1-10, 10 being strongest): "))
    showStartNet = input("View initial network? (Y/N) ")
    showNet = input("View static network? (Y/N) ")

    if showStartNet == "Y" or showStartNet == "y":
        #Draw initial network
        print("At t=0:")
        print("100% of nodes are healthy.")
        nx.draw(network, node_color='green', with_labels = True)
        plt.show()

    if showNet == "Y" or showNet == "y":
        #Draw current network
        exp_rate=2
        rate=exp_rate * stPerimeter * (stDefender/stAttacker)
        lmbda=1/rate
        
        #Update table of node states
        time+=1
        table=updateTable(10, stPerimeter, stDefender, stAttacker, time, table)
        
