In [2]:
# Imports
import numpy as np
import os
import matplotlib.pyplot as plt

In [3]:
'''
Simulation Class simulates the experiment environment for Reinforcement Learning agent
Arguments:
    - n_maxstep: the maximum number of steps the agent can take in the simulation
    - dictionary: each row represents a different electrode-amplitude pair. Each column represents a activation probability of different cell.
    - n_elecs: the number of electrodes
    - n_amps: the number of amplitudes
    - elecs: the electrode numbers corresponding to the rows of the dictionary
    - amps: the amplitudes corresponding to the rows of the dictionary
    - elec_map: maps electrode numbers to their locations on the brain
    - cell_ids: the cell ids corresponding to the columns of the dictionary

Variables:
    - n_step: the current number of steps taken in the simulation
    - elec: the current electrode
    - amp: the current amplitude
    - done: whether the episode is done
    - reward: the reward for the current step
    - state: the current state of the simulation
    
Functions:
    - __init__: initializes the simulation
    - reset: resets the simulation to the initial state
    - step: takes in an action and returns the next state, reward, and whether the episode is done
    - sample: samples cell activations using the probability specified in the dictionary
    - render: renders the current state of the simulation
    - close: closes the simulation
'''
class SimulationEnv:
    def __init__(self, path, reward_func, score_func, n_maxstep, n_elecs, n_amps):
        # Load relevant data from .npz files
        try:
            with np.load(os.path.join(path,"dictionary.npz")) as data:
                self.dict = data["dictionary"]
                self.elecs = data["entry_elecs"]
                self.amps = data["entry_amps"]
                self.elec_map = data["elec_map"]
            with np.load(os.path.join(path,"decoders.npz")) as data:
                self.cell_ids = data["cell_ids"]
        except FileNotFoundError:
            print("Please make sure the dictionary.npz and decoders.npz files are present in the specified path")

        # Initialize variables
        self.reward_func = reward_func
        self.score_func = score_func
        self.n_maxstep = n_maxstep
        self.n_elecs = n_elecs
        self.n_amps = n_amps
        self.n_cells = len(self.cell_ids)
        self.reset()
        
    def reset(self):
        # Reset variables
        self.n_step = 0
        self.elec = 0 # electrode number (1~n_elecs)
        self.amp = 0 # amplitude (1~n_amps)
        self.done = False
        self.reward = 0
        self.dict_hat = np.zeros((self.n_elecs*self.n_amps, len(self.cell_ids)), dtype=np.uint16)
        self.state = 0
        return self.state
    
    def step(self, action):
        self.elec = action[0]
        self.amp = action[1]

        if self.n_step >= self.n_maxstep:
            self.done = True
        else:
            sampled_activations = self.sample(self.elec, self.amp)
            self.dict_hat[(self.elec-1)*self.n_amps + (self.amp-1)] += sampled_activations
            self.state = self.score_func(self.dict_hat)
            self.reward = self.reward_func(sampled_activations)
            self.n_step += 1
        return self.state, self.reward, self.done
    
    def sample(self, elec, amp):
        try:
            idx = np.where((self.elecs == elec) & (self.amps == amp))[0][0]
            dist = self.dict[idx]
        except IndexError:
            # print(f"Electrode {elec} with amplitude {amp} was not in the dictionary")
            # print(f"Assume no cells were activated")
            dist = np.zeros(len(self.cell_ids), dtype=np.float64)

        if np.any(dist < 0):
            invalid_idx = np.where(dist < 0)[0]
            print(f"Invalid value at index {invalid_idx}: {dist[invalid_idx]}")
            dist[invalid_idx] = 0
        if np.any(dist > 1):
            invalid_idx = np.where(dist > 1)[0]
            print(f"Invalid value at index {invalid_idx}: {dist[invalid_idx]}")
            dist[invalid_idx] = 1
        if np.any(np.isnan(dist)):
            invalid_idx = np.where(np.isnan(dist))[0]
            print(f"Invalid value at index {invalid_idx}: {dist[invalid_idx]}")
            dist[invalid_idx] = 0

        # if np.any(dist < 0) or np.any(dist > 1) or np.any(np.isnan(dist)):
        #     print(f"Invalid distribution: {dist}")
        #     dist = np.zeros(len(self.cell_ids), dtype=np.float64)
        sampled_activations = np.random.binomial(1, dist).astype(dtype=np.uint8)

        return sampled_activations
    
    def render(self, elec, amp):
        print(self.dict_hat[(elec-1)*self.n_amps + (amp-1)])
    
    def close(self):
        pass

In [24]:
class FullStateSimulationEnv:
    def __init__(self, path, reward_func, score_func, n_maxstep, n_elecs, n_amps):
        # Load relevant data from .npz files
        try:
            with np.load(os.path.join(path,"dictionary.npz")) as data:
                self.dict = data["dictionary"]
                self.elecs = data["entry_elecs"]
                self.amps = data["entry_amps"]
                self.elec_map = data["elec_map"]
            with np.load(os.path.join(path,"decoders.npz")) as data:
                self.cell_ids = data["cell_ids"]
        except FileNotFoundError:
            print("Please make sure the dictionary.npz and decoders.npz files are present in the specified path")

        # Initialize variables
        self.reward_func = reward_func
        self.score_func = score_func
        self.n_maxstep = n_maxstep
        self.n_elecs = n_elecs
        self.n_amps = n_amps
        self.n_cells = len(self.cell_ids)
        self.reset()
        
    def reset(self):
        # Reset variables
        self.n_step = 0
        self.elec = 0 # electrode number (1~n_elecs)
        self.amp = 0 # amplitude (1~n_amps)
        self.done = False
        self.reward = 0
        self.state = np.zeros((self.n_elecs*self.n_amps, len(self.cell_ids)), dtype=np.uint16)
        return self.state
    
    def step(self, action):
        self.elec = action[0]
        self.amp = action[1]

        if self.n_step >= self.n_maxstep:
            self.done = True
        else:
            sampled_activations = self.sample(self.elec, self.amp)
            self.state[(self.elec-1)*self.n_amps + (self.amp-1)] += sampled_activations
            # self.state = self.score_func(self.state)
            self.reward = self.reward_func(sampled_activations)
            self.n_step += 1
        return self.state, self.reward, self.done
    
    def sample(self, elec, amp):
        try:
            idx = np.where((self.elecs == elec) & (self.amps == amp))[0][0]
            dist = self.dict[idx]
        except IndexError:
            # print(f"Electrode {elec} with amplitude {amp} was not in the dictionary")
            # print(f"Assume no cells were activated")
            dist = np.zeros(len(self.cell_ids), dtype=np.float64)

        if np.any(dist < 0):
            invalid_idx = np.where(dist < 0)[0]
            # print(f"Invalid value at index {invalid_idx}: {dist[invalid_idx]}")
            dist[invalid_idx] = 0
        if np.any(dist > 1):
            invalid_idx = np.where(dist > 1)[0]
            # print(f"Invalid value at index {invalid_idx}: {dist[invalid_idx]}")
            dist[invalid_idx] = 1
        if np.any(np.isnan(dist)):
            invalid_idx = np.where(np.isnan(dist))[0]
            # print(f"Invalid value at index {invalid_idx}: {dist[invalid_idx]}")
            dist[invalid_idx] = 0

        # if np.any(dist < 0) or np.any(dist > 1) or np.any(np.isnan(dist)):
        #     print(f"Invalid distribution: {dist}")
        #     dist = np.zeros(len(self.cell_ids), dtype=np.float64)
        sampled_activations = np.random.binomial(1, dist).astype(dtype=np.uint8)

        return sampled_activations
    
    def render(self, elec, amp):
        print(self.state[(elec-1)*self.n_amps + (amp-1)])
    
    def close(self):
        pass

In [5]:
'''
Reward func calculates the reward for the agent
'''
def inverse_reward_func(array):
    if np.sum(array) == 1:
        return 1
    elif np.sum(array) == 0:
        return 0
    else:
        return 1/np.sum(array)

def span_reward_function(array):
    pass

In [6]:
'''
Score func calculates how well the produced dictionary is
Score is between 0 and the number of cells
'''
def span_score_func(vectors):
    # calculate the span of a set of vectors
    # :param vectors: list of vectors
    # :return: span of vectors
    span = np.linalg.matrix_rank(vectors)
    return span

def cosine_sim_score_func(vectors):
    # calculate the cosine similarity of a set of vectors
    # :param vectors: list of vectors
    # :return: cosine similarity of vectors
    sim = np.dot(vectors, vectors.T)
    return sim

In [7]:
# Eplison Greedy Agent
'''
Epsilon Greedy Agent takes in a simulation environment and uses epsilon greedy policy to find the optimal policy
'''

class EpsilonGreedyAgent:
    def __init__(self, env, epsilon=0.1, gamma=0.9, alpha=0.1):
        self.env = env
        self.epsilon = epsilon
        self.gamma = gamma
        self.alpha = alpha
        self.Q = np.zeros((env.n_cells, env.n_elecs*env.n_amps), dtype=np.float32)
        self.n = np.zeros((env.n_cells, env.n_elecs*env.n_amps), dtype=np.uint16)
        self.policy = np.zeros((env.n_cells), dtype=np.uint16)
        self.reset()

    def reset(self):
        self.state = self.env.reset()
        self.action = self.policy[self.state]
        return self.state, self.action
    
    def get_action(self, state):
        if np.random.random() < self.epsilon:
            action_idx = np.random.randint(0, self.env.n_elecs*self.env.n_amps)
        else:
            action_idx = np.argmax(self.Q[state])
            
        # convert action_idx to (elec, amp)
        action = (action_idx//self.env.n_amps + 1, action_idx%self.env.n_amps + 1)
        return action
    
    def step(self):
        self.action = self.get_action(self.state)
        self.s_next, self.reward,self.done = self.env.step(self.action)
        return self.s_next, self.reward, self.done
    
    def update(self):
        self.n[self.state, self.action] += 1
        self.Q[self.state, self.action] += self.alpha*(self.reward + self.gamma*np.max(self.Q[self.s_next]))
        self.policy[self.state] = np.argmax(self.Q[self.state])
        self.state = self.s_next
        return self.state, self.action
    
    def run(self, n_episodes=1000):
        self.reset()
        for i in range(n_episodes):
            # while not self.done:
            self.step()
            self.update()
        return self.policy
    
    def render(self):
        pass


In [8]:
# Stateless Eplison Greedy Agent
class StatelessEpsilonGreedyAgent:
    def __init__(self, env, epsilon=0.8, gamma=0.9, alpha=0.1):
        self.env = env
        self.epsilon = epsilon
        self.gamma = gamma
        self.alpha = alpha
        self.Q = np.zeros((env.n_elecs*env.n_amps), dtype=np.float32)
        self.n = np.zeros((env.n_elecs*env.n_amps), dtype=np.uint16)
        self.reset()

    def reset(self):
        self.action = self.get_action()
        return self.action
    
    def get_action(self):
        if np.random.random() < self.epsilon:
            action_idx = np.random.randint(0, self.env.n_elecs*self.env.n_amps)
        else:
            action_idx = np.argmax(self.Q)
        return action_idx

    def map_action(self, action_idx):
        # convert action_idx to (elec, amp)
        action = (action_idx//self.env.n_amps + 1, action_idx%self.env.n_amps + 1)
        return action
    
    def step(self):
        self.action = self.get_action()
        _, self.reward,self.done = self.env.step(self.map_action(self.action))
    
    def update(self):
        self.n[self.action] += 1
        self.Q[self.action] = self.reward + self.gamma*np.max(self.Q[self.action])
    
    def run(self, n_episodes=1000):
        self.reset()
        for i in range(n_episodes):
            # while not self.done:
            self.step()
            self.update()
        return self.Q
    
    def render(self):
        pass


In [25]:
experiments = ["2022-11-04-2", "2022-11-28-1"]
path = f"./data/{experiments[0]}/dictionary"

In [26]:
sim = FullStateSimulationEnv(path, reward_func=inverse_reward_func, score_func=span_score_func, n_maxstep=100000, n_elecs=512, n_amps=42)

In [27]:
agent = StatelessEpsilonGreedyAgent(sim, epsilon=0.7, gamma=0.9, alpha=0.1)
Q_value = agent.run(n_episodes=50000)

In [35]:
def display_non_zero(array):
    # array is a 2D numpy array
    non_zero_entries = np.nonzero(array)
    print("Number of non-zero entries: {}".format(len(non_zero_entries[0])))

    # sort non-zero entries by value
    non_zero_entries = zip(*sorted(zip(*non_zero_entries), key=lambda x: array[x[0]][x[1]], reverse=True))
    
    print("Non-zero entries:")
    for i, j in zip(*non_zero_entries):
        print("Row: {}, Column: {}, Value: {}".format(i, j, array[i][j]))

In [36]:
display_non_zero(Q_value.reshape(-1,1))

Number of non-zero entries: 1922
Non-zero entries:
Row: 14228, Column: 0, Value: 9.073711395263672
Row: 3994, Column: 0, Value: 4.095099925994873
Row: 4772, Column: 0, Value: 4.095099925994873
Row: 4973, Column: 0, Value: 4.095099925994873
Row: 5175, Column: 0, Value: 4.095099925994873
Row: 12078, Column: 0, Value: 4.095099925994873
Row: 13258, Column: 0, Value: 4.095099925994873
Row: 14459, Column: 0, Value: 4.095099925994873
Row: 14635, Column: 0, Value: 4.095099925994873
Row: 16067, Column: 0, Value: 4.095099925994873
Row: 16564, Column: 0, Value: 4.095099925994873
Row: 12073, Column: 0, Value: 3.940345048904419
Row: 2112, Column: 0, Value: 3.883690118789673
Row: 1442, Column: 0, Value: 3.6451001167297363
Row: 3210, Column: 0, Value: 3.595100164413452
Row: 273, Column: 0, Value: 3.439000129699707
Row: 1310, Column: 0, Value: 3.439000129699707
Row: 2950, Column: 0, Value: 3.439000129699707
Row: 4337, Column: 0, Value: 3.439000129699707
Row: 4338, Column: 0, Value: 3.439000129699707
R

(1922,)

In [51]:
action_idx = 3995

In [43]:
def action2elec_amp(action_idx, n_amps):
    # convert action_idx to (elec, amp)
    (elec, amp) = (action_idx//n_amps + 1, action_idx%n_amps + 1)
    return (elec, amp)

In [44]:
def load_dictionary(path):
    # Load relevant data from .npz files
    try:
        with np.load(os.path.join(path,"dictionary.npz")) as data:
            dict = data["dictionary"]
            elecs = data["entry_elecs"]
            amps = data["entry_amps"]
            elec_map = data["elec_map"]
        with np.load(os.path.join(path,"decoders.npz")) as data:
            cell_ids = data["cell_ids"]
    except FileNotFoundError:
        print("Please make sure the dictionary.npz and decoders.npz files are present in the specified path")

    return dict, elecs, amps, elec_map, cell_ids

In [46]:
def get_dict_from_action_idx(action_idx, n_amps, path):
    # convert action_idx to (elec, amp)
    (elec, amp) = action2elec_amp(action_idx, n_amps)
    
    dict, elecs, amps, elec_map, cell_ids = load_dictionary(path)

    # get dictionary entry
    try:
        idx = np.where((elecs == elec) & (amps == amp))[0][0]
        dict_entry = dict[idx]
    except IndexError:
        print(f"Electrode {elec} with amplitude {amp} was not in the dictionary")
        print(f"Assume no cells were activated")
        dict_entry = np.zeros(len(cell_ids), dtype=np.float64)

    return dict_entry

In [52]:
get_dict_from_action_idx(action_idx, sim.n_amps, path)

array([0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  ,
       1.  , 0.  , 0.  , 0.  , 0.  , 0.05, 0.  , 0.  , 0.  , 0.  , 0.  ,
       0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.05, 0.  ,
       0.  , 0.05, 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  ,
       0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  ,
       0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  ,
       0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  ,
       0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  ,
       0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  ,
       0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.05, 0.  , 0.  ,
       0.05, 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  ,
       0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  ,
       0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  ,
       0.  , 0.  , 0.  , 0.  ])

In [59]:
new_dict = []
for a in np.nonzero(Q_value)[0]:
    new_dict.append(get_dict_from_action_idx(a, sim.n_amps, path))
new_dict = np.array(new_dict)

In [60]:
new_dict.shape

(1922, 147)

In [61]:
span_score_func(new_dict)

109

In [64]:
dict = load_dictionary(path)[0]

In [65]:
dict.shape

(4702, 147)

In [76]:
def mag_cosine_sim_score_func(vectors):
    score = 0.0
    _, dim = vectors.shape
    # calculate the magnitude of each vector
    mag = np.linalg.norm(vectors, axis=1)

    # calculate the magnitude difference between each vector and unit vector
    abs_mag_diff = np.abs(mag - 1)

    # calculate the cosine similarity between each vector and unit vector
    for d in range(dim):
        unit_vec = np.zeros(dim)
        unit_vec[d] = 1
        cos_sim = np.dot(vectors, unit_vec) / mag
        score += np.max(cos_sim)-abs_mag_diff[np.argmax(cos_sim)]

    return score

In [79]:
vectors = np.array([[1,0,0],[0,0.1,0],[0,0,1]])
score = mag_cosine_sim_score_func(vectors)
print(score)

2.1


In [67]:
hoge = np.zeros(5)
hoge[0]=1
np.linalg.norm(hoge)

1.0

In [69]:
np.linalg.norm(np.array([0.1,0.05,0,0,0,1]))

1.0062305898749053