In [2]:
from src import engine, data_setup, models, utils

In [3]:
device = engine.initialize()

[CONNECTED] NVIDIA GeForce RTX 3070


In [13]:
import torch
from torch import nn

class PatchEmbed(nn.Module):

    """Split image into patches and then embed them.
    
    Parameters
    ----------
    img_size : int
        Size of image (it is a square).

    patch_size : int
        Size of patch (it is a square).

    in_channels : int
        Number of input channels.

    embed_dim : int
        Dimension of embedding (output).

    Attributes
    ----------
    n_patches : int
        Number of patches inside image.

    proj : nn.Conv2d
        Convolutional layer that does both the splitting into patches and their embedding.
    """

    def __init__(self, img_size, patch_size, in_channels=3, embed_dim=768):
        super().__init__()
        
        self.img_size = img_size
        self.patch_size = patch_size
        self.in_channels = in_channels
        self.embed_dim = embed_dim

        self.n_patches = (img_size // patch_size) ** 2

        self.proj = nn.Conv2d(
            in_channels=in_channels, 
            out_channels=embed_dim,
            kernel_size=patch_size,
            stride=patch_size
        )

    def forward(self, x):

        """Run forward pass
        
        Parameters
        ----------
        x : torch.Tensor
            Shape `(n_samples, in_channels, patch_size, patch_size)`.

        Returns
        -------
        torch.Tensor
            Shape `(n_samples, n_patches, embed_dim)`.
        """

        x = self.proj(x)        # (n_samples, embed_dim, n_patches**.5, n_patches**.5)
        x = x.flatten(2)        # (n_samples, embed_dim, n_patches)
        x = x.transpose(1, 2)   # (n_samples, n_patches, embed_dim)

        return x
    
class Attention(nn.Module):

    """Attention mechanism
    
    Parameters
    ----------
    dim: int
        Embedding Dimension: The input and output dimension per token features. 

    n_heads: int
        Number of attention heads.

    qkv_bias: bool
        If True then we include bias to the query, key and value projections.

    attn_p: float
        Dropout probability applied to the query, key and value tensors.

    proj_p: float
        Dropout probability applied to the output tensor.

    Returns
    -------
    """



In [12]:
x = torch.randn(10, 700, 5, 5)
x.flatten(start_dim=2).shape

torch.Size([10, 700, 25])