# Bidirectional Gram Converter

We will use the Bidirection Gram Grid Converter architecture to convert between gram embeddings and the original representation. This uses autoencoding for initial training, and makes decoding the opposite of encoding. It also encourages the model to encode into the gram encodings

**Premise**

* If each encoder cell uses a gram embedding to create an update then subtracts it from the input, each decoder cell can reverse it using that same layer to remake the update and then add it.
* We can encourage the model to shunt it's information through gram embeddings instead of the main embeddings with the proper loss on the encoder output.
* By sprinkling in positional embeddings regularly before convolutions, we can let the model retain the ability to access to positional information.

**Design**

 At a overall level, the encoder/decoder model acts a lot like a ResNet. It consists of cells with residual bypasses that operate more or less on pixel embeddings - however each cell produces both the next pixel embedding in the chain, AND a Gram Encoding of the layer. 
 
However, the cells themselves are quite sophisticated, and are designed to allow bidirectional encoding-decoding with almost exactly the same parameters. 

A BResNet cell has three main components. These are.

* Latent Summary stage: Uses internal parameters and actions to create a Gram Encoding. 
* Latent Decode stage: Uses a provided Gram Encoding and a known grid shape to create an update of the same shape as the summary input.
* Merge stage: Either add or subtract the update. Add when encoding. Subtract when decoding. Replaces the residual bypass: Here is your residual now.

The effect of this is that the distinction between encoding, and decoding, from the model's perspective is only a distintion in whether you subtract and use the summary stage, or add and use external Gram Encodings. Lets consider one particular cell operating in encoding and decoding mode to see this illustrated

**Training**

The mechnanism can be trained as an autoencoder. However, we shall motivate the autoencoder to not rely on the encodings output by the encoder by penalizing those encodings when nonzero.


# Hyperparameters and Imports

Hyperparameter and imports go here

In [None]:
import torch
import unittest
from torch import nn
from typing import Tuple


# Primitive Layers

We begin to define the various pieces needed here

## Pixel Cell

The convolution processing cell is really the only architecture-specific piece here. It is specialized for processing embedded pixel data.

**Premise**

* We need a convolutional processing mechanism for pixel images.
* We will need a cell for that.

**Dependencies**

* `embedding_dim`: The dimensions of the embedding.
* `num_layers`: The number of layers deep the cell is.
* `kernel_size`: The size of the convolutional kernel.
* `dropout_prob`: The probability of an element to be zeroed during dropout (default: 0.5).

**Accepts**

* `embeddings`: Image embeddings. Shape (batch x N x M x embedding_dim)
* `mask` (optional): A mask to apply at the end. Shape (batch x N x M)

**Returns**

* `embeddings`: New image embeddings. Shape (batch x N x M x embedding_dim)


In [None]:
class PixelCell(nn.Module):
    """
    Convolutional cell network. Processes image embeddings without reduction.
    """
    def __init__(self,
                 embedding_dim: int,
                 num_layers: int,
                 kernel_size: int,
                 dropout_prob: float = 0.5):
        """
        Initialize the conv cell
        :param embedding_dim: The dimension of the embeddings
        :param num_layers: The number of layers to make
        :param kernel_size: The kernel size of the convolutional layers
        :param dropout_prob: The probability of an element to be zeroed (default: 0.5)
        """
        super().__init__()

        self.embedding_dim = embedding_dim
        self.num_layers = num_layers
        self.kernel_size = kernel_size
        self.dropout_prob = dropout_prob

        # Construct layers
        padding_size = (kernel_size - 1) // 2
        layers = []
        for i in range(num_layers):
            layer = nn.Sequential(
                nn.Conv2d(embedding_dim, embedding_dim, kernel_size, padding=padding_size),
                nn.LeakyReLU()
            )
            layers.append(layer)
        self.layers = nn.Sequential(*layers)
        self.dropout = nn.Dropout2d(dropout_prob)

    def forward(self, embeddings: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
        """
        Forward pass to process embeddings through convolutional layers.
        :param embeddings: The input embeddings. Shape (batch x N x M x embedding_dim)
        :param mask: Optional mask to apply at the end. Shape (batch x N x M)
        :return: The output embeddings. Shape (batch x N x M x embedding_dim)
        """
        assert embeddings.dim() == 4, "Embeddings must have 4 dimensions (batch x N x M x embedding_dim)"
        assert embeddings.shape[-1] == self.embedding_dim, f"Expected embedding dimension {self.embedding_dim}, but got {embeddings.shape[-1]}"

        # Permute to match the expected input shape for Conv2d
        channels = embeddings.permute(0, 3, 1, 2)  # shape: (batch x embedding_dim x N x M)
        channels = self.layers(channels)  # shape: (batch x embedding_dim x N x M)
        channels = self.dropout(channels)  # Apply dropout here
        embeddings = channels.permute(0, 2, 3, 1)  # shape: (batch x N x M x embedding_dim)

        if mask is not None:
            assert mask.dim() == 3, "Mask must have 3 dimensions (batch x N x M)"
            mask = mask.unsqueeze(-1)  # shape: (batch x N x M x 1)
            embeddings = embeddings * mask  # Apply mask

        return embeddings


In [None]:
class TestPixelCell(unittest.TestCase):
    def setUp(self):
        # Setting up the necessary parameters and creating an instance of PixelCell
        self.embedding_dim = 64
        self.num_layers = 3
        self.kernel_size = 3
        self.dropout_prob = 0.5
        self.batch_size = 2
        self.height, self.width = 8, 8
        self.pixel_cell = PixelCell(self.embedding_dim, self.num_layers, self.kernel_size, self.dropout_prob)

    def test_forward_shape(self):
        # Creating dummy data for embeddings
        embeddings = torch.randn(self.batch_size, self.height, self.width, self.embedding_dim)
        
        # Running the forward pass
        output = self.pixel_cell(embeddings)
        
        # Asserting the output shape is as expected
        self.assertEqual(output.shape, (self.batch_size, self.height, self.width, self.embedding_dim))

    def test_forward_invalid_dim(self):
        # Creating dummy data with invalid number of dimensions
        embeddings = torch.randn(self.batch_size, self.height, self.embedding_dim)
        
        # Asserting that an AssertionError is raised for invalid input dimensions
        with self.assertRaises(AssertionError):
            self.pixel_cell(embeddings)

    def test_forward_invalid_embedding_dim(self):
        # Creating dummy data with invalid embedding dimension
        embeddings = torch.randn(self.batch_size, self.height, self.width, self.embedding_dim + 1)
        
        # Asserting that an AssertionError is raised for invalid embedding dimension
        with self.assertRaises(AssertionError):
            self.pixel_cell(embeddings)

    def test_no_nan_values(self):
        # Creating dummy data for embeddings
        embeddings = torch.randn(self.batch_size, self.height, self.width, self.embedding_dim)
        
        # Running the forward pass
        output = self.pixel_cell(embeddings)
        
        # Asserting the output does not contain NaNs
        self.assertFalse(torch.isnan(output).any(), "Output contains NaNs")

    def test_no_inf_values(self):
        # Creating dummy data for embeddings
        embeddings = torch.randn(self.batch_size, self.height, self.width, self.embedding_dim)
        
        # Running the forward pass
        output = self.pixel_cell(embeddings)
        
        # Asserting the output does not contain infinite values
        self.assertFalse(torch.isinf(output).any(), "Output contains infinite values")

    def test_dropout_effect(self):
        # Creating dummy data for embeddings
        embeddings = torch.randn(self.batch_size, self.height, self.width, self.embedding_dim)
        
        # Running the forward pass multiple times to check for dropout effect
        outputs = [self.pixel_cell(embeddings) for _ in range(5)]
        
        # Asserting that the outputs are different due to dropout
        different_outputs = any(not torch.equal(outputs[i], outputs[i + 1]) for i in range(len(outputs) - 1))
        self.assertTrue(different_outputs, "Dropout does not seem to have an effect")

    def test_forward_with_mask(self):
        # Creating dummy data for embeddings and mask
        embeddings = torch.randn(self.batch_size, self.height, self.width, self.embedding_dim)
        mask = torch.ones(self.batch_size, self.height, self.width)
        
        # Running the forward pass with mask
        output = self.pixel_cell(embeddings, mask)
        
        # Asserting the output shape is as expected
        self.assertEqual(output.shape, (self.batch_size, self.height, self.width, self.embedding_dim))
        
        # Check that masking was applied correctly (if mask is all ones, output should be unaffected)
        self.assertTrue(torch.equal(output, self.pixel_cell(embeddings)))

    def test_forward_with_partial_mask(self):
        # Creating dummy data for embeddings and partial mask
        embeddings = torch.randn(self.batch_size, self.height, self.width, self.embedding_dim)
        mask = torch.ones(self.batch_size, self.height, self.width)
        mask[:, :self.height//2, :self.width//2] = 0  # Zero out the top-left quarter of the mask
        
        # Running the forward pass with partial mask
        output = self.pixel_cell(embeddings, mask)
        
        # Asserting the masked areas in the output are zeroed
        self.assertTrue(torch.equal(output[:, :self.height//2, :self.width//2, :], torch.zeros_like(output[:, :self.height//2, :self.width//2, :])))


## TextCell

Designed for processing a stream of text embeddings. I am not sure if I will use it

**Dependencies**

* num_layers: The number of encoding cells to use.
* embedding_dim: The width of each individual embedding.
* num_heads: The number of transformer heads.
* dim_feedforward: The size of the feedforward layer

**Accepts**

* embeddings: A tensor of text embeddings. Shape (batch x N x Embeddings)
* mask: A mask of active text embeddings. Shape (batch x N x Embedding)

**Returns**

* Embeddings: An output sequence of embeddings. Shape (batch x N x Embeddings)

**Design**

Basically, we use a sequence of transformer encoder layers to encode the embeddings.

In [None]:
class TextCell(nn.Module):
    """
    Transformer-based text processing cell. Processes a stream of text embeddings.
    """
    def __init__(self,
                 num_layers: int,
                 embedding_dim: int,
                 num_heads: int,
                 dim_feedforward: int,
                 dropout_prob: float = 0.1):
        """
        Initialize the transformer-based text cell
        :param num_layers: The number of encoding cells to use.
        :param embedding_dim: The width of each individual embedding.
        :param num_heads: The number of transformer heads.
        :param dim_feedforward: The size of the feedforward layer.
        :param dropout_prob: The probability of an element to be zeroed (default: 0.1)
        """
        super().__init__()

        self.embedding_dim = embedding_dim
        self.num_layers = num_layers
        self.num_heads = num_heads
        self.dim_feedforward = dim_feedforward
        self.dropout_prob = dropout_prob

        # Construct transformer encoder layers
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embedding_dim,
            nhead=num_heads,
            dim_feedforward=dim_feedforward,
            dropout=dropout_prob,
            batch_first=True
        )
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers)

    def forward(self, embeddings: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
        """
        Forward pass to process text embeddings through transformer encoder layers.
        :param embeddings: A tensor of text embeddings. Shape (batch x N x embedding_dim)
        :param mask: A mask of active text embeddings. Shape (batch x N)
        :return: An output sequence of embeddings. Shape (batch x N x embedding_dim)
        """
        assert embeddings.dim() == 3, "Embeddings must have 3 dimensions (batch x N x embedding_dim)"
        assert embeddings.shape[-1] == self.embedding_dim, f"Expected embedding dimension {self.embedding_dim}, but got {embeddings.shape[-1]}"

        # Pass through transformer encoder with optional mask
        if mask is not None:
            assert mask.dim() == 2, "Mask must have 2 dimensions (batch x N)"
        output = self.transformer_encoder(embeddings, src_key_padding_mask=mask)

        return output

In [None]:
class TestTextCell(unittest.TestCase):
    def setUp(self):
        # Setting up the necessary parameters and creating an instance of TextCell
        self.num_layers = 4
        self.embedding_dim = 64
        self.num_heads = 8
        self.dim_feedforward = 256
        self.dropout_prob = 0.1
        self.batch_size = 2
        self.seq_length = 10
        self.text_cell = TextCell(self.num_layers, self.embedding_dim, self.num_heads, self.dim_feedforward, self.dropout_prob)

    def test_forward_shape(self):
        # Creating dummy data for embeddings
        embeddings = torch.randn(self.batch_size, self.seq_length, self.embedding_dim)
        
        # Running the forward pass
        output = self.text_cell(embeddings)
        
        # Asserting the output shape is as expected
        self.assertEqual(output.shape, (self.batch_size, self.seq_length, self.embedding_dim))

    def test_forward_invalid_dim(self):
        # Creating dummy data with invalid number of dimensions
        embeddings = torch.randn(self.batch_size, self.seq_length, self.embedding_dim, 2)
        
        # Asserting that an AssertionError is raised for invalid input dimensions
        with self.assertRaises(AssertionError):
            self.text_cell(embeddings)

    def test_forward_invalid_embedding_dim(self):
        # Creating dummy data with invalid embedding dimension
        embeddings = torch.randn(self.batch_size, self.seq_length, self.embedding_dim + 1)
        
        # Asserting that an AssertionError is raised for invalid embedding dimension
        with self.assertRaises(AssertionError):
            self.text_cell(embeddings)

    def test_no_nan_values(self):
        # Creating dummy data for embeddings
        embeddings = torch.randn(self.batch_size, self.seq_length, self.embedding_dim)
        
        # Running the forward pass
        output = self.text_cell(embeddings)
        
        # Asserting the output does not contain NaNs
        self.assertFalse(torch.isnan(output).any(), "Output contains NaNs")

    def test_no_inf_values(self):
        # Creating dummy data for embeddings
        embeddings = torch.randn(self.batch_size, self.seq_length, self.embedding_dim)
        
        # Running the forward pass
        output = self.text_cell(embeddings)
        
        # Asserting the output does not contain infinite values
        self.assertFalse(torch.isinf(output).any(), "Output contains infinite values")

    def test_forward_with_mask(self):
        # Creating dummy data for embeddings and mask
        embeddings = torch.randn(self.batch_size, self.seq_length, self.embedding_dim)
        mask = torch.ones(self.batch_size, self.seq_length, dtype=torch.bool)
        
        # Running the forward pass with mask
        output = self.text_cell(embeddings, mask)
        
        # Asserting the output shape is as expected
        self.assertEqual(output.shape, (self.batch_size, self.seq_length, self.embedding_dim))
        
        # Check that masking was applied correctly (if mask is all ones, output should be unaffected)
        self.assertTrue(torch.equal(output, self.text_cell(embeddings)))

    def test_forward_with_partial_mask(self):
        # Creating dummy data for embeddings and partial mask
        embeddings = torch.randn(self.batch_size, self.seq_length, self.embedding_dim)
        mask = torch.ones(self.batch_size, self.seq_length, dtype=torch.bool)
        mask[:, :self.seq_length//2] = False  # Zero out the first half of the sequence for the mask
        
        # Running the forward pass with partial mask
        output = self.text_cell(embeddings, mask)
        
        # Since the mask is True (or 1) where tokens are valid and False (or 0) where they are not,
        # The masked parts of the output should be unaffected.
        # No straightforward way to check without knowing transformer internals or ground truth values.
        # For simplicity, we'll just check the output shape here.
        self.assertEqual(output.shape, (self.batch_size, self.seq_length, self.embedding_dim))


## GramEncoder

**Premise**

* We need something to convert a sequence of embeddings into a reduced-dimensions gram embedding.
* It is easier to process gram matrices back to embeddings when the matrix dimensions are small

**Dependencies**

* embedding_dim: The embeddin dimensions.
* num_heads: The number of encoding heads

**Accepts**

* Embeddings: The embeddings to process. (batch x ... x E)
* Mask: The masked embeddings. (batch x ...)

**Returns**

* Gram Embedding: The gram embedding for the situation. (batch x E)

**Design**

Basically, we make heads as in a transformer, then gram encode the heads, then recombine the results and feedforward. This will involve significantly less needed parameters than directly processing the encoding. In specific, we:

* Mask the embeddings.
* flatten into 1d embeddings.
* reshape the input embedding dim into num_heads.
* create and flatten gram matrices for each head. 
* concatenate the heads together and run them through a linear projection back into embedding_dim



In [1]:
class GramEncoder(nn.Module):
    """
    Encodes embeddings as gram encodings.
    """
    def __init__(self,
                 embedding_dim: int,
                 num_heads: int):
        """
        :param embedding_dim: The dimension of the embeddings
        :param num_heads: The number of heads
        """
        super().__init__()

        assert embedding_dim % num_heads == 0, "embedding_dim must be divisible by num_heads"

        self.num_heads = num_heads
        self.embedding_dim = embedding_dim
        self.head_dim = embedding_dim // num_heads
        self.encode_dim = self.head_dim ** 2

        self.combine = nn.Linear(self.encode_dim * self.num_heads, self.embedding_dim)

    @staticmethod
    def create_gram_encodings(embeddings: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
        """
        Create gram encodings from embeddings. These are created from a mean.
        :param embeddings: The embeddings. (batch x num_heads x L x head_dim)
        :param mask: The mask. (batch x L)
        :return: The encodings. (batch x num_heads x head_dim^2)
        """
        # Apply mask
        embeddings = embeddings * mask.unsqueeze(1).unsqueeze(-1)

        # Compute the gram matrix
        gram_matrix = torch.einsum('bnhl,bnhk->bnhk', embeddings, embeddings)  # Shape (batch x num_heads x head_dim x head_dim)
        encodings = gram_matrix.flatten(2, -1)  # shape (batch x num_heads x head_dim^2)

        # Normalize the gram encodings by the active entries that contributed. This takes the mean and handles variable sizes.
        counts = mask.sum(dim=-1, keepdim=True)  # (batch x 1)
        counts = counts.unsqueeze(1)  # shape: (batch x 1 x 1)
        encodings = encodings / (counts + 1e-8)  # shape (batch x num_heads x head_dim^2)

        return encodings

    def forward(self, embeddings: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
        """
        Forward method to produce gram embeddings
        :param embeddings: The input embeddings. Shape (batch_size x ... x E)
        :param mask: The input mask. Shape (batch_size x ...)
        :return: The gram embedding. Shape (batch_size x E)
        """
        # Check input shapes
        assert embeddings.dim() >= 3, "Embeddings must have at least 3 dimensions (batch_size, ..., E)"
        assert mask.dim() == embeddings.dim() - 1, "Mask must have one less dimension than embeddings"

        # Flatten
        batch_size = embeddings.size(0)
        embeddings = embeddings.flatten(1, -2)  # Shape(batch_size, L, E)
        mask = mask.flatten(1, -1)  # Shape (batch x L)

        # Create heads
        heads = embeddings.view(batch_size, -1, self.num_heads, self.head_dim)  # Shape (batch x L x num_heads x head_dim)
        heads = heads.permute(0, 2, 1, 3)  # Shape (batch x num_heads x L x head_dim)

        # Process gram encodings
        encodings = self.create_gram_encodings(heads, mask)  # (batch x num_heads x head_dim^2)

        # Recombine
        encodings = encodings.flatten(1, -1)  # shape: (batch x num_heads * head_dim^2)
        embeddings = self.combine(encodings)  # shape: (batch x E)
        return embeddings

IndentationError: expected an indented block after function definition on line 5 (8195848.py, line 9)

In [None]:
class TestGramEncoder(unittest.TestCase):
    def setUp(self):
        # Setting up the necessary parameters and creating an instance of GramEncoder
        self.embedding_dim = 64
        self.num_heads = 4
        self.batch_size = 2
        self.height, self.width = 8, 8
        self.encoder = GramEncoder(self.embedding_dim, self.num_heads)

    def test_gram_encoding(self):
        # Creating dummy data for embeddings and mask
        embeddings = torch.randn(self.batch_size, self.height, self.width, self.embedding_dim)
        mask = torch.ones(self.batch_size, self.height, self.width)
        
        # Running the forward pass
        gram_embedding = self.encoder(embeddings, mask)
        
        # Asserting the output shape is as expected
        self.assertEqual(gram_embedding.shape, (self.batch_size, self.embedding_dim))
        
        # Additional tests can be added here to check specific values or properties
        # Example: check if the output is not NaN
        self.assertFalse(torch.isnan(gram_embedding).any(), "Output contains NaNs")
        
        # Example: check if the output is not infinite
        self.assertFalse(torch.isinf(gram_embedding).any(), "Output contains infinite values")


## Feedforward

The feedforward layer works a lot like a transformers. It lets the model decide to respond differently with different encoding inputs

**Dependencies**

* embedding_dim: The dimension of the embedding input
* feedforward_dim: The dimension of the internal feedforward channel

**Accepts**

*gram_embeddings: Embeddings. Shape (batch x embeddings)

**Returns**

*gram_embeddings: Embeddings. Shape (batch x embeddings)

**Design**

We have a linear, a relu, and a linear. Not much else to say.

In [None]:
class FeedForward(nn.Module):
    """
    Feedforward layer. Works similarly to a transformer feedforward layer.
    """
    def __init__(self, embedding_dim: int, feedforward_dim: int):
        """
        Initialize the feedforward layer
        :param embedding_dim: The dimension of the embedding input.
        :param feedforward_dim: The dimension of the internal feedforward channel.
        """
        super().__init__()
        self.embedding_dim = embedding_dim
        self.feedforward_dim = feedforward_dim

        # Define the feedforward network
        self.network = nn.Sequential(
            nn.Linear(embedding_dim, feedforward_dim),
            nn.ReLU(),
            nn.Linear(feedforward_dim, embedding_dim)
        )

    def forward(self, gram_embeddings: torch.Tensor) -> torch.Tensor:
        """
        Forward pass through the feedforward network.
        :param gram_embeddings: Embeddings. Shape (batch x embedding_dim)
        :return: Embeddings. Shape (batch x embedding_dim)
        """
        assert gram_embeddings.dim() == 2, "gram_embeddings must have 2 dimensions (batch x embedding_dim)"
        assert gram_embeddings.shape[-1] == self.embedding_dim, f"Expected embedding dimension {self.embedding_dim}, but got {gram_embeddings.shape[-1]}"

        return self.network(gram_embeddings)

In [None]:
class TestFeedforward(unittest.TestCase):
    def setUp(self):
        # Setting up the necessary parameters and creating an instance of Feedforward
        self.embedding_dim = 64
        self.feedforward_dim = 256
        self.batch_size = 32
        self.feedforward_layer = Feedforward(self.embedding_dim, self.feedforward_dim)

    def test_forward_shape(self):
        # Creating dummy data for gram_embeddings
        gram_embeddings = torch.randn(self.batch_size, self.embedding_dim)
        
        # Running the forward pass
        output = self.feedforward_layer(gram_embeddings)
        
        # Asserting the output shape is as expected
        self.assertEqual(output.shape, (self.batch_size, self.embedding_dim))

    def test_forward_invalid_dim(self):
        # Creating dummy data with invalid number of dimensions
        gram_embeddings = torch.randn(self.batch_size, self.embedding_dim, 2)
        
        # Asserting that an AssertionError is raised for invalid input dimensions
        with self.assertRaises(AssertionError):
            self.feedforward_layer(gram_embeddings)

    def test_forward_invalid_embedding_dim(self):
        # Creating dummy data with invalid embedding dimension
        gram_embeddings = torch.randn(self.batch_size, self.embedding_dim + 1)
        
        # Asserting that an AssertionError is raised for invalid embedding dimension
        with self.assertRaises(AssertionError):
            self.feedforward_layer(gram_embeddings)

    def test_no_nan_values(self):
        # Creating dummy data for gram_embeddings
        gram_embeddings = torch.randn(self.batch_size, self.embedding_dim)
        
        # Running the forward pass
        output = self.feedforward_layer(gram_embeddings)
        
        # Asserting the output does not contain NaNs
        self.assertFalse(torch.isnan(output).any(), "Output contains NaNs")

    def test_no_inf_values(self):
        # Creating dummy data for gram_embeddings
        gram_embeddings = torch.randn(self.batch_size, self.embedding_dim)
        
        # Running the forward pass
        output = self.feedforward_layer(gram_embeddings)
        
        # Asserting the output does not contain infinite values
        self.assertFalse(torch.isinf(output).any(), "Output contains infinite values")

# Architecture Layers

The main architecture layers and their mechanisms. Everything here forward is setup using dependency injection.

## BGCEncoderCell

An encoder cell that takes in an image embedding and produces a gram embedding. 

**Premise**

We need to be able to encode an input into a gram embedding for this architecture to work.

**Dependencies**

* encoder: The encoding stack.
* gram_encoder: The gram encoder
* feedforward: A feedforward network that goes off against the gram embeddings

**Accepts**

* embeddings: a batch of 2d positional embeddings: (batch x ... x E)
* mask: a mask for the batch elements, True means active: (batch x ...)
* pos_encoding: a grid of positional encodings to inject. Shape (batch x ... x E)

**Returns**

* gram_embedding: A gram embedding. Shape (batch x E)

**Design**

The layer starts by injecting the positional encodings into the input, then processes them using
the encoder. The result is passed into the GramEncoder, which is then passed back out.


In [None]:
class BGCEncoderCell(nn.Module):
    """
    An encoder cell that takes in a collection of embeddings and produces a gram embedding.
    """
    def __init__(self, encoder: nn.Module, gram_encoder: nn.Module, feedforward: nn.Module):
        """
        Initialize the BGCEncoderCell
        :param encoder: The encoding stack.
        :param gram_encoder: The gram encoder.
        :param feedforward: A feedforward network that goes off against the gram embeddings.
        """
        super().__init__()
        self.encoder = encoder
        self.gram_encoder = gram_encoder
        self.feedforward = feedforward

    def forward(self, embeddings: torch.Tensor, mask: torch.Tensor, pos_encoding: torch.Tensor) -> torch.Tensor:
        """
        Forward pass to produce a gram embedding from input embeddings.
        :param embeddings: A batch of positional embeddings. Shape (batch x ... x embedding_dim)
        :param mask: A mask for the batch elements, True means active. Shape (batch x ...)
        :param pos_encoding: A grid of positional encodings to inject. Shape (batch x ... x embedding_dim)
        :return: A gram embedding. Shape (batch x embedding_dim)
        """
        assert embeddings.dim() == pos_encoding.dim(), "Embeddings and positional encodings must have the same number of dimensions"
        assert embeddings.shape[-1] == pos_encoding.shape[-1], "Embeddings and positional encodings must have the same embedding dimension"
        assert embeddings.shape[:-1] == mask.shape, "Embeddings and mask must have the same batch and spatial dimensions"

        # Inject positional encodings
        embeddings = embeddings + pos_encoding

        # Process with the encoder
        encoded = self.encoder(embeddings, mask)

        # Produce the gram embedding
        gram_embedding = self.gram_encoder(encoded, mask)

        # Pass the gram embedding through the feedforward network
        gram_embedding = self.feedforward(gram_embedding)

        return gram_embedding

In [None]:
class TestBGCEncoderCell(unittest.TestCase):
    def setUp(self):
        # Setting up the necessary parameters and creating an instance of BGCEncoderCell
        self.embedding_dim = 64

        # Dummy modules for encoder, gram_encoder, and feedforward
        class DummyEncoder(nn.Module):
            def forward(self, x, mask):
                return x

        class DummyGramEncoder(nn.Module):
            def forward(self, x, mask):
                return torch.mean(x, dim=tuple(range(1, x.dim() - 1)))

        class DummyFeedforward(nn.Module):
            def forward(self, x):
                return x

        self.encoder = DummyEncoder()
        self.gram_encoder = DummyGramEncoder()
        self.feedforward = DummyFeedforward()

        self.bgc_encoder_cell = BGCEncoderCell(self.encoder, self.gram_encoder, self.feedforward)

    def test_forward_shape_1d(self):
        batch_size = 2
        seq_length = 10
        embeddings = torch.randn(batch_size, seq_length, self.embedding_dim)
        mask = torch.ones(batch_size, seq_length)
        pos_encoding = torch.randn(batch_size, seq_length, self.embedding_dim)
        
        output = self.bgc_encoder_cell(embeddings, mask, pos_encoding)
        self.assertEqual(output.shape, (batch_size, self.embedding_dim))

    def test_forward_shape_2d(self):
        batch_size = 2
        height, width = 8, 8
        embeddings = torch.randn(batch_size, height, width, self.embedding_dim)
        mask = torch.ones(batch_size, height, width)
        pos_encoding = torch.randn(batch_size, height, width, self.embedding_dim)
        
        output = self.bgc_encoder_cell(embeddings, mask, pos_encoding)
        self.assertEqual(output.shape, (batch_size, self.embedding_dim))

    def test_forward_shape_3d(self):
        batch_size = 2
        depth, height, width = 4, 8, 8
        embeddings = torch.randn(batch_size, depth, height, width, self.embedding_dim)
        mask = torch.ones(batch_size, depth, height, width)
        pos_encoding = torch.randn(batch_size, depth, height, width, self.embedding_dim)
        
        output = self.bgc_encoder_cell(embeddings, mask, pos_encoding)
        self.assertEqual(output.shape, (batch_size, self.embedding_dim))

    def test_forward_invalid_dim(self):
        batch_size = 2
        seq_length = 10
        embeddings = torch.randn(batch_size, seq_length, self.embedding_dim, 2)
        mask = torch.ones(batch_size, seq_length, 2)
        pos_encoding = torch.randn(batch_size, seq_length, self.embedding_dim, 2)
        
        with self.assertRaises(AssertionError):
            self.bgc_encoder_cell(embeddings, mask, pos_encoding)

    def test_forward_invalid_embedding_dim(self):
        batch_size = 2
        seq_length = 10
        embeddings = torch.randn(batch_size, seq_length, self.embedding_dim + 1)
        mask = torch.ones(batch_size, seq_length)
        pos_encoding = torch.randn(batch_size, seq_length, self.embedding_dim + 1)
        
        with self.assertRaises(AssertionError):
            self.bgc_encoder_cell(embeddings, mask, pos_encoding)

    def test_forward_with_mask(self):
        batch_size = 2
        seq_length = 10
        embeddings = torch.randn(batch_size, seq_length, self.embedding_dim)
        mask = torch.ones(batch_size, seq_length)
        pos_encoding = torch.randn(batch_size, seq_length, self.embedding_dim)
        
        output = self.bgc_encoder_cell(embeddings, mask, pos_encoding)
        self.assertEqual(output.shape, (batch_size, self.embedding_dim))
        
        # Check that masking was applied correctly (if mask is all ones, output should be unaffected)
        self.assertTrue(torch.equal(output, self.bgc_encoder_cell(embeddings, mask, pos_encoding)))

    def test_forward_with_partial_mask(self):
        batch_size = 2
        seq_length = 10
        embeddings = torch.randn(batch_size, seq_length, self.embedding_dim)
        mask = torch.ones(batch_size, seq_length)
        mask[:, :seq_length // 2] = 0  # Zero out the first half of the sequence for the mask
        pos_encoding = torch.randn(batch_size, seq_length, self.embedding_dim)
        
        output = self.bgc_encoder_cell(embeddings, mask, pos_encoding)
        self.assertEqual(output.shape, (batch_size, self.embedding_dim))
        
        # Since the mask is True (or 1) where tokens are valid and False (or 0) where they are not,
        # The masked parts of the output should be unaffected.
        # No straightforward way to check without knowing encoder internals or ground truth values.
        # For simplicity, we'll just check the output shape here.

In [None]:
def build_bgc_encoder_cell_for_text(
        num_layers: int,
        embedding_dim: int,
        transformer_heads: int,
        transformer_feedforward: int,
        dropout_prob: float,
        gram_feedforward: int,
        gram_heads: int,
        **kwargs
    ) -> BGCEncoderCell:
    """
    Build a BGCEncoderCell for text data.
    :param num_layers: Number of layers in the TextCell.
    :param embedding_dim: The dimension of the embeddings.
    :param transformer_heads: The number of transformer heads in the TextCell.
    :param transformer_feedforward: The size of the feedforward layer in the TextCell.
    :param dropout_prob: The dropout probability in the TextCell.
    :param gram_feedforward: The dimension of the internal feedforward network for the gram encoder.
    :param gram_heads: The number of heads in the GramEncoder.
    :return: An instance of BGCEncoderCell configured for text data.
    """
    encoder = TextCell(
        num_layers=num_layers,
        embedding_dim=embedding_dim,
        num_heads=transformer_heads,
        dim_feedforward=transformer_feedforward,
        dropout_prob=dropout_prob
    )
    gram_encoder = GramEncoder(
        embedding_dim=embedding_dim,
        num_heads=gram_heads
    )
    feedforward = FeedForward(
        embedding_dim=embedding_dim,
        feedforward_dim=gram_feedforward
    )
    return BGCEncoderCell(encoder, gram_encoder, feedforward)

def build_bgc_encoder_cell_for_pixel(
        num_layers: int,
        embedding_dim: int,
        kernel_size: int,
        dropout_prob: float,
        gram_feedforward: int,
        gram_heads: int,
        **kwargs
    ) -> BGCEncoderCell:
    """
    Build a BGCEncoderCell for pixel data.
    :param num_layers: Number of layers in the PixelCell.
    :param embedding_dim: The dimension of the embeddings.
    :param kernel_size: The size of the convolutional kernel in the PixelCell.
    :param dropout_prob: The dropout probability in the PixelCell.
    :param gram_feedforward: The dimension of the internal feedforward network for the gram encoder.
    :param gram_heads: The number of heads in the GramEncoder.
    :return: An instance of BGCEncoderCell configured for pixel data.
    """
    encoder = PixelCell(
        embedding_dim=embedding_dim,
        num_layers=num_layers,
        kernel_size=kernel_size,
        dropout_prob=dropout_prob
    )
    gram_encoder = GramEncoder(
        embedding_dim=embedding_dim,
        num_heads=gram_heads
    )
    feedforward = FeedForward(
        embedding_dim=embedding_dim,
        feedforward_dim=gram_feedforward
    )
    return BGCEncoderCell(encoder, gram_encoder, feedforward)

In [None]:
class TestBGCEncoderCellBuilders(unittest.TestCase):
    def setUp(self):
        # Setting up the necessary parameters for text and pixel BGCEncoderCell builders
        self.num_layers_text = 4
        self.embedding_dim_text = 64
        self.num_heads_text = 8
        self.dim_feedforward_text = 256
        self.dropout_prob_text = 0.1
        self.feedforward_dim_text = 256
        self.gram_heads_text = 8

        self.num_layers_pixel = 3
        self.embedding_dim_pixel = 64
        self.kernel_size_pixel = 3
        self.dropout_prob_pixel = 0.5
        self.feedforward_dim_pixel = 256
        self.gram_heads_pixel = 8

    def test_build_bgc_encoder_cell_for_text(self):
        text_bgc_encoder_cell = build_bgc_encoder_cell_for_text(
            num_layers=self.num_layers_text,
            embedding_dim=self.embedding_dim_text,
            num_heads=self.num_heads_text,
            dim_feedforward=self.dim_feedforward_text,
            dropout_prob=self.dropout_prob_text,
            feedforward_dim=self.feedforward_dim_text,
            gram_heads=self.gram_heads_text
        )
        self.assertIsInstance(text_bgc_encoder_cell, BGCEncoderCell)
        self.assertIsInstance(text_bgc_encoder_cell.encoder, TextCell)
        self.assertIsInstance(text_bgc_encoder_cell.gram_encoder, GramEncoder)
        self.assertIsInstance(text_bgc_encoder_cell.feedforward, Feedforward)

    def test_build_bgc_encoder_cell_for_pixel(self):
        pixel_bgc_encoder_cell = build_bgc_encoder_cell_for_pixel(
            num_layers=self.num_layers_pixel,
            embedding_dim=self.embedding_dim_pixel,
            kernel_size=self.kernel_size_pixel,
            dropout_prob=self.dropout_prob_pixel,
            feedforward_dim=self.feedforward_dim_pixel,
            gram_heads=self.gram_heads_pixel
        )
        self.assertIsInstance(pixel_bgc_encoder_cell, BGCEncoderCell)
        self.assertIsInstance(pixel_bgc_encoder_cell.encoder, PixelCell)
        self.assertIsInstance(pixel_bgc_encoder_cell.gram_encoder, GramEncoder)
        self.assertIsInstance(pixel_bgc_encoder_cell.feedforward, Feedforward)

    def test_integration_text_bgc_encoder_cell(self):
        text_bgc_encoder_cell = build_bgc_encoder_cell_for_text(
            num_layers=self.num_layers_text,
            embedding_dim=self.embedding_dim_text,
            num_heads=self.num_heads_text,
            dim_feedforward=self.dim_feedforward_text,
            dropout_prob=self.dropout_prob_text,
            feedforward_dim=self.feedforward_dim_text,
            gram_heads=self.gram_heads_text
        )
        batch_size = 2
        seq_length = 10
        embeddings = torch.randn(batch_size, seq_length, self.embedding_dim_text)
        mask = torch.ones(batch_size, seq_length)
        pos_encoding = torch.randn(batch_size, seq_length, self.embedding_dim_text)
        
        output = text_bgc_encoder_cell(embeddings, mask, pos_encoding)
        self.assertEqual(output.shape, (batch_size, self.embedding_dim_text))

    def test_integration_pixel_bgc_encoder_cell(self):
        pixel_bgc_encoder_cell = build_bgc_encoder_cell_for_pixel(
            num_layers=self.num_layers_pixel,
            embedding_dim=self.embedding_dim_pixel,
            kernel_size=self.kernel_size_pixel,
            dropout_prob=self.dropout_prob_pixel,
            feedforward_dim=self.feedforward_dim_pixel,
            gram_heads=self.gram_heads_pixel
        )
        batch_size = 2
        height, width = 8, 8
        embeddings = torch.randn(batch_size, height, width, self.embedding_dim_pixel)
        mask = torch.ones(batch_size, height, width)
        pos_encoding = torch.randn(batch_size, height, width, self.embedding_dim_pixel)
        
        output = pixel_bgc_encoder_cell(embeddings, mask, pos_encoding)
        self.assertEqual(output.shape, (batch_size, self.embedding_dim_pixel))

## BGCDecoderCell

A decode cell turns a gram embedding into an update that can be applied.

**Premise**

We need to be able to create updates from gram embeddings in order to encode and decode.

**Dependencies**

* `decoder`: The primary decoder mechanism, likely a stack of convolutional networks or similar.

**Accepts**

* `gram_embedding`: The gram embedding of the layer. Shape: (batch x embedding_dim).
* `mask`: The mask indicating active elements. Shape: (batch x ...).
* `pos_encodings`: The positional encodings to use. Shape: (batch x ... x embedding_dim).

**Returns**

* `update`: The update built from the gram embeddings. Shape: (batch x ... x embedding_dim).

**Design**

We follow the following steps:

* Expand the `gram_embedding` to match the shape of `pos_encodings`.
* Add the positional encodings to the expanded gram embedding and multiply by the mask.
* Run the combined input through the decoder layer.
* Return the result as the update.


In [None]:
class BGCDecoderCell(nn.Module):
    """
    A decode cell turns a gram embedding into an update that can be applied.
    """
    def __init__(self, decoder: nn.Module):
        """
        Initialize the BGCDecoderCell
        :param decoder: The primary decoder mechanism. Likely a stack of convolutional networks.
        """
        super().__init__()
        self.decoder = decoder

    def forward(self, gram_embedding: torch.Tensor, mask: torch.Tensor, pos_encodings: torch.Tensor) -> torch.Tensor:
        """
        Forward pass to produce an update from gram embeddings.
        :param gram_embedding: The gram embedding of the layer. Shape (batch x embedding_dim).
        :param mask: The mask indicating active elements. Shape (batch x ...).
        :param pos_encodings: The positional encodings to use. Shape (batch x ... x embedding_dim).
        :return: The update built from the gram embeddings. Shape: (batch x ... x embedding_dim).
        """
        assert gram_embedding.dim() == 2, "Gram embedding must have 2 dimensions (batch x embedding_dim)"
        assert pos_encodings.shape[:-1] == mask.shape, "Positional encodings and mask must have the same shape except for the last dimension"
        assert pos_encodings.shape[-1] == gram_embedding.shape[-1], "Positional encodings and gram embeddings must have the same embedding dimension"

        # Expand gram_embedding to match pos_encoding shape.
        while gram_embedding.dim() < pos_encodings.dim():
            gram_embedding = gram_embedding.unsqueeze(1)

        # Take the positional encodings, add it to the expanded gram embedding, and multiply by the mask
        combined_input = (gram_embedding + pos_encodings) * mask.unsqueeze(-1)

        # Run the combined input through the decoder
        update = self.decoder(combined_input)

        return update

In [None]:

class DummyDecoder(nn.Module):
    def forward(self, x):
        return x

class TestBGCDecoderCell(unittest.TestCase):
    def setUp(self):
        # Setting up the necessary parameters and creating an instance of BGCDecoderCell
        self.decoder = DummyDecoder()
        self.bgc_decoder_cell = BGCDecoderCell(self.decoder)
        self.embedding_dim = 64

    def test_forward_shape_1d(self):
        batch_size = 2
        seq_length = 10
        gram_embedding = torch.randn(batch_size, self.embedding_dim)
        pos_encoding = torch.randn(batch_size, seq_length, self.embedding_dim)
        mask = torch.ones(batch_size, seq_length)
        
        output = self.bgc_decoder_cell(gram_embedding, mask, pos_encoding)
        self.assertEqual(output.shape, (batch_size, seq_length, self.embedding_dim))

    def test_forward_shape_2d(self):
        batch_size = 2
        height, width = 8, 8
        gram_embedding = torch.randn(batch_size, self.embedding_dim)
        pos_encoding = torch.randn(batch_size, height, width, self.embedding_dim)
        mask = torch.ones(batch_size, height, width)
        
        output = self.bgc_decoder_cell(gram_embedding, mask, pos_encoding)
        self.assertEqual(output.shape, (batch_size, height, width, self.embedding_dim))

    def test_forward_shape_3d(self):
        batch_size = 2
        depth, height, width = 4, 8, 8
        gram_embedding = torch.randn(batch_size, self.embedding_dim)
        pos_encoding = torch.randn(batch_size, depth, height, width, self.embedding_dim)
        mask = torch.ones(batch_size, depth, height, width)
        
        output = self.bgc_decoder_cell(gram_embedding, mask, pos_encoding)
        self.assertEqual(output.shape, (batch_size, depth, height, width, self.embedding_dim))

    def test_forward_invalid_dim(self):
        batch_size = 2
        seq_length = 10
        gram_embedding = torch.randn(batch_size, self.embedding_dim, 2)
        pos_encoding = torch.randn(batch_size, seq_length, self.embedding_dim)
        mask = torch.ones(batch_size, seq_length)
        
        with self.assertRaises(AssertionError):
            self.bgc_decoder_cell(gram_embedding, mask, pos_encoding)

    def test_forward_invalid_embedding_dim(self):
        batch_size = 2
        seq_length = 10
        gram_embedding = torch.randn(batch_size, self.embedding_dim + 1)
        pos_encoding = torch.randn(batch_size, seq_length, self.embedding_dim)
        mask = torch.ones(batch_size, seq_length)
        
        with self.assertRaises(AssertionError):
            self.bgc_decoder_cell(gram_embedding, mask, pos_encoding)

    def test_forward_with_mask(self):
        batch_size = 2
        seq_length = 10
        gram_embedding = torch.randn(batch_size, self.embedding_dim)
        pos_encoding = torch.randn(batch_size, seq_length, self.embedding_dim)
        mask = torch.ones(batch_size, seq_length)
        
        output = self.bgc_decoder_cell(gram_embedding, mask, pos_encoding)
        self.assertEqual(output.shape, (batch_size, seq_length, self.embedding_dim))

    def test_forward_with_partial_mask(self):
        batch_size = 2
        seq_length = 10
        gram_embedding = torch.randn(batch_size, self.embedding_dim)
        pos_encoding = torch.randn(batch_size, seq_length, self.embedding_dim)
        mask = torch.ones(batch_size, seq_length)
        mask[:, :seq_length // 2] = 0  # Zero out the first half of the sequence for the mask
        
        output = self.bgc_decoder_cell(gram_embedding, mask, pos_encoding)
        self.assertEqual(output.shape, (batch_size, seq_length, self.embedding_dim))

In [None]:
def build_bgc_decoder_cell_for_text(
        num_layers: int,
        embedding_dim: int,
        transformer_heads: int,
        transformer_feedforward: int,
        dropout_prob: float,
        **kwargs
    ) -> BGCDecoderCell:
    """
    Build a BGCDecoderCell for text data.
    :param num_layers: Number of layers in the TextCell.
    :param embedding_dim: The dimension of the embeddings.
    :param transformer_heads: The number of transformer heads in the TextCell.
    :param transformer_feedforward: The size of the feedforward layer in the TextCell.
    :param dropout_prob: The dropout probability in the TextCell.
    :return: An instance of BGCDecoderCell configured for text data.
    """
    decoder = TextCell(
        num_layers=num_layers,
        embedding_dim=embedding_dim,
        num_heads=transformer_heads,
        dim_feedforward=transformer_feedforward,
        dropout_prob=dropout_prob
    )
    return BGCDecoderCell(decoder)

def build_bgc_decoder_cell_for_image(
        num_layers: int,
        embedding_dim: int,
        kernel_size: int,
        dropout_prob: float,
        **kwargs
    ) -> BGCDecoderCell:
    """
    Build a BGCDecoderCell for image data.
    :param num_layers: Number of layers in the PixelCell.
    :param embedding_dim: The dimension of the embeddings.
    :param kernel_size: The size of the convolutional kernel in the PixelCell.
    :param dropout_prob: The dropout probability in the PixelCell.
    :return: An instance of BGCDecoderCell configured for image data.
    """
    decoder = PixelCell(
        embedding_dim=embedding_dim,
        num_layers=num_layers,
        kernel_size=kernel_size,
        dropout_prob=dropout_prob
    )
    return BGCDecoderCell(decoder)


In [None]:
class TestBGCDecoderCellBuilders(unittest.TestCase):
    def setUp(self):
        # Setting up the necessary parameters for text and image BGCDecoderCell builders
        self.num_layers_text = 4
        self.embedding_dim_text = 64
        self.transformer_heads_text = 8
        self.transformer_feedforward_text = 256
        self.dropout_prob_text = 0.1

        self.num_layers_image = 3
        self.embedding_dim_image = 64
        self.kernel_size_image = 3
        self.dropout_prob_image = 0.5

    def test_build_bgc_decoder_cell_for_text(self):
        text_bgc_decoder_cell = build_bgc_decoder_cell_for_text(
            num_layers=self.num_layers_text,
            embedding_dim=self.embedding_dim_text,
            transformer_heads=self.transformer_heads_text,
            transformer_feedforward=self.transformer_feedforward_text,
            dropout_prob=self.dropout_prob_text
        )
        self.assertIsInstance(text_bgc_decoder_cell, BGCDecoderCell)
        self.assertIsInstance(text_bgc_decoder_cell.decoder, TextCell)

    def test_build_bgc_decoder_cell_for_image(self):
        image_bgc_decoder_cell = build_bgc_decoder_cell_for_image(
            num_layers=self.num_layers_image,
            embedding_dim=self.embedding_dim_image,
            kernel_size=self.kernel_size_image,
            dropout_prob=self.dropout_prob_image
        )
        self.assertIsInstance(image_bgc_decoder_cell, BGCDecoderCell)
        self.assertIsInstance(image_bgc_decoder_cell.decoder, PixelCell)

    def test_integration_text_bgc_decoder_cell(self):
        text_bgc_decoder_cell = build_bgc_decoder_cell_for_text(
            num_layers=self.num_layers_text,
            embedding_dim=self.embedding_dim_text,
            transformer_heads=self.transformer_heads_text,
            transformer_feedforward=self.transformer_feedforward_text,
            dropout_prob=self.dropout_prob_text
        )
        batch_size = 2
        seq_length = 10
        gram_embedding = torch.randn(batch_size, self.embedding_dim_text)
        pos_encoding = torch.randn(batch_size, seq_length, self.embedding_dim_text)
        mask = torch.ones(batch_size, seq_length)
        
        output = text_bgc_decoder_cell(gram_embedding, mask, pos_encoding)
        self.assertEqual(output.shape, (batch_size, seq_length, self.embedding_dim_text))

    def test_integration_image_bgc_decoder_cell(self):
        image_bgc_decoder_cell = build_bgc_decoder_cell_for_image(
            num_layers=self.num_layers_image,
            embedding_dim=self.embedding_dim_image,
            kernel_size=self.kernel_size_image,
            dropout_prob=self.dropout_prob_image
        )
        batch_size = 2
        height, width = 8, 8
        gram_embedding = torch.randn(batch_size, self.embedding_dim_image)
        pos_encoding = torch.randn(batch_size, height, width, self.embedding_dim_image)
        mask = torch.ones(batch_size, height, width)
        
        output = image_bgc_decoder_cell(gram_embedding, mask, pos_encoding)
        self.assertEqual(output.shape, (batch_size, height, width, self.embedding_dim_image))

## BGCCell

A decoder/encoder layer capable of performing the decoding or encoding action on demand.

**Premise**

* We need to specify a cell to actually do the BGC process.
* It can elegantly support bidirectionality if we subtract when decoding and add when encoding.

**Dependencies**

* `encoder_cell` [BGCEncoderCell]: Encodes an input into gram embeddings.
* `decoder_cell` [BCGDecoderCell]: Starts from a gram embedding and the shape, and produces an update.

### Method: Encode

This is the encode action for the cell.

**Accepts**

* `embeddings`: A batch of embeddings. Shape: (batch x ... x embedding_dim)
* `pos_encodings`: Positional encodings. Shape: (batch x ... x embedding_dim)
* `mask`: A mask indicating active elements. Shape: (batch x ...)

**Returns**

* `embeddings`: The output embeddings. Shape: (batch x ... x embedding_dim)
* `gram_embeddings`: The gram embedding for the layer. Shape: (batch x embedding_dim)

**Design**

We follow the following sequence of events:

* Use the `encoder_cell` to create a `gram_embedding` from the inputs.
* Use the `decoder_cell` to create an 'update' from the `gram_embedding`.
* Subtract the 'update' from the original embeddings.
* Return the embeddings and `gram_embedding`.

### Method: Decode

Runs a decode action for the layer. Conceptually, this is like the encode action but in reverse.

**Accepts**

* `embeddings`: The current embeddings. Shape: (batch x ... x embedding_dim)
* `pos_encodings`: Positional encodings. Shape: (batch x ... x embedding_dim)
* `gram_encoding`: The gram encoding for the layer. Shape: (batch x embedding_dim)
* `mask`: The mask for the layer. Shape: (batch x ...)

**Returns**

* `embeddings`: The output embeddings. Shape: (batch x ... x embedding_dim)

**Design**

We proceed similarly to encoding, except we use the decoder cell, and we add the update instead:

* Use the `decoder_cell` to create an 'update' from the `gram_encoding`.
* Add the update to the embeddings.
* Return the resulting embeddings.


In [None]:
class BGCCell(nn.Module):
    """
    A decoder/encoder layer capable of performing the decoding or encoding action on demand.
    """
    def __init__(self, encoder_cell: nn.Module, decoder_cell: nn.Module):
        """
        Initialize the BGCCell
        :param encoder_cell: Encodes an input into gram embeddings.
        :param decoder_cell: Starts from a gram embedding and the shape, and produces an update.
        """
        super().__init__()
        self.encoder_cell = encoder_cell
        self.decoder_cell = decoder_cell

    def encode(self, embeddings: torch.Tensor, pos_encodings: torch.Tensor, mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Encode action for the cell.
        :param embeddings: A batch of embeddings. Shape: (batch x ... x embedding_dim)
        :param pos_encodings: Positional encodings. Shape: (batch x ... x embedding_dim)
        :param mask: A mask indicating active elements. Shape: (batch x ...)
        :return: embeddings: The output embeddings. Shape: (batch x ... x embedding_dim)
                 gram_embeddings: The gram embedding for the layer. Shape: (batch x embedding_dim)
        """
        # Use the encoder_cell to create a gram_embedding from the inputs
        gram_embedding = self.encoder_cell(embeddings, mask, pos_encodings)

        # Use the decoder_cell to create an 'update' from the gram_embedding
        update = self.decoder_cell(gram_embedding, mask, pos_encodings)

        # Subtract the 'update' from the original embeddings
        embeddings = embeddings - update

        return embeddings, gram_embedding

    def decode(self, embeddings: torch.Tensor, pos_encodings: torch.Tensor, gram_encoding: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
        """
        Decode action for the cell.
        :param embeddings: The current embeddings. Shape: (batch x ... x embedding_dim)
        :param pos_encodings: Positional encodings. Shape: (batch x ... x embedding_dim)
        :param gram_encoding: The gram encoding for the layer. Shape: (batch x embedding_dim)
        :param mask: The mask for the layer. Shape: (batch x ...)
        :return: The output embeddings. Shape: (batch x ... x embedding_dim)
        """
        # Use the decoder_cell to create an 'update' from the gram_encoding
        update = self.decoder_cell(gram_encoding, mask, pos_encodings)

        # Add the update to the embeddings
        embeddings = embeddings + update

        return embeddings

In [None]:
class DummyEncoderCell(nn.Module):
    def forward(self, embeddings, mask, pos_encodings):
        # Flatten everything between the first and last dimension
        flattened_embeddings = embeddings.flatten(start_dim=1, end_dim=-2)
        # Take the mean of the middle dimension
        return torch.mean(flattened_embeddings, dim=1)  # Simplified

class DummyDecoderCell(nn.Module):
    def forward(self, gram_embedding, mask, pos_encodings):
        # Just return the pos_encodings as it has the right shape
        return pos_encodings

class TestBGCCell(unittest.TestCase):
    def setUp(self):
        self.embedding_dim = 64
        self.encoder_cell = DummyEncoderCell()
        self.decoder_cell = DummyDecoderCell()
        self.bgc_cell = BGCCell(self.encoder_cell, self.decoder_cell)

    def test_encode_shape_1d(self):
        batch_size = 2
        seq_length = 10
        embeddings = torch.randn(batch_size, seq_length, self.embedding_dim)
        pos_encodings = torch.randn(batch_size, seq_length, self.embedding_dim)
        mask = torch.ones(batch_size, seq_length)
        
        encoded_embeddings, gram_embedding = self.bgc_cell.encode(embeddings, pos_encodings, mask)
        self.assertEqual(encoded_embeddings.shape, (batch_size, seq_length, self.embedding_dim))
        self.assertEqual(gram_embedding.shape, (batch_size, self.embedding_dim))

    def test_encode_shape_2d(self):
        batch_size = 2
        height, width = 8, 8
        embeddings = torch.randn(batch_size, height, width, self.embedding_dim)
        pos_encodings = torch.randn(batch_size, height, width, self.embedding_dim)
        mask = torch.ones(batch_size, height, width)
        
        encoded_embeddings, gram_embedding = self.bgc_cell.encode(embeddings, pos_encodings, mask)
        self.assertEqual(encoded_embeddings.shape, (batch_size, height, width, self.embedding_dim))
        self.assertEqual(gram_embedding.shape, (batch_size, self.embedding_dim))

    def test_decode_shape_1d(self):
        batch_size = 2
        seq_length = 10
        embeddings = torch.randn(batch_size, seq_length, self.embedding_dim)
        pos_encodings = torch.randn(batch_size, seq_length, self.embedding_dim)
        gram_encoding = torch.randn(batch_size, self.embedding_dim)
        mask = torch.ones(batch_size, seq_length)
        
        decoded_embeddings = self.bgc_cell.decode(embeddings, pos_encodings, gram_encoding, mask)
        self.assertEqual(decoded_embeddings.shape, (batch_size, seq_length, self.embedding_dim))

    def test_decode_shape_2d(self):
        batch_size = 2
        height, width = 8, 8
        embeddings = torch.randn(batch_size, height, width, self.embedding_dim)
        pos_encodings = torch.randn(batch_size, height, width, self.embedding_dim)
        gram_encoding = torch.randn(batch_size, self.embedding_dim)
        mask = torch.ones(batch_size, height, width)
        
        decoded_embeddings = self.bgc_cell.decode(embeddings, pos_encodings, gram_encoding, mask)
        self.assertEqual(decoded_embeddings.shape, (batch_size, height, width, self.embedding_dim))

    def test_encode_decode_consistency_1d(self):
        batch_size = 2
        seq_length = 10
        embeddings = torch.randn(batch_size, seq_length, self.embedding_dim)
        pos_encodings = torch.randn(batch_size, seq_length, self.embedding_dim)
        mask = torch.ones(batch_size, seq_length)

        encoded_embeddings, gram_embedding = self.bgc_cell.encode(embeddings, pos_encodings, mask)
        decoded_embeddings = self.bgc_cell.decode(encoded_embeddings, pos_encodings, gram_embedding, mask)
        self.assertEqual(decoded_embeddings.shape, (batch_size, seq_length, self.embedding_dim))

    def test_encode_decode_consistency_2d(self):
        batch_size = 2
        height, width = 8, 8
        embeddings = torch.randn(batch_size, height, width, self.embedding_dim)
        pos_encodings = torch.randn(batch_size, height, width, self.embedding_dim)
        mask = torch.ones(batch_size, height, width)

        encoded_embeddings, gram_embedding = self.bgc_cell.encode(embeddings, pos_encodings, mask)
        decoded_embeddings = self.bgc_cell.decode(encoded_embeddings, pos_encodings, gram_embedding, mask)
        self.assertEqual(decoded_embeddings.shape, (batch_size, height, width, self.embedding_dim))


In [None]:
def build_bgc_cell_for_text(
        num_layers: int,
        embedding_dim: int,
        num_heads: int,
        dim_feedforward: int,
        dropout_prob: float,
        gram_heads: int,
        **kwargs
    ) -> BGCCell:
    """
    Build a BGCCell for text data.
    :param num_layers: Number of layers in the TextCell.
    :param embedding_dim: The dimension of the embeddings.
    :param num_heads: The number of transformer heads in the TextCell.
    :param dim_feedforward: The size of the feedforward layer in the TextCell.
    :param dropout_prob: The dropout probability in the TextCell.
    :param gram_heads: The number of heads for the GramEncoder.
    :return: An instance of BGCCell configured for text data.
    """
    # Build encoder
    encoder = TextCell(
        num_layers=num_layers,
        embedding_dim=embedding_dim,
        num_heads=num_heads,
        dim_feedforward=dim_feedforward,
        dropout_prob=dropout_prob
    )
    gram_encoder = GramEncoder(
        embedding_dim=embedding_dim,
        num_heads=gram_heads
    )
    feedforward = FeedForward(
        embedding_dim=embedding_dim,
        feedforward_dim=dim_feedforward
    )
    encoder_cell = BGCEncoderCell(encoder, gram_encoder, feedforward)

    # Build decoder
    decoder = TextCell(
        num_layers=num_layers,
        embedding_dim=embedding_dim,
        num_heads=num_heads,
        dim_feedforward=dim_feedforward,
        dropout_prob=dropout_prob
    )
    decoder_cell = BGCDecoderCell(decoder)

    # Return BGCCell
    return BGCCell(encoder_cell, decoder_cell)

def build_bgc_cell_for_image(
        num_layers: int,
        embedding_dim: int,
        kernel_size: int,
        dropout_prob: float,
        gram_heads: int,
        **kwargs
    ) -> BGCCell:
    """
    Build a BGCCell for image data.
    :param num_layers: Number of layers in the PixelCell.
    :param embedding_dim: The dimension of the embeddings.
    :param kernel_size: The size of the convolutional kernel in the PixelCell.
    :param dropout_prob: The dropout probability in the PixelCell.
    :param gram_heads: The number of heads for the GramEncoder.
    :return: An instance of BGCCell configured for image data.
    """
    # Build encoder
    encoder = PixelCell(
        embedding_dim=embedding_dim,
        num_layers=num_layers,
        kernel_size=kernel_size,
        dropout_prob=dropout_prob
    )
    gram_encoder = GramEncoder(
        embedding_dim=embedding_dim,
        num_heads=gram_heads
    )
    feedforward = FeedForward(
        embedding_dim=embedding_dim,
        feedforward_dim=dim_feedforward
    )
    encoder_cell = BGCEncoderCell(encoder, gram_encoder, feedforward)

    # Build decoder
    decoder = PixelCell(
        embedding_dim=embedding_dim,
        num_layers=num_layers,
        kernel_size=kernel_size,
        dropout_prob=dropout_prob
    )
    decoder_cell = BGCDecoderCell(decoder)

    # Return BGCCell
    return BGCCell(encoder_cell, decoder_cell)

## BidirectionalGramConverter

The main converter model. It can convert into latent space, or from latent space back into an image grid. It has an encode mode, and a decode mode. 

**Premise**

* We can use the fact addition and subtraction are inverses to encode and decode respectively
* gram encodings make good latent representations for this task

**Dependencies**

*primary*

* embeddings: a nn.Embedding layer. Will be used to embed the int grid to a set of embeddings.
* logits: a nn.Module. Converts an embedding to logits that can be interpreted. 
* cells List[BGCCell]: A sequence of BGCCells which are the actual actions that will be done
* pos_encoding_2d: A positional encoding 2d mechanism that provides encodings with dimensionality embed_dim/2

*config*

* return_internal_embeddings [Bool, default: false]: Return the intermediate embeddings in a tensor

### Method: compute_pos_encoding

A helper method. Used to precompute positional encodings that contain positional information and the
information about the shape of the grid integrated

**Accepts**

* mask: The mask for the case. (batch x N x M)

**Returns** 

* pos_encoding_2d: A specialized 2d positional encoding used to encode x position, y position, and grid shape. (batch x N x M)

**Design**

We operate as follows.

* Figure out the size of the grid based on the mask. We get an int for x, and an int for y
* create 2d positional encodings big enough for the input based on the precomputed grid. This will only be half as long as it needs to be.
* Extract from the precomputed grid the element at x_size, y_size. Concatenate it onto the pos encodings we are building.
* return the encodings

This will mean part of the encoding encodes grid position, and another portion encodes grid size.

### Method: encode

Encodes the latent representation as a sequence of gram embeddings

**Accepts**

* input: A grid of integers. (batch x N x M)
* mask: A mask indicating what items are active. (batch x M x N)

**Returns**

If not return_internal_embeddings:

* embeddings: The output embeddings after processing (batch x N x M x E)
* gram_embeddings: The stack of gram embeddings (batch x L x E)

Else:

* embeddings: The output embeddings after processing. (batch x N x M x E)
* gram_embeddings: The stack of gram embeddings. (batch x L x E)

* internal_embeddings: The embeddings produced at each intermediate layer. (batch x L x N x M x E)

**Design**

We embed, encode positional encodings, and then run the encoding methods in order

* Embed: Run the input through the embedding layer. Mask it.
* Build encodings: Build the positional encodings
* Run layers: For each BGCCEll, moving forward
    * Call .encode
    * Store resulting gram embedding
    * update current embedding
    * if relevant, store intermediate embedding
* Stack gram_embeddings, if relevant stack intermediate embeddings
* return contents

### Method: decode

Decodes a latent representation based on gram embeddings back into the original content. Does this by running the layers in reverse

**Accepts**

* embeddings: A grid of pixel embeddings. (batch x N x M x E). May be filled with zeros
* mask: A activity mask. (batch x N x M). True means include
* gram_embeddings: The gram embeddings. (batch x L x E)

**Returns**

If not return_internal_embeddings:

* embeddings: The decoded embeddings
* logits: A logit projection for each encoding.

Else:
   
* embeddings: The decoded embeddings
* logits: A logit projection for each encoding.
* internal_embeddings: embedding stack from the various layers. Drawn from the same place as in the input encoding process.

**Design**

We basically run the encode process but in reverse. 

Start from the embedding and the gram_embeddings. For each gram_embedding, embedding running in reverse, call the associated decode method. Then project the embedding into logits and return.