In [2]:
import pandas as pd
import os
import shutil

In [42]:
def swire_scenes_for(split: str) -> tuple[pd.DataFrame, pd.DataFrame]:
    df = pd.read_csv("../data/swire_master.csv", dtype={"upc": str})
    boxes = pd.read_csv(f"../data/price-graphs-ii/{split}/raw/product_boxes.csv", dtype={'ml_label_name': str}).rename(columns={"ml_label_name": "upc"})
    joined = boxes.merge(df, on='upc', how='left')
    scenes_with_nonswire_product = joined.groupby("attributionset_id").filter(lambda group: group["brand_code"].isna().any()).attributionset_id.unique()
    swire_scenes = joined[~joined.attributionset_id.isin(scenes_with_nonswire_product)]
    prices = pd.read_csv(f"../data/price-graphs-ii/{split}/raw/price_boxes.csv")
    prices = prices[~prices.attributionset_id.isin(scenes_with_nonswire_product)]
    print(f"There are {swire_scenes.attributionset_id.unique().shape[0]} unique scenes with Swire product only in '{split}'")
    return swire_scenes.drop_duplicates(), prices

In [43]:
train, train_prices = swire_scenes_for("train")
val, val_prices = swire_scenes_for("val")
test, test_prices =swire_scenes_for("test")

There are 715 unique scenes with Swire product only in 'train'
There are 85 unique scenes with Swire product only in 'val'
There are 76 unique scenes with Swire product only in 'test'


In [47]:
def save_swire_split(product_df: pd.DataFrame, price_df: pd.DataFrame, split: str):
    source = "../data/price-graphs-ii"
    target = "../data/swire"
    for i, row in product_df.iterrows():
        if not os.path.exists(os.path.join(target, row.local_path)):
            shutil.copy(os.path.join(source, row.local_path), os.path.join(target, row.local_path))
    product_df.to_csv(os.path.join(target, split, "raw", "product_boxes.csv"), index=False)
    price_df.to_csv(os.path.join(target, split, "raw", "price_boxes.csv"), index=False)
    shutil.copy(os.path.join(source, split, "raw", "graph_components.pt"), os.path.join(target, split, "raw", "graph_components.pt"))

In [49]:
save_swire_split(train, train_prices, "train")
save_swire_split(val, val_prices, "val")
save_swire_split(test, test_prices, "test")


In [40]:
train.shape

(38589, 15)

In [3]:
pd.read_csv("../data/swire/train/raw/product_boxes.csv").attributionset_id.unique().shape

(715,)