Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MovieLens 1M dataset #397

Merged
merged 7 commits into from
Jun 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added `MovieLens 1M` dataset ([#397](https://github.com/pyg-team/pytorch-frame/pull/397))
- Added light-weight MLP ([#372](https://github.com/pyg-team/pytorch-frame/pull/372))
- Added R^2 metric ([#403](https://github.com/pyg-team/pytorch-frame/pull/403))

Expand Down
51 changes: 51 additions & 0 deletions test/datasets/test_movielens_1m.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import tempfile

import torch

import torch_frame
from torch_frame.config.text_embedder import TextEmbedderConfig
from torch_frame.data.stats import StatType
from torch_frame.datasets import Movielens1M
from torch_frame.testing.text_embedder import HashTextEmbedder


def test_movielens_1m():
with tempfile.TemporaryDirectory() as temp_dir:
dataset = Movielens1M(
temp_dir,
col_to_text_embedder_cfg=TextEmbedderConfig(
text_embedder=HashTextEmbedder(10)),
)
assert str(dataset) == 'Movielens1M()'
assert len(dataset) == 1000209
assert dataset.feat_cols == [
'user_id', 'gender', 'age', 'occupation', 'zip', 'movie_id', 'title',
'genres', 'timestamp'
]

dataset = dataset.materialize()

tensor_frame = dataset.tensor_frame
assert len(tensor_frame.feat_dict) == 4
assert tensor_frame.feat_dict[torch_frame.categorical].dtype == torch.int64
assert tensor_frame.feat_dict[torch_frame.categorical].size() == (1000209,
6)
assert tensor_frame.feat_dict[
torch_frame.multicategorical].dtype == torch.int64
assert tensor_frame.feat_dict[torch_frame.embedding].dtype == torch.float32
assert tensor_frame.col_names_dict == {
torch_frame.categorical:
['age', 'gender', 'movie_id', 'occupation', 'user_id', 'zip'],
torch_frame.multicategorical: ['genres'],
torch_frame.timestamp: ['timestamp'],
torch_frame.embedding: ['title'],
}
assert tensor_frame.y.size() == (1000209, )
assert tensor_frame.y.min() == 1 and tensor_frame.y.max() == 5

col_stats = dataset.col_stats
assert len(col_stats) == 10
assert StatType.COUNT in col_stats['user_id']
assert StatType.MULTI_COUNT in col_stats['genres']
assert StatType.YEAR_RANGE in col_stats['timestamp']
assert StatType.EMB_DIM in col_stats['title']
2 changes: 2 additions & 0 deletions torch_frame/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from .data_frame_benchmark import DataFrameBenchmark
from .data_frame_text_benchmark import DataFrameTextBenchmark
from .mercari import Mercari
from .movielens_1m import Movielens1M
from .amazon_fine_food_reviews import AmazonFineFoodReviews
from .diamond_images import DiamondImages
from .huggingface_dataset import HuggingFaceDatasetDict
Expand All @@ -34,6 +35,7 @@
'DataFrameBenchmark',
'DataFrameTextBenchmark',
'Mercari',
'Movielens1M',
'AmazonFineFoodReviews',
'DiamondImages',
]
Expand Down
90 changes: 90 additions & 0 deletions torch_frame/datasets/movielens_1m.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
from __future__ import annotations

import os.path as osp
import zipfile

import pandas as pd

import torch_frame
from torch_frame.config.text_embedder import TextEmbedderConfig


class Movielens1M(torch_frame.data.Dataset):
r"""The MovieLens 1M rating dataset, assembled by GroupLens Research
from the MovieLens web site, consisting of movies (3,883 nodes) and
users (6,040 nodes) with approximately 1 million ratings between them.

**STATS:**

.. list-table::
:widths: 10 10 10 10 20
:header-rows: 1

* - #Users
- #Items
- #User Field
- #Item Field
- #Samples
* - 6040
- 3952
- 5
- 3
- 1000209
"""

url = 'https://files.grouplens.org/datasets/movielens/ml-1m.zip'

def __init__(
self,
root: str,
col_to_text_embedder_cfg: dict[str, TextEmbedderConfig]
| TextEmbedderConfig | None = None,
):
path = self.download_url(self.url, root)
folder_path = osp.dirname(path)

with zipfile.ZipFile(path, 'r') as zip_ref:
zip_ref.extractall(folder_path)

data_path = osp.join(folder_path, 'ml-1m')
users = pd.read_csv(
osp.join(data_path, 'users.dat'),
header=None,
names=['user_id', 'gender', 'age', 'occupation', 'zip'],
sep='::',
engine='python',
)
movies = pd.read_csv(
osp.join(data_path, 'movies.dat'),
header=None,
names=['movie_id', 'title', 'genres'],
sep='::',
engine='python',
encoding='ISO-8859-1',
)
ratings = pd.read_csv(
osp.join(data_path, 'ratings.dat'),
header=None,
names=['user_id', 'movie_id', 'rating', 'timestamp'],
sep='::',
engine='python',
)

df = pd.merge(pd.merge(ratings, users), movies) \
.sort_values(by='timestamp') \
.reset_index().drop('index', axis=1)

col_to_stype = {
'user_id': torch_frame.categorical,
'gender': torch_frame.categorical,
'age': torch_frame.categorical,
'occupation': torch_frame.categorical,
'zip': torch_frame.categorical,
'movie_id': torch_frame.categorical,
'title': torch_frame.text_embedded,
'genres': torch_frame.multicategorical,
'rating': torch_frame.numerical,
'timestamp': torch_frame.timestamp,
}
super().__init__(df, col_to_stype, target_col='rating', col_to_sep='|',
col_to_text_embedder_cfg=col_to_text_embedder_cfg)
Loading