Skip to content

Commit

Permalink
Add MovieLens 1M dataset (#397)
Browse files Browse the repository at this point in the history
Co-authored-by: zhou.nanxuan <zhou.nanxuan@lmwn.com>
Co-authored-by: Yiwen Yuan <yyuanlisette@gmail.com>
  • Loading branch information
3 people committed Jun 5, 2024
1 parent 7d6d5c6 commit 4a6bb50
Show file tree
Hide file tree
Showing 4 changed files with 144 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added `MovieLens 1M` dataset ([#397](https://github.com/pyg-team/pytorch-frame/pull/397))
- Added light-weight MLP ([#372](https://github.com/pyg-team/pytorch-frame/pull/372))
- Added R^2 metric ([#403](https://github.com/pyg-team/pytorch-frame/pull/403))

Expand Down
51 changes: 51 additions & 0 deletions test/datasets/test_movielens_1m.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import tempfile

import torch

import torch_frame
from torch_frame.config.text_embedder import TextEmbedderConfig
from torch_frame.data.stats import StatType
from torch_frame.datasets import Movielens1M
from torch_frame.testing.text_embedder import HashTextEmbedder


def test_movielens_1m():
with tempfile.TemporaryDirectory() as temp_dir:
dataset = Movielens1M(
temp_dir,
col_to_text_embedder_cfg=TextEmbedderConfig(
text_embedder=HashTextEmbedder(10)),
)
assert str(dataset) == 'Movielens1M()'
assert len(dataset) == 1000209
assert dataset.feat_cols == [
'user_id', 'gender', 'age', 'occupation', 'zip', 'movie_id', 'title',
'genres', 'timestamp'
]

dataset = dataset.materialize()

tensor_frame = dataset.tensor_frame
assert len(tensor_frame.feat_dict) == 4
assert tensor_frame.feat_dict[torch_frame.categorical].dtype == torch.int64
assert tensor_frame.feat_dict[torch_frame.categorical].size() == (1000209,
6)
assert tensor_frame.feat_dict[
torch_frame.multicategorical].dtype == torch.int64
assert tensor_frame.feat_dict[torch_frame.embedding].dtype == torch.float32
assert tensor_frame.col_names_dict == {
torch_frame.categorical:
['age', 'gender', 'movie_id', 'occupation', 'user_id', 'zip'],
torch_frame.multicategorical: ['genres'],
torch_frame.timestamp: ['timestamp'],
torch_frame.embedding: ['title'],
}
assert tensor_frame.y.size() == (1000209, )
assert tensor_frame.y.min() == 1 and tensor_frame.y.max() == 5

col_stats = dataset.col_stats
assert len(col_stats) == 10
assert StatType.COUNT in col_stats['user_id']
assert StatType.MULTI_COUNT in col_stats['genres']
assert StatType.YEAR_RANGE in col_stats['timestamp']
assert StatType.EMB_DIM in col_stats['title']
2 changes: 2 additions & 0 deletions torch_frame/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from .data_frame_benchmark import DataFrameBenchmark
from .data_frame_text_benchmark import DataFrameTextBenchmark
from .mercari import Mercari
from .movielens_1m import Movielens1M
from .amazon_fine_food_reviews import AmazonFineFoodReviews
from .diamond_images import DiamondImages
from .huggingface_dataset import HuggingFaceDatasetDict
Expand All @@ -34,6 +35,7 @@
'DataFrameBenchmark',
'DataFrameTextBenchmark',
'Mercari',
'Movielens1M',
'AmazonFineFoodReviews',
'DiamondImages',
]
Expand Down
90 changes: 90 additions & 0 deletions torch_frame/datasets/movielens_1m.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
from __future__ import annotations

import os.path as osp
import zipfile

import pandas as pd

import torch_frame
from torch_frame.config.text_embedder import TextEmbedderConfig


class Movielens1M(torch_frame.data.Dataset):
r"""The MovieLens 1M rating dataset, assembled by GroupLens Research
from the MovieLens web site, consisting of movies (3,883 nodes) and
users (6,040 nodes) with approximately 1 million ratings between them.
**STATS:**
.. list-table::
:widths: 10 10 10 10 20
:header-rows: 1
* - #Users
- #Items
- #User Field
- #Item Field
- #Samples
* - 6040
- 3952
- 5
- 3
- 1000209
"""

url = 'https://files.grouplens.org/datasets/movielens/ml-1m.zip'

def __init__(
self,
root: str,
col_to_text_embedder_cfg: dict[str, TextEmbedderConfig]
| TextEmbedderConfig | None = None,
):
path = self.download_url(self.url, root)
folder_path = osp.dirname(path)

with zipfile.ZipFile(path, 'r') as zip_ref:
zip_ref.extractall(folder_path)

data_path = osp.join(folder_path, 'ml-1m')
users = pd.read_csv(
osp.join(data_path, 'users.dat'),
header=None,
names=['user_id', 'gender', 'age', 'occupation', 'zip'],
sep='::',
engine='python',
)
movies = pd.read_csv(
osp.join(data_path, 'movies.dat'),
header=None,
names=['movie_id', 'title', 'genres'],
sep='::',
engine='python',
encoding='ISO-8859-1',
)
ratings = pd.read_csv(
osp.join(data_path, 'ratings.dat'),
header=None,
names=['user_id', 'movie_id', 'rating', 'timestamp'],
sep='::',
engine='python',
)

df = pd.merge(pd.merge(ratings, users), movies) \
.sort_values(by='timestamp') \
.reset_index().drop('index', axis=1)

col_to_stype = {
'user_id': torch_frame.categorical,
'gender': torch_frame.categorical,
'age': torch_frame.categorical,
'occupation': torch_frame.categorical,
'zip': torch_frame.categorical,
'movie_id': torch_frame.categorical,
'title': torch_frame.text_embedded,
'genres': torch_frame.multicategorical,
'rating': torch_frame.numerical,
'timestamp': torch_frame.timestamp,
}
super().__init__(df, col_to_stype, target_col='rating', col_to_sep='|',
col_to_text_embedder_cfg=col_to_text_embedder_cfg)

0 comments on commit 4a6bb50

Please sign in to comment.