In [9]:
import numpy as np


class PMC(object):

    def __init__(self, probabilities, k):
        self.p: np.ndarray = probabilities
        self.n, self.m = probabilities.shape
        self.k = k

    def arms(self): return self.n * self.m
    
    def play_super_arm(self, super_arm):
        played = super_arm.reshape(self.n, self.m)
        arm_outcomes = (np.random.rand(self.n, self.m) < self.p).astype(int) * played
        activated = arm_outcomes.any(axis=0).astype(int)
        return activated.sum(), arm_outcomes.flatten()

    def play_super_arm_average(self, super_arm, trials):
        accumulative = 0
        for _ in range(trials): accumulative += self.play_super_arm(super_arm)[0]
        return accumulative / trials
    

In [14]:
import numpy as np

class CUCB(object):
    def __init__(self, arms, play):
        self.arms = arms
        self.play = play
        self.temp = []

        self.newest_play = None
        
    def initialize(self):
        self.times_played = np.zeros(self.arms)
        self.means = np.ones(self.arms) 
    
    def oracle(self, mu):
        raise NotImplementedError

    def train(self, iterations):
        self.initialize()
        for t in range(1, iterations + 1):
            super_arm = self.oracle(np.minimum(self.means + np.sqrt(1.5 * np.log(t) / self.times_played), 1))
            x, arm_outcomes = self.play(super_arm)

            self.temp.append(x)
            if len(self.temp) == 1000:
                print(t, sum(self.temp) / 1000)
                self.temp = []
            
            self.newest_play = super_arm
            self.means = (self.means * self.times_played + arm_outcomes) / (self.times_played + super_arm)
            self.times_played = self.times_played + super_arm

In [15]:
import numpy as np

from cucb import CUCB

class PMCTrainer(CUCB):
    
    def __init__(self, n, m, k, play):
        super().__init__(n * m, play)
        self.n, self.m, self.k = n, m, k

    def initialize(self):
        self.times_played = np.zeros(self.arms)
        accumulative = np.zeros(self.arms)
        for i in range(self.n // self.k):
            nodes = np.zeros((self.n, 1))
            for j in range(i * self.k, (i + 1) * self.k):
                nodes[j] = 1
            super_arm = (nodes @ np.ones((1, self.m))).flatten() 
            _, outcomes = self.play(super_arm)
            self.times_played += super_arm
            accumulative += outcomes
        nodes = np.zeros((self.n, 1))
        for j in range(-self.k, 0):
            nodes[j] = 1
        super_arm = (nodes @ np.ones((1, self.m))).flatten() 
        _, outcomes = self.play(super_arm)
        self.times_played += super_arm
        accumulative += outcomes

        self.means = accumulative / self.times_played

    def oracle(self, mu):
        p = mu.reshape(self.n, self.m)
        nodes_played = np.zeros((self.n, 1))
        unactivated_portions = np.ones((self.m,1))
        for _ in range(self.k):
            # greedily pick a node that activates the most
            new_node = ((p @ unactivated_portions) * (1 - nodes_played)).argmax()
            nodes_played[new_node] = 1
            unactivated_portions = unactivated_portions * (1 - p[new_node:new_node+1, :].T)
        return (nodes_played @ np.ones((1, self.m))).flatten() 


In [16]:
import numpy as np
import scipy.sparse

from pmc_solve import PMCTrainer
from pmc import PMC



def main():
    n = 200
    m = 100
    k = 10
    p = scipy.sparse.random(n, m, density=0.05).toarray()

    problem = PMC(p, k)
    model = PMCTrainer(n, m, k, problem.play_super_arm)

    print("--->", problem.play_super_arm_average(model.oracle(p.flatten()), trials=1000))

    try:
        model.train(1000000)
    except KeyboardInterrupt:
        print("Interrupted")

    print("Train complete")
    print("--->", problem.play_super_arm_average(model.oracle(p.flatten()), trials=1000))
    print("final", problem.play_super_arm_average(model.newest_play, trials=1000))

    

if __name__ == "__main__": main()



---> 47.234
1000 5.566
2000 5.683
3000 22.5
4000 26.756
5000 27.371
6000 27.378
7000 27.908
8000 28.033
9000 28.285
10000 28.623
11000 28.637
12000 29.092
13000 29.18
14000 29.119
15000 29.507
16000 29.499
17000 29.687
18000 29.858
19000 30.167
20000 30.147
21000 30.355
22000 30.467
23000 31.003
24000 30.67
25000 30.801
26000 30.855
27000 31.169
28000 31.341
29000 31.452
30000 31.373
31000 31.597
32000 31.973
33000 31.936
34000 31.923
35000 31.893
36000 32.214
37000 32.33
38000 32.504
39000 32.631
40000 32.581
41000 32.79
42000 32.876
43000 33.031
44000 33.088
45000 33.031
46000 33.438
47000 33.435
48000 33.454
49000 33.49
50000 33.702
51000 33.626
52000 33.836
53000 33.798
54000 34.071
55000 34.215
56000 34.176
57000 34.384
58000 34.606
59000 34.476
60000 34.553
61000 34.763
62000 34.812
63000 34.824
64000 34.718
65000 34.772
66000 34.695
67000 34.936
68000 35.073
69000 35.395
70000 35.242
71000 35.736
72000 35.76
73000 35.652
74000 35.378
75000 35.539
76000 35.502
77000 35.852
78000 