In [1]:
# Import required libraries
import os
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader
import torch
from tqdm.notebook import tqdm

In [2]:
# Constants
SELECTED_RELATIONSHIP_TYPES = ['ss', 'bb', 'ms', 'fs', 'fd', 'md', 'sibs']
TRAIN_SIZE = 0.7
VAL_SIZE = 0.15
TEST_SIZE = 0.15
RANDOM_SEED = 42
BATCH_SIZE = 128

# Paths
DATA_ROOT = '/mimer/NOBACKUP/groups/naiss2023-22-1358/samir_kinship_data/processed/fiw/train'
TRIPLETS_CSV = os.path.join(DATA_ROOT, 'filtered_triplets_with_labels.csv')
OUTPUT_DIR = os.path.join(DATA_ROOT, 'splits')
MODEL_DIR = os.path.join(DATA_ROOT, 'models')

# Create necessary directories
os.makedirs(OUTPUT_DIR, exist_ok=True)
os.makedirs(MODEL_DIR, exist_ok=True)



### 1. Load and Analyze Data

In [3]:
# Load the filtered triplets
df = pd.read_csv(TRIPLETS_CSV)
print("Total triplets before filtering:", len(df))
print("\nInitial relationship distribution:")
print(df['ptype'].value_counts())

# Filter out grandparent relationships
df = df[df['ptype'].isin(SELECTED_RELATIONSHIP_TYPES)]
print("\nTotal triplets after filtering:", len(df))
print("\nFinal relationship distribution:")
print(df['ptype'].value_counts())

print("\nSample triplet:")
print(df.iloc[0])

Total triplets before filtering: 189550

Initial relationship distribution:
ptype
ss      34230
bb      34230
ms      30874
fs      25801
fd      23708
md      23321
sibs    12131
gfgs     1570
gmgs     1368
gfgd     1277
gmgd     1040
Name: count, dtype: int64

Total triplets after filtering: 184295

Final relationship distribution:
ptype
ss      34230
bb      34230
ms      30874
fs      25801
fd      23708
md      23321
sibs    12131
Name: count, dtype: int64

Sample triplet:
Triplet_ID                                                    1
Anchor        data/processed/fiw/train/train-faces/F0001/MID...
Positive      data/processed/fiw/train/train-faces/F0001/MID...
Negative      data/processed/fiw/train/train-faces/F0361/MID...
ptype                                                        fs
Name: 0, dtype: object


## 2. Create Stratified Splits

In [4]:
def create_stratified_splits(df, train_size=0.7, val_size=0.15, random_state=42):
    """Create train/val/test splits while maintaining relationship distribution"""
    
    # First split: separate test set
    train_val_df, test_df = train_test_split(
        df,
        test_size=TEST_SIZE,
        stratify=df['ptype'],
        random_state=random_state
    )
    
    # Second split: separate train and validation sets
    relative_val_size = val_size / (train_size + val_size)
    train_df, val_df = train_test_split(
        train_val_df,
        test_size=relative_val_size,
        stratify=train_val_df['ptype'],
        random_state=random_state
    )
    
    return train_df, val_df, test_df

# Create splits
train_df, val_df, test_df = create_stratified_splits(
    df, 
    train_size=TRAIN_SIZE, 
    val_size=VAL_SIZE,
    random_state=RANDOM_SEED
)


## 3. Save Splits and Print Statistics

In [5]:
# Save splits
train_df.to_csv(os.path.join(OUTPUT_DIR, 'train_triplets.csv'), index=False)
val_df.to_csv(os.path.join(OUTPUT_DIR, 'val_triplets.csv'), index=False)
test_df.to_csv(os.path.join(OUTPUT_DIR, 'test_triplets.csv'), index=False)

# Print statistics
print("Dataset Split Sizes:")
print(f"Train set: {len(train_df)} triplets ({len(train_df)/len(df):.1%})")
print(f"Validation set: {len(val_df)} triplets ({len(val_df)/len(df):.1%})")
print(f"Test set: {len(test_df)} triplets ({len(test_df)/len(df):.1%})")

print("\nRelationship Distribution in Each Split:")
print("\nTrain set:")
print(train_df['ptype'].value_counts())
print("\nValidation set:")
print(val_df['ptype'].value_counts())
print("\nTest set:")
print(test_df['ptype'].value_counts())

Dataset Split Sizes:
Train set: 129005 triplets (70.0%)
Validation set: 27645 triplets (15.0%)
Test set: 27645 triplets (15.0%)

Relationship Distribution in Each Split:

Train set:
ptype
bb      23960
ss      23960
ms      21612
fs      18061
fd      16596
md      16325
sibs     8491
Name: count, dtype: int64

Validation set:
ptype
bb      5135
ss      5135
ms      4631
fs      3870
fd      3556
md      3498
sibs    1820
Name: count, dtype: int64

Test set:
ptype
bb      5135
ss      5135
ms      4631
fs      3870
fd      3556
md      3498
sibs    1820
Name: count, dtype: int64


## 4. Verify Data Integrity

In [6]:
# Check for overlapping triplets between splits
train_triplets = set(train_df['Triplet_ID'])
val_triplets = set(val_df['Triplet_ID'])
test_triplets = set(test_df['Triplet_ID'])

print("Checking for overlapping triplets...")
print(f"Train-Val overlap: {len(train_triplets & val_triplets)}")
print(f"Train-Test overlap: {len(train_triplets & test_triplets)}")
print(f"Val-Test overlap: {len(val_triplets & test_triplets)}")

# Check unique images in each split
def get_unique_images(df):
    images = set()
    for col in ['Anchor', 'Positive', 'Negative']:
        images.update(df[col].unique())
    return images

train_images = get_unique_images(train_df)
val_images = get_unique_images(val_df)
test_images = get_unique_images(test_df)

print("\nUnique images in each split:")
print(f"Train: {len(train_images)}")
print(f"Validation: {len(val_images)}")
print(f"Test: {len(test_images)}")

Checking for overlapping triplets...
Train-Val overlap: 0
Train-Test overlap: 0
Val-Test overlap: 0

Unique images in each split:
Train: 14463
Validation: 12657
Test: 12655


## 5. Create DataLoader Test

In [7]:
class KinshipDataset(Dataset):
    def __init__(self, df):
        self.df = df
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        return {
            'triplet_id': row['Triplet_ID'],
            'anchor_path': row['Anchor'],
            'positive_path': row['Positive'],
            'negative_path': row['Negative'],
            'relationship': SELECTED_RELATIONSHIP_TYPES.index(row['ptype'])
        }

# Test dataloader creation
train_dataset = KinshipDataset(train_df)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)

# Test a batch
batch = next(iter(train_loader))
print("\nSample batch contents:")
for k, v in batch.items():
    if isinstance(v, torch.Tensor):
        print(f"{k}: shape={v.shape}, dtype={v.dtype}")
    else:
        print(f"{k}: type={type(v)}, length={len(v)}")


Sample batch contents:
triplet_id: shape=torch.Size([128]), dtype=torch.int64
anchor_path: type=<class 'list'>, length=128
positive_path: type=<class 'list'>, length=128
negative_path: type=<class 'list'>, length=128
relationship: shape=torch.Size([128]), dtype=torch.int64
