In [2]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt

In [3]:
import warnings
warnings.filterwarnings("ignore")

In [4]:
#load CIFAR-10 data
from torchvision import datasets
from torchvision import transforms
from torch.utils.data.sampler import SubsetRandomSampler
from torch.utils.data import DataLoader

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((32, 32))
])

train_data_cifar = datasets.CIFAR10('data_cifar', train=True, download=True, transform=transform)
test_data_cifar = datasets.CIFAR10('data_cifar', train=False, download=True, transform=transform)

train_loader_cifar = DataLoader(train_data_cifar, batch_size=64, shuffle=True)
test_loader_cifar = DataLoader(test_data_cifar, batch_size=64, shuffle=True)



Files already downloaded and verified
Files already downloaded and verified


In [5]:
#load mnist dataset

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((32, 32))
])

train_data_mnist = datasets.MNIST('data_mnist', train=True, download=True, transform=transform)
test_data_mnist = datasets.MNIST('data_mnist', train=False, download=True, transform=transform)

train_loader_mnist = DataLoader(train_data_mnist, batch_size=64, shuffle=True)
test_loader_mnist = DataLoader(test_data_mnist, batch_size=64, shuffle=True)


### Patch a Single Image

In [63]:
#patch a single image from the MNIST dataset
patch_size = 4
image_size = 32

image, label = train_data_mnist[0]
image = image.squeeze()
print(image.shape, label)
patches = image.reshape(image_size//patch_size, patch_size, -1, patch_size).swapaxes(1,2).reshape(-1, patch_size, patch_size)

# plt.imshow(image)
# plt.show()

# plt.figure(figsize=(10, 10))
# for i in range(patches.shape[0]):
#     plt.subplot(8, 8, i+1)
#     plt.imshow(patches[i])
#     plt.axis('off')

torch.Size([32, 32]) 5


### RPE Methodologies
- General Learnable Function: $f_\Theta : \mathbb{R} \rightarrow \mathbb{R}$
- Monotonically Decreasing Function: $f = e^{-\alpha x}$
- Ratio of two polynomial functions: $f = \frac{h}{g}$

General Learnable Function: 

In [50]:
class GeneralLearnableFunction(nn.Module):
    def __init__(self, embedding_dim):
        super(GeneralLearnableFunction, self).__init__()
        self.fc = nn.Linear(1, embedding_dim)

    def forward(self, distances):
        batch_size, num_patches, _ = distances.size()
        distances = distances.unsqueeze(-1) 
        positional_encodings = self.fc(distances)
        return positional_encodings


In [57]:
#testing the general RPE function on a random symmetric distance matrix
glf = GeneralLearnableFunction(64)
distances = torch.randn(8, 8)
distances = distances + distances.transpose(0, 1)
distances = distances - torch.diag(distances.diagonal())
distances = distances.unsqueeze(0)
positional_encodings = glf(distances)
print(positional_encodings.shape)
print(positional_encodings[0, 0, 0]) #should be embedding of dimension 64 with random values

torch.Size([1, 8, 8, 64])
tensor([ 0.5951,  0.7530, -0.6586,  0.2280, -0.6728, -0.2245, -0.7294, -0.1395,
         0.5804, -0.4483, -0.9598,  0.5763,  0.6590,  0.3613, -0.0698,  0.5394,
        -0.4711,  0.8144,  0.2408,  0.0364, -0.9000, -0.9715,  0.1651, -0.8909,
        -0.0119,  0.3047, -0.3610,  0.9064, -0.0058,  0.7433, -0.8007, -0.9259,
        -0.6425, -0.0829, -0.1002, -0.1912,  0.9596, -0.8411, -0.0797, -0.8629,
         0.2683,  0.7774, -0.6628,  0.8056,  0.0256, -0.3716, -0.2160, -0.8486,
         0.0528, -0.9476, -0.4661, -0.2484, -0.7321,  0.9030,  0.1189,  0.8965,
         0.9787, -0.4790, -0.4192, -0.0835, -0.4569, -0.6565,  0.1636,  0.8846],
       grad_fn=<SelectBackward0>)


Monotonically-Decreasing Function: 

In [None]:
class MonotonicFunction(nn.Module):
    def __init__(self, embedding_dim):
        super(MonotonicFunction, self).__init__()
        self.alpha = nn.Parameter(torch.randn(1))
        self.embedding_dim = embedding_dim

    def forward(self, distances):
        batch_size, num_patches, _ = distances.size()
        decay_factor = torch.exp(-self.alpha * distances)
        decay_factor_normalized = decay_factor / torch.sum(decay_factor, dim=-1, keepdim=True)
        positional_encodings = decay_factor_normalized.unsqueeze(-1).expand(batch_size, num_patches, num_patches, self.embedding_dim)
        
        return positional_encodings

Ratio of Two Polynomials: 

Absolute Positional Encoding:

In [None]:
#regular positonal encoding


### Transformer Encoder Layer

In [60]:
class TransformerEncoderLayer():
    pass

### Encoder Block

In [61]:
class EncoderBlock():
    pass

### Classification Head

In [62]:
class ClassificationHead(nn.Module):
    def __init(self, embedding_dim, num_classes):
        super(ClassificationHead, self).__init__()
        self.fc = nn.Linear(embedding_dim, num_classes)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        x = self.fc(x)
        x = self.softmax(x)
        return x
    