In [None]:
import pandas as pd
from pathlib import Path
import sqlite3
import riiideducation
from sklearn.metrics import roc_auc_score
from tqdm.notebook import tqdm

In [None]:
class PredictEnv:
    def __init__(self, folds_path, folds):
        self.conn = sqlite3.connect(':memory:')
        self.c = self.conn.cursor()
        self.setup_folds(folds_path, folds)

    def setup_folds(self, folds_path, folds):        
        self.c.executescript(f"""
            ATTACH DATABASE "{folds_path}" AS folds_db;

            DROP TABLE IF EXISTS b_records;

            CREATE TABLE b_records AS
            SELECT row_id, timestamp, user_id, content_id, content_type_id, task_container_id, prior_question_elapsed_time,
                prior_question_had_explanation, answered_correctly, user_answer
            FROM folds_db.train
            WHERE fold in ({(', ').join(list(map(str,folds)))})
            ORDER BY user_id, task_container_id, row_id;

            CREATE INDEX user_id_task_container_id_index ON b_records (user_id, task_container_id);

            DROP TABLE IF EXISTS b_users;

            CREATE TABLE b_users AS
            SELECT user_id, MIN(task_container_id) - 1 task_container_id_next, MAX(task_container_id) task_container_id_max
            FROM b_records
            GROUP BY user_id
                ORDER BY user_id, task_container_id_next;

            CREATE UNIQUE INDEX user_id_index ON b_users (user_id);

            ALTER TABLE b_users
                ADD COLUMN group_num INTEGER;

        """).fetchone()

        self.group_num = 0
        self.records_remaining = self.c.execute('SELECT COUNT(*) FROM b_records').fetchone()[0]
        self.df_users = pd.read_sql('SELECT * FROM b_users', self.conn)


    def iter_test(self):
        next_correct = '[]'
        next_responses = '[]'

        while self.records_remaining:
            self.c.execute(f"""
                INSERT INTO b_users (user_id)
                SELECT user_id
                FROM b_users
                WHERE task_container_id_next <= task_container_id_max
                LIMIT 1 + ABS(RANDOM() % 40) + ABS(RANDOM() % 1000) * (ABS(RANDOM() % 100) < 5)
                ON CONFLICT (user_id) DO UPDATE SET
                    task_container_id_next = task_container_id_next + 1,
                    group_num = {self.group_num};
            """).fetchone()
            
            self.conn.commit()

            df_b = pd.read_sql(f"""
                SELECT r.*
                FROM b_records r
                JOIN b_users u
                ON group_num = {self.group_num}
                    AND r.user_id = u.user_id
                    AND r.task_container_id = u.task_container_id_next
            """, self.conn)

            if len(df_b):
                df_b['group_num'] = self.group_num
                df_b['prior_group_answers_correct'] = None
                df_b.at[0, 'prior_group_answers_correct'] = next_correct

                df_b['prior_group_responses'] = None
                df_b.at[0, 'prior_group_responses'] = next_responses

                next_correct = f'[{(", ").join(df_b.answered_correctly.astype(str))}]'
                next_responses = f'[{(", ").join(df_b.user_answer.astype(str))}]'
                del df_b['answered_correctly']
                del df_b['user_answer']

                df_b = df_b.set_index('group_num')

                df_p = df_b[['row_id']].copy()
                df_p['answered_correctly'] = 0.5
                
                self.records_remaining -= len(df_b)

                yield df_b, df_p
            
            self.group_num += 1

    def predict(self, df_pred):
        if (df_pred.answered_correctly == -1).any():
            raise
        else:
            df_pred.reset_index().to_sql('predictions', self.conn, if_exists='append', index=False)

    def get_predictions(self):
        df_preds = pd.read_sql("""
            SELECT p.row_id, b.answered_correctly y_true, p.answered_correctly y_pred
            FROM predictions p
            JOIN b_records b
            ON p.row_id = b.row_id
        """, self.conn)

        self.score = roc_auc_score(df_preds.y_true, df_preds.y_pred)

        print(f'ROC AUC Score: {self.score:0.4f}')

        return df_preds

Folds specified in [RIIID: Folding is Fun!](https://www.kaggle.com/calebeverett/riiid-folding-is-fun).

In [None]:
mock = True

if mock:
    FOLDS = Path('../input/riiid-folds/riiid.db')
    env = PredictEnv(FOLDS, [0, 1])
    iter_test = env.iter_test()

else:
    env = riiideducation.make_env()
    iter_test = env.iter_test()

In [None]:
%%time
if mock:
    pbar = tqdm(total=env.records_remaining)

df_batch_prior = None

for test_batch in iter_test:
    
    if df_batch_prior is not None:
        answers = eval(test_batch[0]['prior_group_answers_correct'].iloc[0])
        df_batch_prior['answered_correctly'] = answers

    df_batch_prior = test_batch[0]
    
    df_batch = test_batch[0][test_batch[0].content_type_id == 0].copy()
    if len(df_batch):
        df_batch['answered_correctly'] = 0.65
        env.predict(df_batch[['row_id', 'answered_correctly']])
        
    if mock:
        pbar.update(len(df_batch_prior))

In [None]:
if mock:
    df_pred = env.get_predictions()
    df_pred.head()