In [2]:
import torch
import torch.nn as nn

In [None]:
class Embeddings(nn.Module):
    def __init__(self, n_items: int, d: int, g: int, n_ctx: int, n_attrs: int):
        super().__init__()

        self.items_embed = nn.Embedding(num_embeddings=n_items + 1, embedding_dim=d, padding_idx=0)
        self.feats_embed = nn.Linear(in_features=n_ctx + n_attrs, out_features=g)
        self.joint_embed = nn.Linear(in_features=g + d, out_features=d)

    def forward(self, x, q):
        z = self.items_embed(x)
        q = self.feats_embed(q)
        e = self.joint_embed(torch.cat((z, q), dim=-1))
        return e

In [None]:
# Profile-level self-attention block
class SelfAttentionBlock(nn.Module):
    def __init__(self, d: int, H: int, p: float):
        super().__init__()

        # Attention
        self.norm1 = nn.LayerNorm(normalized_shape=d)
        self.attention = nn.MultiheadAttention(embed_dim=d, num_heads=H)
        self.dropout1 = nn.Dropout(p=p)

        # FFN
        self.norm2 = nn.LayerNorm(normalized_shape=d)
        self.ffn_1 = nn.Conv1d(in_channels=d, out_channels=d, kernel_size=1)
        self.activation = nn.LeakyReLU()
        self.dropout2 = nn.Dropout(p=p)

        self.ffn_2 = nn.Conv1d(in_channels=d, out_channels=d, kernel_size=1)
        self.dropout3 = nn.Dropout(p=p)
        self.norm3 = nn.LayerNorm(normalized_shape=d)

    def forward(self, x):
        q = self.norm1(x)

        s = self.attention(q, x, x)
        s = self.dropout1(s)

        s = torch.mul(s, x)  # Multiplicative residual connection
        s = self.norm2(s)

        f = self.ffn_1(f)
        f = self.activation(f)
        f = self.dropout2(f)

        f = self.ffn_2(f)
        f = self.dropout3(f)

        f = torch.mul(f, s)  # Multiplicative residual connection
        f = self.norm3(f)

        return f

In [None]:
# Target-level cross-attention block
class CrossAttentionBlock(nn.Module):
    def __init__(self, d: int, H: int, p: float):
        super().__init__()

        # Attention
        self.attention = nn.MultiheadAttention(embed_dim=d, num_heads=H)
        self.dropout = nn.Dropout(p=p)

        # FFN
        self.ffn = nn.Conv1d(in_channels=d, out_channels=d, kernel_size=1)
        self.sig = nn.Sigmoid()
    
    def forward(self, e, f):
        s = self.attention(e, f, f)
        s = self.dropout(s)

        s = torch.mul(s, e)  # Multiplicative residual connection

        y = self.ffn(s)
        y = self.sig(y)

        return y

In [None]:
class CARCA(nn.Module):
    def __init__(self, n_items: int, d: int, g: int, n_ctx: int, n_attrs: int, H: int, p: float, B: int):
        super().__init__()

        
        self.embeds = Embeddings(n_items, d, g, n_ctx, n_attrs)
        self.sa_blocks = nn.Sequential(*[SelfAttentionBlock(d, H, p) for _ in range(B)])
        self.ca_blocks = CrossAttentionBlock(d, H, p)
    
    def forward(self, p_x, p_q, o_x, o_q):
        p_e = self.embeds(p_x, p_q)
        o_e = self.embeds(o_x, o_q)
        
        f = self.sa_blocks(p_e)
        y_pred = self.ca_blocks(o_e, f)

        return y_pred

In [None]:
class BinaryCrossEntropy(nn.Module):
    def __init__(self):
        super().__init__()
    
    def forward(self, y_true, y_pred, mask):
        loss = -(y_true * torch.log(y_pred) + (1.0 - y_true) * torch.log(1.0 - y_pred))
        loss = torch.sum(loss * mask)
        return loss