In [None]:
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import random
import torch
import torch.nn as nn

In [None]:
cols_to_load = ['user_id','timestamp','content_id','prior_question_had_explanation','answered_correctly']

In [None]:
data = pd.read_pickle("../input/real-data-for-version1/out.pkl")

In [None]:
correct = data[data.answered_correctly != -1].groupby(["content_id", 'answered_correctly'], as_index=False).size()
correct = correct.pivot(index= "content_id", columns='answered_correctly', values='size')
correct.columns = ['Wrong', 'Right']
correct = correct.fillna(0)
correct[['Wrong', 'Right']] = correct[['Wrong', 'Right']].astype(int)

In [None]:
acu_id  = correct.Right/(correct.Wrong + correct.Right)

In [None]:
data = data[cols_to_load]

In [None]:
import riiideducation
env = riiideducation.make_env()
iter_test = env.iter_test()

In [None]:
def clust_info(user_id,timestamp,data):  # get former sequence
    # user_id is in data.user_id 
    cols_to_output = ['content_id','prior_question_had_explanation','answered_correctly']
    user_re = data[data.user_id==user_id]
    user_re = user_re[user_re.timestamp < timestamp]
    user_re = user_re[cols_to_output]
    return user_re

In [None]:
def df_to_input(df):
    return(torch.tensor(np.array(df))[None,:,:])

In [None]:
content_num = 40000
device = torch.device('cpu')
class Riiid(nn.Module):
    def __init__(self, emb_con_size=16, emb_bin_size=8, hidden_size=32, middle_size=16, state_size=4, dropout=0.2):
        self.hidden_size = hidden_size
        self.emb_con_size = emb_con_size
        self.emb_bin_size = emb_bin_size
        self.middle_size = middle_size
        self.state_size = state_size
        super(Riiid, self).__init__()
        self.Emb_content = nn.Embedding(content_num, self.emb_con_size)
        self.Emb_ans = nn.Embedding(2, self.emb_bin_size)
        self.Emb_explanation = nn.Embedding(2, self.emb_bin_size)
        self.LSTM = nn.LSTM(input_size=self.emb_bin_size+self.emb_con_size, hidden_size=self.hidden_size,
                            batch_first=True)
        self.fc = nn.Linear(in_features=self.hidden_size+self.emb_con_size+self.emb_bin_size+self.state_size,
                            out_features=self.middle_size)
        self.decoder = nn.Linear(in_features=self.middle_size, out_features=1)
        self.ln = nn.Linear(1, self.state_size)
        self.dropout = nn.Dropout(dropout)
        nn.init.xavier_uniform_(self.decoder.weight)

    def forward(self, seq_data, seq_tar, seq_ans, length):
        # input: seq + len
        len_sorted, sorted_idx = length.sort(0, descending=True)
        # content id and ans
        seq_cat = torch.cat([seq_data[:, :, 0][:, :, None], seq_ans], dim=-1)
        seq_sorted = seq_cat[sorted_idx.long().reshape(-1)]
        emb_seq_con = self.Emb_content(seq_sorted[:, :, 0].to(device))
        emb_seq_ans = self.Emb_ans(seq_sorted[:, :, 1].to(device))
        emb_data_sorted = torch.cat([emb_seq_con, emb_seq_ans], dim=-1)
        packed_seq = nn.utils.rnn.pack_padded_sequence(emb_data_sorted, len_sorted.long().reshape(-1).cpu().data.numpy(),
                                                       batch_first=True)

        rnn_out, hid = self.LSTM(packed_seq)  # hid is the knowledge state
        # rnn_out, length_unpacked = nn.utils.rnn.pad_packed_sequence(rnn_out, batch_first=True)
        _, origin_index = sorted_idx.sort(0, descending=False)
        # rnn_out = rnn_out[origin_index.long()][:, 0, 0, :, :].contiguous()
        emb_seq_tar_q = self.Emb_content(seq_tar[:, :, 0])  # emb q info
        emb_seq_tar_e = self.Emb_explanation(seq_tar[:, :, 1])  # emb explanation
        q_state = torch.cat([emb_seq_tar_q, emb_seq_tar_e], -1)

        u_state_in = torch.div(torch.sum(seq_ans, 1).float(), length[:, :, 0].float()) # could be put in data process
        u_state = self.ln(u_state_in)

        fc_in = hid[0]
        fc_in = fc_in[0][origin_index][:, 0, 0, :]
        # fc_in = torch.zeros([rnn_out.shape[0], rnn_out.shape[2]])
        # for i in range(len(length_unpacked)):
        #     fc_in[i] = rnn_out[i, length_unpacked[i] - 1, :]
        state = torch.cat([q_state[:, 0, :], fc_in.to(device), u_state], -1)
        out = torch.sigmoid(self.decoder(self.fc(state.float())))
        return out

In [None]:
net = Riiid()
model = torch.load("../input/model-parameters/model.pickle",map_location=torch.device('cpu'))
net.param = model

In [None]:
def convert_test(test_line):
    out_0 = test_line.content_id
    if test_line.prior_question_had_explanation:
        out_1 = 1
    else:
        out_1 = 0
    return(torch.tensor([[[out_0,out_1]]]))

In [None]:
def test_out(test):
    re = list()
    for i in range(len(test)):
        tar = test.iloc[i]
        if tar.user_id in data.user_id:
            prior_seq = clust_info(tar.user_id,tar.timestamp,data) # 如果content未学习过，直接采用随机化的emb 可能会影响效果
            temp = df_to_input(prior_seq)
            
            re += net(temp[:,:,:2],convert_test(tar),torch.abs(temp[:,:,2][:,:,None]),torch.tensor([[[temp.shape[1]]]]))[0].data.numpy().tolist()
            
        else: # user
            if tar.content_id < len(acu_id):
                re += [acu_id[tar.content_id]]
            else:
                re.append(0.5)
        
    return re

In [None]:
cols_to_load_ = ['user_id','timestamp','content_id','prior_question_had_explanation']

In [None]:
for (test_df, sample_prediction_df) in iter_test:
    test_df=test_df[test_df['content_type_id'] == 0]
    test_df.reset_index(drop=True)
    re = test_out(test_df[cols_to_load_])
    test_df['answered_correctly']=re
    env.predict(test_df[['row_id', 'answered_correctly']])