In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
from torch import optim

# Image Captioning Based on Transformer
* Encoder: ViT
* Decoder: Original Decoder

#### Encoder ViT
<div style="display: flex;">
  <img src="Img/ViT.png" alt="ViT" style="width: 80%;">
  <img src="Img/Decoder.png" alt="Decoder" style="width: 20%;">
</div>

### Patch Embedding
```
    Separate original image to n patches (tokens)
    Convert tokens to 1D tensor
```

In [2]:
class PatchEmbedding(nn.Module):
    def __init__(self, img_size, patch_size, in_channels = 3, embed_dim = 512):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.in_channels = in_channels
        self.embed_dim = embed_dim

        self.num_patches = (img_size // patch_size) ** 2
        self.conv = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, img):
        # img: (batch_size, in_channels, img_size, img_size)
        img = self.conv(img)            # (batch_size, embed_dim,  img_size // batch_size, img_size // batch_size)
        img = img.flatten(2)            # (batch_size, embed_dim, num_patches)
        img = img.transpose(1, 2)       # (batch_size, num_patches, embed_dim)
        return img

In [3]:
# Test PatchEmbedding
X = torch.ones((4, 3, 96, 96))
out = PatchEmbedding(96, 16, 3, 512)(X)
out.shape

torch.Size([4, 36, 512])

### Encoder Block
```
    Sublayer 1:
        Residual connection
        Layer Norm
        Multi Head Attention
    Sublayer 2:
        Residual connection
        Layer Norm
        MLP (With dropout)
```

In [4]:
class EncoderBlock(nn.Module):
    def __init__(self, embed_dim = 512, num_heads = 8, mlp_hidden_size = 1024, dropout = 0.1):
        super().__init__()
        self.norm1 = nn.LayerNorm(embed_dim)
        self.attention = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=8, batch_first=True, dropout=dropout)

        self.norm2 = nn.LayerNorm(embed_dim)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, mlp_hidden_size),
            nn.GELU(),
            nn.Dropout(p=dropout),
            nn.Linear(mlp_hidden_size, embed_dim),
            nn.Dropout(p=dropout)
        )

        
    def forward(self, X):
        # X: (batch_size, num_patches, embed_dim)
        normed_X = self.norm1(X)
        X = X + self.attention(normed_X, normed_X, normed_X, need_weights=False)[0]
        # attention returns output and attention weights

        normed_X = self.norm2(X)
        return X + self.mlp(normed_X)

In [5]:
# test EncoderBlock
X = torch.ones((4, 36, 512))
encoder = EncoderBlock()

out = encoder(X)
out.shape

torch.Size([4, 36, 512])

### Vision Transformer
```
    Patch Embedding
    Position Embedding
    Encoder (Fed one by one to stack)
```

In [None]:
class ViT(nn.Module):
    def __init__(self, img_size, patch_size, 
                 embed_dim=512, embed_dropout=0.1, 
                 num_blocks=2, num_heads=8, mlp_hidden_size=1024, mlp_dropout=0.1):
        super().__init__()
        self.patch_embedding = PatchEmbedding(img_size, patch_size, 3, embed_dim)

        # pos_embedding is same for every batch
        self.pos_embedding = nn.Parameter(torch.randn(1, (img_size // patch_size) ** 2, embed_dim))
        self.embed_dropout = nn.Dropout(embed_dropout)

        # Encoder
        modules = [EncoderBlock(embed_dim, num_heads, mlp_hidden_size, mlp_dropout)] * num_blocks
        self.encoder_blocks = nn.Sequential(*modules)
        

    def forward(self, X):
        # X: (batch_size, in_channels, img_size, img_size)
        X = self.patch_embedding(X)         # (batch_size, num_patches, embed_dim)

        for i in range(X.shape[0]):
            X[i] += self.pos_embedding[0]
        X = self.embed_dropout(X)

        X = self.encoder_blocks(X)
        return X

In [7]:
# test ViT
X = torch.ones((4, 3, 96, 96))
model = ViT(96, 16, 512)
out = model(X)
out.shape

torch.Size([4, 36, 512])

### Decoder Block

In [8]:
mask = torch.triu(torch.ones(3, 3), diagonal=1).bool()
mask

tensor([[False,  True,  True],
        [False, False,  True],
        [False, False, False]])

In [9]:
class DecoderBlock(nn.Module):
    def __init__(self, embed_dim = 512, num_heads = 8, dropout = 0.1, mlp_hidden_size = 1024):
        super().__init__()
        self.mask_attention = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
        self.lNorm1 = nn.LayerNorm(embed_dim)

        self.e_d_attention = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
        self.lNorm2 = nn.LayerNorm(embed_dim)

        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, mlp_hidden_size),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(mlp_hidden_size, embed_dim),
            nn.Dropout(dropout)
        )
        self.lNorm3 = nn.LayerNorm(embed_dim)


    def forward(self, X, encoder_output):
        # X: (batch_size, num_steps, embed_dim)
        num_steps = X.shape[1]

        # mask to calculate attention only for first num_steps + 1 steps
        mask = torch.triu(torch.ones(num_steps, num_steps), diagonal=1).bool()
        self_attended = self.mask_attention(X, X, X, attn_mask=mask)
        X = self.lNorm1(X + self_attended[0])

        # encoder-decoder attention
        e_d_attended = self.e_d_attention(X, encoder_output, encoder_output)
        X = self.lNorm2(X + e_d_attended[0])

        # mlp
        X = self.lNorm3(X + self.mlp(X))
        return X

In [10]:
# test decoder block
X = torch.ones((4, 36, 512))
encoder_output = torch.ones((4, 36, 512))
decoder = DecoderBlock()
out = decoder(X, encoder_output)
out.shape

torch.Size([4, 36, 512])

### Decoder

### Positional Encoding
$$PE_{(pos,2i)} = \sin(pos / 10000^{2i/d_{\text{model}}})$$

$$PE_{(pos,2i+1)} = \cos(pos / 10000^{2i/d_{\text{model}}})$$

In [11]:
class PositionalEncoding(nn.Module):
    def __init__(self, embed_dim, max_len = 5000):
        super().__init__()
        self.pos_encoding = torch.zeros(max_len, embed_dim)
        position = torch.arange(0, max_len).unsqueeze(1).float()
        div_term = torch.exp(torch.arange(0, embed_dim, 2).float() * (-np.log(10000.0) / embed_dim))
        self.pos_encoding[:, 0::2] = torch.sin(position * div_term)
        self.pos_encoding[:, 1::2] = torch.cos(position * div_term)
        self.pos_encoding = self.pos_encoding.unsqueeze(0)

    def forward(self, X):
        X = X + self.pos_encoding[:, :X.shape[1]]
        return X

In [12]:
class Decoder(nn.Module):
    def __init__(self, vocab_size, embed_dim=512, num_blocks=2,
                  num_heads=8, mlp_hidden_size=1024, dropout=0.1):
        super().__init__()

        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.pos_encoding = PositionalEncoding(embed_dim)

        # Decoder Blocks
        modules = [DecoderBlock(embed_dim, num_heads, dropout, mlp_hidden_size)] * num_blocks
        self.decoder_blocks = nn.Sequential(*modules)

        # output layer
        self.fc = nn.Linear(embed_dim, vocab_size)

    def forward(self, X, encoder_output):
        # X: (batch_size, num_steps)
        # encoder_output: (batch_size, num_patches, embed_dim)
        X = self.embedding(X)           # (batch_size, num_steps, embed_dim)

        X = self.pos_encoding(X)        # (batch_size, num_steps, embed_dim)

        for block in self.decoder_blocks:
            X = block(X, encoder_output)        # (batch_size, num_steps, embed_dim)

        return self.fc(X)           # (batch_size, num_steps, vocab_size)

In [13]:
# test Decoder
vocab_size, num_steps = 10000, 20

X = torch.randint(0, vocab_size, (4, num_steps))
encoder_output = torch.ones((4, 36, 512))
decoder = Decoder(vocab_size)

out = decoder(X, encoder_output)
out.shape

torch.Size([4, 20, 10000])

In [None]:
def generator(output, vocab_size):
    # output: (batch_size, num_steps, vocab_size)
    return torch.argmax(output[:, -1], dim=-1)      # (batch_size, )

generator(out, vocab_size)

tensor([8685, 1270, 4086, 4023])

## Model
```
    Combine Encoder and Decoder
```

In [None]:
class Transformer(nn.Module):
    def __init__(self, img_size, patch_size, vocab_size, embed_dim=512, num_blocks=2,
                  num_heads=8, mlp_hidden_size=1024, dropout=0.1):
        super().__init__()
        self.vit = ViT(
            img_size, patch_size, embed_dim, num_blocks=num_blocks, 
            num_heads=num_heads, mlp_hidden_size=mlp_hidden_size, mlp_dropout=dropout
        )
        self.decoder = Decoder(
            vocab_size, embed_dim, num_blocks, 
            num_heads, mlp_hidden_size, dropout
        )

    def forward(self, img, state):
        # img: (batch_size, in_channels, img_size, img_size)
        # target: (batch_size, num_steps)
        encoder_output = self.vit(img)                  # (batch_size, num_patches, embed_dim)
        return self.decoder(state, encoder_output)     # (batch_size, num_steps, vocab_size)

In [20]:
# test Transformer
vocab_size, num_steps = 10000, 20

X = torch.ones((4, 3, 96, 96))
state = torch.randint(0, vocab_size, (4, num_steps))

model = Transformer(96, 16, vocab_size)
out = model(X, state)
out.shape, out[:, -1, :].shape

(torch.Size([4, 20, 10000]), torch.Size([4, 10000]))

In [17]:
generator(out, vocab_size)

tensor([2832, 5690, 5770, 6117])

In [18]:
# Test 1 training step

crossEntropy = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

In [51]:
vocab_size = 26
X = torch.ones((4, 3, 96, 96))
state = torch.randint(0, vocab_size, (4, num_steps))
label = torch.ones((4)).type(torch.LongTensor)

model.train()
out = model(X, state)

print(out[:, -1, :].shape, label.shape)

# crossEntropy((batch_size, class), (batch_size))
loss = crossEntropy(out[:, -1, :], label)

optimizer.zero_grad()
loss.backward()
optimizer.step()

torch.Size([4, 26]) torch.Size([4])


## Handle data

#### Convert Image to Tensor have shape (channels, img_width, img_height)

In [None]:
# convert image to tensor
def image_to_tensor(image_path):
    """Converts an image to a tensor with shape (channels, image_width, image_height)."""
    try:
        img = Image.open(image_path)
        img = img.convert('RGB')  # Ensure 3 channels
        img_array = np.array(img)
        tensor = torch.tensor(img_array, dtype=torch.float32).permute(2, 0, 1)
        return tensor
    except FileNotFoundError:
        print(f"Error: Image file not found at {image_path}")
        return None
    except Exception as e:
        print(f"An error occurred: {e}")
        return None

#### Handle text
```
    Clean Text
    Build Vocab
    Convert to Tensor have shape (length, index)
```