# Euler 666
https://projecteuler.net/problem=666

In [14]:
import numpy as np
from collections import Counter

k = 2
m = 2

rn = [306]
for n in range(1, k*m+m+1):
    rn.append((rn[-1]*rn[-1]) % 10007)

# The (u,v)th value in Pq is the probability of q=v for species u
Pq = np.zeros((k, 5))
for i in range(k):
    q = [rn[i*m+j] % 5 for j in range(m)]
    qf = Counter(q)
    Pq[i] = [qf[v]/m for v in range(5)]
    
print(Pq)
death_prob = Pq[:,0]
min_death_prob = np.min(np.extract(np.greater(death_prob, np.zeros(len(death_prob))), death_prob))
print(min_death_prob)

[[0.  0.5 0.  0.5 0. ]
 [0.5 0.  0.  0.  0.5]]
0.5


In [15]:
# Tm is transition array of matrices Tm[i] = np.zeros(5, k). A row v Tm[i] will give 
# the change in ith species at a time step when q = v
Tm = []
for i in range(k):
    t = np.zeros((5,k), dtype=np.int)
    t[0][i] = -1
    t[1][i] = 1 # i clones itself resulting in new type of i

    # i mutates in (2i)mod k
    t[2][i] = -1
    s = (2*i) % k
    t[2][s] += 1
    
    t[3][i] = -1
    s = (i*i+1) % k
    t[3][s] += 3
    
    s = (i+1) % k
    t[4][s] += 1
    
    Tm.append(t)
        
print(Tm)        

[array([[-1,  0],
       [ 1,  0],
       [ 0,  0],
       [-1,  3],
       [ 0,  1]]), array([[ 0, -1],
       [ 0,  1],
       [ 1, -1],
       [ 3, -1],
       [ 1,  0]])]


Let X be a map from a tuple of size k to its probability. The ith entry of the tuple is the population of the ith species

In [16]:
from itertools import product
from queue import PriorityQueue



zero_state = (0,) * k
zero_state_prob = 0.0
init_population_distribution = (1,) + (0,) * (k-1)
X = {init_population_distribution: 1.0}

for i in range(100):
    delete_dist = []
    next_X = {}
    
    print(i)
    max_pop = 0
    for p in X:
        #print("curr population state", p)
        nonzero_species = np.where(np.array(p) > 0)[0]
        possible_evolution = set(product(set((0,1,2,3,4)), repeat=len(nonzero_species)))
        
        for evolution in possible_evolution:
            #print("evo", evolution)
            #print(len(evolution))
            new_prob = X[p] * np.prod(Pq[nonzero_species, evolution])
            if new_prob == 0.0:
                continue
            new_state = np.array(p)
            #print("curr population state", new_state, Pq[nonzero_species, evolution])
            
            for i,q in enumerate(evolution):
                #print("Tm[i][q]", nonzero_species[i], q, Tm[nonzero_species[i]][q])
                new_state = new_state + Tm[nonzero_species[i]][q]
            new_state = tuple(new_state)
            prob_death = 1.0
            for i,n in enumerate(new_state):
                if Pq[i][0] < min_death_prob:
                    death_prob = min_death_prob
                else:
                    death_prob = Pq[i][0]
                prob_death *= np.power(death_prob, n)
                #print(n)
            #print("++++++++", min_death_prob, np.sum(np.array(new_state)), prob_death)
            
            #print("new_state", evolution, new_state)
            if new_state == zero_state:
                zero_state_prob += new_prob
                print("zero_p ", i, zero_state_prob)
            else:
                if prob_death > 0.00000001:
                    if new_state in next_X:
                        next_X[new_state] += new_prob
                    else:
                        next_X[new_state] = new_prob
                    new_population = np.sum(np.array(new_state))
                    if new_population > max_pop:
                        max_pop = new_population
                        #print("max_pop", max_pop, prob_death)
                #else:
                #    print("prob_death", prob_death)
    X = next_X
    print("X_len", max_pop, len(X))
    
    
            

0
X_len 3 2
1
X_len 4 3
2
X_len 7 8
3
zero_p  1 0.0625
X_len 10 17
4
X_len 13 35
5
X_len 16 58
6
X_len 19 93
7
zero_p  1 0.0703125
X_len 22 135
8
X_len 25 184
9
X_len 26 235
10
zero_p  1 0.07049560546875
X_len 26 285
11
zero_p  1 0.0718994140625
X_len 26 326
12
X_len 26 353
13
zero_p  1 0.07190895080566406
X_len 26 368
14
zero_p  1 0.07200050354003906
X_len 26 374
15
zero_p  1 0.07228851318359375
X_len 26 376
16
zero_p  1 0.0722891092300415
X_len 26 376
17
zero_p  1 0.07229667901992798
X_len 26 376
18
zero_p  1 0.07233035564422607
X_len 26 376
19
zero_p  1 0.07239411771297455
X_len 26 376
20
zero_p  1 0.07239478826522827
X_len 26 376
21
zero_p  1 0.07239865325391293
X_len 26 376
22
zero_p  1 0.07240978057961911
X_len 26 376
23
zero_p  1 0.07242465048329905
X_len 26 376
24
zero_p  1 0.07242509035859257
X_len 26 376
25
zero_p  1 0.07242673603104777
X_len 26 376
26
zero_p  1 0.07243023752016597
X_len 26 376
27
zero_p  1 0.0724338502259343
X_len 26 376
28
zero_p  1 0.0724340783471007
X_len

In [None]:
from dataclasses import dataclass
from itertools import product
from typing import List

@dataclass
class State:
    p: float # Probability of the state
    state: np.array # list of size k ith element denoting population of i
        
S_list = [State(1.0, np.array([1] + [0]*(k-1)))]

p = 0
p_die = Pq[:][0]

def gen_next_state_set(s):
    
    s_species = np.where(s.state > 0)[0]
    possible_states = set(product(set((0,1,2,3,4)), repeat=len(s_species)))
    
    
    for p in possible_states:
        new_state = np.zeros(k)
        new_state += s.state
        new_state_prob = s.p
        valid_state = True
        for i,q in enumerate(p):
            if Pq[s_species[i]][q] > 0:
                new_state += Tm[s_species[i]][q]
                new_state_prob *= Pq[s_species[i]][q]
            else:
                valid_state = False
                break
        if valid_state:
            ns = State(new_state_prob, new_state)
            print(ns)
            S_list.append(ns)
        
    
                          

#while True:
for t in range(1000):
    s = S_list.pop(0)
    '''
    for i in range(k):
        for q in range(5):
            if Pq[i][q] > 0.0 and s.state[i] > 0:
                ns = State(s.p*Pq[i][q], s.state + Tm[i][q])
                S_list.append(ns)
    '''
    gen_next_state_set(s)
                
    for index, s in enumerate(S_list):
        substate = np.where(s.state > 0)
        p_die_out = s.p * np.prod(np.power(p_die[substate], s.state[substate]))
        if p_die_out > 0.0:
            print(S_list[index])
            #del S_list[index]
            p += p_die_out
            if len(substate) > 1:
                print("This should not happen", p_die[substate], s.state[substate])
            
    print(p, len(S_list))
    #print(S_list)
            
            