In [None]:
import pandas as pd
from pathlib import Path
import subprocess
import sqlite3
from tqdm.notebook import trange

import riiideducation

In [None]:
PATH = Path('../input/riiid-test-answer-prediction')

In [None]:
%%time
df_train = pd.read_csv(PATH/'train.csv', usecols=['row_id', 'user_id', 'task_container_id', 'timestamp'])

In [None]:
conn = sqlite3.connect(':memory:')
c = conn.cursor()

In [None]:
%%time
chunk_size = 20000
total = len(df_train)
n_chunks = (total // chunk_size + 1)

for i in trange(n_chunks):
    df_train.iloc[i * chunk_size:(i + 1) * chunk_size].to_sql('train', conn, method='multi', if_exists='append', index=False)

In [None]:
%%time
c.executescript("""
    DROP TABLE IF EXISTS user_tids_all;
    
    CREATE TABLE user_tids_all AS
        SELECT user_id, task_container_id, MAX(timestamp) timestamp
        FROM train
        GROUP BY user_id, task_container_id
        ORDER BY user_id;
        
    CREATE UNIQUE INDEX user_id_tid_idx ON user_tids_all (user_id, task_container_id);
""").fetchone()

In [None]:
c.execute('select count(*) from user_tids_all').fetchone()

In [None]:
delta_count_train = c.execute("""
    WITH numbered_records AS (
        SELECT user_id, CAST(task_container_id - LAG(task_container_id) OVER(
            PARTITION BY user_id ORDER BY timestamp
        ) > 1 AS INTEGER) delta_count
        FROM user_tids_all
    )
    SELECT SUM(delta_count)
    FROM numbered_records
""").fetchone()[0]

In [None]:
print(f'{delta_count_train:0,d} deltas greater than one in the full training set.')

In [None]:
%%time
c.executescript("""
    DROP TABLE IF EXISTS user_tids;
    
    CREATE TABLE user_tids AS
        SELECT user_id, MAX(task_container_id) tid_current, 0 tid_delta_count
        FROM train
        GROUP BY user_id
        ORDER BY user_id;
        
    CREATE UNIQUE INDEX user_id_idx ON user_tids (user_id);
""").fetchone()

In [None]:
riiideducation.competition.make_env
env = riiideducation.make_env()
iter_test = env.iter_test()

In [None]:
for i, test_batch in enumerate(iter_test):
    records = test_batch[0][['user_id', 'task_container_id']].to_records(index=False)

    c.executescript(f"""
        INSERT INTO user_tids (user_id, tid_current)
            VALUES {(',').join(map(str, records))}
        ON CONFLICT (user_id) DO UPDATE SET
            tid_delta_count = CAST(excluded.tid_current - tid_current > 1 AS INTEGER) + tid_delta_count,
            tid_current = excluded.tid_current
    """).fetchone()

    delta_count = c.execute("""
        SELECT SUM(tid_delta_count)
        FROM user_tids
        WHERE tid_delta_count > 1
        """).fetchone()[0]
    
    delta_count = 0 if delta_count is None else delta_count
    
    if delta_count > int(1e5):
        raise
    
    if not i % 1000 or i < 5:
        print(f'{delta_count:0,d} deltas greater than one after {i+1:0,d} batches.')
    
    env.predict(test_batch[1][test_batch[0].content_type_id == 0])