In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR
from torchvision.models import vgg19

# Vision Transformer

In [2]:
class PositionalEncoding(nn.Module):

    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

In [152]:
class VisionTransformer(nn.Module):
    def __init__(self):
        super(VisionTransformer, self).__init__()
        self.embeddingLayer = nn.Conv2d(3, 64, 16, 16)
        self.positionalEncoding = PositionalEncoding(64, max_len=196)
        self.transformerEncoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(64, 8, 64, activation="gelu"), 6)
        cls_tensor = torch.randn(1,1,64)
        self.cls = nn.Parameter(cls_tensor)
        mask_token = torch.randn(1,1,64)
        self.mask_token = nn.Parameter(mask_token)

    def forward(self, x, unmasked, masked, swapped):
        x = self.embeddingLayer(x)
                
        n, c, w, h = x.shape
        x = torch.reshape(x, [n, h * w, c])
        
        swapped_embedding_ids = unmasked[:len(swapped)]
        swapped_embeddings = x[:,swapped_embedding_ids,:].detach()
        x[:,swapped_embedding_ids,:] = swapped_embeddings
        
        x[:,masked,:] = self.mask_token
        
        x = self.positionalEncoding(x)

        x = torch.cat((self.cls.repeat(n,1,1), x), 1)
        
        x = self.transformerEncoder(x)
        return x

In [138]:
test = torch.rand(10, 3, 244, 244)
vision_transformer = VisionTransformer()
res = vision_transformer(test, [1,2,3], [4], [5], [6])
res.shape

True


RuntimeError: The size of tensor a (225) must match the size of tensor b (196) at non-singleton dimension 1

# Raw feature extractor

In [5]:
vgg = vgg19(pretrained=True)

In [95]:
class PerceptualLoss(nn.Module):
    def __init__(self):
        super(PerceptualLoss, self).__init__()
        self.feature_extractor = nn.Sequential(*list(vgg.features.children())[:3])

    def forward(self, x, y, masked_patches):
        masked_patches_dims = []
        for masked_patch in masked_patches:
                height_offset = 16 * ((masked_patch) // 2)
                width_offset = 16*((masked_patch) % 2)
                masked_patches_dims.append([height_offset,width_offset])
        
        target_patches = []
        for target in range(y.shape[0]):
            for masked_patch in masked_patches_dims:
                height_offset, width_offset = masked_patch
                extracted_patch = y[target,:,height_offset:height_offset + 16,width_offset:width_offset + 16]
                target_patches.append(extracted_patch)

        target_patches_tensor = torch.stack(target_patches)
        target_patches_tensor = self.feature_extractor(target_patches_tensor)
        target_patches_tensor = F.avg_pool2d(target_patches_tensor, 16, 1)
        target_patches_tensor = target_patches_tensor.view(target_patches_tensor.shape[0],64)
        
        masked_patches_shifted = [masked_patch + 1 for masked_patch in masked_patches]
        x = x[:,masked_patches_shifted,:]
        x = x.view(target_patches_tensor.shape[0], 64)
        
        print(x.shape)
        print(target_patches_tensor.shape)
        
        loss = F.mse_loss(x, target_patches_tensor)
        return loss

In [79]:
p = PerceptualLoss()
y = torch.rand(2, 3, 32, 32)
x = torch.rand(2,4, 64)
p(x, y, 3)

TypeError: 'int' object is not iterable

# Train loader

In [80]:
transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize((224,224))
])

In [81]:
dataset = datasets.ImageFolder('../data', transform=transform)
train_loader = torch.utils.data.DataLoader(dataset, batch_size=32)

In [82]:
len(train_loader)

157

In [83]:
next(iter(train_loader))[0].shape

torch.Size([32, 3, 224, 224])

# Training step

In [153]:
def train(model, loss_fn, device, train_loader, optimizer, epoch):
    patches = list(range(14))
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        np.random.shuffle(patches)
        masked_patches = patches[:len(patches)//2]
        unmasked = [patch for patch in patches if patch not in masked_patches]
        tokenized, swapped, unchanged = np.split(masked_patches, [int(.8*len(masked_patches)),
                                                                  int(.9*len(masked_patches))])
        print(masked_patches)
        print(unmasked)
        print(tokenized, swapped, unchanged)
        optimizer.zero_grad()
        output = model(data, unmasked, masked_patches, swapped)
        loss = loss_fn(output, data, masked_patches)
        loss.backward()
        optimizer.step()
        if batch_idx % 10 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))

In [154]:
model = VisionTransformer().to("cpu")
optimizer = optim.Adadelta(model.parameters())
scheduler = StepLR(optimizer, step_size=1)
perceptual_loss = PerceptualLoss()

In [155]:
for epoch in range(1, 2):
    train(model, perceptual_loss, "cpu", train_loader, optimizer, epoch)
    scheduler.step()

[13, 10, 11, 2, 9, 6, 1]
[8, 12, 3, 0, 5, 4, 7]
[13 10 11  2  9] [6] [1]
torch.Size([224, 64])
torch.Size([224, 64])
[7, 4, 2, 0, 6, 11, 12]
[5, 9, 13, 1, 10, 3, 8]
[7 4 2 0 6] [11] [12]
torch.Size([224, 64])
torch.Size([224, 64])
[2, 4, 6, 10, 13, 0, 3]
[9, 7, 8, 5, 1, 11, 12]
[ 2  4  6 10 13] [0] [3]
torch.Size([224, 64])
torch.Size([224, 64])
[1, 11, 10, 12, 6, 7, 5]
[4, 9, 2, 3, 8, 0, 13]
[ 1 11 10 12  6] [7] [5]
torch.Size([224, 64])
torch.Size([224, 64])
[10, 3, 0, 1, 6, 4, 11]
[13, 9, 2, 8, 12, 7, 5]
[10  3  0  1  6] [4] [11]
torch.Size([224, 64])
torch.Size([224, 64])
[3, 10, 12, 11, 5, 6, 8]
[13, 2, 1, 7, 0, 4, 9]
[ 3 10 12 11  5] [6] [8]
torch.Size([224, 64])
torch.Size([224, 64])
[7, 0, 6, 2, 3, 8, 11]
[9, 4, 1, 10, 5, 12, 13]
[7 0 6 2 3] [8] [11]
torch.Size([224, 64])
torch.Size([224, 64])
[2, 10, 12, 9, 6, 3, 11]
[4, 5, 0, 8, 13, 1, 7]
[ 2 10 12  9  6] [3] [11]
torch.Size([224, 64])
torch.Size([224, 64])
[3, 7, 8, 12, 13, 10, 6]
[2, 0, 5, 11, 9, 1, 4]
[ 3  7  8 12 13] [10]

KeyboardInterrupt: 