# Analyze train/validation/test set
Run the cells to analyze the age distributions and number of interactions generated by the split sets.


In [1]:
import pandas as pd
import os
from dotenv import load_dotenv
from pathlib import Path
env_path = Path('../..') / 'config.env'
load_dotenv(dotenv_path=env_path)
dataset_dir = os.getenv("dataset_directory")

In [2]:
dataset = 'ml' # ml, mlhd, or bx
filtered = True 

In [3]:
if dataset == 'ml':
    data_dir = dataset_dir + f'/processed/ml_rec{"_filtered" if filtered else ""}'
elif dataset == 'mlhd':
    data_dir = dataset_dir + f'/processed/mlhd_rec{"_filtered" if filtered else ""}'
elif dataset == 'bx':
    data_dir = dataset_dir + f'/processed/bx_rec{"_filtered" if filtered else ""}'

In [4]:
train_path = data_dir + f'/train.tsv'
validation_path = data_dir + f'/validation.tsv'
test_path = data_dir + f'/test.tsv'
user_info_path = data_dir + f'/user_info.tsv'

column_names = ['user_id', 'item_id', 'count'] if dataset == 'bx' else ['user_id', 'item_id', 'count', 'timestamp']

train = pd.read_csv(train_path, sep='\t', header=None, names=column_names)
validation = pd.read_csv(validation_path, sep='\t', header=None, names=column_names)
test = pd.read_csv(test_path, sep='\t', header=None, names=column_names)

users = pd.read_csv(user_info_path, sep='\t')


train = pd.merge(train, users, on='user_id', how='inner')
validation = pd.merge(validation, users, on='user_id', how='inner')
test = pd.merge(test, users, on='user_id', how='inner')

In [5]:
len_interactions = len(train) + len(validation) + len(test)

print(f'Train set: {len(train)} rows; {len(train) / len_interactions * 100:.2f}% of total interactions')
print(f'Validation set: {len(validation)} rows; {len(validation) / len_interactions * 100:.2f}% of total interactions')
print(f'Test set: {len(test)} rows; {len(test) / len_interactions * 100:.2f}% of total interactions')

print(f'Train set: {len(train["user_id"].unique())} unique users')
print(f'Validation set: {len(validation["user_id"].unique())} unique users')
print(f'Test set: {len(test["user_id"].unique())} unique users')

train_interactions_per_user = train.groupby('user_id').size()
print(f'Train set: {train_interactions_per_user.mean()} average interactions per user')

validation_interactions_per_user = validation.groupby('user_id').size()
print(f'Validation set: {validation_interactions_per_user.mean()} average interactions per user')

test_interactions_per_user = test.groupby('user_id').size()
print(f'Test set: {test_interactions_per_user.mean()} average interactions per user')

Train set: 340573 rows; 59.59% of total interactions
Validation set: 111917 rows; 19.58% of total interactions
Test set: 119041 rows; 20.83% of total interactions
Train set: 5949 unique users
Validation set: 5949 unique users
Test set: 5949 unique users
Train set: 57.24878130778282 average interactions per user
Validation set: 18.812741637249957 average interactions per user
Test set: 20.01025382417213 average interactions per user


In [6]:
num_items = len(train['item_id'].unique())
print(f'Number of unique items in train set: {num_items}')

Number of unique items in train set: 2810


In [7]:
age_identifier = 'age_at_interaction' if dataset == 'mlhd' else 'age'
train_interactions_per_age_group = train.groupby(age_identifier)

validation_interactions_per_age_group = validation.groupby(age_identifier)

test_interactions_per_age_group = test.groupby(age_identifier)

In [8]:
empty_train_profiles = users[users['user_id'].isin(train['user_id']) == False]
empty_validation_profiles = users[users['user_id'].isin(validation['user_id']) == False]
empty_test_profiles = users[users['user_id'].isin(test['user_id']) == False]

print(f'Train set: {len(empty_train_profiles)} users with empty profiles')
print(f'Validation set: {len(empty_validation_profiles)} users with empty profiles')
print(f'Test set: {len(empty_test_profiles)} users with empty profiles')

small_user_profiles = train_interactions_per_user[train_interactions_per_user < 10]
print(f'Train set: {len(small_user_profiles)} users with less than 10 interactions')
small_user_profiles = validation_interactions_per_user[validation_interactions_per_user < 10]
print(f'Validation set: {len(small_user_profiles)} users with less than 10 interactions')
small_user_profiles = test_interactions_per_user[test_interactions_per_user < 10]
print(f'Test set: {len(small_user_profiles)} users with less than 10 interactions')

Train set: 2 users with empty profiles
Validation set: 2 users with empty profiles
Test set: 2 users with empty profiles
Train set: 474 users with less than 10 interactions
Validation set: 2578 users with less than 10 interactions
Test set: 2317 users with less than 10 interactions


In [9]:
print("Train set")
for age, group in train_interactions_per_age_group:
    print(f'Age: {age}')
    print(f'Number of user profiles: {len(group["user_id"].unique())}')
    print(f'Average items in user profile: {group.groupby("user_id").size().mean()}')
    small_user_profiles = group.groupby("user_id").size()[group.groupby("user_id").size() < 10]
    print(f'Users with less than 10 interactions: {len(small_user_profiles)}')
print()
print()

print("Validation set")
for age, group in validation_interactions_per_age_group:
    print(f'Age: {age}')
    print(f'Number of user profiles: {len(group["user_id"].unique())}')
    print(f'Average items in user profile: {group.groupby("user_id").size().mean()}')
    small_user_profiles = group.groupby("user_id").size()[group.groupby("user_id").size() < 10]
    print(f'Users with less than 10 interactions: {len(small_user_profiles)}')
print()
print()   

print("Test set")
for age, group in test_interactions_per_age_group:
    print(f'Age: {age}')
    print(f'Number of user profiles: {len(group["user_id"].unique())}')
    print(f'Average items in user profile: {group.groupby("user_id").size().mean()}')
    small_user_profiles = group.groupby("user_id").size()[group.groupby("user_id").size() < 10]
    print(f'Users with less than 10 listening events: {len(small_user_profiles)}')
print()
print()

Train set
Age: 1
Number of user profiles: 218
Average items in user profile: 42.11467889908257
Users with less than 10 interactions: 21
Age: 18
Number of user profiles: 1076
Average items in user profile: 55.43494423791822
Users with less than 10 interactions: 88
Age: 25
Number of user profiles: 2076
Average items in user profile: 63.552023121387286
Users with less than 10 interactions: 160
Age: 35
Number of user profiles: 1175
Average items in user profile: 58.84
Users with less than 10 interactions: 88
Age: 45
Number of user profiles: 540
Average items in user profile: 54.17777777777778
Users with less than 10 interactions: 33
Age: 50
Number of user profiles: 488
Average items in user profile: 54.592213114754095
Users with less than 10 interactions: 38
Age: 56
Number of user profiles: 376
Average items in user profile: 39.297872340425535
Users with less than 10 interactions: 46


Validation set
Age: 1
Number of user profiles: 218
Average items in user profile: 13.761467889908257
User

In [10]:
sparsity = 1 - (len_interactions / (len(users) * len(users)))
print(sparsity)

0.9838616217171088
