In [1]:
import gymnasium as gym
import pygame
import numpy as np
from PIL import Image

import snntorch as snn
from snntorch import spikegen, surrogate
from snntorch import utils as snnutils

import torch
from torch import nn
from torch import utils as torchtils

from typing import TypeVar, Union, List, Callable
from torchtyping import TensorType

from itertools import count, product
from collections import deque

from tqdm.auto import tqdm

torch.set_default_device(['cpu', 'cuda'][torch.cuda.is_available()])

pygame.init()
DISPLAYURF = pygame.display.set_mode((500,500),0,32)
clock = pygame.time.Clock()
pygame.display.flip()

env = gym.make('CartPole-v1', render_mode="rgb_array")

In [2]:
env.observation_space

Box([-4.8000002e+00 -3.4028235e+38 -4.1887903e-01 -3.4028235e+38], [4.8000002e+00 3.4028235e+38 4.1887903e-01 3.4028235e+38], (4,), float32)

In [3]:
T_EligibleLeakyLinear = TypeVar("T_EligibleLeakyLinear", bound="EligibleLeakyLinear")
class EligibleLeakyLinear(nn.Module):
    def __init__(self, in_features:int, out_features:int,
                 R=5, C=1e-3, time_step=1e-4, threshold=1.0, spike_grad=None,
                 decay_rate:float=1.):
        super().__init__()
        if spike_grad == None:
            spike_grad = surrogate.fast_sigmoid()
        
        self.decay_rate = decay_rate
        self.beta = 1 - time_step/(R*C)
        
        # self.linear = nn.Linear(in_features, out_features, bias=False) # weight.shape == output_features by in_features
        # self.linear:TensorType[...] = (torch.rand((out_features, in_features))+1)/(2*(in_features**.5))
        self.linear:TensorType[...] = (torch.rand((out_features, in_features))+2)/(2)
        # self.leaky = snn.Leaky(beta, threshold, spike_grad, init_hidden=True, output=True)
        # self.leaky = snn.Leaky(beta, threshold, surrogate_disable=True, output=True)
        self.leaky = snn.Lapicque(R=R, C=C, time_step=time_step, threshold=threshold)
        # self.eligibility = torch.zeros_like((self.linear.state_dict()["weight"]))
        self.reset()
    
    def forward(self, presynaptic_spk:TensorType[...]):
        spk, self.mem = self.leaky(self.linear @ presynaptic_spk, self.mem)
        # print(self.mem)
        return spk

    def reset(self):
        self.eligibility = torch.zeros_like(self.linear)
        self.mem = torch.rand(self.linear.shape[0])
    
    

In [4]:
T_Brain = TypeVar("T_Brain", bound="Brain")
class Brain:
    @classmethod
    def remap_dopamine_impulse(self, x):
        return 2*np.arctan(x)/np.pi
    
    def __init__(self, alpha=1e-5, R:float=5, C:float=1e-3, 
                 time_step:float=1e-4, threshold:float=1.0, spike_grad:Union[None, Callable]=None,
                 feature_list:List=[4,32,2], obs_duration:int=8, max_queue_size:Union[int,None]=None):
        if spike_grad == None:
            spike_grad = surrogate.fast_sigmoid()
            
        self.alpha, self.threshold, self.spike_grad, self.feature_list, self.obs_duration, self.max_queue_size = \
            alpha, threshold, spike_grad, feature_list, obs_duration, max_queue_size
        self.R, self.C, self.time_step = R, C, time_step
        self.LTP_coeff, self.LTD_coeff = 1, 1.5
        
        self.lif_linears:List[T_EligibleLeakyLinear] = []
        for in_features, out_features in zip(feature_list[:-1], feature_list[1:]):
            self.lif_linears.append(
                EligibleLeakyLinear(in_features, out_features,
                                    R=R,    C=C,    time_step=time_step,
                                    threshold=threshold, spike_grad=spike_grad)
                )
        
        self.reset()
    
    def get_beta(self): return 1-self.time_step/self.R/self.C
    
    def step(self, x:TensorType["observation_space"], dopamine_impulse:float=0., attention=False) -> TensorType["action_space"]:
        assert len(self.lif_linears) + 1 == len(self.last_spikes)
        
        next_spikes:List[TensorType[...]] = [x] # next_spike[0] will be provided in the next step.
        for i, lif_linear in enumerate(self.lif_linears):
            next_spikes.append(lif_linear(self.last_spikes[i]))
        
        for l, lif_linear in enumerate(self.lif_linears):
            lif_linear.eligibility = self.get_beta() * lif_linear.eligibility \
                + self.alpha * self.LTP_coeff * torch.einsum("i,j -> ij", next_spikes[l+1], self.last_spikes[l]) \
                - self.alpha * self.LTD_coeff * torch.einsum("i,j -> ij", self.last_spikes[l+1], next_spikes[l])
                # + self.LTP_coeff * torch.einsum("i,j -> ij", next_spikes[l+1], next_spikes[l])
        
        self.dopamine = self.dopamine * self.get_beta() + dopamine_impulse
        # print(f"DA: {self.dopamine}")
        for lif_linear in self.lif_linears:
            # print(f"delta weight : {(lif_linear.eligibility**2).mean() * self.dopamine}")
            lif_linear.linear = torch.sigmoid(
                torch.logit(lif_linear.linear * lif_linear.decay_rate, eps=1e-6)\
                    + lif_linear.eligibility * self.dopamine
                )
            # assert (lif_linear.linear<0).sum() == 0
            
        self.last_spikes = next_spikes
        
        assert next_spikes[-1].shape == (2,)
        self.spike_history.append(next_spikes[-1])
        return next_spikes[-1]
    
    def get_max_steps(self):
        return self.obs_duration + len(self.feature_list)
    
    def encode(self, x:np.ndarray) -> TensorType["num_steps", "observation_space"]:
        pos_prob = lambda pos: pos / 9.6 + .5
        v_prob = lambda v: np.arctan(v)/np.pi + .5
        angle_prob = lambda theta: theta/.836 + .5
        angle_v_prob = lambda omega: np.arctan(omega)/np.pi + .5 
        assert len(x) == 4
        prob_x = pos_prob(x[0]), v_prob(x[1]), angle_prob(x[2]), angle_v_prob(x[3])
        prob_x = torch.tensor(prob_x, dtype=torch.float32)
        return spikegen.rate(prob_x, num_steps=self.get_max_steps()) #last # of layers steps will not passed to output layer. 
    
    def get_action(self) -> TensorType[int]:
        action = torch.zeros((self.feature_list[-1]))
        while len(self.spike_history) != 0:
            # print(spike:=self.spike_history.popleft())
            action = action * (1-1/len(self.feature_list)) + self.spike_history.popleft()
        #     print(f"{action}", end=", ")
        # print()
        # if action[0]!=action[1]:
        #     print(action)
        return action.argmax()
    
    def observe2action(self, x:np.ndarray, dopamine_impulse:float) -> int:
        encoded_x = self.encode(x)
        # print(encoded_x.mean(dim=0))
        self.step(encoded_x[0], Brain.remap_dopamine_impulse(dopamine_impulse))
        for step in range(1, encoded_x.shape[0]):
            self.step(encoded_x[step])
        action = self.get_action().item()
        return action

    def reset(self):
        self.last_spikes:List[TensorType[...]] = [torch.zeros(n_features) for n_features in self.feature_list]
        self.spike_history = deque([],
            maxlen= self.obs_duration if self.max_queue_size==None else min(self.obs_duration, self.max_queue_size)
            )
        self.dopamine = 0.
        for lif_linear in self.lif_linears:
            lif_linear.reset()
        print("\r",(self.lif_linears[0].linear**2).sum(), end="")

In [5]:
def test(**kwargs):
    brain = Brain(**kwargs)
    max_t = -float("inf")
    for i_episode in (pbar:=tqdm(range(1000))):
        observation, _ = env.reset()
        reward, done = 0, False
        brain.reset()
        accm_reward = 0
        for t in range(100):
            image = env.render()
            # action = env.action_space.sample()
            action = brain.observe2action(observation, reward)
            if done:
                pbar.desc = f"Timestep {kwargs.get('time_step', None)} Episode {i_episode} finished after {t+1} timesteps, with reward {accm_reward}."
                break
            
            observation, reward, done, *info = env.step(action)
            reward = max_t < t - done
            accm_reward += reward
            
            # image = Image.fromarray(image, "RGB")
            # mode, size, data = image.mode, image.size, image.tobytes()
            # pygame.event.get()
            # image = pygame.image.fromstring(data, size, mode)
            
            # DISPLAYURF.blit(image, (0,0))
            ## print(f"Episode {i_episode}, Step {t}, Reward {reward}")
            # pygame.display.update()
            # clock.tick(100)
        max_t = max(max_t, t)
        if i_episode==0:
            t_mean = t
        else:
            t_mean = 0.99*t_mean + 0.05*t
    return t_mean
        

In [6]:
result = []
# time_steps = [*np.arange(0.01, 0.1, 0.01)]
options = {"time_step":[8e-5], "alpha":[10**(-x) for x in range(1,6)]}
# time_steps = [*np.arange(8e-5,13e-5,1e-5)]
options_prod = (dict(zip(options.keys(), x)) for x in product(*options.values()))
for kwargs in options_prod:
    result.append(test(**kwargs))

 tensor(203.9799, device='cuda:0')

  0%|          | 0/1000 [00:00<?, ?it/s]

 tensor(42.8917, device='cuda:0'))

In [None]:
env.close()
pygame.quit()

In [None]:
result

[13.157577955045467]