In [1]:
import os
from typing import Iterable, Union

import numpy as np
import pandas as pd
import scipy.sparse as sps
from dask import array as da
from dask import delayed
from dask_ml.preprocessing import DummyEncoder
from tqdm import tqdm

from rec.constants import DATA_PATH, MSDMetadata
from rec.data_loader import Dataset
from rec.msdutils import _get_song_index

In [2]:
def create_occurrence_matrix(
    dataset, included_users: Union[str, Iterable[str], Iterable[int]] = 'all',
    include_play_counts=False
) -> sps.coo_matrix:
    """
    Given a dataset of (user_id, song_id, num_plays) triplets, creates a 
    occurrence matrix to be used downstream in a latent semantic analysis-based
    recommender system (or otherwise).
    
    ============================================
    Args:
    ============================================
    (rec.data_loader.Dataset) dataset:
    * A dataset implementing the method `iterate_over_visible_data()`.
    
    (Union[str, Iterable[str], Iterable[int]]) included_users:
    * This argument controls which users' data is included in the matrix. This
      is done because the number of data is too high to fit in memory, even
      as a sparse matrix.
    * It has three valid values:
        (1) 'all':
            - Includes all users. Will require very large amounts of memory.
        (2) Iterable[str] of user IDs: 
            - Each user is included if their corresponding user ID is included
              in this list/set/...
        (3) Iterable[int] of integer indices:        
            - Each user is included if their position in which they occur in 
              the full (text file) dataset is included in this list/set/...
            - This can be used if the exact group of users included in this
              matrix does not matter.
    
    (bool) include_play_counts:
    * Whether to return a binary vector (has the user played/not played each song?)
      or a non-negative integer vector (where each entry contains the number of 
      times the user played each song). 
    * Note that the majority of the entries will be zero.
    ============================================
    
    Returns:
    ============================================
    (scipy.sparse.coo_matrix) occurrence_matrix
    * Returns a matrix with shape (num_users, num_unique_songs).
    ============================================
    """
    if isinstance(included_users, str) and included_users == 'all':
        include = lambda ids: True
        
    # assume included_users is a list of user IDs
    elif isinstance(included_users[0], str):
        include = lambda ids: ids[0] in included_users
        
    # assume included_users is a list of integer indices of users
    elif isinstance(included_users[0], np.integer):
        include = lambda ids: ids[1] in included_users
        
    else:
        raise ValueError("Invalid argument 'included_users' (got %s)" % included_users)
    
    
    rows = []
    cols = []
    data = []
    ctr = 0
    for int_id, (user_id, user_data) in enumerate(dataset.iterate_over_visible_data()):
        if not include((user_id, int_id)):
            continue
            
        for (song_id, num_plays) in user_data:
            rows.append(ctr)
            cols.append(_get_song_index(song_id))
            data.append(num_plays if include_play_counts else 1)
            
        ctr += 1
        
    num_rows = len(dataset) if included_users == 'all' else len(included_users)
    num_unique_songs = song_ids.values.shape[0]
        
    return sps.coo_matrix(
        (data, (rows, cols)), shape=(num_rows, num_unique_songs), dtype=int
    )


    
    

In [3]:
def _make_full_occ_matrix(dataset, include_play_counts=False):
    df = dataset.load_dask_dataframe()
    


In [4]:
include_play_counts = False

loader = Dataset(which='valid')
data_df = loader.load_dask_dataframe()

de = DummyEncoder(columns=['song_id'])
de.fit(data_df)
X = de.transform(data_df)
song_cols = X.columns[2:]

X = X.drop(columns=['num_plays'])
aggregated = X.groupby('user_id').agg(
    {col: 'sum' for col in song_cols}
)


In [12]:
aggregated.head()

KeyboardInterrupt: 