In [1]:
import warnings
warnings.filterwarnings("ignore")

In [2]:
import os
import pandas as pd
from tqdm import tqdm_notebook

## Load Datasets

In [3]:
data_dir = "data/AAAI2023Competition/"

In [4]:
df_train_valid = pd.read_csv(os.path.join(data_dir,"train_valid_sequences.csv"))

In [5]:
df_train_valid.head(1)

Unnamed: 0,fold,uid,questions,concepts,responses,timestamps,selectmasks,is_repeat
0,0,11066,"3751,3752,3753,3754,1990,3739,3740,3742,3756,3...","187,187,374,187,374,188,188,228,166,170,221,40...","1,1,1,0,1,1,1,1,0,0,0,1,1,1,1,1,1,1,1,1,0,1,1,...","1595229836000,1595233013000,1595233687000,1595...","1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,...","0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,..."


In [6]:
valid_fold = 0 #use the 0 fold to train and valid
df_train = df_train_valid[df_train_valid['fold']!=valid_fold].copy()#train dataset for fold 0
df_valid = df_train_valid[df_train_valid['fold']==valid_fold].copy()#valid dataset for fold 0

In [7]:
df_train.shape,df_valid.shape

((26728, 8), (6669, 8))

In [8]:
def flatten_dataset(df):
    interaction_list = []
    for _, row in tqdm_notebook(df.iterrows()):
        uid = row['uid']
        for question, concept, response in zip(row['questions'].split(","),
                                               row['concepts'].split(","),
                                               row['responses'].split(",")):
            if response == "-1":#remove the padding
                break
            interaction = {"uid": int(uid),
                           "question": int(question),
                           "concept": int(concept),
                           "response": int(response)}
            interaction_list.append(interaction)
    df_interaction = pd.DataFrame(interaction_list)
    print(f"# interaction is {len(interaction_list)}")
    return df_interaction

In [9]:
df_train_inter = flatten_dataset(df_train)
df_valid_inter = flatten_dataset(df_valid)

0it [00:00, ?it/s]

# interaction is 4109190


0it [00:00, ?it/s]

# interaction is 1029854


## Train a Model

In [10]:
df_train_inter.head(1)

Unnamed: 0,uid,question,concept,response
0,11779,3934,243,1


### Question Level

In [11]:
top_answers_que_level = df_train_inter.groupby(
    'question')['response'].agg("mean")
top_answers_dict_que_level = top_answers_que_level.to_dict()

### Concept Level

In [12]:
top_answers_concept_level = df_train_inter.groupby('concept')[
    'response'].agg("mean")
top_answers_dict_concept_level = top_answers_concept_level.to_dict()

## Predict on the Validation Dataset

In [13]:
df_valid_inter['que_pred'] = df_valid_inter['question'].apply(
    lambda x: top_answers_dict_que_level.get(x, 0))
df_valid_inter['concept_pred'] = df_valid_inter['concept'].apply(
    lambda x: top_answers_dict_concept_level.get(x, 0))

## Evaluate on the Validation Dataset

In [14]:
from sklearn.metrics import roc_auc_score

In [15]:
que_auc = roc_auc_score(df_valid_inter['response'],df_valid_inter['que_pred'])
que_auc

0.7323341119020579

In [16]:
concept_auc = roc_auc_score(df_valid_inter['response'],df_valid_inter['concept_pred'])
concept_auc

0.619159927619923

## Train use the Whole Dataset

In [17]:
top_answers_que_level = pd.concat([df_train_inter,df_valid_inter]).groupby(
    'question')['response'].agg("mean")
top_answers_dict_que_level = top_answers_que_level.to_dict()

In [18]:
# check again
df_valid_inter['que_pred'] = df_valid_inter['question'].apply(
    lambda x: top_answers_dict_que_level.get(x, 0))
que_auc = roc_auc_score(df_valid_inter['response'],df_valid_inter['que_pred'])
que_auc

0.735416364022037

## Predict on the Test Dataset

From the experiement, we notice the question level model has better performance, we use the question level model as our final model

In [19]:
df_test = pd.read_csv(os.path.join(data_dir,"pykt_test.csv"))

In [22]:
df_test

Unnamed: 0,uid,questions,concepts,responses,timestamps,is_repeat,num_test
0,8572,"2203,268,266,271,270,269,2204,274,277,2206,220...","140,139,9,144,143,142,119,119,62,65,62,65,148,...","1,1,1,0,1,1,1,1,0,0,1,1,1,0,1,1,1,1,0,1,1,0,1,...","1594547742000,1594547742000,1594547742000,1594...","0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,...",186
1,179,"3436,3437,3438,5243,2335,292,293,297,3440,1950...","155,155,155,55,55,155,155,55,117,307,55,307,55...","1,1,1,1,1,0,0,1,1,1,1,1,1,0,0,1,1,1,0,1,1,1,0,...","1599133602000,1599133602000,1599133602000,1599...","0,0,0,0,0,0,0,0,0,0,1,0,1,0,1,0,0,0,0,0,0,0,0,...",138
2,5664,"1367,2504,983,4358,4356,4356,1370,4355,4357,43...","187,188,187,335,188,471,335,188,188,335,188,18...","1,1,0,1,1,1,1,1,1,0,1,0,1,0,1,1,1,1,0,0,1,1,1,...","1595219073000,1595244223000,1595244223000,1595...","0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,...",199
3,3587,"268,266,267,270,271,269,1937,2323,2322,812,811...","139,9,141,143,144,142,232,231,232,30,30,30,147...","1,1,0,1,1,0,1,1,0,1,1,1,1,1,1,1,1,1,1,1,1,1,1,...","1597571706000,1597571706000,1597571706000,1597...","0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,...",163
4,6711,"2208,795,794,2333,2334,791,1945,1138,292,1946,...","155,304,303,155,155,155,155,155,155,155,155,10...","1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,0,0,...","1599216384000,1599290697000,1599290697000,1599...","0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,1,0,1,0,0,0,...",270
...,...,...,...,...,...,...,...
3608,601,"3293,3294,3292,5373,5372,5374,1130,3292,3293,3...","587,588,587,66,9,11,366,587,587,588,17,714,66,...","1,1,1,1,0,1,1,1,1,1,1,0,1,1,1,0,0,1,1,0,1,1,1,...","1594458929000,1594458929000,1594458929000,1594...","0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,0,1,0,1,0,0,0,1,...",129
3609,2540,"2309,2313,2318,2318,2328,2328,862,3268,3031,86...","365,142,231,232,497,488,223,144,144,142,140,9,...","1,0,1,1,0,0,1,0,0,1,1,1,0,1,1,1,1,1,1,1,1,1,1,...","1594446365000,1594446676000,1594447133000,1594...","0,0,0,1,0,1,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,...",226
3610,14945,"3421,3419,3420,2202,1130,1129,3263,2310,2310,2...","140,502,365,365,366,365,139,66,9,140,7,11,9,6,...","1,0,1,1,1,1,0,1,1,1,1,1,1,1,0,1,1,1,1,0,1,1,0,...","1595073081000,1595073081000,1595073081000,1595...","0,0,0,0,0,0,0,0,1,0,0,0,0,1,0,0,0,0,0,0,0,0,0,...",122
3611,9105,"1985,451,452,1986,1987,1989,1988,1990,453,454,...","475,138,187,188,188,374,188,374,228,228,214,21...","0,0,1,0,1,0,1,0,1,1,1,1,1,1,1,1,1,1,1,1,0,1,1,...","1594466792000,1594467267000,1594467816000,1594...","0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,...",241


In [27]:
predict_str_list = []
num_test = 0
for _, row in df_test.iterrows():
    predict_results = []
    for question, response, is_repeat in zip(row['questions'].split(","), 
                                             row['responses'].split(","), 
                                             row['is_repeat'].split(",")):
        question, response,is_repeat = int(question), int(response),int(is_repeat)
        if is_repeat!=0:#skip the repeat
            continue
        if response == -1:
            num_test += 1
            predict_results.append(top_answers_dict_que_level.get(question, 0))
    predict_str = ",".join([str(x) for x in predict_results])
    predict_str_list.append(predict_str)

In [28]:
num_test

552290

In [25]:
df_submit = pd.DataFrame({"responses":predict_str_list})

In [26]:
df_submit.to_csv("prediction.csv",index=False)