In [2]:
import pandas as pd
import numpy as np
import json
import os
from collections import defaultdict 
def preprocess_data():
    # Load datasets
    ratings = pd.read_csv("ratings.dat", sep="::", engine="python", names=["UserID", "MovieID", "Rating", "Timestamp"])
    movies = pd.read_csv("movies.dat", sep="::", engine="python", encoding="ISO-8859-1", names=["MovieID", "Title", "Genres"])

    # Merge and process data
    data = pd.merge(ratings, movies, on="MovieID")
    data["Genres"] = data["Genres"].str.split('|')
    data = data.explode("Genres")

    # Filter top genres to reduce arms
    popular_genres = data["Genres"].value_counts().head(10).index.tolist()
    data = data[data["Genres"].isin(popular_genres)]

    arms = sorted(data["Genres"].unique())
    print(f"Using {len(arms)} arms (genres): {arms}")

    # Create cached reward pools for each genre
    rewards_by_genre = defaultdict(list)
    for _, row in data.iterrows():
        rewards_by_genre[row["Genres"]].append(row["Rating"] / 5)  # normalize to [0, 1]

    # Convert reward lists to numpy arrays for fast access
    sample_pool = {genre: np.array(rewards_by_genre[genre]) for genre in arms}

    # Save arms and rewards to JSON files
    save_path = "../../save/"
    os.makedirs(save_path, exist_ok=True)

    with open(os.path.join(save_path, "arms.json"), "w") as f:
        json.dump(arms, f)

    with open(os.path.join(save_path, "rewards.json"), "w") as f:
        json.dump({genre: sample_pool[genre].tolist() for genre in arms}, f)

    print(f"Arms and rewards saved to {save_path}")

if __name__ == "__main__":
    preprocess_data()

Using 10 arms (genres): ['Action', 'Adventure', "Children's", 'Comedy', 'Crime', 'Drama', 'Horror', 'Romance', 'Sci-Fi', 'Thriller']
Arms and rewards saved to ../../save
