In [1]:
%load_ext autoreload
%autoreload 2

In [162]:
from common import (
    get_starting_df,
    find_sessions,
    add_regression_target,
    calc_y,
    add_fold,
)
import polars as pl
import gymnasium as gym
from IPython import display

In [10]:
df = (
    get_starting_df()
    .pipe(find_sessions, 60 * 30)
    .pipe(add_regression_target)
)

y = calc_y(df)

In [11]:
df.pipe(add_fold, y).write_parquet("with_fold.parquest")

In [12]:
del df

In [13]:
df = pl.read_parquet("with_fold.parquest")

In [15]:
df.shape

(99999, 9)

In [16]:
df['session'].max()

2289

In [17]:
df['user'].max()

943

In [18]:
df['fold'].value_counts()

fold,count
i32,u32
3,25004
0,25023
1,24986
2,24986


In [19]:
y.value_counts()

disengage,count
bool,u32
True,32180
False,67819


In [194]:
class CitizenScienceEnv(gym.Env):
    def __init__(self, df, seed: int = 1):
        super().reset(seed=seed)
        self.df = (
            df
            .with_columns(time_passed=(pl.col('timestamp') - pl.col('timestamp').min().over('session')).dt.total_seconds())
        )
        self.sessions = self.df.unique('session')
        self.state_space = gym.spaces.Box(0, 6000, dtype=int)
        self.action_space = gym.spaces.Discrete(2) # 0 will not disengage, 1 will disengage
        self.session_entries = None
        self.entry = None

    def reset(self):
        session = self.sessions.sample()['session'].item()
        self.session_entries = self.df.filter(pl.col('session') == session).sort('timestamp')
        self.entry = 0
        return self._obs()

    def _obs(self):
        return self.session_entries[self.entry, 'timestamp'].timestamp(), {}
            
    def step(self, action):
        if action == 1:
            reward = 0
            terminated = True
            self.entry = None
            self.session_entries = None
        elif action == 0:
            if self.entry < self.session_entries.shape[0] - 1:
                reward = self.session_entries[self.entry, 'time_passed']
                self.entry += 1
                terminated = False
            else:
                reward = -5000
                terminated = True
                # self.entry = None
                # self.session_entries = None
        else:
            assert False
        obs, info = self._obs()
        truncated = False
        return obs, reward, terminated, truncated, info

    def render(self):
        display.display(self.session_entries.to_pandas().style.apply(
            lambda col: [('background-color: yellow' if row == self.entry else '') for row, _ in enumerate(col)]
        ))

In [195]:
env = CitizenScienceEnv(df.filter(pl.col('fold') != 0))

In [196]:
env.reset()

(1420354030.0, {})

In [197]:
env.render()

Unnamed: 0,user_id,timestamp,ts_diff,user_diff,new_session_mark,session,user,target,fold,time_passed
0,54a998989cd118469400024d,2015-01-04 07:47:10,False,True,True,2253,915,782,3,0
1,54a998989cd118469400024d,2015-01-04 07:49:09,False,False,False,2253,915,663,3,119
2,54a998989cd118469400024d,2015-01-04 07:49:24,False,False,False,2253,915,648,3,134
3,54a998989cd118469400024d,2015-01-04 07:49:25,False,False,False,2253,915,647,3,135
4,54a998989cd118469400024d,2015-01-04 07:49:28,False,False,False,2253,915,644,3,138
5,54a998989cd118469400024d,2015-01-04 07:49:58,False,False,False,2253,915,614,3,168
6,54a998989cd118469400024d,2015-01-04 07:50:19,False,False,False,2253,915,593,3,189
7,54a998989cd118469400024d,2015-01-04 07:50:25,False,False,False,2253,915,587,3,195
8,54a998989cd118469400024d,2015-01-04 07:50:34,False,False,False,2253,915,578,3,204
9,54a998989cd118469400024d,2015-01-04 07:50:41,False,False,False,2253,915,571,3,211


In [198]:
env.step(0)

(1420354149.0, 0, False, False, {})

In [199]:
env.render()

Unnamed: 0,user_id,timestamp,ts_diff,user_diff,new_session_mark,session,user,target,fold,time_passed
0,54a998989cd118469400024d,2015-01-04 07:47:10,False,True,True,2253,915,782,3,0
1,54a998989cd118469400024d,2015-01-04 07:49:09,False,False,False,2253,915,663,3,119
2,54a998989cd118469400024d,2015-01-04 07:49:24,False,False,False,2253,915,648,3,134
3,54a998989cd118469400024d,2015-01-04 07:49:25,False,False,False,2253,915,647,3,135
4,54a998989cd118469400024d,2015-01-04 07:49:28,False,False,False,2253,915,644,3,138
5,54a998989cd118469400024d,2015-01-04 07:49:58,False,False,False,2253,915,614,3,168
6,54a998989cd118469400024d,2015-01-04 07:50:19,False,False,False,2253,915,593,3,189
7,54a998989cd118469400024d,2015-01-04 07:50:25,False,False,False,2253,915,587,3,195
8,54a998989cd118469400024d,2015-01-04 07:50:34,False,False,False,2253,915,578,3,204
9,54a998989cd118469400024d,2015-01-04 07:50:41,False,False,False,2253,915,571,3,211


In [200]:
env.step(0)

(1420354164.0, 119, False, False, {})