In [94]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR
from torchvision.models import vgg19

# Vision Transformer

In [95]:
class PositionalEncoding(nn.Module):

    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

In [156]:
class VisionTransformer(nn.Module):
    def __init__(self):
        super(VisionTransformer, self).__init__()
        self.embeddingLayer = nn.Conv2d(3, 64, 16, 16)
        self.positionalEncoding = PositionalEncoding(64, max_len=196)
        self.transformerEncoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(64, 8, 64, activation="gelu"), 6)
        cls_tensor = torch.randn(1,1,64)
        self.cls = nn.Parameter(cls_tensor)

    def forward(self, x):
        x = self.embeddingLayer(x)
                
        n, c, w, h = x.shape
        x = torch.reshape(x, [n, h * w, c])
        
        x = self.positionalEncoding(x)

        x = torch.cat((self.cls.repeat(n,1,1), x), 1)
        
        x = self.transformerEncoder(x)
        return x

In [72]:
test = torch.rand(10, 3, 32, 32)
vision_transformer = VisionTransformer()
res = vision_transformer(test)
res.shape

RuntimeError: The size of tensor a (4) must match the size of tensor b (1040) at non-singleton dimension 1

# Raw feature extractor

In [97]:
vgg = vgg19(pretrained=True)

In [160]:
class PerceptualLoss(nn.Module):
    def __init__(self):
        super(PerceptualLoss, self).__init__()
        self.feature_extractor = nn.Sequential(*list(vgg.features.children())[:3])

    def forward(self, x, y, masked_patch):
        height_offset = 16 * ((masked_patch) // 2)
        width_offset = 16*((masked_patch) % 2)
        y = y[:,:,height_offset:height_offset + 16,width_offset:width_offset + 16]
              
        y = self.feature_extractor(y)
        y = F.avg_pool2d(y, 16, 1)
        y = y.view(y.shape[0],64)
        
        x = x[:,masked_patch,:]
        
        loss = F.mse_loss(x, y)
        return loss

In [75]:
p = PerceptualLoss()
y = torch.rand(2, 3, 32, 32)
x = torch.rand(2,4, 64)
p(x, y, 3)

16
16


RuntimeError: shape '[1, 64]' is invalid for input of size 128

# Train loader

In [151]:
transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize((224,224))
])

In [152]:
dataset = datasets.ImageFolder('../data', transform=transform)
train_loader = torch.utils.data.DataLoader(dataset, batch_size=32)

In [146]:
len(train_loader)

157

In [143]:
next(iter(train_loader))[0].shape

torch.Size([1, 3, 224, 336])

# Training step

In [161]:
def train(model, loss_fn, device, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        masked_patch = np.random.randint(0,14)
        optimizer.zero_grad()
        output = model(data)
        loss = loss_fn(output, data, masked_patch)
        loss.backward()
        optimizer.step()
        if batch_idx % 10 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))

In [162]:
model = VisionTransformer().to("cpu")
optimizer = optim.Adadelta(model.parameters())
scheduler = StepLR(optimizer, step_size=1)
perceptual_loss = PerceptualLoss()

In [163]:
for epoch in range(1, 2):
    train(model, perceptual_loss, "cpu", train_loader, optimizer, epoch)
    scheduler.step()

