In [17]:
import os
import random
from pathlib import Path
from tqdm import tqdm

In [19]:
INPUT_DIR = Path('Processed_Data') 

OUTPUT_DIR = Path('Test_Train_Data')
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

In [21]:
TEST_RATIO = 0.2
SEED = 42

In [23]:
def process_and_split(file_path, test_ratio=0.2, seed=42):
    """
    Reads a 'User Item1 Item2...' file.
    Splits every user's items into Train/Test sets.
    Returns two dictionaries: train_data, test_data.
    """
    random.seed(seed)
    
    train_lines = []
    test_lines = []
    
    stats = {
        'users': 0,
        'train_interactions': 0,
        'test_interactions': 0
    }

    with open(file_path, 'r') as f:
        for line in f:
            parts = line.strip().split()
            if len(parts) < 2: continue
            
            user_id = parts[0]
            items = parts[1:]
            
            random.shuffle(items)
            
            # 2. Determine Split Point
            n_total = len(items)
            n_test = int(n_total * test_ratio)

            if n_test < 1 and n_total > 1:
                n_test = 1

            test_items = items[:n_test]
            train_items = items[n_test:]

            if train_items:
                train_lines.append(f"{user_id} {' '.join(train_items)}\n")
                stats['train_interactions'] += len(train_items)
                
            if test_items:
                test_lines.append(f"{user_id} {' '.join(test_items)}\n")
                stats['test_interactions'] += len(test_items)
                
            stats['users'] += 1
            
    return train_lines, test_lines, stats

In [25]:
input_files = sorted(list(INPUT_DIR.glob("data_k*.txt")))

for file_path in input_files:
    filename = file_path.name
    print(f"Processing {filename}...")

    train_data, test_data, stats = process_and_split(file_path, TEST_RATIO, SEED)

    base_name = file_path.stem # "train_k5"
    train_out = OUTPUT_DIR / f"{base_name}_train.txt"
    test_out = OUTPUT_DIR / f"{base_name}_test.txt"

    with open(train_out, 'w') as f:
        f.writelines(train_data)
        
    with open(test_out, 'w') as f:
        f.writelines(test_data)

    total_interactions = stats['train_interactions'] + stats['test_interactions']
    test_pct = (stats['test_interactions'] / total_interactions) * 100
    
    print(f"   Users: {stats['users']:,}")
    print(f"   Train Interactions: {stats['train_interactions']:,}")
    print(f"   Test Interactions:  {stats['test_interactions']:,} ({test_pct:.2f}%)")


Processing data_k2.txt...
   Users: 52,643
   Train Interactions: 1,924,114
   Test Interactions:  455,835 (19.15%)
Processing data_k3.txt...
   Users: 52,643
   Train Interactions: 1,922,908
   Test Interactions:  455,545 (19.15%)
Processing data_k5.txt...
   Users: 52,642
   Train Interactions: 1,918,235
   Test Interactions:  454,380 (19.15%)
