In [10]:
import numpy as np

from tinytorch.core.dataloader import Dataset, TensorDataset
from tinytorch.core.tensor import Tensor


## Unit Test - Dataloader

In [11]:
def test_unit_dataset():
    print("ðŸ”¬ Unit Test: Dataset Abstract Base Class...")

    try:
        dataset = Dataset()
        assert False, 'SHould not be able to instantiate abstract Dataset'
    except TypeError:
        print("âœ… Dataset is properly abstract")
    print("âœ… Dataset interface works correctly!")

    class TestDataset(Dataset):
        def __init__(self, size):
            self.size = size

        def __len__(self) -> int:
            return self.size

        def __getitem__(self, idx: int):
            return f'item_{idx}'

    dataset = TestDataset(10)
    assert len(dataset) == 10
    assert dataset[0] == 'item_0'
    assert dataset[9] == 'item_9'

if __name__=='__main__':
    test_unit_dataset()

ðŸ”¬ Unit Test: Dataset Abstract Base Class...
âœ… Dataset is properly abstract
âœ… Dataset interface works correctly!


## Unit Test - Tensor Dataset Class

In [12]:
def test_unit_tensordataset():
    print("ðŸ”¬ Unit Test: TensorDataset...")

    features = Tensor([[1, 2], [3, 4], [5, 6]])
    labels = Tensor([0, 1, 0])
    dataset = TensorDataset(features, labels)

    # Test length
    assert len(dataset) == 3, f"Expected length 3, got {len(dataset)}"

    # Test indexing
    sample = dataset[0]
    assert len(sample) == 2, "Should return tuple with 2 tensors"
    assert np.array_equal(sample[0].data, [1, 2]), f'Wrong features: {sample[0].data}'

    # Test error handling
    try:
        dataset[10]
        assert False, "Shoudl raise IndexError for out of bounds access"
    except IndexError:
        pass

    # Test mismatch between tensor sizes
    try:
        bad_features = Tensor([[1,2], [3, 4]])
        bad_labels = Tensor([0, 1, 0])
        TensorDataset(bad_features, bad_labels)
        assert False, "Should riase error for mismatched tensor sizes"
    except ValueError:
        pass
    
    
    print("âœ… TensorDataset works correctly!")

if __name__ =='__main__':
    test_unit_tensordataset()

ðŸ”¬ Unit Test: TensorDataset...
âœ… TensorDataset works correctly!
