In [None]:
import torch
from torch.utils.data import TensorDataset, random_split
from typing import Union, Tuple

def get_random_subset_tensordataset(
    original_dataset: TensorDataset, 
    subset_size: Union[int, float] = 0.002,
    generator: torch.Generator = None
) -> Tuple[TensorDataset, TensorDataset]:
    """
    Creates a new TensorDataset that is a random subset of an existing one.

    Args:
        original_dataset: The source torch.utils.data.TensorDataset.
        subset_size: The desired size of the random subset. 
                     Can be an integer (number of samples) or a float 
                     (fraction of the original dataset size).
        generator: An optional torch.Generator for reproducible randomness.

    Returns:
        A tuple containing: 
        1. new_subset_dataset (TensorDataset): The randomly sampled subset.
        2. remaining_dataset (TensorDataset): The remaining samples.
    """
    
    total_size = len(original_dataset)
    
    # 1. Calculate the actual size of the subset
    if isinstance(subset_size, float):
        if not (0.0 < subset_size <= 1.0):
            raise ValueError("Fractional subset_size must be between 0.0 and 1.0.")
        subset_len = int(total_size * subset_size)
    elif isinstance(subset_size, int):
        if not (0 < subset_size <= total_size):
            raise ValueError(f"Integer subset_size must be > 0 and <= {total_size}.")
        subset_len = subset_size
    else:
        raise TypeError("subset_size must be an int or a float.")
        
    remaining_len = total_size - subset_len

    if remaining_len < 0:
         # Should be caught by the previous checks, but as a safety
        raise ValueError("Subset size is larger than the original dataset size.")

    # 2. Use random_split to get Subset objects
    # Note: random_split returns Subset objects, not TensorDataset objects.
    subset_obj, remaining_obj = random_split(
        original_dataset, 
        [subset_len, remaining_len], 
        generator=generator
    )

    # 3. Extract data from the Subset object to create a new TensorDataset
    
    # Get the indices of the random subset
    subset_indices = subset_obj.indices
    remaining_indices = remaining_obj.indices

    # The original TensorDataset stores data as a tuple of tensors in .tensors
    original_tensors = original_dataset.tensors

    # Use the indices to select the corresponding slices from the original tensors
    subset_tensors = tuple(tensor[subset_indices] for tensor in original_tensors)
    remaining_tensors = tuple(tensor[remaining_indices] for tensor in original_tensors)

    # 4. Create the new TensorDataset instances
    new_subset_dataset = TensorDataset(*subset_tensors)
    remaining_dataset = TensorDataset(*remaining_tensors)

    return new_subset_dataset, remaining_dataset

test = torch.load("../data/main/7_datasets/test.pt", weights_only=False)
train = torch.load("../data/main/7_datasets/train.pt", weights_only=False)
val = torch.load("../data/main/7_datasets/val.pt", weights_only=False)

test, _ = get_random_subset_tensordataset(test)
train, _ = get_random_subset_tensordataset(train)
val, _ = get_random_subset_tensordataset(val)

torch.save(test, "../experiments/data/241025/test.pt")
torch.save(train, "../experiments/data/241025/train.pt")
torch.save(val, "../experiments/data/241025/val.pt")
