In [None]:
import gymnasium as gym
import pygame
import numpy as np
from PIL import Image

import snntorch as snn
from snntorch import spikegen, surrogate
from snntorch import utils as snnutils

import torch
from torch import nn
from torch import utils as torchtils
from torch.utils.tensorboard import SummaryWriter

from typing import TypeVar, Union, List, Callable, Any, Deque
from torchtyping import TensorType

from itertools import count, product
from collections import deque

from tqdm.auto import tqdm

import logging
from time import strftime, time, localtime

import argparse

torch.set_default_device(['cpu', 'cuda'][torch.cuda.is_available()])
writer:SummaryWriter
counter = count()

torch.manual_seed(42)

In [None]:
date = strftime("%m-%d-%H-%M", localtime())
logging.basicConfig(format='%(asctime)s %(levelname)s:%(message)s', filename="last.log", level=logging.INFO, datefmt='%m/%d/%Y %I:%M:%S %p')
logger = logging.getLogger()

In [None]:
# pygame.init()
# DISPLAYURF = pygame.display.set_mode((500,500),0,32)
# clock = pygame.time.Clock()
# pygame.display.flip()

# env = gym.make('CartPole-v1', render_mode="rgb_array")
env = gym.make('CartPole-v1')
env.observation_space

In [None]:
step = [0]
obs_history = []

In [None]:
T_EligibleLeakyLinear = TypeVar("T_EligibleLeakyLinear", bound="EligibleLeakyLinear")
class EligibleLeakyLinear(nn.Module):
    def __init__(self, in_features:int, out_features:int,
                 R=5, C=1e-2, time_step=1e-3, threshold=1.0, spike_grad=None):
        super().__init__()
        if spike_grad == None:
            spike_grad = surrogate.fast_sigmoid()
        
        self.time_step = time_step
        self.R, self.C = R, C
        self.in_features, self.out_features = in_features, out_features
        # self.linear_max = 1/in_features**.5
        self.linear_max = 5
        
        beta = torch.distributions.Beta(2, 2)
        # self.linear:TensorType[...] = torch.rand((out_features, in_features))*(out_features/in_features)
        self.linear:TensorType[...] = beta.sample((out_features, in_features)) * self.linear_max
        self.leaky = snn.Lapicque(R=R, C=C, time_step=time_step, threshold=threshold,
                                  reset_mechanism="zero")
        self.reset()
    
    @property
    def beta(self):
        return self.time_step/self.R/self.C
    
    def forward(self, presynaptic_spk:TensorType[...]):
        spk, self.mem = self.leaky(self.linear @ presynaptic_spk, self.mem)
        # print(self.mem)
        return spk

    def reset(self):
        self.f_eligibility = torch.zeros_like(self.linear)
        self.b_eligibility = torch.zeros_like(self.linear)
        self.mem = torch.full(self.linear.shape[:1], 0)
    
    

In [None]:
T_Brain = TypeVar("T_Brain", bound="Brain")
class Brain:
    MIN_LIF_TIMESTEP = 1e-3
    def __init__(self, lr=1e-4, R:float=5, C:float=1e-2, 
                 time_step:float=1e-3, threshold:float=1.0, spike_grad:Union[None, Callable]=None,
                 feature_list:List=[4, 128, 2], elg_coeff=5e-1):
        if spike_grad == None:
            spike_grad = surrogate.fast_sigmoid()
            
        self.lr, self.threshold, self.spike_grad, self.feature_list, self.elg_coeff= \
            lr, threshold, spike_grad, feature_list, elg_coeff
        self.R, self.C, self.time_step = R, C, time_step
        assert self.beta < 1
        self.LTP_coeff, self.LTD_coeff = 1., 1.5
        
        self.lif_linears:List[T_EligibleLeakyLinear] = []
        for in_features, out_features in zip(feature_list[:-1], feature_list[1:]):
            self.lif_linears.append(
                EligibleLeakyLinear(in_features, out_features,
                                    R=R,    C=C,    time_step=Brain.MIN_LIF_TIMESTEP,
                                    threshold=threshold, spike_grad=spike_grad)
                )
        self.reset()
    
    def __call__(self, *args: Any, **kwds: Any) -> Any:
        return self.observe2action(*args, **kwds)
    
    @property
    def num_steps(self):
        return int(self.time_step/Brain.MIN_LIF_TIMESTEP)
    
    @property
    def beta(self):
        return (Brain.MIN_LIF_TIMESTEP/self.R/self.C/self.num_steps)**self.elg_coeff # dt/tau ~= 1/5
    
    def step(self, x:TensorType["observation_space"], dopamine_impulse:float=0., step_idx:int=0) -> TensorType["action_space"]:
        assert len(self.lif_linears) + 1 == len(self.spk_history[0])
        
        step[0] = next(counter)
        if step[-1] == 300:
            assert False
        # if step[0] == 100: assert False
        
        next_spikes:List[TensorType[...]] = [x] # next_spike[0] will be provided in the next step.
        for i, lif_linear in enumerate(self.lif_linears):
            next_spikes.append(lif_linear(self.spk_history[-1][i]))
        
        for l, lif_linear in enumerate(self.lif_linears):
            f_eligible_presynaptic = torch.zeros_like(self.spk_history[0][l])
            b_eligible_postsynaptic = torch.zeros_like(self.spk_history[0][l+1])
            for prev_spks in self.spk_history:
                f_eligible_presynaptic = f_eligible_presynaptic * 0.5 + prev_spks[l]
                b_eligible_postsynaptic = b_eligible_postsynaptic * 0.75 + prev_spks[l+1]
            forward_elg = torch.outer(next_spikes[l+1], f_eligible_presynaptic)
            backward_elg = torch.outer(b_eligible_postsynaptic, next_spikes[l])
            # forward_elg = torch.outer(next_spikes[l+1], self.spk_history[l])
            # backward_elg = torch.outer(self.spk_history[l+1], next_spikes[l])
            lif_linear.f_eligibility = lif_linear.f_eligibility * self.beta + forward_elg
            lif_linear.b_eligibility = lif_linear.b_eligibility * self.beta + backward_elg
            # lif_linear.f_eligibility = torch.minimum(lif_linear.f_eligibility * self.beta + forward_elg, torch.tensor(1))
            # lif_linear.b_eligibility = torch.minimum(lif_linear.b_eligibility * self.beta + backward_elg, torch.tensor(1))
        # if step[-1]%100 == 0:
            # writer.add_image("Linear-Out-Eligibility", self.lif_linears[0].eligibility[None], step[-1])
            # writer.add_image("Linear-Out", self.lif_linears[-1].linear[None], step[-1])
            # writer.add_image("forward_elg", forward_elg[None]/torch.tensor(.1/self.time_step), step[-1])
            # writer.add_image("backward_elg", backward_elg[None], step[-1])
        # writer.add_scalar("pre_ELG-F-Norm", forward_elg.norm(), step[-1])
        # writer.add_scalar("ELG-F-Norm", self.lif_linears[-1].f_eligibility.norm(), step[-1])
        writer.add_scalar("Elg[0][0,0]", self.lif_linears[0].f_eligibility[0,0], step[-1])
        writer.add_scalar("Linear[0][0,0]", self.lif_linears[0].linear[0,0], step[-1])
        writer.add_scalar("Linear[0]Mem[0]", self.lif_linears[0].mem[0], step[-1])
        for lif_linear in self.lif_linears:
            if dopamine_impulse > 0:
                lif_linear.linear += self.lr * dopamine_impulse * lif_linear.f_eligibility\
                                            * (1 - lif_linear.linear/lif_linear.linear_max)
            else:
                lif_linear.linear += self.lr * dopamine_impulse * lif_linear.b_eligibility\
                                            * (- lif_linear.linear/lif_linear.linear_max)
                # f_elg_mask = lif_linear.f_eligibility>0
                # b_elg_mask = lif_linear.b_eligibility>0
                # delta = (self.dopamine > 0) - lif_linear.linear/(lif_linear.out_features/lif_linear.in_features)
                # delta *= self.lr * self.dopamine
                # lif_linear.linear += torch.where(f_elg_mask, delta * self.LTP_coeff, 0)
                # lif_linear.linear -= torch.where(b_elg_mask, delta * self.LTD_coeff, 0)
                # lif_linear.linear -= self.LTD_coeff * self.lr * dopamine_impulse * lif_linear.b_eligibility\
                #                              * lif_linear.linear/lif_linear.in_features
                
                # lif_linear.linear = torch.sigmoid(
                #     torch.logit(lif_linear.linear, eps=1e-8) + self.lr * delta# sigmoid(logit(x)+1-1)
                # )
        
        logger.info(str(next_spikes[-1]))
        if step_idx + 1 >= self.num_steps:
            self.spike_avarage += next_spikes[-1]
        # writer.add_histogram("action[0]", next_spikes[-1][0], step[-1])
        # writer.add_histogram("action[1]", next_spikes[-1][1], step[-1])
        writer.add_scalar("Num-of-Spk", sum(spk.sum() for spk in next_spikes[1:]), step[-1])
        self.spk_history.append(next_spikes)
        return next_spikes[-1]
    
    @property
    def max_steps(self):
        return self.num_steps + len(self.feature_list)
    
    def encode(self, x:np.ndarray) -> TensorType["num_steps", "observation_space"]:
        pos_prob = lambda pos: pos / 9.6 + .5
        v_prob = lambda v: torch.sigmoid(torch.tensor(v)).item()
        angle_prob = lambda theta: theta/.836 + .5
        angle_v_prob = lambda omega: torch.sigmoid(torch.tensor(omega)).item()
        assert len(x) == 4
        prob_x = pos_prob(x[0]), v_prob(x[1]), angle_prob(x[2]), angle_v_prob(x[3])
        # result = []
        # for p in prob_x:
        #     result.append(
        #         spikegen.target_rate_code(rate=p, num_steps=self.max_steps, firing_pattern="regular")[0]
        #     )
        # return torch.stack(result, dim=1)
        prob_x = torch.tensor(prob_x, dtype=torch.float32)
        return spikegen.rate(prob_x, num_steps=self.max_steps) #last # of layers steps will not passed to output layer. 
    
    def get_action(self) -> TensorType[int]:
        return self.spike_avarage.argmax()
        # return self.spk_history[-1][-1].argmax()
    
    def observe2action(self, x:np.ndarray, dopamine_impulse:float) -> int:
        encoded_x = self.encode(x)
        obs_history.append(encoded_x)
        self.step(encoded_x[0], Brain.remap_dopamine_impulse(dopamine_impulse), 0)
        for step in range(1, encoded_x.shape[0]):
            self.step(encoded_x[step], step_idx=step)
        action = self.get_action().item()
        return action

    def reset(self):
        self.spk_history:Deque[List[TensorType[...]]] = deque([[torch.zeros(n_features) for n_features in self.feature_list]], maxlen=7)
        self.spike_avarage = torch.zeros(self.feature_list[-1])
        self.dopamine = 0.
        for lif_linear in self.lif_linears:
            lif_linear.reset()
        writer.add_scalar("F-Norm", self.lif_linears[0].linear.norm(), step[-1])
    
    @staticmethod
    def remap_dopamine_impulse(x):
        """dopamine in [-1, 1]."""
        # return 2*np.arctan(x)/np.pi
        return min(max(x, 0), 1)
    
    @staticmethod
    def inv_softplus(x):
        return torch.log(torch.exp(x) - 1)

In [None]:
def test(alpha = 0.99, **kwargs):
    brain = Brain(**kwargs)
    max_t = -float("inf")
    # current_reward_multiplier = 1
    # current_punishment_multiplier = 1
    for i_episode in (pbar:=tqdm(range(1000))):
        observation, _ = env.reset()
        brain.reset()
        reward, done = 0, False
        accm_reward = 0
        for t in range(100):
            # image = env.render()
            # action = env.action_space.sample()
            action = brain(observation, reward)
            if done:
                pbar.desc = f"time_step {kwargs.get('time_step', None)} Episode {i_episode} finished after {t} time_steps, with reward {accm_reward}."
                break
            
            observation, reward, done, *info = env.step(action)
            # reward = 1 + (max_t <= t) - 3*done
            # reward = max(t-8, 0) - 1*done
            reward = .1 if not done else -1
            accm_reward += reward
            
            # image = Image.fromarray(image, "RGB")
            # mode, size, data = image.mode, image.size, image.tobytes()
            # pygame.event.get()
            # image = pygame.image.fromstring(data, size, mode)
            
            # DISPLAYURF.blit(image, (0,0))
            ## print(f"Episode {i_episode}, Step {t}, Reward {reward}")
            # pygame.display.update()
            # clock.tick(100)
        max_t = max(max_t, t)
        if i_episode==0:
            t_mean = t
        else:
            t_mean = alpha*t_mean + (1-alpha)*t
        writer.add_histogram("Step", t, i_episode)
        writer.add_scalar("Max-Step", max_t, i_episode)
        writer.add_scalar("Episode-Reward", accm_reward, i_episode)
    logging.info(f"test about f{kwargs}: mean {t_mean}")
    
    return t_mean
        

In [None]:
parser = argparse.ArgumentParser()
parser.add_argument("--time_step", action="store", type=float, required=True)
parser.add_argument("--elg_coeff", action="store", type=float, required=True)
parser.add_argument("--alpha", action="store", default=1e-4, type=float, required=False)

In [None]:
try:
    options = vars(parser.parse_args())
except:
    options = {"time_step":1e-1, "alpha":1e-2, "lr":1e-1, "elg_coeff":1e-2}
writer = SummaryWriter(max_queue=100000, comment=str(options))
print(options)
logger.info(options)
test(**options)
    # options = {"time_step":[*np.arange(1e-5, 13e-5, 1e-5)], "alpha":[1e-5, 3e-5]}
    # options = {"time_step":[*np.arange(1e-5, 13e-5, 1e-5)], "alpha":[1e-4, 3e-4]}
    # options = {"time_step":[*np.arange(1e-5, 13e-5, 1e-5)], "alpha":[1e-3, 3e-3]}
    # options = {"time_step":[*np.arange(1e-5, 13e-5, 1e-5)], "alpha":[1e-2, 3e-2]}
    # time_steps = [*np.arange(8e-5,13e-5,1e-5)]

    # options = (dict(zip(options.keys(), x)) for x in product(*options.values()))
    # for kwargs in options:
    #     test(**kwargs)

In [None]:
writer.close()
env.close()
pygame.quit()

In [None]:
# from snntorch import spikeplot as splt
# from matplotlib import pyplot as plt

In [None]:
# fig = plt.figure()
# fig.set_size_inches(15, 4)
# ax = fig.add_subplot(111)
# total_history = torch.concat(obs_history[:4], dim=0)
# splt.raster(total_history, ax, s=.5, alpha=1)
# ax.plot()
# fig.show()