In [3]:
from typing import Any, Iterable, List

def batches(
    iterable: Iterable[Any], batch_size: int
) -> Iterable[List[Any]]:
    pass

In [6]:
from itertools import chain

def test_batch_on_lists():
    assert list(batches([1, 2, 3, 4, 5, 6], 1)) == [
        [1], [2], [3], [4], [5], [6]
    ]
    assert list(batches([1, 2, 3, 4, 5, 6], 2)) == [
        [1,2], [3,4], [5,6]
    ]
    assert list(batches([1, 2, 3, 4, 5, 6], 3)) == [
        [1,2,3], [4,5,6]
    ]
    assert list(batches([1, 2, 3, 4, 5, 6], 4)) == [
        [1,2,3,4], [5,6]
    ]
    

In [7]:
test_batch_on_lists()

TypeError: 'NoneType' object is not iterable

In [8]:
def test_batch_order():
    iterable = range(100)
    batch_size = 2 
    
    output = batches(iterable, batch_size)
    
    assert list(chain.from_iterable(output)) == list(iterable)
    
def test_batch_sizes():
    iterable = range(100)
    batch_size = 2 
    
    output = list(batches(iterable, batch_size))
    
    for batch in output[:-1]:
        assert len(batch) == batch_size
        
    assert len(output[-1]) <= batch_size
    

In [11]:
! pip install pytest


Collecting pytest
  Downloading pytest-7.2.2-py3-none-any.whl (317 kB)
     -------------------------------------- 317.2/317.2 kB 9.6 MB/s eta 0:00:00
Collecting iniconfig
  Downloading iniconfig-2.0.0-py3-none-any.whl (5.9 kB)
Collecting tomli>=1.0.0
  Downloading tomli-2.0.1-py3-none-any.whl (12 kB)
Collecting pluggy<2.0,>=0.12
  Using cached pluggy-1.0.0-py2.py3-none-any.whl (13 kB)
Collecting exceptiongroup>=1.0.0rc8
  Downloading exceptiongroup-1.1.1-py3-none-any.whl (14 kB)
Installing collected packages: tomli, pluggy, iniconfig, exceptiongroup, pytest
Successfully installed exceptiongroup-1.1.1 iniconfig-2.0.0 pluggy-1.0.0 pytest-7.2.2 tomli-2.0.1


In [12]:
! pytest -v

platform win32 -- Python 3.10.8, pytest-7.2.2, pluggy-1.0.0 -- C:\Users\shshin\anaconda3\envs\data_analyzer\python.exe
cachedir: .pytest_cache
rootdir: C:\Users\shshin\Documents\jupyter\test_code_sample\Expert Python Programming
plugins: anyio-3.5.0
[1mcollecting ... [0mcollected 0 items



In [13]:
test_batch_order()

TypeError: 'NoneType' object is not iterable

In [14]:
test_batch_sizes()

TypeError: 'NoneType' object is not iterable

In [15]:
def batches(
    iterable: Iterable[Any], batch_size: int
) -> Iterable[List[Any]]:
    results = []
    batch = []
    
    for item in iterable:
        batch.append(item)
        if len(batch) == batch_size:
            results.append(batch)
            batch = []
            
    if batch:
        results.append(batch)
        
    return results

In [17]:
! pytest -v

platform win32 -- Python 3.10.8, pytest-7.2.2, pluggy-1.0.0 -- C:\Users\shshin\anaconda3\envs\data_analyzer\python.exe
cachedir: .pytest_cache
rootdir: C:\Users\shshin\Documents\jupyter\test_code_sample\Expert Python Programming
plugins: anyio-3.5.0
[1mcollecting ... [0mcollected 0 items



In [18]:
test_batch_on_lists()

In [19]:
test_batch_order()

In [20]:
test_batch_sizes()

In [21]:
# iterable 객체가 무한대이면 어플리케이션에 큰 부담을 주기 때문에 제너레이터로 바꿔서 성공 시켜보자

def batches(
    iterable: Iterable[Any], batch_size: int
) -> Iterable[List[Any]]:
    results = []
    batch = []
    
    for item in iterable:
        batch.append(item)
        
        if len(batch) == batch_size:
            yield batch
            batch = []
            
    if batch:
        yield batch
       
    

In [22]:
test_batch_on_lists()

In [23]:
test_batch_order()


In [24]:
test_batch_sizes()

In [25]:
# 이터레이터와 itertools 모듈을 이용하는 방법

from itertools import islice

def batches(
    iterable: Iterable[Any], batch_size: int
) -> Iterable[List[Any]]:
    iterator = iter(iterable)
    
    while True:
        batch = list(islice(iterator, batch_size))
        
        if not batch:
            return
        
        yield batch


In [26]:
! pytest -v

platform win32 -- Python 3.10.8, pytest-7.2.2, pluggy-1.0.0 -- C:\Users\shshin\anaconda3\envs\data_analyzer\python.exe
cachedir: .pytest_cache
rootdir: C:\Users\shshin\Documents\jupyter\test_code_sample\Expert Python Programming
plugins: anyio-3.5.0
[1mcollecting ... [0mcollected 0 items



In [27]:
test_batch_on_lists()

In [28]:
test_batch_order()

In [29]:
test_batch_sizes()

In [31]:
import pytest

@pytest.mark.parametrize(
    "batch_size, expected", [
        # 동일한 크기의 배치
        [1, [[1], [2], [3], [4], [5], [6]]],
        [2, [[1, 2], [3, 4], [5, 6]]],
        [3, [[1,2,3], [4,5,6]]],
        # 마지막 잔여 배치 포함
        [4, [[1,2,3,4], [5,6]]]
    ]
)
        
def test_batch_parameterized(batch_size, expected):
    iterable = [1, 2, 3, 4, 5, 6]
    assert list(batches(iterable, batch_size)) == expected

In [32]:
test_batch_parameterized()

TypeError: test_batch_parameterized() missing 2 required positional arguments: 'batch_size' and 'expected'