In [1]:
# Imports
import numpy as np
import os
import matplotlib.pyplot as plt

In [90]:
'''
Simulation Class simulates the experiment environment for Reinforcement Learning agent
Arguments:
    - n_maxstep: the maximum number of steps the agent can take in the simulation
    - dictionary: each row represents a different electrode-amplitude pair. Each column represents a activation probability of different cell.
    - n_elecs: the number of electrodes
    - n_amps: the number of amplitudes
    - elecs: the electrode numbers corresponding to the rows of the dictionary
    - amps: the amplitudes corresponding to the rows of the dictionary
    - elec_map: maps electrode numbers to their locations on the brain
    - cell_ids: the cell ids corresponding to the columns of the dictionary

Variables:
    - n_step: the current number of steps taken in the simulation
    - elec: the current electrode
    - amp: the current amplitude
    - done: whether the episode is done
    - reward: the reward for the current step
    - state: the current state of the simulation
    
Functions:
    - __init__: initializes the simulation
    - reset: resets the simulation to the initial state
    - step: takes in an action and returns the next state, reward, and whether the episode is done
    - sample: samples cell activations using the probability specified in the dictionary
    - render: renders the current state of the simulation
    - close: closes the simulation
'''
class SimulationEnv:
    def __init__(self, path, reward_func, score_func, n_maxstep, n_elecs, n_amps):
        # Load relevant data from .npz files
        try:
            with np.load(os.path.join(path,"dictionary.npz")) as data:
                self.dict = data["dictionary"]
                self.elecs = data["entry_elecs"]
                self.amps = data["entry_amps"]
                self.elec_map = data["elec_map"]
            with np.load(os.path.join(path,"decoders.npz")) as data:
                self.cell_ids = data["cell_ids"]
        except FileNotFoundError:
            print("Please make sure the dictionary.npz and decoders.npz files are present in the specified path")

        # Initialize variables
        self.reward_func = reward_func
        self.score_func = score_func
        self.n_maxstep = n_maxstep
        self.n_elecs = n_elecs
        self.n_amps = n_amps
        self.n_cells = len(self.cell_ids)
        self.reset()
        
    def reset(self):
        # Reset variables
        self.n_step = 0
        self.elec = 0 # electrode number (1~n_elecs)
        self.amp = 0 # amplitude (1~n_amps)
        self.done = False
        self.reward = 0
        self.dict_hat = np.zeros((self.n_elecs*self.n_amps, len(self.cell_ids)), dtype=np.uint16)
        self.state = 0
        return self.state
    
    def step(self, action):
        self.elec = action[0]
        self.amp = action[1]

        if self.n_step >= self.n_maxstep:
            self.done = True
        else:
            sampled_activations = self.sample(self.elec, self.amp)
            self.dict_hat[(self.elec-1)*self.n_amps + (self.amp-1)] += sampled_activations
            self.state = self.score_func(self.dict_hat)
            self.reward = self.reward_func(sampled_activations)
            self.n_step += 1
        return self.state, self.reward, self.done
    
    def sample(self, elec, amp):
        try:
            idx = np.where((self.elecs == elec) & (self.amps == amp))[0][0]
            dist = self.dict[idx]
        except IndexError:
            # print(f"Electrode {elec} with amplitude {amp} was not in the dictionary")
            # print(f"Assume no cells were activated")
            dist = np.zeros(len(self.cell_ids), dtype=np.float64)

        if np.any(dist < 0):
            invalid_idx = np.where(dist < 0)[0]
            print(f"Invalid value at index {invalid_idx}: {dist[invalid_idx]}")
            dist[invalid_idx] = 0
        if np.any(dist > 1):
            invalid_idx = np.where(dist > 1)[0]
            print(f"Invalid value at index {invalid_idx}: {dist[invalid_idx]}")
            dist[invalid_idx] = 1
        if np.any(np.isnan(dist)):
            invalid_idx = np.where(np.isnan(dist))[0]
            print(f"Invalid value at index {invalid_idx}: {dist[invalid_idx]}")
            dist[invalid_idx] = 0

        # if np.any(dist < 0) or np.any(dist > 1) or np.any(np.isnan(dist)):
        #     print(f"Invalid distribution: {dist}")
        #     dist = np.zeros(len(self.cell_ids), dtype=np.float64)
        sampled_activations = np.random.binomial(1, dist).astype(dtype=np.uint8)

        return sampled_activations
    
    def render(self, elec, amp):
        print(self.dict_hat[(elec-1)*self.n_amps + (amp-1)])
    
    def close(self):
        pass

In [None]:
'''
Simulation Class simulates the experiment environment for Reinforcement Learning agent
Arguments:
    - n_maxstep: the maximum number of steps the agent can take in the simulation
    - dictionary: each row represents a different electrode-amplitude pair. Each column represents a activation probability of different cell.
    - n_elecs: the number of electrodes
    - n_amps: the number of amplitudes
    - elecs: the electrode numbers corresponding to the rows of the dictionary
    - amps: the amplitudes corresponding to the rows of the dictionary
    - elec_map: maps electrode numbers to their locations on the brain
    - cell_ids: the cell ids corresponding to the columns of the dictionary

Variables:
    - n_step: the current number of steps taken in the simulation
    - elec: the current electrode
    - amp: the current amplitude
    - done: whether the episode is done
    - reward: the reward for the current step
    - state: the current state of the simulation
    
Functions:
    - __init__: initializes the simulation
    - reset: resets the simulation to the initial state
    - step: takes in an action and returns the next state, reward, and whether the episode is done
    - sample: samples cell activations using the probability specified in the dictionary
    - render: renders the current state of the simulation
    - close: closes the simulation
'''
class FullStateSimulationEnv:
    def __init__(self, path, reward_func, score_func, n_maxstep, n_elecs, n_amps):
        # Load relevant data from .npz files
        try:
            with np.load(os.path.join(path,"dictionary.npz")) as data:
                self.dict = data["dictionary"]
                self.elecs = data["entry_elecs"]
                self.amps = data["entry_amps"]
                self.elec_map = data["elec_map"]
            with np.load(os.path.join(path,"decoders.npz")) as data:
                self.cell_ids = data["cell_ids"]
        except FileNotFoundError:
            print("Please make sure the dictionary.npz and decoders.npz files are present in the specified path")

        # Initialize variables
        self.reward_func = reward_func
        self.score_func = score_func
        self.n_maxstep = n_maxstep
        self.n_elecs = n_elecs
        self.n_amps = n_amps
        self.n_cells = len(self.cell_ids)
        self.reset()
        
    def reset(self):
        # Reset variables
        self.n_step = 0
        self.elec = 0 # electrode number (1~n_elecs)
        self.amp = 0 # amplitude (1~n_amps)
        self.done = False
        self.reward = 0
        self.state = np.zeros((self.n_elecs*self.n_amps, len(self.cell_ids)), dtype=np.uint16)
        return self.state
    
    def step(self, action):
        self.elec = action[0]
        self.amp = action[1]

        if self.n_step >= self.n_maxstep:
            self.done = True
        else:
            sampled_activations = self.sample(self.elec, self.amp)
            self.state[(self.elec-1)*self.n_amps + (self.amp-1)] += sampled_activations
            self.state = self.score_func(self.state)
            self.reward = self.reward_func(sampled_activations)
            self.n_step += 1
        return self.state, self.reward, self.done
    
    def sample(self, elec, amp):
        try:
            idx = np.where((self.elecs == elec) & (self.amps == amp))[0][0]
            dist = self.dict[idx]
        except IndexError:
            # print(f"Electrode {elec} with amplitude {amp} was not in the dictionary")
            # print(f"Assume no cells were activated")
            dist = np.zeros(len(self.cell_ids), dtype=np.float64)

        if np.any(dist < 0):
            invalid_idx = np.where(dist < 0)[0]
            print(f"Invalid value at index {invalid_idx}: {dist[invalid_idx]}")
            dist[invalid_idx] = 0
        if np.any(dist > 1):
            invalid_idx = np.where(dist > 1)[0]
            print(f"Invalid value at index {invalid_idx}: {dist[invalid_idx]}")
            dist[invalid_idx] = 1
        if np.any(np.isnan(dist)):
            invalid_idx = np.where(np.isnan(dist))[0]
            print(f"Invalid value at index {invalid_idx}: {dist[invalid_idx]}")
            dist[invalid_idx] = 0

        # if np.any(dist < 0) or np.any(dist > 1) or np.any(np.isnan(dist)):
        #     print(f"Invalid distribution: {dist}")
        #     dist = np.zeros(len(self.cell_ids), dtype=np.float64)
        sampled_activations = np.random.binomial(1, dist).astype(dtype=np.uint8)

        return sampled_activations
    
    def render(self, elec, amp):
        print(self.state[(elec-1)*self.n_amps + (amp-1)])
    
    def close(self):
        pass

In [91]:
'''
Reward func calculates the reward for the agent
'''
def inverse_reward_func(array):
    if np.sum(array) == 1:
        return 1
    elif np.sum(array) == 0:
        return 0
    else:
        return 1/np.sum(array)

def span_reward_function(array):
    pass

In [92]:
'''
Score func calculates how well the produced dictionary is
Score is between 0 and the number of cells
'''
def span_score_func(vectors):
    # calculate the span of a set of vectors
    # :param vectors: list of vectors
    # :return: span of vectors
    span = np.linalg.matrix_rank(vectors)
    return span

def cosine_sim_score_func(vectors):
    # calculate the cosine similarity of a set of vectors
    # :param vectors: list of vectors
    # :return: cosine similarity of vectors
    sim = np.dot(vectors, vectors.T)
    return sim

In [109]:
# Eplison Greedy Agent
'''
Epsilon Greedy Agent takes in a simulation environment and uses epsilon greedy policy to find the optimal policy
'''

class EpsilonGreedyAgent:
    def __init__(self, env, epsilon=0.1, gamma=0.9, alpha=0.1):
        self.env = env
        self.epsilon = epsilon
        self.gamma = gamma
        self.alpha = alpha
        self.Q = np.zeros((env.n_cells, env.n_elecs*env.n_amps), dtype=np.float32)
        self.n = np.zeros((env.n_cells, env.n_elecs*env.n_amps), dtype=np.uint16)
        self.policy = np.zeros((env.n_cells), dtype=np.uint16)
        self.reset()

    def reset(self):
        self.state = self.env.reset()
        self.action = self.policy[self.state]
        return self.state, self.action
    
    def get_action(self, state):
        if np.random.random() < self.epsilon:
            action_idx = np.random.randint(0, self.env.n_elecs*self.env.n_amps)
        else:
            action_idx = np.argmax(self.Q[state])
            
        # convert action_idx to (elec, amp)
        action = (action_idx//self.env.n_amps + 1, action_idx%self.env.n_amps + 1)
        return action
    
    def step(self):
        self.action = self.get_action(self.state)
        self.s_next, self.reward,self.done = self.env.step(self.action)
        return self.s_next, self.reward, self.done
    
    def update(self):
        self.n[self.state, self.action] += 1
        self.Q[self.state, self.action] += self.alpha*(self.reward + self.gamma*np.max(self.Q[self.s_next]))
        self.policy[self.state] = np.argmax(self.Q[self.state])
        self.state = self.s_next
        return self.state, self.action
    
    def run(self, n_episodes=1000):
        self.reset()
        for i in range(n_episodes):
            # while not self.done:
            self.step()
            self.update()
        return self.policy
    
    def render(self):
        pass


In [110]:
experiments = ["2022-11-04-2", "2022-11-28-1"]
path = f"./data/{experiments[0]}/dictionary"

In [111]:
sim = SimulationEnv(path, reward_func=inverse_reward_func, score_func=span_score_func, n_maxstep=5000, n_elecs=512, n_amps=42)

In [112]:
agent = EpsilonGreedyAgent(sim, epsilon=0.8, gamma=0.9, alpha=0.1)
policy = agent.run(n_episodes=1000)

Invalid value at index [107]: [-1.11022302e-16]
Invalid value at index [114]: [-1.11022302e-16]
Invalid value at index [107]: [-1.11022302e-16]
Invalid value at index [40]: [-1.11022302e-16]


In [113]:
25*512*42

537600

In [114]:
span_score_func(agent.Q)

46

In [115]:
def display_non_zero(array):
    non_zero_entries = np.nonzero(array)
    print("Non-zero entries:")
    for i, j in zip(*non_zero_entries):
        print("Row: {}, Column: {}, Value: {}".format(i, j, array[i][j]))

In [116]:
display_non_zero(agent.Q)

Non-zero entries:
Row: 0, Column: 1, Value: 0.10000000149011612
Row: 1, Column: 1, Value: 0.10900000482797623
Row: 1, Column: 2, Value: 0.009000000543892384
Row: 1, Column: 17, Value: 0.009000000543892384
Row: 1, Column: 19, Value: 0.10000000149011612
Row: 1, Column: 203, Value: 0.009000000543892384
Row: 1, Column: 441, Value: 0.10000000149011612
Row: 2, Column: 1, Value: 0.10000000149011612
Row: 2, Column: 20, Value: 0.10000000149011612
Row: 2, Column: 21, Value: 0.10000000149011612
Row: 2, Column: 386, Value: 0.10000000149011612
Row: 3, Column: 7, Value: 0.10000000149011612
Row: 3, Column: 410, Value: 0.10000000149011612
Row: 4, Column: 26, Value: 0.05000000074505806
Row: 4, Column: 101, Value: 0.05000000074505806
Row: 5, Column: 11, Value: 0.05000000074505806
Row: 5, Column: 474, Value: 0.05000000074505806
Row: 6, Column: 35, Value: 0.03333333507180214
Row: 6, Column: 453, Value: 0.03333333507180214
Row: 7, Column: 13, Value: 0.10000000149011612
Row: 7, Column: 18, Value: 0.10000000