In [1]:
from torch.utils.data import IterDataPipe

In [2]:
# Example IterDataPipe
class ExampleIterPipe(IterDataPipe):
    def __init__(self, range = 20):
        self.range = range
    def __iter__(self):
        for i in range(self.range):
            yield i

## Concat

Function: `concat`

Description: Returns DataPipes with elements from the first datapipe following by elements from second datapipes

Alternatives:

    `dp = dp + dp2`
    
    `dp = dp.concat(dp2, dp3)`


In [3]:
dp = ExampleIterPipe(4)
dp2 = ExampleIterPipe(3)
dp = dp.concat(dp2)
for i in dp:
    print(i)

0
1
2
3
0
1
2


## Batch

Function: `batch`

Description: 

Alternatives:

Arguments:
  - `batch_size: int` desider batch size
  - `batch_level: bool = False` whether elements from the source DataPipe counted as independant objects or we count mini-batched elements indivudually.
  - `drop_last: bool = False`

In [4]:
dp = ExampleIterPipe(10).batch(3)
for i in dp:
    print(i)

[0, 1, 2]
[3, 4, 5]
[6, 7, 8]
[9]


In [5]:
dp = ExampleIterPipe(10).batch(3, drop_last=True)
for i in dp:
    print(i)

[0, 1, 2]
[3, 4, 5]
[6, 7, 8]


Without `batch_level` override sequential `batch` calls do rebatching.

In [6]:
dp = ExampleIterPipe(10).batch(3).batch(2)
for i in dp:
    print(i)

[0, 1]
[2, 3]
[4, 5]
[6, 7]
[8, 9]


Setting `batch_level=True` allows to take input lists as singular objects and nest them

In [7]:
dp = ExampleIterPipe(10).batch(3).batch(2, batch_level=True)
for i in dp:
    print(i)

[[0, 1, 2], [3, 4, 5]]
[[6, 7, 8], [9]]


In [8]:
dp = ExampleIterPipe(10).batch(3).map(lambda x: x * 2)
for i in dp:
    print(i)

[0, 2, 4]
[6, 8, 10]
[12, 14, 16]
[18]


In [9]:
dp = ExampleIterPipe(10).batch(3).map(lambda x: x * 2, batch_level=True)
for i in dp:
    print(i)

[0, 1, 2, 0, 1, 2]
[3, 4, 5, 3, 4, 5]
[6, 7, 8, 6, 7, 8]
[9, 9]


In [10]:
dp = ExampleIterPipe(10).map(lambda x: x * 2)
for i in dp:
    print(i)

0
2
4
6
8
10
12
14
16
18


In [11]:
dp = ExampleIterPipe(10)
dp = dp.batch(3).filter(lambda x: x % 2 == 0)
for i in dp:
    print(i)

[0, 2]
[4]
[6, 8]


In [12]:
dp = ExampleIterPipe(10)
dp = dp.batch(3).filter(lambda x: x > 4)
for i in dp:
    print(i)

[5]
[6, 7, 8]
[9]


In [13]:
dp = ExampleIterPipe(10)
dp = dp.batch(3).filter(lambda x: x > 4, drop_empty_batches=False)
for i in dp:
    print(i)

[]
[5]
[6, 7, 8]
[9]


In [14]:
dp = ExampleIterPipe(10)
dp = dp.batch(3).filter(lambda l: len(l) < 3, batch_level=True)
for i in dp:
    print(i)

[9]


In [15]:
dp = ExampleIterPipe(10).batch(3).shuffle()
for i in dp:
    print(i)

[9, 6, 7]
[2, 1, 5]
[8, 4, 3]
[0]


In [16]:
dp = ExampleIterPipe(10).batch(3).shuffle(cross_shuffle = False)
for i in dp:
    print(i)

[1, 0, 2]
[4, 5, 3]
[8, 7, 6]
[9]


In [17]:
dp = ExampleIterPipe(10).shuffle()
for i in dp:
    print(i)

1
9
2
7
6
5
8
3
4
0


In [18]:
dp = ExampleIterPipe(10).batch(3).shuffle().unbatch()
for i in dp:
    print(i)

5
7
1
9
3
4
2
0
6
8


In [19]:
dp = ExampleIterPipe(10).batch(3).collate()
for i in dp:
    print(i)

tensor([0, 1, 2])
tensor([3, 4, 5])
tensor([6, 7, 8])
tensor([9])


In [20]:
dp = ExampleIterPipe(10).map(lambda x: (x % 3, x)).shuffle().groupby(lambda x: x[0])
for i in dp:
    print(i)

[(0, 0), (0, 9), (0, 3), (0, 6)]
[(1, 4), (1, 1), (1, 7)]
[(2, 5), (2, 2), (2, 8)]


In [21]:
dp = ExampleIterPipe(10).batch(3).groupby(lambda x: len(x), batch_level = True)
for i in dp:
    print(i)

[[0, 1, 2], [3, 4, 5], [6, 7, 8]]
[[9]]


## GroupBy

Function: `groupby`

Usage: `dp.groupby(lambda x: x[0])`

Description: 

Arguments:

`group_size` - yeild resulted group as soon as `group_size` elements accumulated

#### Attention
As datasteam can be arbitrary large, grouping is done on best effort basis and there is no guarantee that same key will never present in the different groups.

In [22]:
dp = ExampleIterPipe(15).map(lambda x: (x % 3, x)).shuffle().groupby(lambda x: x[0])
for i in dp:
    print(i)

[(2, 8), (2, 5), (2, 11), (2, 14), (2, 2)]
[(0, 12), (0, 9), (0, 3), (0, 6), (0, 0)]
[(1, 4), (1, 10), (1, 13), (1, 7), (1, 1)]


In [23]:
dp = ExampleIterPipe(15).map(lambda x: (x % 3, x)).shuffle().groupby(lambda x: x[0], buffer_size = 5)
for i in dp:
    print(i)

[(1, 1), (1, 4)]
[(0, 3), (0, 0), (0, 12)]
[(2, 2), (2, 5), (2, 11), (2, 14)]
[(1, 7), (1, 10), (1, 13)]
[(0, 6), (0, 9)]
[(2, 8)]


`groupby` will produce `group_size` sized batches on as fast as possible basis, remaining groups must be at least `guaranteed_group_size` big. 

In [24]:
dp = ExampleIterPipe(15).map(lambda x: (x % 3, x)).shuffle().groupby(lambda x: x[0], group_size = 3, guaranteed_group_size = 2)
for i in dp:
    print(i)

[(1, 4), (1, 7), (1, 1)]
[(2, 11), (2, 2), (2, 8)]
[(0, 12), (0, 0), (0, 9)]
[(2, 5), (2, 14)]
[(0, 3), (0, 6)]
[(1, 10), (1, 13)]


Without defined `group_size` function will try to accumulate at least `guaranteed_group_size` elements before yielding resulted group

In [25]:
dp = ExampleIterPipe(15).map(lambda x: (x % 3, x)).shuffle().groupby(lambda x: x[0], guaranteed_group_size = 2)
for i in dp:
    print(i)

[(2, 8), (2, 5), (2, 11), (2, 2), (2, 14)]
[(0, 3), (0, 9), (0, 12), (0, 6), (0, 0)]
[(1, 1), (1, 13), (1, 7), (1, 4), (1, 10)]


This behaviour becomes noticable when data is bigger than buffer and some groups getting evicted before gathering all potential items

In [26]:
dp = ExampleIterPipe(15).map(lambda x: (x % 3, x)).groupby(lambda x: x[0], guaranteed_group_size = 2, buffer_size = 6)
for i in dp:
    print(i)

[(0, 0), (0, 3)]
[(1, 1), (1, 4), (1, 7)]
[(2, 2), (2, 5), (2, 8)]
[(0, 6), (0, 9), (0, 12)]
[(1, 10), (1, 13)]
[(2, 11), (2, 14)]


With randomness involved you might end up with incomplete groups (so next example expected to fail in most cases)

In [27]:
dp = ExampleIterPipe(15).map(lambda x: (x % 3, x)).shuffle().groupby(lambda x: x[0], guaranteed_group_size = 2, buffer_size = 6)
for i in dp:
    print(i)

[(2, 2), (2, 14), (2, 8), (2, 5)]
[(0, 3), (0, 9), (0, 12), (0, 6)]
[(1, 10), (1, 1), (1, 4), (1, 13)]


Exception: ('Failed to group items', '[(2, 11)]')

In [28]:
_dp = ExampleIterPipe(5)
dp = ExampleIterPipe(5).zip(_dp)
for i in dp:
    print(i)

AttributeError: 