In [2]:
import torch
import torch.nn as nn

class TwoPartEmbeddings(nn.Module):
    def __init__(self, n_users, n_items, embedding_dim):
        super().__init__()
        
        total_nodes = n_users + n_items
        self.node_feats = nn.Embedding(total_nodes, embedding_dim)
        
        # Create views for client_feats and item_feats from node_feats
        self.client_feats = nn.Embedding.from_pretrained(self.node_feats.weight[:n_users], freeze=False)
        self.item_feats = nn.Embedding.from_pretrained(self.node_feats.weight[n_users:], freeze=False)

    def forward(self, node_ids):
        return self.node_feats(node_ids)

In [3]:
# class TwoPartEmbeddingsv2(nn.Module):
#     def __init__(self, n_users, n_items, embedding_dim):
#         super().__init__()

#         self.client_feats = nn.Embedding(n_users, embedding_dim)
#         self.item_feats = nn.Embedding(n_items, embedding_dim)
        
#         user_feats_weight = self.client_feats.weight.data
#         item_feats_weight = self.item_feats.weight.data
#         # Concatenate the embedding matrices
#         self.node_feats = torch.cat([user_feats_weight, item_feats_weight])

#     def forward(self, node_ids):
#         return self.node_feats[node_ids]

In [4]:
class TwoPartEmbeddingsv3(nn.Module):
    def __init__(self, n_users, n_items, embedding_dim):
        super().__init__()

        self.client_feats = nn.Embedding(n_users, embedding_dim)
        self.item_feats = nn.Embedding(n_items, embedding_dim)
        
        user_feats_weight = self.client_feats.weight
        item_feats_weight = self.item_feats.weight
        # Concatenate the embedding matrices
        self.node_feats = torch.cat([user_feats_weight, item_feats_weight])

    def forward(self, node_ids):
        return self.node_feats[node_ids]

In [5]:
def tests_two_part_embeddings(embs_ctor):
    # Setup
    n_users = 10
    n_items = 10
    embedding_dim = 64
    model = embs_ctor(n_users, n_items, embedding_dim)
    optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
    criterion = nn.MSELoss()

    # Create dummy input and target
    node_ids = torch.LongTensor([0, 1, 2, n_users, n_users + 1, n_users + 2])  # Mix of user and item IDs
    dummy_target = torch.randn(len(node_ids), embedding_dim)


    # Initial weights
    if type(model.client_feats) != nn.Embedding:
        initial_client_weights = model.client_feats.clone().detach()
    else:
        initial_client_weights = model.client_feats.weight.clone().detach()
    if type(model.item_feats) != nn.Embedding:
        initial_item_weights = model.item_feats.clone().detach()
    else:
        initial_item_weights = model.item_feats.weight.clone().detach()
    if type(model.node_feats) != nn.Embedding:
        initial_node_weights = model.node_feats.clone().detach()
    else:
        initial_node_weights = model.node_feats.weight.clone().detach()

    # Forward pass
    output = model(node_ids)
    # print("Initial output:", output)

    # Compute dummy loss
    loss = criterion(output, dummy_target)

    # Backward pass
    loss.backward()

    # Optimization step
    optimizer.step()

    # updated weights
    if type(model.client_feats) != nn.Embedding:
        updated_client_weights = model.client_feats.clone().detach()
    else:
        updated_client_weights = model.client_feats.weight.clone().detach()
    if type(model.item_feats) != nn.Embedding:
        updated_item_weights = model.item_feats.clone().detach()
    else:
        updated_item_weights = model.item_feats.weight.clone().detach()
    if type(model.node_feats) != nn.Embedding:
        updated_node_weights = model.node_feats.clone().detach()
    else:
        updated_node_weights = model.node_feats.weight.clone().detach()

    # Check that weights have been updated
    assert not torch.equal(initial_client_weights, updated_client_weights), "Client weights did not update"
    assert not torch.equal(initial_item_weights, updated_item_weights), "Item weights did not update"
    assert not torch.equal(initial_node_weights, updated_node_weights), "Node weights did not update"

    # Check that the client and item weights are still parts of the node weights
    assert torch.equal(model.client_feats.weight, updated_node_weights[:n_users]), "Client weights are not a slice of node weights"
    assert torch.equal(model.item_feats.weight, updated_node_weights[n_users:]), "Item weights are not a slice of node weights"

    # print("Updated client weights:", updated_client_weights)
    # print("Updated item weights:", updated_item_weights)
    # print("Updated node weights:", updated_node_weights)

    print("Passed")

In [6]:
tests_two_part_embeddings(TwoPartEmbeddings)

Passed


In [8]:
tests_two_part_embeddings(TwoPartEmbeddingsv3)

AssertionError: Node weights did not update