In [1]:
from dataclasses import dataclass
from typing import Dict, List, Tuple, Set, Iterable
import random, time, json
from collections import deque

In [2]:

class DiGraph:
    def __init__(self, n: int):
        self.n = n
        self.adj = {i: [] for i in range(n)}
        self.nodes = list(range(n))
    def add_edge(self, u: int, v: int, p: float):
        assert 0 <= u < self.n and 0 <= v < self.n
        assert 0.0 <= p <= 1.0
        self.adj[u].append((v, p))
    def neighbors(self, u: int):
        return self.adj[u]

class LiveEdgeSampler:
    def __init__(self, g: DiGraph, R: int, seed: int = 42):
        self.g = g; self.R = R; self.random = random.Random(seed)
        self.samples = []; self._generate_samples()
    def _generate_samples(self):
        self.samples = []
        for _ in range(self.R):
            adj_live = {i: [] for i in self.g.nodes}
            for u in self.g.nodes:
                for v, p in self.g.neighbors(u):
                    if self.random.random() < p:
                        adj_live[u].append(v)
            self.samples.append(adj_live)
    def reach_set(self, sample_idx: int, seeds: Iterable[int]) -> Set[int]:
        adj = self.samples[sample_idx]
        vis = set(); q = deque()
        for s in seeds:
            if s not in vis:
                vis.add(s); q.append(s)
        while q:
            u = q.popleft()
            for v in adj[u]:
                if v not in vis:
                    vis.add(v); q.append(v)
        return vis
    def reach_set_single(self, sample_idx: int, s: int) -> Set[int]:
        return self.reach_set(sample_idx, [s])

class InfluenceEstimator:
    def __init__(self, sampler: LiveEdgeSampler):
        self.sam = sampler
    def influence(self, S: Iterable[int]) -> float:
        S = set(S); total = 0
        for r in range(self.sam.R):
            total += len(self.sam.reach_set(r, S))
        return total / self.sam.R
    def marginal_gain_node(self, S: Set[int], v: int) -> float:
        if v in S: return 0.0
        total = 0
        for r in range(self.sam.R):
            B = self.sam.reach_set(r, S)
            if v in B: continue
            Rv = self.sam.reach_set_single(r, v)
            total += len(Rv - B)
        return total / self.sam.R
    def marginal_gains_all(self, S: Set[int]):
        return {v: self.marginal_gain_node(S, v) for v in self.sam.g.nodes if v not in S}

class GreedyBasic:
    def __init__(self, estimator: InfluenceEstimator):
        self.est = estimator
    def select(self, k: int):
        S = set()
        for _ in range(k):
            gains = self.est.marginal_gains_all(S)
            if not gains: break
            v = max(gains.items(), key=lambda x: x[1])[0]
            S.add(v)
        return list(S)

class CELF:
    def __init__(self, estimator: InfluenceEstimator):
        self.est = estimator
    def select(self, k: int):
        import heapq
        S = set(); heap = []; gains = {}
        for v in self.est.sam.g.nodes:
            g = self.est.marginal_gain_node(set(), v)
            gains[v] = g; heap.append((-g, v, 0))
        heapq.heapify(heap)
        while len(S) < k and heap:
            negg, v, last = heapq.heappop(heap)
            if last != len(S):
                g = self.est.marginal_gain_node(S, v)
                gains[v] = g; heapq.heappush(heap, (-g, v, len(S)))
            else:
                S.add(v)
        return list(S)

class NewGreedy:
    def __init__(self, sampler: LiveEdgeSampler):
        self.sam = sampler
        self._cache = {}
    def _reach_single_cached(self, r: int, v: int) -> Set[int]:
        key = (r, v)
        if key not in self._cache:
            self._cache[key] = self.sam.reach_set_single(r, v)
        return self._cache[key]
    def _avg_marginal(self, S: Set[int], v: int) -> float:
        if v in S: return 0.0
        total = 0
        for r in range(self.sam.R):
            B = self.sam.reach_set(r, S)
            if v in B: continue
            Rv = self._reach_single_cached(r, v)
            total += len(Rv - B)
        return total / self.sam.R
    def select(self, k: int):
        S = set()
        for _ in range(k):
            gains = {v: self._avg_marginal(S, v) for v in self.sam.g.nodes if v not in S}
            if not gains: break
            v = max(gains.items(), key=lambda x: x[1])[0]
            S.add(v)
        return list(S)

@dataclass
class GAParams:
    Nc: int = 80; Rs: float = 0.5; Rc: float = 0.9; Rm: float = 0.05; Ni: int = 120

class GNA:
    def __init__(self, sampler: LiveEdgeSampler, k: int, params: GAParams = GAParams(), seed: int = 11):
        self.sam = sampler; self.k = k; self.params = params
        self.rand = random.Random(seed); self.est = InfluenceEstimator(sampler)
        self._cache = {}
    def _reach_single_cached(self, r: int, v: int):
        key = (r, v)
        if key not in self._cache:
            self._cache[key] = self.sam.reach_set_single(r, v)
        return self._cache[key]
    def _per_seed_contrib(self, S: Set[int]):
        S = set(S); contrib = {s:0 for s in S}
        for r in range(self.sam.R):
            adj = self.sam.samples[r]
            visited = set(S)
            for s in S: contrib[s] += 1
            q = deque((s, s) for s in S)
            while q:
                u, own = q.popleft()
                for v in adj[u]:
                    if v not in visited:
                        visited.add(v); contrib[own] += 1; q.append((v, own))
        for s in contrib: contrib[s] /= self.sam.R
        return contrib
    def _newgreedy_local_improve(self, S: Set[int]):
        if not S: return S
        contrib = self._per_seed_contrib(S)
        weakest = min(contrib.items(), key=lambda x: x[1])[0]
        nonseeds = [v for v in self.sam.g.nodes if v not in S]
        pool = self.rand.sample(nonseeds, min(25, len(nonseeds))) if nonseeds else []
        base = self.est.influence(S); bestv=None; bestg=0.0
        for v in pool:
            g = self.est.influence(S|{v}) - base
            if g>bestg: bestg, bestv = g, v
        if bestv is not None and bestg>1e-9:
            S2=set(S); S2.remove(weakest); S2.add(bestv); return S2
        return S
    def _random_chrom(self): return set(self.rand.sample(self.sam.g.nodes, self.k))
    def _fitness(self, S): return self.est.influence(S)
    def _roulette(self, pop, fits, keep):
        pairs = list(zip(pop, fits)); pairs.sort(key=lambda x:x[1], reverse=True)
        elite = [pairs[0][0]]; 
        if keep==1: return elite
        tot = sum(max(0.0,f) for f in fits)+1e-9
        probs = [max(0.0,f)/tot for f in fits]
        cum=[]; s=0.0
        for p in probs: s+=p; cum.append(s)
        out = elite[:]
        import bisect, random as _r
        while len(out)<keep:
            r=_r.random(); idx=bisect.bisect_left(cum, r)
            out.append(pop[idx])
        return out
    def _crossover(self, P1: Set[int], P2: Set[int]):
        child = set(P1)|set(P2)
        if len(child)>self.k:
            contrib = self._per_seed_contrib(child)
            for s,_ in sorted(contrib.items(), key=lambda x:x[1]):
                if len(child)<=self.k: break
                child.remove(s)
        while len(child)<self.k:
            v = self.rand.choice(self.sam.g.nodes)
            if v not in child: child.add(v)
        return child
    def _mutate(self, S: Set[int]):
        if self.rand.random()<self.params.Rm and len(S)>0:
            drop = self.rand.choice(list(S))
            S2=set(S); S2.remove(drop)
            choices = [v for v in self.sam.g.nodes if v not in S2]
            if choices: S2.add(self.rand.choice(choices))
            return S2
        return S
    def run(self, verbose=False):
        pop=[self._random_chrom() for _ in range(self.params.Nc)]
        bestS=None; bestf=-1.0; hist=[]
        for it in range(self.params.Ni):
            pop=[self._newgreedy_local_improve(S) for S in pop]
            fits=[self._fitness(S) for S in pop]
            bi=max(range(len(pop)), key=lambda i: fits[i])
            if fits[bi]>bestf: bestf=fits[bi]; bestS=set(pop[bi])
            hist.append(bestf)
            if verbose and (it%max(1,self.params.Ni//10)==0):
                print(f"[GNA] iter={it} best_f={bestf:.3f}")
            keep=max(1,int(self.params.Rs*len(pop)))
            parents=self._roulette(pop,fits,keep)
            children=[]
            while len(children)+len(parents)<len(pop):
                if self.rand.random()<self.params.Rc and len(parents)>=2:
                    c=self._crossover(*self.rand.sample(parents,2))
                else:
                    c=set(self.rand.choice(parents))
                c=self._mutate(c); children.append(c)
            pop=parents+children
        return bestS, bestf, hist

def build_demo_graph():
    g=DiGraph(12)
    edges=[(0,1,0.05),(0,2,0.03),(1,3,0.05),(2,3,0.04),(2,4,0.03),(3,5,0.04),
           (4,5,0.05),(5,6,0.04),(5,7,0.03),(6,8,0.05),(7,8,0.04),(8,9,0.05),
           (8,10,0.03),(9,11,0.06),(10,11,0.05),(1,4,0.02),(3,7,0.03),(2,6,0.02),
           (4,8,0.03),(0,9,0.02),(6,11,0.04)]
    for u,v,p in edges: g.add_edge(u,v,p)
    return g

if __name__ == "__main__":
    random.seed(0)
    g=build_demo_graph(); k=3; R=300
    sampler=LiveEdgeSampler(g,R=R,seed=7)
    est=InfluenceEstimator(sampler)
    gb=GreedyBasic(est); S1=gb.select(k); f1=est.influence(S1)
    celf=CELF(est); S2=celf.select(k); f2=est.influence(S2)
    ng=NewGreedy(sampler); S3=ng.select(k); f3=est.influence(S3)
    gna=GNA(sampler,k,GAParams()); S4,f4,h=gna.run(verbose=False)
    out={"k":k,"R":R,"GreedyBasic":{"S":sorted(S1),"I(S)":f1},
         "CELF":{"S":sorted(S2),"I(S)":f2},
         "NewGreedy":{"S":sorted(S3),"I(S)":f3},
         "GNA":{"S":sorted(S4),"I(S)":f4,"history_len":len(h)}}
    print(json.dumps(out,indent=2))


{
  "k": 3,
  "R": 300,
  "GreedyBasic": {
    "S": [
      0,
      4,
      6
    ],
    "I(S)": 3.3033333333333332
  },
  "CELF": {
    "S": [
      0,
      4,
      6
    ],
    "I(S)": 3.3033333333333332
  },
  "NewGreedy": {
    "S": [
      0,
      4,
      6
    ],
    "I(S)": 3.3033333333333332
  },
  "GNA": {
    "S": [
      0,
      4,
      6
    ],
    "I(S)": 3.3033333333333332,
    "history_len": 120
  }
}
