In [None]:
from typing import Protocol
from contextlib import contextmanager
from functools import wraps

class Profiler:
    @contextmanager
    def profile(self, name: str):
        yield
    
    @property
    def results(self):
        return {}


def profile(_func=None, *, name=None):
    def _decorator(func):
        if name is None:
            name = func.__qualname__
        
        @wraps(func)
        def _wrapper(self, *args, **kwargs):
            with self.profiler.profile(name=name):
                return func(*args, **kwargs)
        
        return _wrapper
    
    return _decorator if _func is None else _decorator(_func)

In [None]:
from torch import Tensor
import torch.distributions as D
import torch.distributions.constraints as constraints

class _OneOf(constraints.Constraint):
    is_discrete = True
    
    def __init__(self, value_set: Tensor):
        super().__init__()
        self.value_set = value_set
    
    def check(self, value):
        return (value in self.value_set)

one_of = _OneOf

class Particle(D.Distribution):
    def __init__(self, supp: Tensor, index: D.Distribution):
        super().__init__(batch_shape=index.batch_shape, event_shape=index.event_shape)
        
        self.index = index
        self.supp = supp
    
    @constraints.dependent_property(is_discrete=True, event_dim=0)
    def support(self):
        return one_of(self.supp)
    
    

In [2]:
import torch

x = torch.empty(1, 3)
x.broadcast_to(2, *x.shape[1:])

tensor([[ 6.2937e+35,  4.5731e-41, -3.2200e+17],
        [ 6.2937e+35,  4.5731e-41, -3.2200e+17]])

In [1]:
from rsrch.rl.agent import RandomAgent
from rsrch.rl.wrappers import CollectEpisodes
from rsrch.rl.data.buffer import EpisodeBuffer
from rsrch.rl.data.transforms import *
from rsrch.rl.data.online import StepRollout
import gymnasium as gym

env = gym.make("CartPole-v1")
buf = EpisodeBuffer(int(1e3))
env = CollectEpisodes(env, buf)

ds = buf.map(Compose(
    Subsample(max_seq_len=32),
    ToTensorSeq(),
    PadTensorSeq(min_seq_len=32)
))

agent = RandomAgent(env)
play = StepRollout(env, agent, num_steps=int(1e4))

for _ in play:
    ...

In [13]:
ds[5]

TensorTrajectory(obs=tensor([[-7.4647e-03, -6.6075e-03, -1.7085e-02,  3.4957e-02],
        [-7.5968e-03, -2.0148e-01, -1.6386e-02,  3.2220e-01],
        [-1.1626e-02, -3.9637e-01, -9.9419e-03,  6.0967e-01],
        [-1.9554e-02, -2.0111e-01,  2.2515e-03,  3.1387e-01],
        [-2.3576e-02, -6.0158e-03,  8.5290e-03,  2.1902e-02],
        [-2.3696e-02, -2.0126e-01,  8.9670e-03,  3.1726e-01],
        [-2.7721e-02, -6.2659e-03,  1.5312e-02,  2.7422e-02],
        [-2.7847e-02, -2.0160e-01,  1.5861e-02,  3.2490e-01],
        [-3.1879e-02, -6.7115e-03,  2.2359e-02,  3.7257e-02],
        [-3.2013e-02,  1.8808e-01,  2.3104e-02, -2.4829e-01],
        [-2.8251e-02, -7.3614e-03,  1.8138e-02,  5.1591e-02],
        [-2.8399e-02,  1.8750e-01,  1.9170e-02, -2.3531e-01],
        [-2.4649e-02, -7.8947e-03,  1.4464e-02,  6.3353e-02],
        [-2.4807e-02,  1.8702e-01,  1.5731e-02, -2.2473e-01],
        [-2.1066e-02,  3.8191e-01,  1.1236e-02, -5.1241e-01],
        [-1.3428e-02,  1.8663e-01,  9.8776e-04, -

In [1]:
import torch

logits = torch.ones(10) * (-torch.inf)
logits[3] = 1.0

Ix = torch.distributions.Ca

tensor([-inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf])

In [41]:
import numpy as np

class SegmentArray:
    def __init__(self, min_size: int, dtype=np.float32):
        exp = int(np.ceil(np.log2(min_size)))
        self.size = 2 ** exp
        self.array = np.zeros(self.size, dtype=dtype)
        self.tree = np.zeros(self.size, dtype=dtype)
    
    def __getitem__(self, idx):
        return self.array[idx]
    
    def __setitem__(self, idx: np.ndarray, val: np.ndarray):
        if isinstance(idx, slice):
            idx = np.arange(idx.start, idx.stop, idx.step)
        else:
            idx = np.asarray(idx)
        val = np.asarray(val)
        
        self.array[idx] = val
        self._update_tree(idx, val, 0, self.buf[0], 0, self.size)
    
    def _update_tree(self, idxes, values, root, root_v, root_beg, root_end):
        print(locals())
        is_leaf = (root >= self.size - 1)
        holds_v = root == 0 or root % 2 != 0
        
        if is_leaf:
            if holds_v:
                self.tree[(root + 1) // 2] = values[0]
            return values[0] - root_v
        else:
            left, right = 2 * root + 1, 2 * root + 2
            left_v = self.tree[(left + 1) // 2]
            right_v = root_v - left_v
            pivot = (root_end + root_beg) // 2
            pivot_idx = np.searchsorted(idxes, pivot)
            
            if pivot_idx > 0:
                left_dv = self._update_tree(idxes[:pivot_idx], values[:pivot_idx],
                                            left, left_v, root_beg, pivot)
            else:
                left_dv = 0.0
                
            if pivot_idx < len(idxes):
                right_dv = self._update_tree(idxes[pivot_idx:], values[pivot_idx:],
                                            right, right_v, pivot, root_end)
            else:
                right_dv = 0.0
            
            dv = left_dv + right_dv
            if holds_v:
                self.tree[(root + 1) // 2] = root_v + dv
            
            return dv
    
    def _refresh(self):
        full_tree = np.empty(2 * self.size - 1, dtype=self.array.dtype)
        full_tree[self.size-1:] = self.array
        
        for root in reversed(range(2 * self.size - 1)):
            
        

In [46]:
x = np.array([0, 3, -1, 2])
np.cumsum(x)

array([0, 3, 2, 4])

In [43]:
s = SegmentArray(8)
s.update([0, 3, 4, 7], [0.5, -0.2, 0.4, 0.1])
s.numpy()

{'self': <__main__.SegmentArray object at 0x7fce10d85840>, 'idxes': array([0, 3, 4, 7]), 'values': array([ 0.5, -0.2,  0.4,  0.1]), 'root': 0, 'root_v': 0.0, 'root_beg': 0, 'root_end': 8}
{'self': <__main__.SegmentArray object at 0x7fce10d85840>, 'idxes': array([0, 3]), 'values': array([ 0.5, -0.2]), 'root': 1, 'root_v': 0.0, 'root_beg': 0, 'root_end': 4}
{'self': <__main__.SegmentArray object at 0x7fce10d85840>, 'idxes': array([0]), 'values': array([0.5]), 'root': 3, 'root_v': 0.0, 'root_beg': 0, 'root_end': 2}
{'self': <__main__.SegmentArray object at 0x7fce10d85840>, 'idxes': array([0]), 'values': array([0.5]), 'root': 7, 'root_v': 0.0, 'root_beg': 0, 'root_end': 1}
{'self': <__main__.SegmentArray object at 0x7fce10d85840>, 'idxes': array([3]), 'values': array([-0.2]), 'root': 4, 'root_v': 0.0, 'root_beg': 2, 'root_end': 4}
{'self': <__main__.SegmentArray object at 0x7fce10d85840>, 'idxes': array([3]), 'values': array([-0.2]), 'root': 10, 'root_v': 0.0, 'root_beg': 3, 'root_end': 4}

array([ 0.5       ,  0.        ,  0.        , -0.19999999,  0.4       ,
        0.        ,  0.        ,  0.09999999], dtype=float32)

In [63]:
class ProportionalSampler:
    def __init__(self, size: int, dtype=None, eps=1e-8):
        self.size = size
        self.eps = eps
        self._levels = int(np.ceil(np.log2(self.size)))
        self._cap = 2 ** self._levels
        self._off = self._cap - 1
        self.buf = np.zeros(2 * self._cap - 1, dtype=dtype)
    
    def __getitem__(self, idx):
        idx = self._resolve_idx(idx)
        return self.buf[idx]

    def _numpy(self):
        return self.buf[self._off:self._off+self.size]
    
    def __setitem__(self, idx, val):
        idx = self._resolve_idx(idx)
        for i, v in zip(idx, val):
            self.buf[i] = np.abs(v) + self.eps
            while i > 0:
                i = (i - 1) // 2
                self.buf[i] = self.buf[2*i+1] + self.buf[2*i+2]

    def _resolve_idx(self, idx):
        if isinstance(idx, slice):
            idx = range(self.size)[idx]
            idx = range(idx.start + self._off, idx.stop + self._off, idx.step)
        else:
            idx = np.asarray(idx) + self._off
        return idx
    
    def _sample(self, u: float):
        node = 0
        while node < self._off:
            left, right = 2*node+1, 2*node+2
            left_v = self.buf[left]
            if left_v <= u:
                node, u = left, u
            else:
                node, u = right, u - left_v
        
        idx = node - self._off
        prob = u / self.buf[0]
        return idx, prob
        
    def sample(self):
        u = np.random.rand() * self.buf[0]
        return self._sample(u)

    def stratified_sample(self, k: int):
        u_batch = (np.arange(k) + np.random.rand(k)) / k * self.buf[0]
        return [self._sample(u) for u in u_batch]