In [1]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import networkx as nx
import collections

In [2]:
%matplotlib notebook

# Classes

In [3]:
class BS:
    def __init__(self, N, network, random_relations=False, species_ids=None, fitnesses=None):
        self.N = N
        self.network = network
        self.node_idx = N
        
        # create graph
        if network[0] == "watts-strogatz":
            (_, k, p) = network
            self.g = nx.watts_strogatz_graph(N, k, p)
            
            if species_ids is None:
                species_ids = range(N)
            if fitnesses is None:
                fitnesses = np.random.random(N)
            if random_relations:
                random_indices = np.random.permutation(N)
                species_ids = species_ids[random_indices]
                fitnesses = fitnesses[random_indices]
            for node in self.g:
                self[node]["id"] = species_ids[node]
                self[node]["fitness"] = fitnesses[node]
        else:
            raise ValueError("Unknown network type") 
    
    def __getitem__(self, i):
        return self.g.node[i]
    
    def species(self):
        return self.g.nodes.data("id")
    
    def species_list(self):
        return dict(self.g.nodes.data("id")).values()
    
    def fitness(self):
        return self.g.nodes.data("fitness")
    
    def min_fitness(self):
        return min(self.fitness(), key=lambda x: x[1])
    
    def max_fitness(self):
        return max(self.fitness(), key=lambda x: x[1])
    
    def sorted_fitness(self):
        return sorted(self.fitness(), key=lambda x: x[1])
    
    def mutate(self, node_id, new_species_id, fitness=None):
        if fitness is None:
            fitness = np.random.random()
        self[node_id]["fitness"] = fitness
        self[node_id]["id"] = new_species_id


In [4]:
class Lattice:
    def __init__(self, dimensions, N, network, P):
        self.dimensions = dimensions
        self.N = N
        self.P = P
        self.species_idx = np.prod(dimensions) * N
        
        # create lattice
        self.lattice = nx.grid_graph(list(dimensions), periodic=True)
        
        # create BS network in each point of the lattice
        for (i, point) in enumerate(self.lattice):
            self.lattice.node[point]["BS"] = BS(N, network, species_ids=range(N*i, N*(i+1)), fitnesses=None)

    def __getitem__(self, i):
        return self.lattice.node[i]["BS"]

    def run(self, t_max, collect_data=False):
        if collect_data:
            self.data = np.empty((t_max+1,) + self.dimensions, dtype=object)
            for i,j in np.ndindex(self.dimensions):
                self.data[0, i, j] = set(self[i, j].species_list())
        
        for t in range(1, t_max+1):
            for i, j in np.random.permutation(self.lattice):
                bs = self[i,j]
                lattice_neighbours = list(self.lattice[i,j])
                
                (idx, f) = bs.min_fitness()
                nodes = [idx] + list(bs.g[idx])
                
                for node in nodes:
                    if np.random.random() < self.P:
                        # mutate
                        bs.mutate(node, self.species_idx)
                        self.species_idx += 1
                        
                    else:
                        # die + migrate
                        bs_from = self[lattice_neighbours[np.random.randint(0, len(lattice_neighbours))]]
                        migration_success = False
                        for node_id in np.random.permutation(bs_from.g):
                            spec_id = bs_from[node_id]["id"]
                            if spec_id not in bs.species_list():
                                bs.mutate(node, spec_id, fitness=bs_from[node_id]["fitness"])
                                migration_success = True
                                break
                        if not migration_success:
                            print("Migration failed from {} to {}".format(bs_from.species(), bs.species()))
                            
                if collect_data:
                    self.data[t, i, j] = set(bs.species_list())
                                
    def draw(self):
        nx.draw(self.lattice, with_labels=True)
    
    def mean_species(self):
        ids = set()
        for point in self.lattice:
            ids = ids.union(self[point].species_list())
        return len(ids) / np.prod(self.dimensions)

            
    def area_curve(self, sampling_scheme='quadrats', log=True):
        # sampling_scheme = ['nested'|'quadrats']
        fig, ax = plt.subplots()

        if sampling_scheme == 'nested':
            area_curve = np.empty((self.dimensions[0],2))
            for s in range(1, self.dimensions[0]+1):
                species = set()
                for i in range(0, s):
                    for j in range(0, s):
                        species = species.union(self[(i,j)].species_list())

                nr = len(species)
                area_curve[s-1] = (s**2, nr)
            print(area_curve)
            ax.plot(*area_curve.T, "o")
            
            power = (np.log10(area_curve[-1,1]) - np.log10(area_curve[0,1])) / (np.log10(self.dimensions[0]**2) - np.log10(1))


        elif sampling_scheme == 'quadrats':
            area_curve = collections.OrderedDict()
            sd = []
            for grain_size in range(1, max(self.dimensions)+1):
                assert self.dimensions[0] == self.dimensions[1]
                means = []
                i = 0
                for interval_i in range(0, self.dimensions[0]-grain_size+1):
                    for interval_j in range(0, self.dimensions[1]-grain_size+1):
                        species = set()
                        for i,j in np.ndindex((grain_size,grain_size)):
                            ii = interval_i + i
                            jj = interval_j + j
                            species = species.union(self[ii,jj].species_list())
                            i += 1

                        means.append(len(species))
                
                area_curve[grain_size**2] = np.mean(means)
                sd.append(np.std(means)*1.96/np.sqrt(i))
                print('grain size: %i, mean: %0.3f'% (grain_size, np.mean(means)))
            ax.errorbar(list(area_curve.keys()), list(area_curve.values()), yerr=sd, fmt="x", capsize=2)
            ax.plot()
            
            power = (np.log10(area_curve[self.dimensions[0]**2]) - np.log10(area_curve[1])) / (np.log10(self.dimensions[0]**2) - np.log10(1))

        
        ax.set_xlabel("Area")
        ax.set_ylabel("Number of species")
        
        if log:
            ax.set_xscale("log")
            ax.set_yscale("log")
            
        print("power: {}".format(power))
        
    

# Experiments

In [5]:
l = Lattice((10,10), 100, ("watts-strogatz", 3, 0), 1/3)
%time l.run(100)

l.area_curve("quadrats")

CPU times: user 1.46 s, sys: 0 ns, total: 1.46 s
Wall time: 1.49 s


<IPython.core.display.Javascript object>

grain size: 1, mean: 100.000
grain size: 2, mean: 300.407
grain size: 3, mean: 592.078
grain size: 4, mean: 977.224
grain size: 5, mean: 1457.056
grain size: 6, mean: 2028.400
grain size: 7, mean: 2694.125
grain size: 8, mean: 3428.778
grain size: 9, mean: 4160.500
grain size: 10, mean: 4748.000
power: 0.8382553551412768


## Mean number of species over time

In [6]:
dimensions = (6,6)
l = Lattice(dimensions, 100, ("watts-strogatz", 3, 0), 1/3)
step_size = 10
steps = 10

fig, ax = plt.subplots()
line, = plt.plot([],[], "o")
ax.set_xlim(0, steps*step_size)
ax.set_xlabel("Steps")
ax.set_ylabel("Mean number of species")

def animate(i):
    if i > 0:
        l.run(step_size)

    t = i * step_size
    
    ax.set_title("step: {}".format(t))
    line.set_xdata(np.append(line.get_data()[0], [t]))
    line.set_ydata(np.append(line.get_data()[1], [l.mean_species()]))
    
    ax.relim()
    ax.autoscale_view()
    
    return line

ani = animation.FuncAnimation(fig, animate, np.arange(steps+1), interval=500, repeat=False, blit=True, save_count=1000)

<IPython.core.display.Javascript object>

## Area curve over time

In [7]:
np.random.seed(0)

l = Lattice((10,10), 100, ("watts-strogatz", 3, 0), 1/3)
l.run(1000, collect_data=True)

fig, ax = plt.subplots()
line, = ax.plot([],[], "o")
ax.set_xlabel("Area")
ax.set_ylabel("Number of species")
# ax.set_xscale("log")
# ax.set_yscale("log")

step_size = 10

def animate(data):
    step, species = data
    
    ax.set_title("step: {}".format(step * step_size))

    # calculate area curve
    area_curve = collections.OrderedDict()
    for grain_size in range(1, max(dimensions)+1):
        means = []
        i = 0
        for interval_i in range(0, l.dimensions[0]-grain_size+1):
            for interval_j in range(0, l.dimensions[1]-grain_size+1):
                unique_species = set()
                for i,j in np.ndindex((grain_size,grain_size)):
                    ii = interval_i + i
                    jj = interval_j + j
                    unique_species = unique_species.union(species[ii,jj])
                    i += 1

                means.append(len(unique_species))

        area_curve[grain_size**2] = np.mean(means)
    line.set_xdata(list(area_curve.keys()))
    line.set_ydata(list(area_curve.values()))

    ax.relim()
    ax.autoscale_view()
    
    return line

ani = animation.FuncAnimation(fig, animate, enumerate(l.data[::step_size]), interval=100, repeat=False, blit=True, save_count=1000)

<IPython.core.display.Javascript object>