In [None]:
%cd /workspace

In [None]:
import pandas as pd
from pathlib import Path
from src.logger import Logger

from torch.nn.utils.rnn import pad_sequence
import torch

In [None]:
class Config:
    debug = True
    seed = 8823

In [None]:
HOME = Path("/workspace")
RESOURCES = HOME / "resources"
INPUT = RESOURCES / "input"

config = Config()
logger = Logger(__name__)

In [None]:
def read_parquet_from_csv(filepath:Path, dirpath:Path) -> pd.DataFrame: 
    name = filepath.name.split(".")[0]
    parquet_filepath = dirpath / f"{name}.parquet"
    if parquet_filepath.is_file():
        logger.info("load parquet file")
        return pd.read_parquet(parquet_filepath)
    
    logger.info("load csv & convert to parquet")
    df = pd.read_csv(filepath)
    df.to_parquet(parquet_filepath)
    return df
    

In [None]:
task1_df = read_parquet_from_csv(filepath=INPUT / "task1_dataset.csv.gz", dirpath=INPUT)
task2_df = read_parquet_from_csv(filepath=INPUT / "task2_dataset.csv.gz", dirpath=INPUT)
poi_df = read_parquet_from_csv(filepath=INPUT / "cell_POIcat.csv.gz", dirpath=INPUT)

if config.debug:
    user_ids = task1_df["uid"].sample(100, random_state=config.seed)
    task1_df = task1_df[task1_df["uid"].isin(user_ids)]

In [None]:

def make_sequences(df:pd.DataFrame, group_key:str, group_values:list[str]):
    grouped = df.groupby(group_key)
    
    def _agg(group_value):
        return [torch.tensor(group[group_value].to_numpy()) for _, group in grouped]
    
    sequences = {group_value:_agg(group_value) for group_value in group_values}
    return sequences
    
sequences = make_sequences(df=task1_df, group_key="uid", group_values=["uid", "d", "t"])


In [None]:
class TrainDataset(torch.utils.data.Dataset):
    def __init__(self, sequences):
        self.uid = sequences["uid"]
        self.d = sequences["d"]
        self.t = sequences["t"]

    def __len__(self):
        return len(self.uid)

    def __getitem__(self, index: int) -> dict[str : torch.Tensor]:
        features = self.uid[index]
        return {"features":features}


def collate_fn(batch):
    # batch is a list of sequences
    sequences = [item["features"] for item in batch]
    lengths = [len(seq) for seq in sequences]
    sequences_padded = pad_sequence([torch.as_tensor(seq) for seq in sequences], batch_first=True)
    return {"features":sequences_padded, "lengths":lengths}

ds = TrainDataset(sequences=sequences)
dl = torch.utils.data.DataLoader(ds, batch_size=5, collate_fn=collate_fn)

for x in dl:
    print(x["lengths"], [len(x_) for x_ in x["features"]])