# Patch pooling (10pts)
Your task is to implement patch pooling. Patch pooling takes an input sequence $(a_0, a_1, \ldots, a_{B-1})$ of $D$-dimensional embeddings and output an output sequence $(b_0, \ldots, b_{k-1})$ of $D$-dimensional embeddings. The length of the output sequence is not longer than the length of the input sequence and is bounded by $P$. Each element of the input sequence is called a token. Each element of the output sequence is called a patch. Consecutive patches are constructed as a mean pooling of consecutive contiguous token spans.

You are given two tensors:
1. `batch` - a $3$-dimensional tensor, which is an input to a standard transformer model with the following dimensions:
* B - batch size
* S - sequence lenght
* D - dimension of embedding of a single token

`batch[x,y,:]` is the embedding of the $y+1$-th token of the $x+1$-th sequence in the `batch`.

2. `patch_lengths` - $2$-dimensional integer-valued tensor with the following dimensions:
* B - batch size
* P - maximal number of patches

`patch_lengths[x,y]` is the number of tokens forming patch number $y+1$ in the $x+1$-th sequence in the `batch`.

The output should be a $3$-dimensional tensor with batch of sequences of patch embeddings.

# Example
The following snippet
```python
batch = torch.tensor([[[ 1.,  1.,  1.,  1.,  1.],
         [ 1.,  1.,  1.,  1.,  1.],
         [ 1.,  1.,  1.,  1.,  1.],
         [ 2.,  2.,  2.,  2.,  2.],
         [ 3.,  3.,  3.,  3.,  3.],
         [ 3.,  3.,  3.,  3.,  3.]],

        [[ 4.,  4.,  4.,  4.,  4.],
         [ 4.,  4.,  4.,  4.,  4.],
         [ 4.,  4.,  4.,  4.,  4.],
         [ 4.,  4.,  4.,  4.,  4.],
         [ 5.,  5.,  5.,  5.,  5.],
         [-1., -1., -1., -1., -1.]],

        [[ 6.,  6.,  6.,  6.,  6.],
         [-1., -1., -1., -1., -1.],
         [-1., -1., -1., -1., -1.],
         [-1., -1., -1., -1., -1.],
         [-1., -1., -1., -1., -1.],
         [-1., -1., -1., -1., -1.]]])
patch_lengths = torch.tensor([[3, 1, 2],
        [4, 1, 0],
        [1, 0, 0]])
patch_pooling = PatchPooling()
output = patch_pooling(batch, patch_lengths)
output
```

should ouptut

```python
torch.tensor([[[1., 1., 1., 1., 1.],
         [2., 2., 2., 2., 2.],
         [3., 3., 3., 3., 3.]],

        [[4., 4., 4., 4., 4.],
         [5., 5., 5., 5., 5.],
         [-1., -1., -1., -1., -1.]],

        [[6., 6., 6., 6., 6.],
         [-1., -1., -1., -1., -1.],
         [-1., -1., -1., -1., -1.]]])
```

Remarks:

1. In this problem you can assume that embeddings of the padding token are vectors with all coordinates equal to $-1$.

2. Solutions will be graded with unit tests. You are given a single test case, which will be a part of evaluation.

3. Solutions not satisfying the below requirements will be graded up to 4pts:
* You are not allowed to call custom python functions
* You are not allowed to use Python loops
* Your are not allowed to use any other imports

In [27]:
%%file test_patch_pooling.py

import pytest
import torch


class PatchPooling(torch.nn.Module):
    def forward(self, batch: torch.Tensor, patch_lengths: torch.Tensor) -> torch.Tensor:
        B, S, D = batch.shape
        B_1, P = patch_lengths.shape

        assert B == B_1

        ### Your code goes here ###

        ###########################

        


class TestPatchPooling:
    @pytest.mark.parametrize(
        "batch,patch_lengths,expected_output",
        [
            (
                torch.tensor(
                    [
                        [
                            [1.0, 1.0, 1.0, 1.0, 1.0],
                            [1.0, 1.0, 1.0, 1.0, 1.0],
                            [1.0, 1.0, 1.0, 1.0, 1.0],
                            [2.0, 2.0, 2.0, 2.0, 2.0],
                            [3.0, 3.0, 3.0, 3.0, 3.0],
                            [3.0, 3.0, 3.0, 3.0, 3.0],
                        ],
                        [
                            [4.0, 4.0, 4.0, 4.0, 4.0],
                            [4.0, 4.0, 4.0, 4.0, 4.0],
                            [4.0, 4.0, 4.0, 4.0, 4.0],
                            [4.0, 4.0, 4.0, 4.0, 4.0],
                            [5.0, 5.0, 5.0, 5.0, 5.0],
                            [-1.0, -1.0, -1.0, -1.0, -1.0],
                        ],
                        [
                            [6.0, 6.0, 6.0, 6.0, 6.0],
                            [-1.0, -1.0, -1.0, -1.0, -1.0],
                            [-1.0, -1.0, -1.0, -1.0, -1.0],
                            [-1.0, -1.0, -1.0, -1.0, -1.0],
                            [-1.0, -1.0, -1.0, -1.0, -1.0],
                            [-1.0, -1.0, -1.0, -1.0, -1.0],
                        ],
                    ]
                ),
                torch.tensor([[3, 1, 2], [4, 1, 0], [1, 0, 0]]),
                torch.tensor(
                    [
                        [
                            [1.0, 1.0, 1.0, 1.0, 1.0],
                            [2.0, 2.0, 2.0, 2.0, 2.0],
                            [3.0, 3.0, 3.0, 3.0, 3.0],
                        ],
                        [
                            [4.0, 4.0, 4.0, 4.0, 4.0],
                            [5.0, 5.0, 5.0, 5.0, 5.0],
                            [-1.0, -1.0, -1.0, -1.0, -1.0],
                        ],
                        [
                            [6.0, 6.0, 6.0, 6.0, 6.0],
                            [-1.0, -1.0, -1.0, -1.0, -1.0],
                            [-1.0, -1.0, -1.0, -1.0, -1.0],
                        ],
                    ]
                ),
            )
        ],
    )
    def test_forward(
        self,
        batch: torch.Tensor,
        patch_lengths: torch.Tensor,
        expected_output: torch.Tensor,
    ) -> None:
        # given
        patch_pooling = PatchPooling()

        # when
        output = patch_pooling(batch=batch, patch_lengths=patch_lengths)

        # then
        assert torch.all(torch.isclose(output, expected_output))

Overwriting test_patch_pooling.py


In [28]:
!python -m pytest test_patch_pooling.py

platform linux -- Python 3.11.11, pytest-8.3.4, pluggy-1.5.0
rootdir: /content
plugins: langsmith-0.3.2, typeguard-4.4.1, anyio-3.7.1
collected 1 item                                                                                   [0m

test_patch_pooling.py [32m.[0m[32m                                                                      [100%][0m

