# Analyze train/validation/test set
Run the cells to analyze the age distributions and number of interactions generated by the split sets.


In [19]:
import pandas as pd
import os
from dotenv import load_dotenv
from pathlib import Path
env_path = Path('../..') / 'config.env'
load_dotenv(dotenv_path=env_path)
dataset_dir = os.getenv("dataset_directory")

In [20]:
dataset = 'mlhd' # ml, mlhd, or bx
filtered = True 

In [21]:
if dataset == 'ml':
    data_dir = dataset_dir + f'/processed/ml_rec{"_filtered" if filtered else ""}'
elif dataset == 'mlhd':
    data_dir = dataset_dir + f'/processed/mlhd_rec{"_filtered" if filtered else ""}'
elif dataset == 'bx':
    data_dir = dataset_dir + f'/processed/bx_rec{"_filtered" if filtered else ""}'

In [None]:
train_path = data_dir + f'/train.tsv'
validation_path = data_dir + f'/validation.tsv'
test_path = data_dir + f'/test.tsv'
user_info_path = data_dir + f'/user_info.tsv'

column_names = ['user_id', 'item_id', 'count'] if dataset == 'bx' else ['user_id', 'item_id', 'count', 'timestamp']

train = pd.read_csv(train_path, sep='\t', header=None, names=column_names)
validation = pd.read_csv(validation_path, sep='\t', header=None, names=column_names)
test = pd.read_csv(test_path, sep='\t', header=None, names=column_names)

users = pd.read_csv(user_info_path, sep='\t')


train = pd.merge(train, users, on='user_id', how='inner')
validation = pd.merge(validation, users, on='user_id', how='inner')
test = pd.merge(test, users, on='user_id', how='inner')

In [25]:
len_interactions = len(train) + len(validation) + len(test)

print(f'Train set: {len(train)} rows; {len(train) / len_interactions * 100:.2f}% of total interactions')
print(f'Validation set: {len(validation)} rows; {len(validation) / len_interactions * 100:.2f}% of total interactions')
print(f'Test set: {len(test)} rows; {len(test) / len_interactions * 100:.2f}% of total interactions')

print(f'Train set: {len(train["user_id"].unique())} unique users')
print(f'Validation set: {len(validation["user_id"].unique())} unique users')
print(f'Test set: {len(test["user_id"].unique())} unique users')

train_interactions_per_user = train.groupby('user_id').size()
print(f'Train set: {train_interactions_per_user.mean()} average interactions per user')

validation_interactions_per_user = validation.groupby('user_id').size()
print(f'Validation set: {validation_interactions_per_user.mean()} average interactions per user')

test_interactions_per_user = test.groupby('user_id').size()
print(f'Test set: {test_interactions_per_user.mean()} average interactions per user')

Train set: 11743489 rows; 78.79% of total interactions
Validation set: 1988948 rows; 13.34% of total interactions
Test set: 1171649 rows; 7.86% of total interactions
Train set: 32572 unique users
Validation set: 32572 unique users
Test set: 32572 unique users
Train set: 360.53938965983053 average interactions per user
Validation set: 61.0631216996193 average interactions per user
Test set: 35.97104875353064 average interactions per user


In [27]:
num_items = len(train['track_id'].unique())
print(f'Number of unique items in train set: {num_items}')

Number of unique items in train set: 207603


In [None]:
age_identifier = 'age_at_interaction' if dataset == 'mlhd' else 'age'
train_interactions_per_age_group = train.groupby(age_identifier)

validation_interactions_per_age_group = validation.groupby(age_identifier)

test_interactions_per_age_group = test.groupby(age_identifier)

In [29]:
empty_train_profiles = users[users['user_id'].isin(train['user_id']) == False]
empty_validation_profiles = users[users['user_id'].isin(validation['user_id']) == False]
empty_test_profiles = users[users['user_id'].isin(test['user_id']) == False]

print(f'Train set: {len(empty_train_profiles)} users with empty profiles')
print(f'Validation set: {len(empty_validation_profiles)} users with empty profiles')
print(f'Test set: {len(empty_test_profiles)} users with empty profiles')

small_user_profiles = train_interactions_per_user[train_interactions_per_user < 10]
print(f'Train set: {len(small_user_profiles)} users with less than 10 interactions')
small_user_profiles = validation_interactions_per_user[validation_interactions_per_user < 10]
print(f'Validation set: {len(small_user_profiles)} users with less than 10 interactions')
small_user_profiles = test_interactions_per_user[test_interactions_per_user < 10]
print(f'Test set: {len(small_user_profiles)} users with less than 10 interactions')

Train set: 0 users with empty profiles
Validation set: 0 users with empty profiles
Test set: 0 users with empty profiles
Train set: 358 users with less than 10 interactions
Validation set: 3455 users with less than 10 interactions
Test set: 7148 users with less than 10 interactions


In [30]:
print("Train set")
for age, group in train_interactions_per_age_group:
    print(f'Age: {age}')
    print(f'Number of user profiles: {len(group["user_id"].unique())}')
    print(f'Average items in user profile: {group.groupby("user_id").size().mean()}')
    small_user_profiles = group.groupby("user_id").size()[group.groupby("user_id").size() < 10]
    print(f'Users with less than 10 interactions: {len(small_user_profiles)}')
print()
print()

print("Validation set")
for age, group in validation_interactions_per_age_group:
    print(f'Age: {age}')
    print(f'Number of user profiles: {len(group["user_id"].unique())}')
    print(f'Average items in user profile: {group.groupby("user_id").size().mean()}')
    small_user_profiles = group.groupby("user_id").size()[group.groupby("user_id").size() < 10]
    print(f'Users with less than 10 interactions: {len(small_user_profiles)}')
print()
print()   

print("Test set")
for age, group in test_interactions_per_age_group:
    print(f'Age: {age}')
    print(f'Number of user profiles: {len(group["user_id"].unique())}')
    print(f'Average items in user profile: {group.groupby("user_id").size().mean()}')
    small_user_profiles = group.groupby("user_id").size()[group.groupby("user_id").size() < 10]
    print(f'Users with less than 10 listening events: {len(small_user_profiles)}')
print()
print()

Train set
Age: 12.0
Number of user profiles: 134
Average items in user profile: 291.3134328358209
Users with less than 10 interactions: 2
Age: 13.0
Number of user profiles: 456
Average items in user profile: 301.140350877193
Users with less than 10 interactions: 8
Age: 14.0
Number of user profiles: 995
Average items in user profile: 341.3819095477387
Users with less than 10 interactions: 12
Age: 15.0
Number of user profiles: 1807
Average items in user profile: 377.9894853348091
Users with less than 10 interactions: 21
Age: 16.0
Number of user profiles: 2528
Average items in user profile: 376.38805379746833
Users with less than 10 interactions: 22
Age: 17.0
Number of user profiles: 3126
Average items in user profile: 384.5403071017274
Users with less than 10 interactions: 28
Age: 18.0
Number of user profiles: 3382
Average items in user profile: 373.5198107628622
Users with less than 10 interactions: 39
Age: 19.0
Number of user profiles: 3253
Average items in user profile: 372.5539501998

In [18]:
sparsity = 1 - (len_interactions / (len(users) * len(users)))
print(sparsity)

0.9838616499542093
