In [None]:
"""
from CoLLM paper

ML-1M
preserve the interactions from the most recent twenty months, using the first 10 months for training, the middle 5 months for validation, and the last 5 months for testing
Train: 33,891
Valid: 10,401
Test: 7,331
User: 839
Item: 3,256


Amazon-Book dataset
preserve interactions from the year 2017 (including about 4 million interactions)
allocating the first 11 months for training, and the remaining two half months for validation and testing, respectively

filtered out users and items with fewer than 20 interactions to ensure data quality for measuring warm-start performance

Train: 727,468
Valid: 25,747
Test: 25,747
User: 22,967
Item: 34,154
"""

In [1]:
import os
from pathlib import Path
from urllib.parse import urljoin
import requests
import pprint
import gzip
import json

import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset
# from torch_geometric.data import Dataset, download_url

from datasets import load_dataset

from src.data.utils import loadFileFromURL
from src.utils.wrapper import tryExcept, timeMeasured


class AmazonDataset(Dataset):
    """
    Dataset class for the Amazon Review Dataset from 2023.
    Overview: https://amazon-reviews-2023.github.io/main.html
    """
    
    def __init__(self, root, datasetConfig, datasetName):
        self.root = root
        self.datasetConfig = datasetConfig
        self.datasetName = datasetName
        
        with open(datasetConfig, "r") as configFile:
            configData = json.load(configFile)
            self.datasetConfig = configData.get(datasetName, {})
            print(f"Dataset Config set as:")
            pprint.pp(self.datasetConfig)
        self.category = self.datasetConfig.get("category")
        urls = self.datasetConfig.get("urls", {})
        self.interactionDataUrl = urls.get("interactionDataUrl", "default_interaction_data_url")
        self.metaDataUrl = urls.get("metaDataUrl", "default_meta_data_url")
        self.reviewDataUrl = urls.get("reviewDataUrl", "default_review_data_url")
        
        self.rawDataDir = Path(self.root) / "raw"
        self.rawDataDir.mkdir(parents=True, exist_ok=True)

        self.downloadInteractionData()
        self.downloadItemDataAsJSON()
        self.unwrapItemData(self.rawMetaDatasetPath)
    
    @tryExcept
    @timeMeasured
    def downloadInteractionData(self):
        for split in ["train", "valid", "test"]:
            datasetFilename = f"{self.category}.{split}.csv.gz"
            datasetPath = self.rawDataDir / "Interactions" / datasetFilename
            datasetPath.parent.mkdir(parents=True, exist_ok=True)
            if not datasetPath.exists():
                datasetUrl = urljoin(self.interactionDataUrl, datasetFilename)
                loadFileFromURL(datasetUrl, datasetPath)
    
    @tryExcept
    @timeMeasured
    def downloadItemDataAsJSON(self):
        metaDataFilename = f"meta_{self.category}.jsonl.gz"
        self.rawMetaDatasetPath = self.rawDataDir / "Items" / metaDataFilename
        self.rawMetaDatasetPath.parent.mkdir(parents=True, exist_ok=True)
        if not self.rawMetaDatasetPath.exists():
            loadFileFromURL(urljoin(self.metaDataUrl, metaDataFilename), self.rawMetaDatasetPath)
        
        reviewDataFilename = f"{self.category}.jsonl.gz"
        self.rawReviewDatasetPath = self.rawDataDir / "Items" / reviewDataFilename
        self.rawReviewDatasetPath.parent.mkdir(parents=True, exist_ok=True)
        if not self.rawReviewDatasetPath.exists():
            loadFileFromURL(urljoin(self.reviewDataUrl, reviewDataFilename), self.rawReviewDatasetPath)
    
    @tryExcept
    @timeMeasured
    def downloadItemDataFromHF(self):
        datasetFilename = f"{self.category}ItemMetadata.csv.gz"
        datasetPath = self.rawDataDir / "Items" / datasetFilename
        datasetPath.parent.mkdir(parents=True, exist_ok=True)
        if not datasetPath.exists():
            itemInformation = load_dataset(
                "McAuley-Lab/Amazon-Reviews-2023",
                f"raw_meta_{self.category}",
                split="full",
                trust_remote_code=True
            )
            dataframe = pd.DataFrame.from_records(itemInformation)
            dataframe.to_csv(datasetPath, index=False, compression="gzip")
            print(f"File downloaded and saved to {datasetPath}")
    
    def checkRequiredFields(self, jsonLine):
        """
        Checks whether the required fields are present based on the config.
        Returns True if all required fields are valid; False otherwise.
        """
        requiredFields = self.datasetConfig.get("required_fields", {})
        for field, isRequired in requiredFields.items():
            if isRequired:
                value = jsonLine.get(field)
                # For description and images, we ensure they are non-empty lists
                if field == "description" or field == "images":
                    if not isinstance(value, list) or not value:
                        return False
                # For other fields, just check if they are truthy (non-null, non-empty)
                elif not value:
                    return False
        return True
    
    @tryExcept
    @timeMeasured
    def unwrapItemData(self, datasetPath):
        rawUnwrappedDataDir = self.rawDataDir / "Items" / "Unwrapped" / self.datasetName
        os.makedirs(rawUnwrappedDataDir, exist_ok=True)
        
        with gzip.open(datasetPath, "rt", encoding="utf-8") as f:
            linesCount, self.dumpedJSONlist = 0, []
            for line in f:
                linesCount += 1
                jsonLine = json.loads(line.strip())
                if self.checkRequiredFields(jsonLine):  # Only save if required fields are valid
                    parent_asin = jsonLine.get("parent_asin")
                    outputFilePath = rawUnwrappedDataDir / f"{parent_asin}.json"
                    with open(outputFilePath, "w", encoding="utf-8") as outputFile:
                        json.dump(jsonLine, outputFile, indent=4)
                    self.dumpedJSONlist.append(parent_asin)
                # if linesCount == 100:
                #     break
        
        print(f"Unwrapped dataset from {datasetPath}")
        print(f"Saved {len(self.dumpedJSONlist)} from a total of {linesCount} lines.")


  from .autonotebook import tqdm as notebook_tqdm


In [None]:
root="data/AmazonReviews"
os.makedirs(root, exist_ok=True)
datasetConfig = "src/data/datasetConfigAmazon.json"
datasetName = "AmazonAllBeautyDataset"

AmazonBooksDataset = AmazonDataset(root, datasetConfig, datasetName)

In [None]:
from src.data.utils import timestampToYear

interactionDataTrainingPath = "data/AmazonReviews/raw/Interactions/All_Beauty.train.csv.gz"
interactionDataTraining = pd.read_csv(interactionDataTrainingPath)

# Possibility to filter year
# interactionDataTraining["year"] = interactionDataTraining["timestamp"].apply(timestampToYear)

# Filter out users < k interactions
k = 10
frequent_users = interactionDataTraining["user_id"].value_counts()
frequent_users_mask = frequent_users[frequent_users > k].index.tolist()
interactionDataTraining = interactionDataTraining[interactionDataTraining["user_id"].isin(frequent_users_mask)]

In [3]:
interactionDataTraining_filtered

Unnamed: 0,user_id,parent_asin,rating,timestamp,history,interaction_count
715,AE3PLZHW6NXWBMZ76TDVFQG2MJFA,B08P2DZB4X,4.0,1623685953622,B07Z4CVTRP B08BB3P4VQ B08DXZ5VXB B086N136NH B0...,12
2035,AE5ESL52LWWBJTSFOAXSFZA3XCGQ,B08S1LWF9V,5.0,1625767684071,B08BZ1RHPS B07D5FBFQ4 B0851QJPZY B08G5YVHQP B0...,18
1988,AE5IMGWRBJA7JQFBQTBK25HDYGVA,B08W4WQMNM,2.0,1621831743566,B00O2FGBJS B085J354CL B085J24382 B07M9D3WYW B0...,11
360,AEAXAJACFMXIAAH4WOHRMXPSZWFA,B08LPGZMQK,5.0,1624289267108,B083BGJ4P9 B08DK74M1P B08F4ZDVZQ B0B2L218H2 B0...,12
1283,AECADZLPUNH3BDNACLFF7PSHN5MQ,B0863FZFPV,4.0,1616639207392,B085NYYLQ8 B085MCMZLX B088PYN4VM B08BS3WDPJ B0...,11
...,...,...,...,...,...,...
2101,AHSV4TYSAX52BIHH7PLZRD44KZHA,B08W8LKLHB,5.0,1625971856890,B08GJJ5RV9 B085SY4WC3 B085TFXLH1 B08CVCLVS2 B0...,11
388,AHT6AM6BNIZUHFJB5V2M6XM72G7Q,B08MC3ZLV4,5.0,1625104151819,B07TLMZL3T B082VKPJV5 B087Z9X39L B08FD2KP9R B0...,13
19,AHV6QCNBJNSGLATP56JAWJ3C4G2A,B0BTJ6SYKB,4.0,1623188037011,B07NPWK167 B07SW7D6ZR B07WNBZQGT B082NKQ4ZT B0...,12
518,AHX2B4DEER2QR3IU3CCNB3CWC6TA,B08VJ7CZW3,5.0,1628199199569,B081ZN3TD5 B015A5DGG4 B088FBNQXW B07ZJX5MNJ B0...,18


In [4]:
# Map users items from 0 - count
mappingsDir = Path("data/AmazonReviews") / "Mappings"

interactionDataTraining['user_id_mapped'] = pd.factorize(interactionDataTraining['user_id'])[0]
user_first_occurrence_df = interactionDataTraining.drop_duplicates(subset='user_id_mapped', keep='first')
user_id_labels, user_id_mapping = pd.factorize(user_first_occurrence_df['user_id'])
user_id_dict = {user: int(label) for user, label in zip(user_first_occurrence_df['user_id'].unique(), user_id_labels)}

userMappingFileName = 'user_id_mapping.json'
userMappingFilePath = mappingsDir / userMappingFileName
with open(userMappingFilePath, 'w') as jsonFile:
    json.dump(user_id_dict, jsonFile, indent=4)

# items
no_nodes_offset = len(user_id_dict)
interactionDataTraining['parent_asin_mapped'] = pd.factorize(interactionDataTraining['parent_asin'])[0] + no_nodes_offset
item_first_occurrence_df = interactionDataTraining.drop_duplicates(subset='parent_asin_mapped', keep='first')
item_id_labels, item_id_mapping = pd.factorize(item_first_occurrence_df['parent_asin'])
item_id_dict = {item: int(label) + no_nodes_offset for item, label in zip(item_first_occurrence_df['parent_asin'].unique(), item_id_labels)}

itemMappingFileName = 'parent_asin_mapping.json'
itemMappingFilePath = mappingsDir / itemMappingFileName
with open(itemMappingFilePath, 'w') as jsonFile:
    json.dump(item_id_dict, jsonFile, indent=4)   


In [5]:
interactionDataTraining

Unnamed: 0,user_id,parent_asin,rating,timestamp,history,interaction_count,user_id_mapped,parent_asin_mapped
7,AHV6QCNBJNSGLATP56JAWJ3C4G2A,B07NPWK167,5.0,1578928089878,,0,0,62
8,AHV6QCNBJNSGLATP56JAWJ3C4G2A,B07SW7D6ZR,4.0,1580742144072,B07NPWK167,1,0,63
9,AHV6QCNBJNSGLATP56JAWJ3C4G2A,B07WNBZQGT,4.0,1581772960365,B07NPWK167 B07SW7D6ZR,2,0,64
10,AHV6QCNBJNSGLATP56JAWJ3C4G2A,B082NKQ4ZT,4.0,1583932042329,B07NPWK167 B07SW7D6ZR B07WNBZQGT,3,0,65
11,AHV6QCNBJNSGLATP56JAWJ3C4G2A,B083TLNBJJ,4.0,1585261263214,B07NPWK167 B07SW7D6ZR B07WNBZQGT B082NKQ4ZT,4,0,66
...,...,...,...,...,...,...,...,...
2218,AESI2BA4YODTHOSRFLJCSTAM6XDQ,B08GC5GSNG,2.0,1603681904184,B07ZQRX7FX B07YL4485K B087J3H22J B08BF4BKKM B0...,8,61,339
2219,AESI2BA4YODTHOSRFLJCSTAM6XDQ,B08DRBZNZJ,5.0,1604872578751,B07ZQRX7FX B07YL4485K B087J3H22J B08BF4BKKM B0...,9,61,268
2220,AESI2BA4YODTHOSRFLJCSTAM6XDQ,B08DK74M1P,3.0,1605406046699,B07ZQRX7FX B07YL4485K B087J3H22J B08BF4BKKM B0...,10,61,220
2221,AESI2BA4YODTHOSRFLJCSTAM6XDQ,B08CVTNQP1,5.0,1609209532186,B07ZQRX7FX B07YL4485K B087J3H22J B08BF4BKKM B0...,11,61,312
