In [35]:
import torch
import torchvision
import matplotlib.pyplot as plt
import imp
from perceiver_basic import model
from torch.utils.data import DataLoader

In [20]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [63]:
class FMPreprocess(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.shape = (1, 28, 28)
        self.encoding = self._generate_encoding()

    def _generate_encoding(self):
        dims = len(self.shape)-1 #minus batch dim
        encoding = torch.zeros([*self.shape, dims])
        for i in range(encoding[0].shape[0]):
            for j in range(encoding[0].shape[1]):
                encoding[0,i,j][0] = i * 2.0 / encoding[0].shape[0] - 1.0
                encoding[0,i,j][1] = i * 2.0 / encoding[0].shape[0] - 1.0
        return encoding

    def forward(self, in_val):
        x = in_val.squeeze().clone()
        x = x.unsqueeze(-1)
        if in_val.shape[0] == 1:
            x = x.unsqueeze(0)
        batch_count = in_val.shape[0]
        dims = [-1 for _ in range(len(self.encoding.shape))]
        dims[0] = batch_count
        encoding = self.encoding.expand(*dims)
        x = torch.cat([x, encoding], dim=-1) #concatenating positional encodings
        #x = torch.mul(x, encoding) #scaled positional encooding

        #maintain batch and pixel codes, pixels will be in sequential form
        x = torch.flatten(x, start_dim=1, end_dim=-2)

        return x

In [48]:
imp.reload(model)
class PerceiverClassify(torch.nn.Module):
    def __init__(self, latent_dim, heads, wide_factor, latent_count, repeat_count=1, p_dropout=0.1):
        super().__init__()
        self.preprocess = FMPreprocess()
        out_dim = (1,32)
        in_channels = len(self.preprocess.shape) #subtract one for batch dim, but add one for the pixel dim
        self.perceiver = model.PerceiverInternal(in_channels, latent_dim, out_dim, heads, wide_factor, latent_count, repeat_count, p_dropout)
        
        self.lin_out = torch.nn.Linear(32, 10)
        self.sigmoid = torch.nn.Sigmoid()

    def forward(self, x):
        #x = x / self.range
        x = self.preprocess(x)

        x = self.perceiver(x)

        # x.shape <- (batch_count, 32)
        x = x.squeeze()
        x = self.lin_out(x)
        x = self.sigmoid(x) * 2.0 - 1.0

        return x

In [37]:
BATCH_SIZE = 16
EPOCHS = 10

In [38]:
transforms = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])
dataset = torchvision.datasets.FashionMNIST('./fm_set', transform=transforms, download=True)
testset = torchvision.datasets.FashionMNIST('./fm_set', train=False, transform=transforms, download=True)

In [39]:
dataloader  = DataLoader(dataset, num_workers=2, batch_size=BATCH_SIZE)

In [64]:
perceiver = PerceiverClassify(latent_dim=(8, 8), heads=8, wide_factor=4, latent_count=6)
optimizer = torch.optim.Adam(perceiver.parameters(), lr=1e-4)
criterion = torch.nn.CrossEntropyLoss()

In [65]:
garbo = torch.randn([1, 28, 28])
perceiver(garbo).shape

torch.Size([10])

In [66]:
for epoch in range(EPOCHS):
    print(f"Entering epoch {epoch}")
    for batch_no, (features, labels) in enumerate(dataloader):
        features = features.to(device)
        labels = labels.to(device)

        predictions = perceiver(features)
        
        loss = criterion(predictions, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        print(f'Iteration - {batch_no}')

Entering epoch 0
Iteration - 0
Iteration - 1
Iteration - 2
Iteration - 3
Iteration - 4
Iteration - 5
Iteration - 6
Iteration - 7
Iteration - 8
Iteration - 9
Iteration - 10
Iteration - 11
Iteration - 12
Iteration - 13
Iteration - 14
Iteration - 15
Iteration - 16
Iteration - 17
Iteration - 18
Iteration - 19
Iteration - 20
Iteration - 21
Iteration - 22
Iteration - 23
Iteration - 24
Iteration - 25
Iteration - 26
Iteration - 27
Iteration - 28
Iteration - 29
Iteration - 30
Iteration - 31
Iteration - 32
Iteration - 33
Iteration - 34
Iteration - 35
Iteration - 36
Iteration - 37
Iteration - 38
Iteration - 39
Iteration - 40
Iteration - 41
Iteration - 42
Iteration - 43
Iteration - 44
Iteration - 45
Iteration - 46
Iteration - 47
Iteration - 48
Iteration - 49
Iteration - 50
Iteration - 51
Iteration - 52
Iteration - 53
Iteration - 54
Iteration - 55
Iteration - 56
Iteration - 57
Iteration - 58
Iteration - 59
Iteration - 60
Iteration - 61
Iteration - 62
Iteration - 63
Iteration - 64
Iteration - 65
Ite

KeyboardInterrupt: 

In [69]:
testloader = DataLoader(testset, batch_size=1)
for i, (feature, label) in enumerate(testloader):
    print('Wow check this out lul')
    print(label)
    prediction = perceiver(feature)
    print(torch.argmax(prediction))

Wow check this out lul
tensor([9])
tensor(4)
Wow check this out lul
tensor([2])
tensor(4)
Wow check this out lul
tensor([1])
tensor(1)
Wow check this out lul
tensor([1])
tensor(1)
Wow check this out lul
tensor([6])
tensor(3)
Wow check this out lul
tensor([1])
tensor(1)
Wow check this out lul
tensor([4])
tensor(1)
Wow check this out lul
tensor([6])
tensor(1)
Wow check this out lul
tensor([5])
tensor(1)
Wow check this out lul
tensor([7])
tensor(1)
Wow check this out lul
tensor([4])
tensor(4)
Wow check this out lul
tensor([5])
tensor(3)
Wow check this out lul
tensor([7])
tensor(4)
Wow check this out lul
tensor([3])
tensor(3)
Wow check this out lul
tensor([4])
tensor(4)
Wow check this out lul
tensor([1])
tensor(1)
Wow check this out lul
tensor([2])
tensor(3)
Wow check this out lul
tensor([4])
tensor(4)
Wow check this out lul
tensor([8])
tensor(4)
Wow check this out lul
tensor([0])
tensor(4)
Wow check this out lul
tensor([2])
tensor(4)
Wow check this out lul
tensor([5])
tensor(3)
Wow check 

KeyboardInterrupt: 