In [51]:
import torch
import torch.nn as nn
from torch.optim import Adam

In [50]:
class SelfAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super(SelfAttention, self).__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads

        assert (self.head_dim * heads == embed_size), "Embed size needs to be divisible by heads"

        self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.fc_out = nn.Linear(heads * self.head_dim, embed_size)
    
    def forward(self, values, keys, query, mask):
        N = query.shape[0]
        value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]
        
        # Split embedding into self.heads different pieces
        values = values.reshape(N, value_len, self.heads, self.head_dim)
        keys = keys.reshape(N, key_len, self.heads, self.head_dim)
        queries = query.reshape(N, query_len, self.heads, self.head_dim)

        values = self.values(values)  # (N, value_len, heads, head_dim)
        keys = self.keys(keys)  # (N, key_len, heads, head_dim)
        queries = self.queries(queries)  # (N, query_len, heads, head_dim)
        
        energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])  # Good source on Einsum: https://www.youtube.com/watch?v=pkVwUVEHmfI
        # Queries shape: (N, query_len, heads, head_dim)
        # Keys shape: (N, key_len, heads, head_dim)
        # Energy shape: (N, heads, query_len, key_len)

        if mask is not None:
            energy = energy.masked_fill(mask == 0, float("-1e20"))
        
        attention = torch.softmax(energy / (self.embed_size ** (1/2)), dim=3)
        out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(
            N, query_len, self.heads * self.head_dim)
        # Attention shape: (N, heads, query_len, key_len)
        # Values shape: (N, value_len, heads, head_dim)\
        # Out shape : (N, query_len, heads, head_dim), then flatten last 2 dims
        out = self.fc_out(out)
        return out


class TransformerBlock(nn.Module):
    def __init__(self, embed_size, heads, dropout, forward_expansion):
        super(TransformerBlock, self).__init__()
        self.attention = SelfAttention(embed_size, heads)
        self.norm1 = nn.LayerNorm(embed_size)
        self.norm2 = nn.LayerNorm(embed_size)

        self.feed_forward = nn.Sequential(
            nn.Linear(embed_size, forward_expansion * embed_size),
            nn.ReLU(),
            nn.Linear(forward_expansion * embed_size, embed_size),
        )
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, value, key, query, mask):
        attention = self.attention(value, key, query, mask)
        x = self.dropout(self.norm1(attention + query))
        forward = self.feed_forward(x)
        out = self.dropout(self.norm2(forward + x))
        return out


class Encoder(nn.Module):
    def __init__(
        self,
        src_vocab_size,  # NOTE: Continuous rather than fixed vocab size
        embed_size,
        num_layers,
        heads,
        device,
        forward_expansion,
        dropout,
        max_length,
    ):
        super(Encoder, self).__init__()
        self.embed_size = embed_size
        self.device = device
        self.word_embedding = nn.Embedding(src_vocab_size, embed_size)  # NOTE: Retrieve various embeddings from embedder classes
        self.position_embedding = nn.Embedding(max_length, embed_size)
        self.layers = nn.ModuleList(
            [
                TransformerBlock(
                    embed_size,
                    heads,
                    dropout=dropout,
                    forward_expansion=forward_expansion
                )
                for _ in range(num_layers)
            ]
        )
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, mask):
        N, seq_length = x.shape
        positions = torch.arange(0, seq_length).expand(N, seq_length).to(self.device)
        out = self.dropout(self.word_embedding(x) + self.position_embedding(positions))  # NOTE: Convert to multimodal embedding
        for layer in self.layers:
            out = layer(out, out, out, mask)
        return out


class DecoderBlock(nn.Module):
    def __init__(self,
                 embed_size,
                 heads,
                 forward_expansion,
                 dropout,
                 device):
        super(DecoderBlock, self).__init__()
        self.attention = SelfAttention(embed_size, heads)
        self.norm = nn.LayerNorm(embed_size)
        self.transformer_block = TransformerBlock(
            embed_size, heads, dropout, forward_expansion
        )
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, value, key, src_mask, trg_mask):
        attention = self.attention(x, x, x, trg_mask)
        query = self.dropout(self.norm(attention + x))
        out = self.transformer_block(value, key, query, src_mask)
        return out


class Decoder(nn.Module):
    def __init__(self,
                 trg_vocab_size,  # NOTE: Continuous rather than fixed vocab size
                 embed_size,
                 num_layers,
                 heads,
                 forward_expansion,
                 dropout,
                 device,
                 max_length):
        super(Decoder, self).__init__()
        self.device = device
        self.word_embedding = nn.Embedding(trg_vocab_size, embed_size)  # NOTE: Convert to multimodal embedding
        self.position_embedding = nn.Embedding(max_length, embed_size)
        self.layers = nn.ModuleList(
            [
                DecoderBlock(
                    embed_size,
                    heads,
                    forward_expansion,
                    dropout,
                    device
                )
                for _ in range(num_layers)
            ]
        )
        self.fc_out = nn.Linear(embed_size, trg_vocab_size)  # NOTE: Convert to multimodal embedding
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, enc_out, src_mask, trg_mask):
        N, seq_length = x.shape
        positions = torch.arange(0, seq_length).expand(N, seq_length).to(self.device)
        x = self.dropout((self.word_embedding(x) + self.position_embedding(positions)))
        for layer in self.layers:
            x = layer(x, enc_out, enc_out, src_mask, trg_mask)
        out = self.fc_out(x)
        return out


class Transformer(nn.Module):
    def __init__(self,
                 src_vocab_size,  # NOTE: Continuous rather than fixed vocab size
                 trg_vocab_size,  # NOTE: Continuous rather than fixed vocab size
                 src_pad_idx,
                 trg_pad_idx,
                 embed_size=256,
                 num_layers=6,
                 forward_expansion=4,
                 heads=8,
                 dropout=0,
                 device="cuda",
                 max_length=300):
        super(Transformer, self).__init__()

        self.encoder = Encoder(
            src_vocab_size,  # NOTE: Continuous rather than fixed vocab size
            embed_size,
            num_layers,
            heads,
            device,
            forward_expansion,
            dropout,
            max_length,
        )

        self.decoder = Decoder(
            trg_vocab_size,   # NOTE: Continuous rather than fixed vocab size
            embed_size,
            num_layers,
            heads,
            forward_expansion,
            dropout,
            device,
            max_length,
        )

        self.src_pad_idx = src_pad_idx
        self.trg_pad_idx = trg_pad_idx
        self.device = device
    
    def make_src_mask(self, src):
        src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2)
        # src_mask shape: (N, 1, 1, src_len)
        return src_mask.to(self.device)
    
    def make_trg_mask(self, trg):
        N, trg_len = trg.shape
        trg_mask = torch.tril(torch.ones((trg_len, trg_len))).expand(
            N, 1, trg_len, trg_len
        )
        # trg_mask shape: (N, 1, trg_len, trg_len)
        return trg_mask.to(self.device)
    
    def forward(self, src, trg):
        src_mask = self.make_src_mask(src)
        trg_mask = self.make_trg_mask(trg)
        enc_src = self.encoder(src, src_mask)
        out = self.decoder(trg, enc_src, src_mask, trg_mask)
        return out

In [66]:
class ETPatchEmbed(nn.Module):
    """
    Takes a 1D convolution of a patch of ET data to embed and returns the embedded data.
    The out.permute() is done to make the data compatible with the 1D convolution over batches.
    If OSIE is True, the oracle is used to generate a source and target sequence.
    """
    def __init__(self,
                 sample_height=300,
                 in_channels=3,
                 patch_size=15,
                 embed_dim=768,
                 kernel_size=15,
                 stride=15,
                 oracle=None,
                 OSIE=False):
        super().__init__()
        # These three lines are to make the patch embedding compatible with ETTransformer
        # -----
        sample_size = tuple((sample_height, in_channels))
        patch_size = tuple((patch_size, in_channels))
        num_patches = (sample_size[1] // patch_size[1] * sample_size[0] // patch_size[0])
        # -----
        self.sample_size = sample_size
        self.patch_size = patch_size
        self.num_patches = num_patches
        self.norm = nn.BatchNorm1d(sample_height)
        self.projection = nn.Conv1d(in_channels=sample_height,
                                    out_channels=embed_dim,
                                    kernel_size=kernel_size,
                                    stride=stride)
        # self.OSIE = OSIE
        # self.oracle = oracle

    def forward(self, x):
        # if self.OSIE:
        #     x, y = self.oracle.generate_batch(x)
        #     x = self.norm(x)
        #     x = self.projection(x.permute(0, 2, 1)).transpose(1, 2)
        #     return x, y
        # else:
        x = self.norm(x.permute(0, 2, 1))
        x = self.projection(x.permute(0, 2, 1)).transpose(1, 2)
        return x


class ImagePatchEmbed(nn.Module):
    """
    Takes a 2D convolution of a patch of an image data to embed in a lower dimensional space.
    """
    def __init__(self,
                 img_width=800,
                 img_height=600,
                 patch_size=25,
                 in_channels=3,
                 embed_dim=768):
        super().__init__()
        # self.img_size = img_size
        self.patch_size = patch_size
        # self.num_patches = (self.img_size // self.patch_size) ** 2
        self.projection = nn.Conv2d(in_channels=in_channels,
                                    out_channels=embed_dim,
                                    kernel_size=patch_size,
                                    stride=patch_size)
    def forward(self, x):
        x = self.projection(x)
        x = x.flatten(2).transpose(1, 2)
        # print(x.shape)
        return x


class SemanticEmbedding(nn.Module):
    """
    Takes semantic tags per pixel of an image and embeds the information in the same space as ImagePatchEmbed.
    """
    def __init__(self,
                 img_width=800,
                 img_height=600,
                 patch_size=25,
                 in_channels=12,
                 embed_dim=768):
        super().__init__()
        # self.img_size = img_size
        self.patch_size = patch_size
        # self.num_patches = (self.img_size // self.patch_size) ** 2
        self.projection = nn.Conv2d(in_channels=in_channels,
                                    out_channels=embed_dim,
                                    kernel_size=patch_size,
                                    stride=patch_size)
    def forward(self, x):
        x = self.projection(x)
        x = x.flatten(2).transpose(1, 2)
        # print(x.shape)
        return x


class MultimodalBottleneckTransformer(nn.Module):
    def __init__(self,
                 num_layers,
                 num_heads,
                 et_embed_dim,
                 img_embed_dim,
                 sem_embed_dim,
                 et_patch_size,
                 img_patch_size,
                 sem_patch_size,
                 et_sample_height,
                 img_height,
                 img_width):
        super(MultimodalBottleneckTransformer, self).__init__()
        self.et_embed = ETPatchEmbed(sample_height=et_sample_height,
                                      in_channels=3,
                                      patch_size=et_patch_size,
                                      embed_dim=et_embed_dim,
                                      kernel_size=et_patch_size,
                                      stride=et_patch_size)
        self.img_embed = ImagePatchEmbed(img_width=img_width,
                                         img_height=img_height,
                                         patch_size=img_patch_size,
                                         in_channels=3,
                                         embed_dim=img_embed_dim)
        self.sem_embed = SemanticEmbedding(img_width=img_width,
                                           img_height=img_height,
                                           patch_size=sem_patch_size,
                                           in_channels=12,
                                           embed_dim=sem_embed_dim)

        self.transformer = nn.Transformer(
            d_model=et_embed_dim + img_embed_dim + sem_embed_dim,
            nhead=num_heads,
            num_encoder_layers=num_layers,
            num_decoder_layers=num_layers,
            dim_feedforward=4 * (et_embed_dim + img_embed_dim + sem_embed_dim)
        )

    def forward(self, et_data, img_data, sem_data):
        et_emb = self.et_embed(et_data)
        img_emb = self.img_embed(img_data)
        sem_emb = self.sem_embed(sem_data)

        embeddings = torch.cat([et_emb, img_emb, sem_emb], dim=-1)

        transformer_out = self.transformer(embeddings)
        return transformer_out


# Instantiate the model
model = MultimodalBottleneckTransformer(num_layers=6,
                                        num_heads=8,
                                        et_embed_dim=768,
                                        img_embed_dim=768,
                                        sem_embed_dim=768,
                                        et_patch_size=15,
                                        img_patch_size=25,
                                        sem_patch_size=25,
                                        et_sample_height=300,
                                        img_height=600,
                                        img_width=800)

# Initialize input data
et_data = torch.randn(8, 3, 300)
img_data = torch.randn(8, 3, 600, 800)
sem_data = torch.randn(8, 12, 600, 800)

# Forward pass
output = model(et_data, img_data, sem_data)
print(output.shape)


RuntimeError: Given groups=1, weight of size [768, 300, 15], expected input[8, 3, 300] to have 300 channels, but got 3 channels instead