In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import cv2
import glob
from tqdm.auto import tqdm
from sklearn.model_selection import GroupKFold


In [None]:
train_df = pd.read_csv("../input/shopee-product-matching/train.csv")
test_df = pd.read_csv("../input/shopee-product-matching/test.csv")
sample_submission = pd.read_csv("../input/shopee-product-matching/sample_submission.csv")

# About this notebook
In this notebook, I am going to share the way to split into train and test sets.  
My idea is 
* same `label_group` should be in the same fold
* same image should be in the same fold

While doing EDA, I noticed 
* some images that are the same `image_phash` but different `image` path
* some images that are the same `image` but different `label_group`

So splitting only by `label_group` could lead a data leakage since there could be the same image in different folds.

For example, `image_phash = d0c0ea37bd9acce0` has 20 rows but the `label_group` is `4198148727` or `2403374241` although they are the same images. 


In [None]:
train_df.loc[train_df["image_phash"] == "d0c0ea37bd9acce0"]

In [None]:
sample_image_path = train_df.loc[train_df["image_phash"] == "d0c0ea37bd9acce0", "image"].values[0]
sample_image = cv2.imread("../input/shopee-product-matching/train_images/" + sample_image_path)[:, :, ::-1]
plt.imshow(sample_image)
print("The image of image_phash == d0c0ea37bd9acce0")

# How to split

My approach is 
1. creating the dictionary of {"image_phash": label_group}
2. finding duplicates of label_group in different "image_phash"s 
3. if there are duplicates, the "image_phash"s are regarded as the same group


In [None]:
n_folds = 5

In [None]:
phash_group = train_df.groupby("image_phash")

In [None]:
phash_dict = {}
for phash in phash_group.groups.keys():
    label_group = phash_group.get_group(phash)["label_group"].tolist()
    phash_dict[phash] = label_group
    
# phash_dict

In [None]:
phash_group_dict = {}
group_id = 0
for phash1, label_group1 in tqdm(phash_dict.items()):
    ismatch = False
    for phash2, label_group2 in phash_dict.items():
        if phash1 == phash2:
            continue
        if len(set(label_group1) & set(label_group2)) > 0:
            if phash1 in phash_group_dict:  # already decided the gruop
                phash_group_dict[phash2] = phash_group_dict[phash1]  # same gruop
            else:
                phash_group_dict[phash1] = group_id
                phash_group_dict[phash2] = group_id
                group_id += 1
            ismatch = True
    if not ismatch:
        phash_group_dict[phash1] = group_id
        group_id += 1

In [None]:
# phash_group_dict

In [None]:
train_df["group"] = -1
for phash, group in tqdm(phash_group_dict.items()):
    train_df.loc[train_df["image_phash"] == phash, "group"] = group
train_df

In [None]:


train_df["fold"] = -1
gkf = GroupKFold(n_splits=n_folds)
for fold, (train_idx, val_idx) in enumerate(gkf.split(train_df, None, train_df["group"])):
    train_df.loc[val_idx, "fold"] = fold
train_df

# Save the CSV

In [None]:
train_df.to_csv("train_folds.csv", index=False)

# Checking the number of data in each fold

In [None]:
for fold in range(n_folds):
    df = train_df.loc[train_df["fold"]==fold]
    print(f"fold{fold} has {df.shape[0]} data")
    

# Checking the duplicates of label_group

In [None]:
label_group_sets = []
for fold in range(n_folds):
    df = train_df.loc[train_df["fold"]==fold]
    lg = set(df["label_group"].tolist())
    label_group_sets.append(lg)
    
for fold, lgs1 in enumerate(label_group_sets):
    for fold_, lgs2 in enumerate(label_group_sets):
        if (lgs1 is lgs2) or fold > fold_:
            continue
        print(f"The number of duplicates of label_group in fold {fold} and fold {fold_} are {len(lgs1 & lgs2)}")

No duplicaes!

# Checking the duplicates of image_phash

In [None]:
phash_sets = []
for fold in range(n_folds):
    df = train_df.loc[train_df["fold"]==fold]
    ph = set(df["image_phash"].tolist())
    phash_sets.append(ph)
    
for fold, ph1 in enumerate(phash_sets):
    for fold_, ph2 in enumerate(phash_sets):
        if (ph1 is ph2) or fold > fold_:
            continue
        print(f"The number of duplicates of image_phash in fold {fold} and fold {fold_} are {len(ph1 & ph2)}")

No duplicates!