Skip to content

THUwangcy/DirectAU

Repository files navigation

DirectAU

illustration

Implementation of the paper "Towards Representation Alignment and Uniformity in Collaborative Filtering" in KDD'22.

This work investigates the desired properties of representations in collaborative filtering (CF) from the perspective of alignment and uniformity. The proposed DirectAU provides a new learning objective for CF-based recommender systems, which directly optimizes representation alignment and uniformity on the hypersphere. A simple MF encoder optimizing this loss can achieve superior performance compared to SOTA CF methods.

Training with DirectAU

This learning objective is easy to implement as follows (PyTorch-style):

@staticmethod
def alignment(x, y):
    x, y = F.normalize(x, dim=-1), F.normalize(y, dim=-1)
    return (x - y).norm(p=2, dim=1).pow(2).mean()

@staticmethod
def uniformity(x):
    x = F.normalize(x, dim=-1)
    return torch.pdist(x, p=2).pow(2).mul(-2).exp().mean().log()

def calculate_loss(self, user, item):
    user_e, item_e = self.encoder(user, item)  # [bsz, dim]
    align = self.alignment(user_e, item_e)
    uniform = (self.uniformity(user_e) + self.uniformity(item_e)) / 2
    loss = align + self.gamma * uniform
    return loss

We integrate our DirectAU method (directau.py) into the RecBole framework. The datasets used in the paper are already included in the dataset folder. Related experimental settings can be found in the properties folder. To reproduce the results, you can run the following commands after installing all the requirements:

# Beauty
python run_recbole.py \
    --model=DirectAU --dataset=Beauty \
    --learning_rate=1e-3 --weight_decay=1e-6 \
    --gamma=0.5 --encoder=MF --train_batch_size=256

# Gowalla
python run_recbole.py \
    --model=DirectAU --dataset=Gowalla \
    --learning_rate=1e-3 --weight_decay=1e-6 \
    --gamma=5 --encoder=MF --train_batch_size=1024

# Yelp2018
python run_recbole.py \
    --model=DirectAU --dataset=Yelp \
    --learning_rate=1e-3 --weight_decay=1e-6 \
    --gamma=1 --encoder=MF --train_batch_size=1024

To test DirectAU on other datasets, you should prepare datasets similar to the existing ones. More explanations about the dataset format can be found in the Atomic Files of RecBole.

The main hyper-parameters of DirectAU includes:

Param Default Description
--embedding_size 64 The embedding size.
--gamma 1 The weight of the uniformity loss.
--encoder MF The encoder type: MF / LightGCN
--n_layers None The number of layers when --encoder=LightGCN

You can use the following command to tune hyper-parameters in DirectAU (more details see Parameter Tuning in RecBole):

python run_hyper.py \
    --model=DirectAU --dataset=Beauty \
    --config_files='recbole/properties/overall.yaml recbole/properties/model/DirectAU.yaml recbole/properties/dataset/sample.yaml' \
    --params_file=directau.hyper \
    --output_file=hyper.result

Measuring Alignment and Uniformity

The measurement of alignment and uniformity given the learned representations can be implemented as follows (Appendix A.2 in the paper):

def overall_align(user_index, item_index, user_emb, item_emb):
    """ Args:
    user_index (torch.LongTensor): user ids of positive interactions, shape: [|R|, ]
    item_index (torch.LongTensor): item ids of positive interactions, shape: [|R|, ]
    user_emb (torch.nn.Embedding): user embeddings of all the users, shape: [|U|, dim]
    item_emb (torch.nn.Embedding): item embeddings of all the items, shape: [|I|, dim]
    """
    x = F.normalize(user_emb[user_index], dim=-1)
    y = F.normalize(item_emb[item_index], dim=-1)
    return (x - y).norm(p=2, dim=1).pow(alpha).mean()

def overall_uniform(index_list, embedding):
    """ Args:
    index_list (torch.LongTensor): user/item ids of positive interactions, shape: [|R|, ]
    embedding (torch.nn.Embedding): user/item embeddings, shape: [|U|, dim] or [|I|, dim]
    """ 
    values, _= torch.sort(index_list)
    count_series = pd.value_counts(values.tolist(), sort=False)
    count = torch.from_numpy(count_series.values).unsqueeze(0)

    weight_matrix = torch.mm(count.transpose(-1, 0), count)
    weight = torch.triu(weight_matrix, 1).view(-1)[
        torch.nonzero(torch.triu(weight_matrix, 1).view(-1)).view(-1)].to(embedding.device)
    total_freq = (len(index_list) * len(index_list) - weight_matrix.trace()) / 2

    return torch.pdist(embedding[count_series.index], p=2).pow(2).mul(-2).exp().mul(weight).sum().div(total_freq).log()

measurement

Citation

If you find this work is helpful to your research, please consider citing our paper:

@inproceedings{wang2022towards,
  title={Towards Representation Alignment and Uniformity in Collaborative Filtering},
  author={Wang, Chenyang and Yu, Yuanqing and Ma, Weizhi and Zhang, Min and Chen, Chong and Liu, Yiqun and Ma, Shaoping},
  booktitle={Proceedings of the 28th ACM SIGKDD Conference on Knowledge Discovery and Data Mining},
  pages={1816--1825},
  year={2022}
}

Contact

Chenyang Wang (THUwangcy@gmail.com)

About

KDD'2022: Towards Representation Alignment and Uniformity in Collaborative Filtering

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages