In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F


In [9]:
#load CIFAR-10 data
from torchvision import datasets
from torchvision import transforms
from torch.utils.data.sampler import SubsetRandomSampler
from torch.utils.data import DataLoader

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

train_data_cifar = datasets.CIFAR10('data_cifar', train=True, download=True, transform=transform)
test_data_cifar = datasets.CIFAR10('data_cifar', train=False, download=True, transform=transform)



Files already downloaded and verified
Files already downloaded and verified


In [10]:
#load mnist dataset
from torchvision import datasets
from torchvision import transforms
from torch.utils.data.sampler import SubsetRandomSampler

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

train_data_mnist = datasets.MNIST('data_mnist', train=True, download=True, transform=transform)
test_data_mnist = datasets.MNIST('data_mnist', train=False, download=True, transform=transform)


In [11]:
class General_Learnable_RPE(nn.Module):
    def __init__(self, embed_dim, num_heads, max_len=512):
        super(General_Learnable_RPE, self).__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.max_len = max_len
        self.rpe = nn.Parameter(torch.randn(num_heads, embed_dim//num_heads))
    
    def forward(self, distances):
        batch_size, num_patches, _ = distances.size()
        distances = distances.unsqueeze(1).repeat(1, self.num_heads, 1, 1)
        distances = distances / self.max_len
        rpe = self.rpe.unsqueeze(1).unsqueeze(1)
        rpe = rpe.repeat(batch_size, num_patches, 1, 1)
        rpe = rpe * distances
        rpe = torch.sum(rpe, dim=-1)
        return rpe

In [4]:
class PatchEmbedding(nn.Module):
    def __init__(self, image_size, patch_size, in_channels, embed_dim):
        super(PatchEmbedding, self).__init__()
        self.image_size = image_size
        self.patch_size = patch_size
        self.projection = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        x = self.projection(x)  # (batch_size, embed_dim, num_patches_w, num_patches_h)
        x = x.permute(0, 2, 3, 1)  # (batch_size, num_patches_w, num_patches_h, embed_dim)
        return x