In [None]:
import pytorch_lightning as pl
from pl_bolts.models.vision import ImageGPT
from torchvision import datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import os
import torch
import matplotlib.pyplot as plt

In [None]:
training_data = datasets.FashionMNIST(root="data", train=True, download=True, transform=transforms.ToTensor())
test_data = datasets.FashionMNIST(root="data", train=False, download=True, transform=transforms.ToTensor())
train_loader = DataLoader(training_data, batch_size=16, num_workers=os.cpu_count()//2)
test_loader = DataLoader(test_data, batch_size=16, num_workers=os.cpu_count()//2)

In [None]:
trainer = pl.Trainer(max_epochs=5)
model = ImageGPT()
trainer.fit(model, train_dataloaders=train_loader)
torch.save(model.state_dict(), "./saved_models/bolt_imagegpt.pth")

In [None]:
model = ImageGPT()
model.load_state_dict(torch.load("./saved_models/bolt_imagegpt.pth"))

In [None]:
first_batch = next(iter(test_loader))
print(first_batch[0].size())
print(first_batch[0][0][0,].size())
with torch.no_grad():
    pred = model(first_batch[0])

In [None]:
first_batch[0][:,:,0:14,:].size()
with torch.no_grad():
    pred = model(first_batch[0][:,:,0:14,:])

In [None]:
fig = plt.figure(figsize=(20,5))
for j in range(16):
    ax = fig.add_subplot(2,8, j+1, xticks=[], yticks=[])
    ax.imshow(first_batch[0][j][0,].numpy(), cmap='gray')

In [None]:
with torch.no_grad():
    for i in range(16):
        fig = plt.figure(figsize=(20,5))
        for j in range(16):
            ax = fig.add_subplot(2,8, j+1, xticks=[], yticks=[])
            first = pred[:,i][:,j]
            unflattened = torch.unflatten(first,0, (28,28))
            ax.imshow(unflattened.numpy(), cmap='gray')