### Transformer Encoder Model

Main Idea: 
* model student question performance in a sequential manner...use past information to predict future 
* transformer model -> similar to NLP translation models (eg. BERT) 

In [1]:
# Import libraries 
import pandas as pd 
import numpy as np 
import re  
import seaborn as sns 
import matplotlib.pyplot as plt
import os
import datetime as dt

In [2]:
# %% load data
data_dir = r'C:\Users\Timothy Lim\Desktop\junyi\archive'
info_content = pd.read_csv(os.path.join(data_dir, 'Info_Content.csv'))
info_userdata = pd.read_csv(os.path.join(data_dir, 'Info_UserData.csv'))
log_problem = pd.read_csv(os.path.join(data_dir, 'Log_Problem.csv'))

log_problem['timestamp_TW'] = log_problem['timestamp_TW'].apply(lambda x: dt.datetime.strptime(x[:-4], "%Y-%m-%d %H:%M:%S"))
log_problem.sort_values(by='timestamp_TW', ascending=True, inplace=True)
log_problem.head()

Unnamed: 0,timestamp_TW,uuid,ucid,upid,problem_number,exercise_problem_repeat_session,is_correct,total_sec_taken,total_attempt_cnt,used_hint_cnt,is_hint_used,is_downgrade,is_upgrade,level
105021,2018-08-01 07:45:00,U+lqK/FKWkPuoNUM1AbGyrKZfXeQrRRoKOToKrjqDt4=,CPI+5YCeEmhqdk6znJeii6jJUNl1QWGEvwCUJ6uLflg=,9Ksy3bdm0HTi0D+cdKSBKDQJjo5GNahep91FqHUrpts=,2,1,True,8,1,0,False,,,0
11669798,2018-08-01 07:45:00,U+lqK/FKWkPuoNUM1AbGyrKZfXeQrRRoKOToKrjqDt4=,CPI+5YCeEmhqdk6znJeii6jJUNl1QWGEvwCUJ6uLflg=,/Wgjdl2BsldHZDdXXvzwGimusaMX548lqV2b7PgwXAs=,1,1,True,10,1,0,False,,,0
3768239,2018-08-01 07:45:00,U+lqK/FKWkPuoNUM1AbGyrKZfXeQrRRoKOToKrjqDt4=,CPI+5YCeEmhqdk6znJeii6jJUNl1QWGEvwCUJ6uLflg=,vsGlLPd9C58B8myBoGMGre2pDHjh62eRjsqX57D98fU=,5,1,True,6,1,0,False,False,True,1
10163558,2018-08-01 07:45:00,U+lqK/FKWkPuoNUM1AbGyrKZfXeQrRRoKOToKrjqDt4=,hv7kHCAIdj7thZUmlqz553leG5bFNYgzXmLfB5m4Xvw=,h3CI/U4QJd6mjYE5xRH8QEst8lRG7otYIz+q1V6Och4=,5,1,True,5,1,0,False,False,True,1
8372688,2018-08-01 07:45:00,U+lqK/FKWkPuoNUM1AbGyrKZfXeQrRRoKOToKrjqDt4=,CPI+5YCeEmhqdk6znJeii6jJUNl1QWGEvwCUJ6uLflg=,OSuDd4rDo2muXXwwJRU2DQVHkk6/JOGgNzfzNi4PMJM=,3,1,True,4,1,0,False,,,0


Each feature will be represented using embeddings. The following features will be used as input to the model: 
* is_correct
* ucid (content id) 
* level 1 id (ucid may be different, but content may be similar) 
* time taken for previous qn 
* is hint used for previous qn
* problem_number (the Nth problem that the student encountered from this exercise) 

We will need to add 1 to each embedding column, as 0 will be used as the padding token. 

In [19]:
key_cols = ['timestamp_TW', 'uuid']
feature_cols = ['is_correct', 'ucid', 'upid', 'problem_number', 'total_sec_taken', 'is_hint_used']
df = log_problem[key_cols + feature_cols]

In [20]:
first_n = 100  # return first_n observations for each user 
user_log_problem = df.groupby(by='uuid').head(first_n)
user_log_problem

Unnamed: 0,timestamp_TW,uuid,is_correct,ucid,upid,problem_number,total_sec_taken,is_hint_used
105021,2018-08-01 07:45:00,U+lqK/FKWkPuoNUM1AbGyrKZfXeQrRRoKOToKrjqDt4=,True,CPI+5YCeEmhqdk6znJeii6jJUNl1QWGEvwCUJ6uLflg=,9Ksy3bdm0HTi0D+cdKSBKDQJjo5GNahep91FqHUrpts=,2,8,False
11669798,2018-08-01 07:45:00,U+lqK/FKWkPuoNUM1AbGyrKZfXeQrRRoKOToKrjqDt4=,True,CPI+5YCeEmhqdk6znJeii6jJUNl1QWGEvwCUJ6uLflg=,/Wgjdl2BsldHZDdXXvzwGimusaMX548lqV2b7PgwXAs=,1,10,False
3768239,2018-08-01 07:45:00,U+lqK/FKWkPuoNUM1AbGyrKZfXeQrRRoKOToKrjqDt4=,True,CPI+5YCeEmhqdk6znJeii6jJUNl1QWGEvwCUJ6uLflg=,vsGlLPd9C58B8myBoGMGre2pDHjh62eRjsqX57D98fU=,5,6,False
10163558,2018-08-01 07:45:00,U+lqK/FKWkPuoNUM1AbGyrKZfXeQrRRoKOToKrjqDt4=,True,hv7kHCAIdj7thZUmlqz553leG5bFNYgzXmLfB5m4Xvw=,h3CI/U4QJd6mjYE5xRH8QEst8lRG7otYIz+q1V6Och4=,5,5,False
8372688,2018-08-01 07:45:00,U+lqK/FKWkPuoNUM1AbGyrKZfXeQrRRoKOToKrjqDt4=,True,CPI+5YCeEmhqdk6znJeii6jJUNl1QWGEvwCUJ6uLflg=,OSuDd4rDo2muXXwwJRU2DQVHkk6/JOGgNzfzNi4PMJM=,3,4,False
...,...,...,...,...,...,...,...,...
13360785,2019-07-31 23:45:00,90OCKSK8bkx4M+KdLllF8XNzykbJldJtxWaxQRuxPWY=,True,HXnIAxmCDEWAJJV2J7f2pEeGvddv9MpcRjbckar9YTY=,tvXS6hz6XdlgPmxZs3jfoX+saIvfhCzq8/D4gzrULeI=,5,364,False
13345074,2019-07-31 23:45:00,90OCKSK8bkx4M+KdLllF8XNzykbJldJtxWaxQRuxPWY=,True,HXnIAxmCDEWAJJV2J7f2pEeGvddv9MpcRjbckar9YTY=,uQDGUKEsX23MCS0xt/EDL7MhUW0kEPZeazOQSHVGsJg=,4,135,False
7960197,2019-07-31 23:45:00,90OCKSK8bkx4M+KdLllF8XNzykbJldJtxWaxQRuxPWY=,False,gQVxGCahF6+39K3glaRceVrND6A90VnDLSaj2NrTuis=,6M+OindLkGYAsjblVmm+OCNFhT2ch9OE1CWQofaIBYU=,1,436,True
14876381,2019-07-31 23:45:00,90OCKSK8bkx4M+KdLllF8XNzykbJldJtxWaxQRuxPWY=,True,HXnIAxmCDEWAJJV2J7f2pEeGvddv9MpcRjbckar9YTY=,Mo8JichdK3Bf2EAL96fRC9wWB/fXa67Bl9E7i2GLAnk=,7,93,False


In [21]:
# how many questions did each user take on average?
user_log_problem.groupby(by='uuid')['ucid'].count().describe()

count    72758.000000
mean        58.211949
std         39.532770
min          1.000000
25%         17.000000
50%         59.000000
75%        100.000000
max        100.000000
Name: ucid, dtype: float64

In [22]:
# add 1
user_log_problem['problem_number'] += 1
user_log_problem['is_hint_used'] += 1

# shift data
user_log_problem['total_sec_taken_shifted'] = user_log_problem[['uuid', 'total_sec_taken']].groupby('uuid')['total_sec_taken'].shift(1)
user_log_problem['is_hint_used_shifted'] = user_log_problem[['uuid', 'is_hint_used']].groupby('uuid')['is_hint_used'].shift(1)

# add question
user_log_problem['level2_id'] = user_log_problem.merge(info_content, how='left', on='ucid')['level2_id']
user_log_problem['level2_id'].fillna(method='ffill', inplace=True)

user_log_problem.head()

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  This is separate from the ipykernel package so we can avoid doing imports until
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead


Unnamed: 0,timestamp_TW,uuid,is_correct,ucid,upid,problem_number,total_sec_taken,is_hint_used,total_sec_taken_shifted,is_hint_used_shifted,level2_id
105021,2018-08-01 07:45:00,U+lqK/FKWkPuoNUM1AbGyrKZfXeQrRRoKOToKrjqDt4=,True,CPI+5YCeEmhqdk6znJeii6jJUNl1QWGEvwCUJ6uLflg=,9Ksy3bdm0HTi0D+cdKSBKDQJjo5GNahep91FqHUrpts=,3,8,1,,,7f73q332BKPBXaixasa4EkUb+pF6VAsLxNIg4506JJs=
11669798,2018-08-01 07:45:00,U+lqK/FKWkPuoNUM1AbGyrKZfXeQrRRoKOToKrjqDt4=,True,CPI+5YCeEmhqdk6znJeii6jJUNl1QWGEvwCUJ6uLflg=,/Wgjdl2BsldHZDdXXvzwGimusaMX548lqV2b7PgwXAs=,2,10,1,8.0,1.0,7f73q332BKPBXaixasa4EkUb+pF6VAsLxNIg4506JJs=
3768239,2018-08-01 07:45:00,U+lqK/FKWkPuoNUM1AbGyrKZfXeQrRRoKOToKrjqDt4=,True,CPI+5YCeEmhqdk6znJeii6jJUNl1QWGEvwCUJ6uLflg=,vsGlLPd9C58B8myBoGMGre2pDHjh62eRjsqX57D98fU=,6,6,1,10.0,1.0,7f73q332BKPBXaixasa4EkUb+pF6VAsLxNIg4506JJs=
10163558,2018-08-01 07:45:00,U+lqK/FKWkPuoNUM1AbGyrKZfXeQrRRoKOToKrjqDt4=,True,hv7kHCAIdj7thZUmlqz553leG5bFNYgzXmLfB5m4Xvw=,h3CI/U4QJd6mjYE5xRH8QEst8lRG7otYIz+q1V6Och4=,6,5,1,6.0,1.0,7f73q332BKPBXaixasa4EkUb+pF6VAsLxNIg4506JJs=
8372688,2018-08-01 07:45:00,U+lqK/FKWkPuoNUM1AbGyrKZfXeQrRRoKOToKrjqDt4=,True,CPI+5YCeEmhqdk6znJeii6jJUNl1QWGEvwCUJ6uLflg=,OSuDd4rDo2muXXwwJRU2DQVHkk6/JOGgNzfzNi4PMJM=,4,4,1,5.0,1.0,7f73q332BKPBXaixasa4EkUb+pF6VAsLxNIg4506JJs=


In [23]:
user_log_problem.fillna(0, inplace=True)

user_log_problem.isna().any()

A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  downcast=downcast,


timestamp_TW               False
uuid                       False
is_correct                 False
ucid                       False
upid                       False
problem_number             False
total_sec_taken            False
is_hint_used               False
total_sec_taken_shifted    False
is_hint_used_shifted       False
level2_id                  False
dtype: bool

In [24]:
user_log_problem['is_correct'].value_counts(normalize=True)

True     0.721829
False    0.278171
Name: is_correct, dtype: float64

In [53]:
user_log_problem.columns

Index(['timestamp_TW', 'uuid', 'is_correct', 'ucid', 'upid', 'problem_number',
       'total_sec_taken', 'is_hint_used', 'total_sec_taken_shifted',
       'is_hint_used_shifted', 'level2_id'],
      dtype='object')

In [54]:
# group rows by uuid and put in hashtable (dictionary)

user_dict = {uuid: u.drop(columns=['uuid', 'timestamp_TW', 'total_sec_taken', 'is_hint_used'])
             for uuid, u in user_log_problem.groupby('uuid')}

### Data Generator

* to feed data into the model, subclass tf.keras.utils.Sequence
* transform 2D tensor into 3D tensor, eg. [5,4] -> [5, 3, 4] if window_size=3

In [26]:
user_log_problem['is_correct'] = np.where(user_log_problem['is_correct'], 1, 0)
user_log_problem['is_correct'].value_counts(normalize=True)

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  """Entry point for launching an IPython kernel.


1    0.721829
0    0.278171
Name: is_correct, dtype: float64

In [27]:
import tensorflow as tf

# helper functions

def rolling_window(a, w):
    s0, s1 = a.strides
    m, n = a.shape
    return np.lib.stride_tricks.as_strided(
        a, 
        shape=(m-w+1, w, n), 
        strides=(s0, s0, s1)
    )


def make_time_series(x, windows_size):
  x = np.pad(x, [[ windows_size-1, 0], [0, 0]], constant_values=0)
  x = rolling_window(x, windows_size)
  return x 

def add_features_to_user(user):
    # We add one to the column in order to have zeros as padding values
    # Start Of Sentence (SOS) token will be 3. 
    user['is_correct'] = user['is_correct'].shift(fill_value=2)+1
    return user



In [57]:
class RiidSequence(tf.keras.utils.Sequence):
    def __init__(self, 
                 users, 
                 windows_size,
                 batch_size=256,
                 start=0,
                 end=None):
        self.users = users # {'user_id': user_df, ...}
        self.windows_size = windows_size
        # to convert indices to our keys
        self.mapper = dict(zip(range(len(users)), users.keys()))
        # start and end to easy generate training and validation
        self.start = start
        self.end = end if end else len(users)
        # To know where the is_correct_column is
        self.is_correct_index = list(self.user_example().columns).index('is_correct')
        
    def __len__(self):
        return self.end-self.start
    
    def __getitem__(self, idx):
        uid = self.mapper[idx+self.start]
        user = self.users[uid].copy()
        y = user['is_correct'].to_numpy().copy()
        x = add_features_to_user(user)
        return make_time_series(x, self.windows_size), y

    def user_example(self):
        """Just to check what we have till' now."""
        uid = self.mapper[self.start]
        return add_features_to_user(self.users[uid].copy())

      # INFERENCE PART    
    def get_user_for_inference(self, user_row):
        """Picks a new user row and concats it to previous interactions 
        if it was already stored.

        Maybe the biggest trick in the notebook is here. We reuse the user_id column to 
        insert the is_correct SOS token because we previously placed the column 
        there on purpose.

        After it, we roll that column and then crop it if it was bigger than the window
        size, making the SOS token disapear if out of the sequence.

        If the sequence if shorter than the window size, then we pad it.
        """
        uid = user_row[self.is_correct_index]
        user_row[self.is_correct_index] = 2 # SOS token
        user_row = user_row[np.newaxis, ...]
        if uid in self.users:
            x = np.concatenate([self.users[uid], user_row])
              # same as in training, we need to add one!!!
            x[:, self.is_correct_index] = np.roll(x[:, self.is_correct_index], 1) + 1
        else:
            x = user_row

        if x.shape[0] < self.windows_size:
            return np.pad(x, [[self.windows_size-x.shape[0], 0], [0, 0]])
        elif x.shape[0] > self.windows_size:
            return x[-self.windows_size:]
        else:
            return x
    def update_user(self, uid, user):
        """Concat the new user's interactions to the old ones if already stored."""
        if uid in self.users:
            self.users[uid] = np.concatenate([self.users[uid], user])[-self.windows_size:]
        else:
            self.users[uid] = user

In [30]:
RiidSequence(user_dict , windows_size=30).user_example().head()

Unnamed: 0,timestamp_TW,is_correct,ucid,upid,problem_number,total_sec_taken,is_hint_used,total_sec_taken_shifted,is_hint_used_shifted,level2_id
10682047,2018-12-22 08:30:00,3,VY6aXT7f64ny+uy4pszHVNSy3WHyoFPuhwToxBhB3wM=,uIHmyJc0Ia07OXiS1k2wayrEA57AKPTypHaTYtKc4Po=,6,7,1,0.0,0.0,7f73q332BKPBXaixasa4EkUb+pF6VAsLxNIg4506JJs=
13798435,2018-12-22 08:30:00,2,VY6aXT7f64ny+uy4pszHVNSy3WHyoFPuhwToxBhB3wM=,hrdzJpcOlqUZbvMB6A0P2fHVvn6US4bPm05s/PsjRi4=,4,55,1,7.0,1.0,xYDz4OEv0xsri1IpmXlrgMLJ848rgySf+39xWpq4DBI=
5851364,2018-12-22 08:30:00,2,VY6aXT7f64ny+uy4pszHVNSy3WHyoFPuhwToxBhB3wM=,wgsncTN/ZG/jV1BVhOb0ExV10Y7nbx4HDXgpqfV98Ls=,3,17,1,55.0,1.0,1EzKLzTq9Ax8/wlR9cJNrtthvk9lBi/SFdx/4L1PIaE=
15461504,2018-12-22 08:30:00,2,VY6aXT7f64ny+uy4pszHVNSy3WHyoFPuhwToxBhB3wM=,s2T1bZCtwOT2xgCGkHFRiqlSX8MHEOiW36tBJRXYj4E=,5,30,1,17.0,1.0,xYDz4OEv0xsri1IpmXlrgMLJ848rgySf+39xWpq4DBI=
11828870,2018-12-22 08:30:00,2,VY6aXT7f64ny+uy4pszHVNSy3WHyoFPuhwToxBhB3wM=,SZt4MEfoLNcDjOc4qWBsd5Lw5M03DA0f22Baij04cbg=,2,193,1,30.0,1.0,7f73q332BKPBXaixasa4EkUb+pF6VAsLxNIg4506JJs=


0 is the padding token. if first observation in sequence, **total_sec_taken_shifted** and **is_hint_used_shifted** are 0.

In [31]:
x, y = RiidSequence(user_dict, windows_size=30)[0]
x.shape, y.shape

((5, 30, 10), (5,))

#### Modelling part

https://www.tensorflow.org/tutorials/text/transformer

In [73]:

def get_angles(pos, i, d_model):
    angle_rates = 1 / np.power(10000, (2 * (i//2)) / np.float32(d_model))
    return pos * angle_rates


def positional_encoding(position, d_model):
    angle_rads = get_angles(np.arange(position)[:, np.newaxis],
                            np.arange(d_model)[np.newaxis, :],
                            d_model)
    angle_rads[:, 0::2] = np.sin(angle_rads[:, 0::2])
    angle_rads[:, 1::2] = np.cos(angle_rads[:, 1::2])
    pos_encoding = angle_rads[np.newaxis, ...]
    return tf.cast(pos_encoding, dtype=tf.float32)

# NN THINGS
def scaled_dot_product_attention(q, k, v, mask):
    matmul_qk = tf.matmul(q, k, transpose_b=True)
    dk = tf.cast(tf.shape(k)[-1], tf.float32)
    scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)
    if mask is not None:
        scaled_attention_logits += (mask * -1e9)  
        attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1)
        output = tf.matmul(attention_weights, v)
    return output, attention_weights

    
class MultiHeadAttention(tf.keras.layers.Layer):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        self.num_heads = num_heads
        self.d_model = d_model
    
        assert d_model % self.num_heads == 0
        self.depth = d_model // self.num_heads
        self.wq = tf.keras.layers.Dense(d_model)
        self.wk = tf.keras.layers.Dense(d_model)
        self.wv = tf.keras.layers.Dense(d_model)

        self.dense = tf.keras.layers.Dense(d_model)

    def split_heads(self, x, batch_size):
        x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))
        return tf.transpose(x, perm=[0, 2, 1, 3])
    
    def call(self, v, k, q, mask):
        batch_size = tf.shape(q)[0]
        q = self.wq(q)
        k = self.wk(k)
        v = self.wv(v)

        q = self.split_heads(q, batch_size)
        k = self.split_heads(k, batch_size)
        v = self.split_heads(v, batch_size)

        scaled_attention, attention_weights = scaled_dot_product_attention(
            q, k, v, mask)

        scaled_attention = tf.transpose(scaled_attention, perm=[0, 2, 1, 3])

        concat_attention = tf.reshape(scaled_attention, 
                                      (batch_size, -1, self.d_model))

        output = self.dense(concat_attention)

        return output, attention_weights


def point_wise_feed_forward_network(d_model, dff):
    return tf.keras.Sequential([
      tf.keras.layers.Dense(dff, activation='relu'),
      tf.keras.layers.Dense(d_model)
    ])


class EncoderLayer(tf.keras.layers.Layer):
    def __init__(self, d_model, num_heads, dff, rate=0.1):
        super(EncoderLayer, self).__init__()

        self.mha = MultiHeadAttention(d_model, num_heads)
        self.ffn = point_wise_feed_forward_network(d_model, dff)

        self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
        self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)

        self.dropout1 = tf.keras.layers.Dropout(rate)
        self.dropout2 = tf.keras.layers.Dropout(rate)

    def call(self, x, training, mask):

        attn_output, _ = self.mha(x, x, x, mask) 
        attn_output = self.dropout1(attn_output, training=training)
        out1 = self.layernorm1(x + attn_output) 

        ffn_output = self.ffn(out1)
        ffn_output = self.dropout2(ffn_output, training=training)
        out2 = self.layernorm2(out1 + ffn_output) 

        return out2
    
def create_padding_mask(seqs):
  # We mask only those vectors of the sequence in which we have all zeroes 
  # (this is more scalable for some situations).
    mask = tf.cast(tf.reduce_all(tf.math.equal(seqs, 0), axis=-1), tf.float32)
    return mask[:, tf.newaxis, tf.newaxis, :]  # (batch_size, 1, 1, seq_len)

In [61]:
columns = list(RiidSequence(user_dict, 30).user_example().columns)
columns

['is_correct',
 'ucid',
 'upid',
 'problem_number',
 'total_sec_taken_shifted',
 'is_hint_used_shifted',
 'level2_id']

In [62]:
def get_series_model(
        n_features,
        ucids, upids, problem_numbers,
        prev_times_taken, prev_hints_used,
        level2_ids,
        windows_size=30,
        d_model=24,
        num_heads=4,
        n_encoder_layers = 2
    ):
    # Input
    inputs = tf.keras.Input(shape=(windows_size, n_features), name='inputs')
    mask = create_padding_mask(inputs)
    pos_enc = positional_encoding(windows_size, d_model)    
    
    # Divide branches
    is_correct = inputs[..., 0]
    ucid = inputs[..., 1]
    upid = inputs[..., 2]
    problem_number = inputs[..., 3]
    prev_time_taken = inputs[..., 4]
    prev_hint_used = inputs[..., 5]
    level2_id = inputs[..., 6]
    
    # Create embeddings
    ucid_embeddings = tf.keras.layers.Embedding(ucids, d_model)(ucid)
    upid_embeddings = tf.keras.layers.Embedding(upids, d_model)(upid)
    problem_number_embeddings = tf.keras.layers.Embedding(problem_numbers, d_model)(problem_number)
    prev_time_taken_embeddings = tf.keras.layers.Embedding(prev_times_taken, d_model)(prev_time_taken)
    prev_hint_used_embeddings = tf.keras.layers.Embedding(prev_hints_used, d_model)(prev_hints_used)
    level2_id_embeddings = tf.keras.layers.Embedding(level2_ids, d_model)(level2_ids)
    is_correct_embeddings = tf.keras.layers.Embedding(4, d_model)(is_correct)
    
    # Continuous! Only a learnable layer for it.
    # elapsed_time_embeddings = tf.keras.layers.Dense(d_model, use_bias=False)(elapsed_time)
    # part_embeddings = tf.keras.layers.Embedding(part_ids, d_model)(part)
    
    # Add embeddings
    x = tf.keras.layers.Add()([
        pos_enc,
        ucid_embeddings,
        upid_embeddings,
        problem_number_embeddings,
        prev_time_taken_embeddings, 
        prev_hint_used_embeddings,
        level2_id_embeddings,
        is_correct_embeddings
    ])

    for _ in range(n_encoder_layers):
        x = EncoderLayer(d_model=d_model, num_heads=num_heads, dff=d_model*4, rate=0.1)(x, mask=mask)

    x = tf.keras.layers.GlobalAveragePooling1D()(x)
    x = tf.keras.layers.Dropout(0.2)(x)    
    output = tf.keras.layers.Dense(1, activation='sigmoid', name='output')(x)
    return tf.keras.Model(inputs, output, name='model')

#### Training

In [65]:
train_idx = int(len(user_dict)*0.8)
windows_size = 30
epochs = 300
patience = 2
d_model = 32
num_heads = 4
n_encoder_layers = 2

In [70]:
info_content

Unnamed: 0,ucid,content_pretty_name,content_kind,difficulty,subject,learning_stage,level1_id,level2_id,level3_id,level4_id
0,odIwFdIiecFwVUAEEV40K3MSuCSlIZkbq92Zp9tkZq8=,【基礎】怎樣解題：數量關係,Exercise,easy,math,elementary,aH0Dz0KdH9gio7rrcGRHvrmd9vcd/0WJbeEFB7qeUKA=,ICgke8JJv5eapCPwyj1aco8PEtoBkUbTZYIqxmYtqBk=,bo3jsx1beVLEZ+2sckxdZNYnlLpVS7hb5lWU2baQ66k=,KPJMQebU0O24+NzlQ4udb2BXLlKV1Hte61+hV5Xb+oU=
1,dfeeBaa8zDhWS6nu7zeXKwLyi4zqEajI3tJM9/fSBPM=,【基礎】和差問題 1,Exercise,easy,math,elementary,aH0Dz0KdH9gio7rrcGRHvrmd9vcd/0WJbeEFB7qeUKA=,ICgke8JJv5eapCPwyj1aco8PEtoBkUbTZYIqxmYtqBk=,bo3jsx1beVLEZ+2sckxdZNYnlLpVS7hb5lWU2baQ66k=,KPJMQebU0O24+NzlQ4udb2BXLlKV1Hte61+hV5Xb+oU=
2,C2AT0OBTUn+PRxEVd39enhW/DJtka1Tk90DUAR6yVdA=,【基礎】雞兔問題 1,Exercise,easy,math,elementary,aH0Dz0KdH9gio7rrcGRHvrmd9vcd/0WJbeEFB7qeUKA=,ICgke8JJv5eapCPwyj1aco8PEtoBkUbTZYIqxmYtqBk=,bo3jsx1beVLEZ+2sckxdZNYnlLpVS7hb5lWU2baQ66k=,KPJMQebU0O24+NzlQ4udb2BXLlKV1Hte61+hV5Xb+oU=
3,jZvYpEa6VB/WrlKKmQHnfbv/xJ4OypBzq0epVcn500Q=,【基礎】年齡問題 1,Exercise,easy,math,elementary,aH0Dz0KdH9gio7rrcGRHvrmd9vcd/0WJbeEFB7qeUKA=,ICgke8JJv5eapCPwyj1aco8PEtoBkUbTZYIqxmYtqBk=,bo3jsx1beVLEZ+2sckxdZNYnlLpVS7hb5lWU2baQ66k=,KPJMQebU0O24+NzlQ4udb2BXLlKV1Hte61+hV5Xb+oU=
4,M+UxJPgRIW57a0YS3eik8A9YDj+AwaMpTa5yWYn/kAw=,【基礎】追趕問題,Exercise,easy,math,elementary,aH0Dz0KdH9gio7rrcGRHvrmd9vcd/0WJbeEFB7qeUKA=,ICgke8JJv5eapCPwyj1aco8PEtoBkUbTZYIqxmYtqBk=,bo3jsx1beVLEZ+2sckxdZNYnlLpVS7hb5lWU2baQ66k=,KPJMQebU0O24+NzlQ4udb2BXLlKV1Hte61+hV5Xb+oU=
...,...,...,...,...,...,...,...,...,...,...
1325,NPkrCjbLK35wefMCv6y6XMO5eYO/gthS6LGwrdjG2OQ=,【進階】平方公分綜合習題,Exercise,hard,math,elementary,aH0Dz0KdH9gio7rrcGRHvrmd9vcd/0WJbeEFB7qeUKA=,7f73q332BKPBXaixasa4EkUb+pF6VAsLxNIg4506JJs=,8OFhw5s0EmQIKBlKbJL+Dvp+u3ss4rN3foLwLM4xXls=,JtYpwXqNWwRqPrXYh0JhE2lEEUI1gKRPdDH3jYwhqMo=
1326,3xqxJN2W+KEo3SBjh7HnXpGjB4ewSbm1j/18fCn05yc=,【進階】周長與面積綜合習題,Exercise,hard,math,elementary,aH0Dz0KdH9gio7rrcGRHvrmd9vcd/0WJbeEFB7qeUKA=,7f73q332BKPBXaixasa4EkUb+pF6VAsLxNIg4506JJs=,8OFhw5s0EmQIKBlKbJL+Dvp+u3ss4rN3foLwLM4xXls=,Ny1/uHUXA4pvWVN1nVWv+vsdaQde7StyoQV8HAyJD80=
1327,j4rGpwpqhLE9foelXD2yjokS0u3QR+ULhNqLGeF/4sk=,【一般】平方公分綜合習題,Exercise,normal,math,elementary,aH0Dz0KdH9gio7rrcGRHvrmd9vcd/0WJbeEFB7qeUKA=,7f73q332BKPBXaixasa4EkUb+pF6VAsLxNIg4506JJs=,8OFhw5s0EmQIKBlKbJL+Dvp+u3ss4rN3foLwLM4xXls=,JtYpwXqNWwRqPrXYh0JhE2lEEUI1gKRPdDH3jYwhqMo=
1328,W4l0TIo0YQXLT/c8/Uy7OLtElmNuNWPODI7HkJ0NaI0=,【一般】周長與面積綜合習題,Exercise,normal,math,elementary,aH0Dz0KdH9gio7rrcGRHvrmd9vcd/0WJbeEFB7qeUKA=,7f73q332BKPBXaixasa4EkUb+pF6VAsLxNIg4506JJs=,8OFhw5s0EmQIKBlKbJL+Dvp+u3ss4rN3foLwLM4xXls=,Ny1/uHUXA4pvWVN1nVWv+vsdaQde7StyoQV8HAyJD80=


In [74]:
s_train = RiidSequence(user_dict, windows_size, start=0, end=train_idx)
s_val = RiidSequence(user_dict, windows_size, start=train_idx)

# init variables
n_features = s_train[0][0].shape[-1]
ucids = info_content['ucid'].nunique()
upids = user_log_problem['upid'].nunique()
problem_numbers = user_log_problem['problem_number'].nunique()
prev_times_taken = user_log_problem['total_sec_taken'].nunique()
prev_hints_used = 3
level2_ids = 10

tf.keras.backend.clear_session()
model = get_series_model(
        n_features,
        ucids, upids, problem_numbers,
        prev_times_taken, prev_hints_used,
        level2_ids,
        windows_size=windows_size,
        d_model=d_model,
        num_heads=num_heads,
        n_encoder_layers=n_encoder_layers
    )

model.compile(
    optimizer='adam', 
    loss='binary_crossentropy', 
    metrics=[tf.keras.metrics.AUC(name='AUC'), tf.keras.metrics.BinaryAccuracy(name='acc')]
)

InvalidArgumentError: indices = 3 is not in [0, 3) [Op:ResourceGather]