In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
tqdm.pandas()

df = pd.concat((
    pd.read_csv('./data/October_missions_full.csv'),
    pd.read_csv('./data/November_1stW_missions_full.csv'),
), ignore_index=True)

df['mission'] = df['type'] + '_' + df['target'].astype(str)
df = df[['user', 'mission', 'createdAtT', 'type', 'target', 'completed']]
df = df.groupby('user').filter(lambda x: len(x['createdAtT'].unique()) > 1)

df['user'] = df['user'].astype('category').cat.codes
df['type'] = df['type'].astype('category').cat.codes
df['mission'] = df['mission'].astype('category').cat.codes

df.sort_values(by=['createdAtT'], inplace=True, ignore_index=True)
df.drop_duplicates(subset=['user', 'type', 'target'], keep='last', inplace=True)

#def min_max_scaler(x):
#    return (x - x.min()) / (x.max() - x.min())
#df['target'] = df.groupby('type')['target'].transform(min_max_scaler).fillna(0)

df

Unnamed: 0,user,mission,createdAtT,type,target,completed
0,2019,25,1727740807698,5,2,False
1,2019,20,1727740807698,4,1,True
2,2019,8,1727740807698,1,6,False
6,2259,8,1727740821680,1,6,True
9,2133,30,1727740830583,6,1,True
...,...,...,...,...,...,...
96669,3384,12,1731023654775,2,1,False
96670,3384,23,1731023654775,4,4,False
96671,365,11,1731023935124,1,9,False
96672,365,24,1731023935124,5,1,False


In [2]:
import torch
import torch.nn as nn

def train_model(model: nn.Module, device, df: pd.DataFrame, epochs=10, lr=1e-3, batch_size=32, weight_decay=0.0):
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    pos_weight = (df['completed'].where(df['completed'] == 0).count() / df['completed'].where(df['completed'] == 1).count())
    criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(pos_weight))
    dataset = torch.utils.data.TensorDataset(
        torch.tensor(df['user'].values, dtype=torch.long, device=device),
        torch.tensor(df['type'].values, dtype=torch.long, device=device),
        torch.tensor(df['target'].values, dtype=torch.float, device=device),
        torch.tensor(df['completed'].values, dtype=torch.float, device=device),
    )
    
    loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
    model.train()
    model.to(device)
    for _ in (bar := tqdm(range(epochs), leave=False)):
        epoch_loss = 0
        for user, type, target, completed in tqdm(loader, leave=False):
            optimizer.zero_grad()
            pred = model(user, type, target)
            loss = criterion(pred, completed)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
        bar.set_postfix(loss=epoch_loss / len(loader))
    
    model.eval()
    return model

In [3]:
class MF(nn.Module):
    def __init__(self, n, m, embedding_dim=8):
        super().__init__()
        self.n = nn.Embedding(n, embedding_dim)
        self.m = nn.Embedding(m, embedding_dim)
        self.n_bias = nn.Embedding(n, 1)
        self.m_bias = nn.Embedding(m, 1)
        self.global_bias = nn.Parameter(torch.zeros(1))

    def forward(self, i, j):
        i_emb = self.n(i)
        j_emb = self.m(j)
        i_bias = self.n_bias(i).squeeze()
        j_bias = self.m_bias(j).squeeze()
        x = (i_emb * j_emb).sum(dim=1) + i_bias + j_bias + self.global_bias
        return x

class Net(nn.Module):
    def __init__(self, num_users, num_types, embedding_dim=8, difficulty_args=1, device='cpu'):
        super().__init__()
        self.user = nn.Embedding(num_users, embedding_dim)
        self.difficulty_weights = nn.Embedding(num_types, difficulty_args + 1)

        self.user_mlp = nn.Sequential(
            nn.Linear(embedding_dim, 16),
            nn.ReLU(),
            nn.Linear(16, 1),
        )

        self.device = device
    
    def forward(self, user, mission_type, target):
        x = self.user_mlp(self.user(user))
        weights = self.difficulty_weights(mission_type)
        x = torch.stack((x.view(-1), target.view(-1)), dim=1)
        x = (x * weights).sum(dim=1)
        return x.view(-1)
    
    def fit(self, df, epochs=10, lr=1e-3, batch_size=32, weight_decay=0.0):
        return train_model(self, self.device, df, epochs, lr, batch_size, weight_decay)


class ClassicNet(nn.Module):
    def __init__(self, num_users, num_missions, embedding_dim=8, device='cpu'):
        super().__init__()
        self.mf = MF(num_users, num_missions, embedding_dim)
        self.device = device

    def forward(self, user, mission):
        x = self.mf(user, mission)
        return x.view(-1)
    
    def fit(self, df: pd.DataFrame, epochs=10, lr=1e-3, batch_size=32, weight_decay=0.0):
        optimizer = torch.optim.Adam(self.parameters(), lr=lr, weight_decay=weight_decay)
        pos_weight = (df['completed'].where(df['completed'] == 0).count() / df['completed'].where(df['completed'] == 1).count())
        criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(pos_weight))
        dataset = torch.utils.data.TensorDataset(
            torch.tensor(df['user'].values, dtype=torch.long, device=self.device),
            torch.tensor(df['mission'].values, dtype=torch.long, device=self.device),
            torch.tensor(df['completed'].values, dtype=torch.float, device=self.device),
        )
        
        loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
        self.train()
        self.to(self.device)
        for _ in (bar := tqdm(range(epochs), leave=False)):
            epoch_loss = 0
            for user, mission, completed in tqdm(loader, leave=False):
                optimizer.zero_grad()
                pred = self(user, mission)
                loss = criterion(pred, completed)
                loss.backward()
                optimizer.step()
                epoch_loss += loss.item()
            bar.set_postfix(loss=epoch_loss / len(loader))
        
        self.eval()
        return self

In [4]:
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import roc_curve, auc, precision_recall_curve

kfold = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)

roc_auc = []
pr_auc = []
for train_idx, test_idx in tqdm(kfold.split(df, df['completed']), total=kfold.get_n_splits()):
    train = df.iloc[train_idx]
    test = df.iloc[test_idx]
    
    net = Net(
        num_users=df['user'].nunique(),
        num_types=df['type'].nunique(),
        embedding_dim=10,
        device='cuda',
    ).fit(train, epochs=20, weight_decay=1e-4, lr=1e-3)

    mlp = ClassicNet(
        num_users=df['user'].nunique(),
        num_missions=df['mission'].nunique(),
        embedding_dim=16,
        device='cuda',
    ).fit(train, epochs=15, weight_decay=1e-3, lr=1e-2)

    y_hat = pd.DataFrame({
        'net': net(
            torch.tensor(test['user'].values, dtype=torch.long, device=net.device),
            torch.tensor(test['type'].values, dtype=torch.long, device=net.device),
            torch.tensor(test['target'].values, dtype=torch.float, device=net.device),
        ).detach().cpu().numpy(),

        'classic_net': mlp(
            torch.tensor(test['user'].values, dtype=torch.long, device=mlp.device),
            torch.tensor(test['mission'].values, dtype=torch.long, device=mlp.device),
        ).detach().cpu().numpy()
    })

    y = test['completed'].values
    fpr, tpr, _ = roc_curve(y, y_hat['net'])
    roc_auc.append([auc(fpr, tpr)])
    fpr, tpr, _ = roc_curve(y, y_hat['classic_net'])
    roc_auc[-1].append(auc(fpr, tpr))

    precision, recall, _ = precision_recall_curve(y, y_hat['net'])
    pr_auc.append([auc(recall, precision)])
    precision, recall, _ = precision_recall_curve(y, y_hat['classic_net'])
    pr_auc[-1].append(auc(recall, precision))

roc_auc = np.array(roc_auc)
pr_auc = np.array(pr_auc)

print('ROC AUC')
print('Net:', roc_auc[:, 0].mean(), roc_auc[:, 0].std())
print('Classic Net:', roc_auc[:, 1].mean(), roc_auc[:, 1].std())

print('PR AUC')
print('Net:', pr_auc[:, 0].mean(), pr_auc[:, 0].std())
print('Classic Net:', pr_auc[:, 1].mean(), pr_auc[:, 1].std())

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

  0%|          | 0/1170 [00:00<?, ?it/s]

ROC AUC
Net: 0.8594447412116129 0.0047086781952528555
Classic Net: 0.8509574779020396 0.0021364983703082087
PR AUC
Net: 0.720157201572131 0.002819871003479395
Classic Net: 0.6808018893133212 0.0030073667772234696


In [5]:
missions = df[['mission', 'type', 'target']].drop_duplicates().sort_values(by=['type', 'target']).reset_index(drop=True)

In [8]:
user = torch.tensor(np.random.choice(df['user'].unique()), dtype=torch.long, device=net.device)
types = torch.tensor(missions['type'].values, dtype=torch.long, device=net.device)
target = torch.tensor(missions['target'].values, dtype=torch.float, device=net.device)

print('User:', user)
missions['x'] = torch.sigmoid(net(user.repeat(len(types)), types, target)).detach().cpu().numpy()

missions

User: tensor(1237, device='cuda:0')


Unnamed: 0,mission,type,target,x
0,0,0,1,0.09707
1,1,0,2,0.037929
2,2,1,1,0.109202
3,4,1,2,0.08372
4,5,1,3,0.063759
5,6,1,4,0.048307
6,7,1,5,0.036453
7,8,1,6,0.027425
8,9,1,7,0.020584
9,10,1,8,0.015423
