# Inpainting Experiment with ViT Transformer

## Dataset

In [37]:
from torchvision.transforms import transforms
from torch.utils.data import DataLoader
from vit.datasets import SingleImageDataset

ds = SingleImageDataset(transform=transforms.ToTensor(), target_transform=transforms.ToTensor())
dl = DataLoader(ds, batch_size=4, num_workers=2)

## Create model


In [38]:
from matplotlib import pyplot as plt
from torch import nn


# custom loss function

class PatchedMSELoss:

    def __init__(self, patch_coords):
        self.patch_coords = patch_coords

        self.loss = nn.MSELoss()

    def __call__(self, targets, outputs):
        x_min, y_min, h, w = self.patch_coords

        patch = torch.ones(1, 1, *outputs.shape[2:], dtype=torch.bool, device=outputs.device)
        self.patch = patch
        patch[0, x_min:x_min+h, y_min:y_min+w] = 0

        return self.loss(targets * patch, outputs * patch)

In [39]:
from vit.models import HourglassViT
import torch

patch_size = (16, 16)

device = torch.device('cuda')

model = HourglassViT(
        image_size=ds.image_size,
        patch_size=patch_size,
        num_classes=10,
        dim=1024,
        depth=6,
        heads=16,
        mlp_dim=2048,
        dropout=0.1,
        emb_dropout=0.1).float().to(device)

criterion = PatchedMSELoss([200, 400, 200, 200])
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

## Train model

In [None]:
from tqdm import trange

model.train()

train_loss = []

epoch_progress = trange(100)
for epoch in epoch_progress:
    epoch_progress.set_description(f'Epoch {epoch}')
    running_loss = 0.
    running_count = 0

    for inputs, outputs in dl:
        optimizer.zero_grad()

        pred_outputs = model(inputs.to(device))
        loss = criterion(outputs.to(device), pred_outputs)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        running_count += inputs.size(0)

        epoch_progress.set_postfix({
            'train loss': f'{running_loss / running_count:.3f}'
        })

    train_loss.append(running_loss / running_count)
    torch.save(model.state_dict(), f'checkpoints/model_checkpoint-{epoch}.pt')

Epoch 0:   0%|          | 0/100 [00:56<?, ?it/s, train loss=2147.261]

## Evaluation

In [None]:
plt.imshow(ds.image / 255)
plt.show()

In [None]:
test_dl = DataLoader(ds)
for test_input, test_image in test_dl:
    pred_image = model(test_input.to(device)).detach().cpu()

    plt.imshow(test_image[0, ...].permute((1, 2, 0)) / 255)
    plt.show()

    plt.imshow((criterion.patch * test_image)[0, ...].permute((1, 2, 0)) / 255)
    plt.show()

    plt.imshow(pred_image[0, ...].permute((1, 2, 0)) / 255)
    plt.show()

    break