In [1]:
import os
import sys
import torch
from datetime import datetime
from torch.optim import Adam
from torch.utils.data import DataLoader

In [2]:
sys.path.append("..")

In [3]:
from carca.data import CARCADataset, load_attrs, load_ctx, load_profiles
from carca.model import CARCA, BinaryCrossEntropy

In [4]:
def to(*tensors: torch.Tensor, device: str):
    return tuple([t.to(device) for t in tensors])

In [5]:
attrs = load_attrs("video_games")
ctx = load_ctx("video_games")
user_ids, item_ids, profiles = load_profiles("video_games")

In [6]:
n_items = len(item_ids) + 1
n_ctx = next(iter(ctx.values())).shape[0]
n_attrs = attrs.shape[1]

In [7]:
learning_rate = 0.000006
seq_len = 35
n_blocks = 3
n_heads = 3
dropout_rate = 0.3
l2_reg = 0.0001
d_dim = 390
g_dim = 1950
residual_sa = True
residual_ca = False
epochs = 800
batch_size = 128
beta1 = 0.9
beta2 = 0.98

In [8]:
dataset = CARCADataset(
    user_ids=user_ids,
    item_ids=item_ids,
    profiles=profiles,
    attrs=attrs,
    ctx=ctx,
    profile_seq_len=seq_len,
    target_seq_len=100,
    mode="train"
)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0)

In [9]:
model = CARCA(
    n_items=n_items,
    d=d_dim,
    g=g_dim,
    n_ctx=n_ctx,
    n_attrs=n_attrs,
    H=n_heads,
    p=dropout_rate,
    B=n_blocks,
    res_sa=residual_sa,
    res_ca=residual_ca
)

In [10]:
optim = Adam(model.parameters(), lr=learning_rate, weight_decay=l2_reg, betas=(beta1, beta2))
loss_fn = BinaryCrossEntropy()

In [11]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")
model = model.to(device)

Using cpu device


In [None]:
epoch = 0
epochs = 2

while epoch < epochs:
    sum_loss = 0

    for i, batch in enumerate(loader, start=1):
        p_x, p_q, o_x, o_q, y_true, mask = to(*batch, device=device)

        optim.zero_grad()
        y_pred = model(p_x, p_q, o_x, o_q)
        loss = loss_fn(y_pred, y_true, mask)
        loss.backward()
        optim.step()
    
        sum_loss += loss.item()

        time = datetime.now().strftime("%H:%M:%S")
        print(f"{time} - Batch {i:03d}: Average Loss = {(sum_loss / i):.4f}")
    
    time = datetime.now().strftime("%H:%M:%S")
    print(f"{time} - Epoch {epoch:03d}: Average Loss = {(sum_loss / len(loader)):.4f}")
    epoch += 1

In [None]:
# TODO: POSSIBLE ISSUES - CONV1D, DATA LOADING, UNORDERED SEQUENCES, TORCH OPERATIONS (MAYBE SOME PREVENT COMPUTING GRADIENTS)