In [76]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import math

In [3]:
import warnings
warnings.filterwarnings("ignore")

In [4]:
#load CIFAR-10 data
from torchvision import datasets
from torchvision import transforms
from torch.utils.data.sampler import SubsetRandomSampler
from torch.utils.data import DataLoader

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((32, 32))
])

train_data_cifar = datasets.CIFAR10('data_cifar', train=True, download=True, transform=transform)
test_data_cifar = datasets.CIFAR10('data_cifar', train=False, download=True, transform=transform)

train_loader_cifar = DataLoader(train_data_cifar, batch_size=64, shuffle=True)
test_loader_cifar = DataLoader(test_data_cifar, batch_size=64, shuffle=True)



Files already downloaded and verified
Files already downloaded and verified


In [5]:
#load mnist dataset

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((32, 32))
])

train_data_mnist = datasets.MNIST('data_mnist', train=True, download=True, transform=transform)
test_data_mnist = datasets.MNIST('data_mnist', train=False, download=True, transform=transform)

train_loader_mnist = DataLoader(train_data_mnist, batch_size=64, shuffle=True)
test_loader_mnist = DataLoader(test_data_mnist, batch_size=64, shuffle=True)


### Patch a Single Image

In [63]:
#patch a single image from the MNIST dataset
patch_size = 4
image_size = 32

image, label = train_data_mnist[0]
image = image.squeeze()
print(image.shape, label)
patches = image.reshape(image_size//patch_size, patch_size, -1, patch_size).swapaxes(1,2).reshape(-1, patch_size, patch_size)

# plt.imshow(image)
# plt.show()

# plt.figure(figsize=(10, 10))
# for i in range(patches.shape[0]):
#     plt.subplot(8, 8, i+1)
#     plt.imshow(patches[i])
#     plt.axis('off')

torch.Size([32, 32]) 5


### Patching a Batch of Images

In [65]:
class PatchEmbedding(nn.Module):
    def __init__(self, image_size, patch_size, in_channels, embed_dim):
        super(PatchEmbedding, self).__init__()
        self.image_size = image_size
        self.patch_size = patch_size
        self.projection = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        x = self.projection(x)  # (batch_size, embed_dim, num_patches_w, num_patches_h)
        x = x.permute(0, 2, 3, 1)  # (batch_size, num_patches_w, num_patches_h, embed_dim)
        return x

In [72]:
#testing the PatchEmbedding class on a batch of the MNIST dataset
image, label = next(iter(train_loader_mnist))
print(image.shape, label)

patch_size = 4
image_size = 32
in_channels = 1
embed_dim = 64

patch_embed = PatchEmbedding(image_size, patch_size, in_channels, embed_dim)
x = patch_embed(image)
print(x.shape)
#looks right...


torch.Size([64, 1, 32, 32]) tensor([3, 2, 3, 1, 3, 8, 9, 1, 9, 0, 8, 5, 2, 6, 2, 7, 3, 2, 8, 6, 2, 2, 8, 0,
        7, 8, 3, 3, 2, 7, 2, 5, 3, 3, 4, 3, 5, 6, 7, 1, 9, 2, 0, 3, 5, 3, 3, 3,
        5, 0, 0, 4, 0, 7, 1, 7, 3, 9, 4, 8, 3, 8, 9, 3])
torch.Size([64, 8, 8, 64])


### Absolute Positional Encoding (RPE-free Baseline)


In [77]:
class PositionalEncoding(nn.Module):

    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)


In [None]:
#testing the regular positional encoding on 

### RPE Methodologies
- General Learnable Function: $f_\Theta : \mathbb{R} \rightarrow \mathbb{R}$
- Monotonically Decreasing Function: $f = e^{-\alpha x}$
- Ratio of two polynomial functions: $f = \frac{h}{g}$

General Learnable Function: 

In [68]:
class GeneralLearnableFunction(nn.Module):
    def __init__(self, embedding_dim):
        super(GeneralLearnableFunction, self).__init__()
        self.fc = nn.Linear(1, embedding_dim)

    def forward(self, distances):
        batch_size, num_patches, _ = distances.size()
        distances = distances.unsqueeze(-1) 
        positional_encodings = self.fc(distances)
        return positional_encodings


In [73]:
#testing the general RPE function on a random symmetric distance matrix
glf = GeneralLearnableFunction(64)
distances = torch.randn(8, 8)
distances = distances + distances.transpose(0, 1)
distances = distances - torch.diag(distances.diagonal())
distances = distances.unsqueeze(0)
positional_encodings = glf(distances)
print(positional_encodings.shape)
print(positional_encodings[0, 0, 0]) #should be embedding of dimension 64 with random values

torch.Size([1, 8, 8, 64])
tensor([ 0.5504, -0.0272,  0.8280, -0.0335, -0.5601, -0.1169,  0.9159,  0.4995,
        -0.0534,  0.4756, -0.8882,  0.1858,  0.2654,  0.0998,  0.3770,  0.0293,
         0.7467, -0.8272,  0.7443, -0.1012,  0.2873, -0.4944, -0.9595, -0.4560,
         0.9493, -0.4706,  0.9414,  0.9434, -0.7811,  0.9605, -0.8934,  0.0583,
         0.5981, -0.4510, -0.8587,  0.2166,  0.9725,  0.0690, -0.7834, -0.1792,
        -0.4014, -0.0596, -0.1858,  0.2153,  0.1585, -0.9607, -0.0823,  0.9454,
        -0.3373, -0.7897,  0.9248, -0.4429,  0.1606,  0.2784, -0.3660,  0.4903,
        -0.6577,  0.3554, -0.0659,  0.9351,  0.7247,  0.3084,  0.8950,  0.2198],
       grad_fn=<SelectBackward0>)


In [74]:
#test adding the positional encodings to a batch of patch embeddings

print(x[0, 0, 0])
encoded_patches = x + positional_encodings
print(encoded_patches.shape)
print(encoded_patches[0, 0, 0])


tensor([-1.3224e-01, -1.7628e-03, -1.6146e-01,  3.1187e-02,  1.6961e-01,
         2.0496e-01, -6.3421e-02, -1.2596e-01, -6.0501e-02,  1.7331e-01,
        -1.1854e-01, -2.4561e-02, -1.2756e-02, -5.9731e-02, -1.0772e-02,
        -8.2388e-02, -8.3777e-02, -1.5694e-01,  4.7398e-02,  1.2669e-02,
         2.1559e-01, -1.3391e-01,  5.3702e-02,  1.1682e-01,  2.1324e-01,
        -4.7466e-02, -9.0230e-02,  1.9399e-01, -1.5392e-02,  1.2149e-01,
         1.4484e-01,  2.0822e-01,  1.3077e-01, -9.0394e-02, -1.4869e-01,
         1.9546e-02, -1.5460e-01,  2.2889e-01,  1.0046e-04, -6.6848e-02,
        -2.4300e-01, -1.4527e-01, -2.3898e-01, -1.7482e-01, -1.9071e-01,
         2.2830e-01, -1.4923e-01,  2.3880e-01,  1.3067e-01,  1.3353e-01,
        -1.4982e-01, -1.6649e-01,  1.0929e-01,  1.1080e-01, -7.0484e-02,
        -4.4598e-02, -2.4193e-01,  1.7137e-01, -5.9003e-02, -1.0231e-01,
        -1.6886e-01, -4.9098e-02,  1.0148e-01, -3.0133e-02],
       grad_fn=<SelectBackward0>)
torch.Size([64, 8, 8, 64])
te

In [75]:
print(x[5, 5, 5])
print(encoded_patches[5, 5, 5])

tensor([-0.0657, -0.3909,  0.3240,  0.3323, -0.1932, -0.1612,  0.5429, -0.4009,
        -0.9807, -0.2675,  0.2946, -0.9324, -0.2248,  0.0372, -0.7210, -0.8905,
        -0.8576, -0.6487,  0.4725,  0.5766, -0.2445, -0.4656, -0.5605, -0.3511,
         0.2134,  0.4169, -0.5610,  1.0331,  0.2065,  0.2802, -0.2306,  0.1593,
         0.2018, -0.6545,  1.1062, -0.5613, -0.1923,  0.4613,  0.6055,  0.8612,
        -0.0611, -0.6883,  0.1176, -0.0898,  0.2557,  0.9772,  1.0901,  0.5571,
         0.3595,  0.6526, -0.7675, -1.1276,  0.5628, -0.6863,  0.3676,  0.3017,
        -0.7870,  0.0079, -0.5558,  0.9946,  0.0598,  0.3599,  0.6744,  0.1113],
       grad_fn=<SelectBackward0>)
tensor([ 4.8467e-01, -4.1807e-01,  1.1520e+00,  2.9883e-01, -7.5329e-01,
        -2.7814e-01,  1.4588e+00,  9.8546e-02, -1.0341e+00,  2.0816e-01,
        -5.9358e-01, -7.4662e-01,  4.0611e-02,  1.3697e-01, -3.4401e-01,
        -8.6118e-01, -1.1085e-01, -1.4758e+00,  1.2168e+00,  4.7540e-01,
         4.2842e-02, -9.5998e-01,

this works!

Monotonically-Decreasing Function: 

In [None]:
class MonotonicFunction(nn.Module):
    def __init__(self, embedding_dim):
        super(MonotonicFunction, self).__init__()
        self.alpha = nn.Parameter(torch.randn(1))
        self.embedding_dim = embedding_dim

    def forward(self, distances):
        batch_size, num_patches, _ = distances.size()
        decay_factor = torch.exp(-self.alpha * distances)
        decay_factor_normalized = decay_factor / torch.sum(decay_factor, dim=-1, keepdim=True)
        positional_encodings = decay_factor_normalized.unsqueeze(-1).expand(batch_size, num_patches, num_patches, self.embedding_dim)
        
        return positional_encodings

Ratio of Two Polynomials: 

### Encoder Layer

In [64]:
class EncoderLayer(nn.Module):
    pass

### Encoder Block

In [61]:
class EncoderBlock(nn.Module):
    pass

### Classification Head

In [62]:
class ClassificationHead(nn.Module):
    def __init(self, embedding_dim, num_classes):
        super(ClassificationHead, self).__init__()
        self.fc = nn.Linear(embedding_dim, num_classes)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        x = self.fc(x)
        x = self.softmax(x)
        return x
    

### Vision Transformer

In [None]:
class ViT(nn.Module):
    def __init__(self):
        pass
    