In [15]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import os
import json

import torch
from torch import nn, optim
from torch.utils.data import DataLoader, Dataset
from torch.utils.tensorboard import SummaryWriter

In [16]:
train_df = pd.read_csv("data/train_data.csv")
val_df = pd.read_csv("data/valid_data.csv")
student_meta = pd.read_csv("data/student_meta.csv")
question_meta = pd.read_csv("data/question_meta.csv")
subject_meta = pd.read_csv("data/subject_meta.csv")

In [17]:
n_students = max(student_meta["user_id"]) + 1
n_questions = max(question_meta["question_id"]) + 1
print(n_students, n_questions)

542 1774


In [18]:
student_meta_tensor = torch.zeros(n_students, 2)
user_id = torch.tensor(student_meta['user_id'].values, dtype=torch.int32)
gender = torch.tensor(student_meta['gender'].values, dtype=torch.float32)
premium_pupil = torch.tensor(
	student_meta['premium_pupil'].fillna(-1.0).values, dtype=torch.float32
)
student_meta_tensor[user_id, 0] = gender
student_meta_tensor[user_id, 1] = premium_pupil
student_meta_tensor

tensor([[ 2., -1.],
        [ 1., -1.],
        [ 0., -1.],
        ...,
        [ 1., -1.],
        [ 1.,  0.],
        [ 1.,  0.]])

In [19]:
class QuestionDataset(Dataset):
    def __init__(self, df):
        self.user_ids = df['user_id'].values
        self.question_ids = df['question_id'].values
        self.is_correct = df['is_correct'].values

    def __len__(self):
        return len(self.is_correct)

    def __getitem__(self, idx):
        return self.user_ids[idx], self.question_ids[idx], self.is_correct[idx]

class StudentQuestionNet(nn.Module):
    def __init__(self, student_embed_dim, question_embed_dim, student_meta_dim, hidden_layers, dropout_p=0.3):
        super(StudentQuestionNet, self).__init__()
        input_dim = student_embed_dim + question_embed_dim + student_meta_dim
        layers = []

        for hidden_dim in hidden_layers:
            layers.append(nn.Linear(input_dim, hidden_dim))
            layers.append(nn.BatchNorm1d(hidden_dim))
            layers.append(nn.ReLU())
            layers.append(nn.Dropout(dropout_p))
            input_dim = hidden_dim

        layers.append(nn.Linear(input_dim, 1))
        self.network = nn.Sequential(*layers)
    
    def forward(self, student_embed, question_embed, student_meta):
        combined = torch.cat([student_embed, question_embed, student_meta], dim=-1)
        return self.network(combined)

In [20]:
def initialize_dataloaders(train_df, val_df, batch_size):
    train_dataset = QuestionDataset(train_df)
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

    val_dataset = QuestionDataset(val_df)
    val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

    return train_dataloader, val_dataloader

def initialize_model_and_optimizers(n_students, n_questions, student_embed_dim, question_embed_dim, student_meta_dim,
                                    hidden_layers, dropout_p, learning_rate, device):
    student_embed = torch.nn.Parameter(torch.randn(n_students, student_embed_dim).to(device))
    question_embed = torch.nn.Parameter(torch.randn(n_questions, question_embed_dim).to(device))

    model = StudentQuestionNet(
        student_embed_dim=student_embed_dim,
        question_embed_dim=question_embed_dim,
        student_meta_dim=student_meta_dim,
        hidden_layers=hidden_layers,
        dropout_p=dropout_p
    ).to(device)

    model_optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    embed_optimizer = torch.optim.Adam([student_embed, question_embed], lr=learning_rate)

    return model, student_embed, question_embed, model_optimizer, embed_optimizer

def train_step(model, train_dataloader, student_embed, question_embed, student_meta_tensor, criterion, model_optimizer,
               embed_optimizer, device):
    model.train()
    train_loss = 0.0
    correct_predictions = 0
    total_samples = 0

    for user_ids, question_ids, targets in train_dataloader:
        user_ids = user_ids.to(device)
        question_ids = question_ids.to(device)
        targets = targets.float().unsqueeze(1).to(device)

        user_embeds = student_embed[user_ids]
        question_embeds = question_embed[question_ids]
        student_meta = student_meta_tensor.to(device)[user_ids]

        logits = model(user_embeds, question_embeds, student_meta)
        loss = criterion(logits, targets)

        model_optimizer.zero_grad()
        embed_optimizer.zero_grad()
        loss.backward()
        model_optimizer.step()
        embed_optimizer.step()

        train_loss += loss.item()
        predictions = torch.sigmoid(logits) > 0.5
        correct_predictions += (predictions == targets).sum().item()
        total_samples += targets.size(0)

    return train_loss / len(train_dataloader), correct_predictions / total_samples


def val_step(model, val_dataloader, student_embed, question_embed, student_meta_tensor, criterion, device):
    model.eval()
    val_loss = 0.0
    correct_predictions = 0
    total_samples = 0

    with torch.no_grad():
        for user_ids, question_ids, targets in val_dataloader:
            user_ids = user_ids.to(device)
            question_ids = question_ids.to(device)
            targets = targets.float().unsqueeze(1).to(device)

            user_embeds = student_embed[user_ids]
            question_embeds = question_embed[question_ids]
            student_meta = student_meta_tensor.to(device)[user_ids]

            logits = model(user_embeds, question_embeds, student_meta)
            loss = criterion(logits, targets)
            val_loss += loss.item()

            predictions = torch.sigmoid(logits) > 0.5
            correct_predictions += (predictions == targets).sum().item()
            total_samples += targets.size(0)

    return val_loss / len(val_dataloader), correct_predictions / total_samples

def checkpoint(epoch, experiment_path, avg_val_loss, val_accuracy, model, student_embed, question_embed,
                     best_val_acc):
    if val_accuracy > best_val_acc:
        best_val_acc = val_accuracy

        for file in os.listdir(experiment_path):
            if file.endswith('.pt'):
                os.remove(os.path.join(experiment_path, file))

        model_filename = f'epoch{epoch}_val_loss{avg_val_loss:.5f}_val_acc{val_accuracy:.5f}.pt'
        model_save_path = os.path.join(experiment_path, model_filename)
        torch.save(model.state_dict(), model_save_path)

        torch.save(student_embed, os.path.join(experiment_path, 'student_embed.pt'))
        torch.save(question_embed, os.path.join(experiment_path, 'question_embed.pt'))

    return best_val_acc

def train_model(
    n_students,
    n_questions,
    train_df,
    val_df,
    student_meta_tensor,
    student_embed_dim=8,
    question_embed_dim=16,
    hidden_layers=[64, 16],
    dropout_p=0.3,
    batch_size=32,
    learning_rate=1e-3,
    device='cpu'
):
    checkpoint_dir = 'checkpoints'
    log_dir = 'logs'

    experiment_num = 1
    while os.path.exists(os.path.join(checkpoint_dir, f'experiment{experiment_num}')):
        experiment_num += 1
    experiment_path = os.path.join(checkpoint_dir, f'experiment{experiment_num}')
    os.makedirs(experiment_path, exist_ok=True)

    hyperparameters = {
        'n_students': n_students,
        'n_questions': n_questions,
        'student_embed_dim': student_embed_dim,
        'question_embed_dim': question_embed_dim,
        'hidden_layers': hidden_layers,
        'dropout_p': dropout_p,
        'batch_size': batch_size,
        'learning_rate': learning_rate,
        'device': device
    }
    with open(os.path.join(experiment_path, 'hyperparameters.json'), 'w') as f:
        json.dump(hyperparameters, f, indent=4)

    writer = SummaryWriter(os.path.join(log_dir, f'experiment{experiment_num}'))
    train_dataloader, val_dataloader = initialize_dataloaders(train_df, val_df, batch_size)

    model, student_embed, question_embed, model_optimizer, embed_optimizer = initialize_model_and_optimizers(
        n_students, n_questions, student_embed_dim, question_embed_dim, student_meta_tensor.shape[1],
        hidden_layers, dropout_p, learning_rate, device
    )

    criterion = torch.nn.BCEWithLogitsLoss()
    best_val_acc = -1
    epoch = 0

    try:
        while True:
            epoch += 1

            avg_train_loss, train_accuracy = train_step(
                model, train_dataloader, student_embed, question_embed, student_meta_tensor,
                criterion, model_optimizer, embed_optimizer, device
            )

            avg_val_loss, val_accuracy = val_step(
                model, val_dataloader, student_embed, question_embed, student_meta_tensor,
                criterion, device
            )

            writer.add_scalar('Train/Loss', avg_train_loss, epoch)
            writer.add_scalar('Train/Accuracy', train_accuracy, epoch)
            writer.add_scalar('Validation/Loss', avg_val_loss, epoch)
            writer.add_scalar('Validation/Accuracy', val_accuracy, epoch)

            print(f"Epoch {epoch}")
            print(f"Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.5f} | Train Acc: {train_accuracy:.4f} | Val Acc: {val_accuracy:.5f}")

            best_val_acc = checkpoint(
                epoch, experiment_path, avg_val_loss, val_accuracy, model, student_embed, question_embed, best_val_acc
            )

    except KeyboardInterrupt:
        print("early stopping")
    finally:
        writer.close()

In [None]:
train_model(
    n_students,
    n_questions,
    train_df,
    val_df,
    student_meta_tensor,
    student_embed_dim=16,
    question_embed_dim=64,
    hidden_layers=[16, 128, 16],
    dropout_p=0.5,
    batch_size=256,
    learning_rate=1e-4,
	device = "mps"
)