<a href="https://colab.research.google.com/github/upriyam-cmu/EDGE-Rec/blob/main/execute.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Install project code

In [None]:
!pip uninstall -y edge-rec
!pip install -e.

# Train model

In [None]:
from edge_rec.datasets import MovieLensDataHolder, RatingsTransform, FeatureTransform

from edge_rec.model import GraphReconstructionModel, GraphTransformer
from edge_rec.model.embed import MovieLensFeatureEmbedder, SinusoidalPositionalEmbedding

from edge_rec.diffusion import GaussianDiffusion
from edge_rec.exec import Trainer, compute_metrics

In [None]:
data_holder = MovieLensDataHolder(augmentations=dict(
    ratings=RatingsTransform.ToGaussian(),
    rating_counts=FeatureTransform.LogPolynomial(2),  # degree 2 --> dim_size = 2 (for embedder, below)
))

In [None]:
embed = MovieLensFeatureEmbedder(
    user_rating_counts_dims=2,
    movie_rating_counts_dims=2,
)
core = GraphTransformer(
    n_blocks=1,
    n_channels=1,
    n_features=embed.output_sizes,
    time_embedder=SinusoidalPositionalEmbedding(16),
    # attn_kwargs=dict(heads=1, dim_head=8, num_mem_kv=0)
)
model = GraphReconstructionModel(embed, core, feature_dim_size=None)

In [None]:
diffusion_model = GaussianDiffusion(model, image_size=(data_holder.n_users, data_holder.n_movies))
trainer = Trainer(
    # model
    diffusion_model=diffusion_model,
    # datasets
    train_dataset=data_holder.get_dataset(subgraph_size=None, target_density=None, train=True),
    test_dataset=data_holder.get_dataset(subgraph_size=None, target_density=None, train=False),
    # training
    batch_size=1,
    gradient_accumulate_every=1,
    force_batch_size=True,
    train_num_steps=int(1e4),
    train_mask_unknown_ratings=True,
    # optim
    train_lr=1e-4,
    adam_betas=(0.9, 0.99),
    max_grad_norm=1.,
    # logging
    results_folder="./results",
    ema_update_every=10,
    ema_decay=0.995,
    save_and_sample_every=200,
    # accelerator
    amp=False,
    mixed_precision_type='fp16',
    split_batches=True,
)
print("Using device:", trainer.device)

In [None]:
trainer.train()

# Sample ratings

In [None]:
def eval_model(use_inpainting: bool, milestone: int):
    user_indices, product_indices = data_holder.get_subgraph_indices(None, None)  # full graph
    rating_data_train = data_holder.slice_subgraph(
        user_indices=user_indices,
        product_indices=product_indices,
        return_train_edges=True,
        return_test_edges=False,
    )
    rating_data_test = data_holder.slice_subgraph(
        user_indices=user_indices,
        product_indices=product_indices,
        return_train_edges=False,
        return_test_edges=True,
    )
    denoised_graph = trainer.eval(
        rating_data=rating_data_train.clone(),
        milestone=milestone,
        do_inpainting_sampling=use_inpainting,
        tiled_sampling=True,
        batch_size=16,
        subgraph_size=128,
        silence_inner_tqdm=True,
    )
    return denoised_graph, rating_data_train, rating_data_test

In [None]:
denoised_graph, rating_data_train, rating_data_test = eval_model(use_inpainting=False, milestone=4000)

# Evaluate metrics

In [None]:
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.ticker import (MultipleLocator, AutoMinorLocator)

mpl.rcParams.update({'figure.dpi': 300, 'font.size': 4})
plt.style.use('bmh')

hrv_name = {
    'precision': 'Precision',
    'recall': 'Recall',
    'mean_reciprocal_rank': 'MRR',
    'hit_rate': 'HR',
    'ndcg': 'NDCG'
}


def plot_metrics(metrics: dict, plot_name: str):
    def _plot(ax, curve_name):
        ax.set_title(hrv_name[curve_name])
        ks = (1, 5, 10, 20, 30, 40, 50)
        ax.xaxis.grid(True, which='major')
        ax.xaxis.set_major_locator(MultipleLocator(10))
        ax.xaxis.set_major_formatter('{x:.0f}')
        ax.plot(ks, metrics[curve_name], '.-', linewidth=1.0, markersize=4.0)
        ax.set_xlabel("Top-K")
        ax.set_aspect(50 / (np.max(metrics[curve_name]) - np.min(metrics[curve_name])))

    fig, axs = plt.subplots(1, 5)
    for ax, name in zip(axs, ['precision', 'recall', 'ndcg', 'mean_reciprocal_rank', 'hit_rate']):
        _plot(ax, name)
        if name == 'precision':
            ax.set_ylabel(plot_name)
    plt.tight_layout(h_pad=-25.0)
    plt.show()

In [None]:
metrics = compute_metrics(
    predicted_ratings_graph=denoised_graph,
    train_rating_data=rating_data_train,
    test_rating_data=rating_data_test,
    rating_transform=data_holder.ratings_transform,
)

In [None]:
plot_metrics(metrics=metrics, plot_name="ML-100k, 128x128 Patch Sampling")

# Display sampled ratings distribution

In [None]:
plt.figure(figsize=(3, 2))
plt.hist(denoised_graph.numpy().flatten(), bins=20)
plt.show()