In [None]:
import copy, pathlib, dataclasses, itertools as it, numpy as np, pandas as pd, networkx as nx
import matplotlib.pyplot as plt, bokeh as bk, bokeh.plotting, ipywidgets as widgets
bk.io.output_notebook()

@dataclasses.dataclass
class PHE():
    degree: int # only used in wattz strogatz
    steps: int
    beta: float
    rewire: float
    nodes: int
    seed: int = None
    palette: list = bk.palettes.Spectral8
    
    
    def __getitem__(self, key):
        return self.__dict__[key]

    def __setitem__(self, key, value):
        self.__dict__[key] = value

    def __post_init__(self):
        self.rng = np.random.default_rng(self.seed)
        self.make_graph()

    def simulate(self, steps=None):
        # self.make_graph()
        if steps is None:
            steps = self.steps
        L = [pd.Series(dict(self.G.nodes(data='state')))]
        for step in range(steps):
            self.evolve()
            L.append(pd.Series(dict(self.G.nodes(data='state'))))
        self.history = pd.concat(L, axis=1)

    def analyze(self, plot=True):
        self.stats = self.history.T.describe().T
        X = self.stats['mean'].to_frame()
        X['degree'] = pd.Series(dict(self.G.degree))
        self.infections = (self.history==1).sum()
        y = (self.history==1).sum()/self.nodes
        # if plot:
        #     plt.plot(y)
        #     plt.show()
        return y

    def make_graph(self):
            self.G = nx.watts_strogatz_graph(self.nodes, self.degree, self.rewire, seed=self.seed)
            # self.G = nx.erdos_renyi_graph(self.nodes, self.edge_creation, seed=self.seed)
            #self.G = nx.barabasi_albert_graph(self.nodes, self.degree, self.rewire, seed=self.seed)
            # self.G.layout = nx.spring_layout(self.G)  ## added ##
            L=[d for n, d in self.G.degree]
            m=max(L)
            index=[n for n, d in self.G.degree if d==m]
            for n in self.G.nodes():
                self.G.nodes[n]['degree']=self.G.degree(n)
                self.G.nodes[n]['state'] = 0 
                self.G.nodes[n]['beta'] = self.beta
            for n in self.G.edges():
                self.G.edges[n[0], n[1]]['contact_prob'] = self.rng.uniform(low=0.5,high=1)
            for n in self.rng.choice(index,1, replace=False):
                self.G.nodes[n]['state'] = 1
            for n in self.G.nodes():
                self.G.nodes[n]['recovery'] = max(1,np.round(self.rng.exponential(10)))

    def draw_graph(self, step=0, title=None, filename=False):
        hover=[('index', '@index'), ('degree', '@degree'), ('state', '@state'), ('beta', '@beta{0.000}'), ('recovery','@recovery')]
        p = bk.plotting.figure(tooltips=hover, title=title)
        network_graph = bk.plotting.from_networkx(self.G, self.G.layout)  ## changed ##
        network_graph.node_renderer.glyph = bk.models.Circle(size=10, fill_color=str(step))
        network_graph.edge_renderer.glyph = bk.models.MultiLine(line_alpha=0.5, line_width=1)
        p.renderers.append(network_graph)

        if filename:
            bk.plotting.output_file(filename=filename, title=title)
            bk.plotting.save(layout)
        else:
            def update(step):
                network_graph.node_renderer.glyph.fill_color = str(step)
                bk.io.push_notebook()
            bk.plotting.show(p, notebook_handle=True)
            #widgets.interact(update, step=widgets.IntSlider(min=0, max=self.history.shape[1]-1, step=1, value=step))

    def evolve(self):
        updates = copy.deepcopy(dict(self.G.nodes(data=True)))

        for n in self.G.nodes():
            if self.G.nodes[n]['state']==1:
                for m in self.G.adj[n]:
                    # Cjk=self.rng.uniform(low=0.5,high=1)
                    r=self.rng.uniform()
                    if r<= self.G.nodes[n]['beta']*self.G.edges[n, m]['contact_prob']:
                        updates[m]['state']=1
                        self.G.nodes[n]['recovery'] = max(1,np.round(self.rng.exponential(10)))
                updates[n]['recovery']+= -1
            if self.G.nodes[n]['recovery'] <= 0:
                updates[n]['state']=0
                updates[n]['beta']=0
        for n, d in updates.items():
            for key, val in d.items():
                self.G.nodes[n][key] = val


#S = PHE(degree=10,edge_creation=0.5,steps=200, seed=None, palette=bk.palettes.Inferno3,nodes=10,beta=0.011)
params = {
    'nodes': [10000],
    'beta': [0.011],
    'degree': [10],
    'rewire': [0.5],
    # 'edge_creation': [0.5],
}
runs = [dict(zip(params.keys(), p)) for p in it.product(*params.values())]
for r in runs:
    results = {}
    run_list = []
    S = PHE(**r, steps=200, seed=None, palette=bk.palettes.Inferno3)
    G_init = S.G.copy()

    for trial in range(1):
        # **r is dictionary unpacking
        S.G = G_init.copy()
        S.simulate()
        run_list.append(S.analyze())
        

        # # draw plot for first trial
        # if trial==0:
        #     print(r)
        #     S.analyze(plot=True)
        # else:
        #     S.analyze(plot=False)
        
        # save results into results dict

    # print(run_list)
    y_ray = np.array(run_list)
    # print(y_ray)
    # print("########################################")
    y_fin = y_ray.mean(axis=0)
    plt.plot(y_fin)
    plt.show()
    for x in results.keys():
        results[x].append(S[x])

    # turn results dict into dataframe and save into "runs"
    r['results'] = pd.DataFrame(results)        
    # now, we've added a new key:val into r with the results of EACH trial you just ran