In [1]:
import pandas as pd
import numpy as np

In [2]:
class PacardoCell():
    def __init__(self, wires, price):
        self.wires = wires
        self.price = price
        
    def calculate_cost(self, viewed_wires):
        return len(np.setdiff1d(self.wires,viewed_wires))/self.price

In [3]:
table = [
    PacardoCell([1,2], 6),#0
    PacardoCell([1,2], 2),#1
    PacardoCell([1,3], 3),#2
    PacardoCell([6], 4),#3
    PacardoCell([6], 5),#4
    PacardoCell([], 6),#5
    PacardoCell([2], 7),#6
    PacardoCell([3,6], 8),#7
    PacardoCell([3,5], 9),#8
    PacardoCell([5], 10),#9
    PacardoCell([2], 11),#10
    PacardoCell([2,4], 12),#11
    PacardoCell([4,6], 13),#12
    PacardoCell([4,3,5],14),#13
    PacardoCell([5], 15)#14
]

In [4]:
table2 = [
    PacardoCell([1,2], 15),#0
    PacardoCell([1,2], 14),#1
    PacardoCell([1,3], 13),#2
    PacardoCell([6], 12),#3
    PacardoCell([6], 11),#4
    PacardoCell([], 10),#5
    PacardoCell([2], 9),#6
    PacardoCell([3,6], 8),#7
    PacardoCell([3,5], 7),#8
    PacardoCell([5], 6),#9
    PacardoCell([2], 5),#10
    PacardoCell([2,4], 4),#11
    PacardoCell([4,6], 3),#12
    PacardoCell([4,3,5],2),#13
    PacardoCell([5], 7)#14
]

In [5]:
table3 = [
    PacardoCell([1,2], 1),#0
    PacardoCell([1,2], 1),#1
    PacardoCell([1,3], 1),#2
    PacardoCell([6], 1),#3
    PacardoCell([6], 1),#4
    PacardoCell([], 1),#5
    PacardoCell([2], 1),#6
    PacardoCell([3,6], 1),#7
    PacardoCell([3,5], 1),#8
    PacardoCell([5], 1),#9
    PacardoCell([2], 1),#10
    PacardoCell([2,4], 1),#11
    PacardoCell([4,6], 1),#12
    PacardoCell([4,3,5],1),#13
    PacardoCell([5], 1)#14
]

In [6]:
def get_position(available_cells):
    return np.where(available_cells == 0)[0]

In [7]:
def update_state(available_cells, position):
    available_cells[position] = 1

In [8]:
def update_wires(pacardoCell, checked_wires):
    for w in pacardoCell.wires :
        if not checked_wires.count(w):
            checked_wires.append(w)

In [9]:
def greedy_policy(table, available_cells, penalty_cache, exp_rate = .3):
    action = None
    if np.random.uniform(0, 1) <= exp_rate:
        action = np.random.choice(np.where(available_cells == 0)[0])
    else:    
        value_max = -np.inf
        for p in np.where(available_cells == 0)[0]:
            cost = 0 if (penalty_cache.get(p)) is None else penalty_cache.get(p)
            if (cost >= value_max):
                value_max = cost
                action = p
    return action

In [10]:
def feedReward(table, penalty_cache, states_cache):
    for st in reversed(states_cache):
        if penalty_cache.get(st) is None:
            penalty_cache[st] = 0
        wires =  np.inf if len(table[st].wires) == 0 else table[st].price/len(table[st].wires)
        penalty_cache[st] += .01 * (.9 * (-1 * wires) - penalty_cache[st])

In [11]:
def fit(table, penalty_cache):
    available_cells = np.zeros(15)
    checked_wires = []
    states_cache = []
    while(len(checked_wires) < 6):
        action = greedy_policy(table, available_cells, penalty_cache)
        states_cache.append(action)
        update_state(available_cells, action)
        update_wires(table[action], checked_wires)
        if (len(checked_wires) >= 6):
            return states_cache
        
        

In [12]:
import operator
from tqdm import tqdm
def fit_epoch(epochs, table):
    penalty_cache = {}
    for _ in tqdm(range(epochs)):
        states_cache = fit(table, penalty_cache)
        feedReward(table, penalty_cache, states_cache)
    #выводим результаты     
    tmp = []  
    tmp2 = []
    for i in tqdm(reversed(sorted(penalty_cache.items(), key=operator.itemgetter(1)))) :
        if (len(tmp) < 6):
            for w in table[i[0]].wires:
                if (not tmp.count(w)):
                    tmp.append(w)
                    tmp2.append(i[0])
    return set(tmp2)

In [13]:
fit_epoch(100000, table)

100%|██████████| 100000/100000 [00:10<00:00, 9679.21it/s]
15it [00:00, 64726.91it/s]


{1, 7, 8, 13}

In [14]:
fit_epoch(100000, table3)

100%|██████████| 100000/100000 [00:07<00:00, 12599.22it/s]
15it [00:00, 36856.80it/s]


{0, 7, 13}

In [15]:
fit_epoch(100000, table2)

100%|██████████| 100000/100000 [00:09<00:00, 10832.05it/s]
15it [00:00, 67216.41it/s]


{2, 11, 12, 13}