In [1]:
import math
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from tqdm import tqdm

import torch
from torch import nn
from torch.nn import functional
from torch.utils.data import DataLoader, TensorDataset
from torchvision.datasets import MNIST
from torchvision import transforms

import lightning as pl

import warnings
warnings.filterwarnings("ignore")

In [18]:
from typing import Any


from lightning.pytorch.utilities.types import STEP_OUTPUT, OptimizerLRScheduler


class PositionalEncoding(nn.Module):
    def __init__(
        self,
        dim_model: int,
        dropout: float = 0.1,
        max_len: int = 5000,
        batch_first: bool = True,
    ):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)
        self.batch_first = batch_first

        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, dim_model, 2) * -(math.log(1e6) / dim_model)
        )

        if self.batch_first:
            pe = torch.zeros(1, max_len, dim_model)
            pe[0, :, 0::2] = torch.sin(position * div_term)
            pe[0, :, 1::2] = torch.cos(position * div_term)
        else:
            pe = torch.zeros(max_len, 1, dim_model)
            pe[:, 0, 0::2] = torch.sin(position * div_term)
            pe[:, 0, 1::2] = torch.cos(position * div_term)

        self.register_buffer("pe", pe)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.batch_first:
            x = x + self.pe[:, : x.size(1)].requires_grad_(False)
        else:
            x = x + self.pe[: x.size(0)].requires_grad_(False)
        return self.dropout(x)


class Transformer(pl.LightningModule):
    
    def __init__(self,
                 input_shape: int,
                 d_model: int = 256, 
                 n_encoder_heads: int = 2, 
                 dim_encoder_feedforward: int = 2048,
                 n_encoder_layers: int = 4,
                 n_decoder_heads: int = 2, 
                 dim_decoder_feedforward: int = 2048,
                 n_decoder_layers: int = 4
                 ) -> None:
        super(Transformer, self).__init__()
        self.encoder_input_layer = nn.Linear(input_shape, d_model)
        self.pos_enc_layer = PositionalEncoding(d_model)
        
        encoder_layer = nn.TransformerEncoderLayer(d_model, n_encoder_heads, dim_encoder_feedforward, batch_first=True)
        self.encoder = nn.TransformerEncoder(encoder_layer, n_encoder_layers)
        
        decoder_layer = nn.TransformerDecoderLayer(d_model, n_decoder_heads, dim_decoder_feedforward, batch_first=True)
        self.decoder = nn.TransformerDecoder(decoder_layer, n_decoder_layers)
        
    def encoder_forward(self, x: torch.Tensor):
        print(self.encoder_input_layer(x))
        pos_enc = self.pos_enc_layer(self.encoder_input_layer(x))
        return self.encoder(pos_enc)
    
    def decoder_forward(self, x: torch.Tensor):
        return self.decoder(x)
    
    def forward(self, x: torch.Tensor):
        encoded = self.encoder_forward(x)
        print(encoded)
        return self.decoder_forward(encoded), encoded
    
    def training_step(self, batch) -> STEP_OUTPUT:
        x = batch[0]
        print(x.shape)
        x_hat, _ = self.forward(x)
        return 0
        
    def configure_optimizers(self) -> OptimizerLRScheduler:
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-5)
        return optimizer
        
        

In [9]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Lambda(lambda x: x.flatten())])
train_dataset = MNIST(root="data", download=True, train=True, transform=transform)

BATCH_SIZE = 10
trainloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)

In [19]:
model = Transformer(input_shape=784)

trainer = pl.Trainer(max_epochs=1, accelerator="cpu")
trainer.fit(model, trainloader)

GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name                | Type               | Params
-----------------------------------------------------------
0 | encoder_input_layer | Linear             | 200 K 
1 | pos_enc_layer       | PositionalEncoding | 0     
2 | encoder             | TransformerEncoder | 5.3 M 
3 | decoder             | TransformerDecoder | 6.3 M 
-----------------------------------------------------------
11.8 M    Trainable params
0         Non-trainable params
11.8 M    Total params
47.105    Total estimated model params size (MB)


Training: |          | 0/? [00:00<?, ?it/s]

torch.Size([10, 784])
tensor([[ 0.3026,  0.0389,  0.1447,  ...,  0.3451, -0.0884, -0.0962],
        [ 0.2861, -0.2062,  0.0057,  ...,  0.4230,  0.1344, -0.0146],
        [ 0.3609, -0.1533,  0.0149,  ...,  0.3861,  0.0784, -0.1618],
        ...,
        [ 0.2914, -0.3885,  0.2410,  ...,  0.2855, -0.1360,  0.1645],
        [ 0.3813,  0.1413, -0.1474,  ...,  0.1727, -0.0688,  0.1978],
        [ 0.4350,  0.1103,  0.2334,  ..., -0.0290, -0.2093, -0.0628]],
       grad_fn=<AddmmBackward0>)


RuntimeError: The size of tensor a (10) must match the size of tensor b (256) at non-singleton dimension 1

In [9]:
# Note: The model and training settings do not follow the reference settings
# from the paper. The settings are chosen such that the example can easily be
# run on a small dataset with a single GPU.

import pytorch_lightning as pl
import torch
import torchvision
from torch import nn

from lightly.models import utils
from lightly.models.modules import masked_autoencoder
from lightly.transforms.mae_transform import MAETransform


class MAE(pl.LightningModule):
    def __init__(self):
        super().__init__()

        decoder_dim = 512
        vit = torchvision.models.vit_l_16(pretrained=False)
        self.mask_ratio = 0.75
        self.patch_size = vit.patch_size
        self.sequence_length = vit.seq_length
        self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_dim))
        self.backbone = masked_autoencoder.MAEBackbone.from_vit(vit)
        self.decoder = masked_autoencoder.MAEDecoder(
            seq_length=vit.seq_length,
            num_layers=1,
            num_heads=16,
            embed_input_dim=vit.hidden_dim,
            hidden_dim=decoder_dim,
            mlp_dim=decoder_dim * 4,
            out_dim=vit.patch_size**2 * 3,
            dropout=0,
            attention_dropout=0,
        )
        self.criterion = nn.MSELoss()

    def forward_encoder(self, images, idx_keep=None):
        return self.backbone.encode(images, idx_keep)

    def forward_decoder(self, x_encoded, idx_keep, idx_mask):
        # build decoder input
        batch_size = x_encoded.shape[0]
        x_decode = self.decoder.embed(x_encoded)
        x_masked = utils.repeat_token(
            self.mask_token, (batch_size, self.sequence_length)
        )
        print(x_masked)
        x_masked = utils.set_at_index(x_masked, idx_keep, x_decode.type_as(x_masked))

        # decoder forward pass
        x_decoded = self.decoder.decode(x_masked)

        # predict pixel values for masked tokens
        x_pred = utils.get_at_index(x_decoded, idx_mask)
        x_pred = self.decoder.predict(x_pred)
        return x_pred

    def training_step(self, batch, batch_idx):
        views = batch[0]
        images = views[0]  # views contains only a single view
        batch_size = images.shape[0]
        idx_keep, idx_mask = utils.random_token_mask(
            size=(batch_size, self.sequence_length),
            mask_ratio=self.mask_ratio,
            device=images.device,
        )
        x_encoded = self.forward_encoder(images, idx_keep)
        x_pred = self.forward_decoder(x_encoded, idx_keep, idx_mask)

        # get image patches for masked tokens
        patches = utils.patchify(images, self.patch_size)
        # must adjust idx_mask for missing class token
        target = utils.get_at_index(patches, idx_mask - 1)

        loss = self.criterion(x_pred, target)
        return loss

    def configure_optimizers(self):
        optim = torch.optim.AdamW(self.parameters(), lr=1.5e-4)
        return optim


model = MAE()

from torchvision import transforms

transform = transforms.Compose([transforms.Grayscale(3), MAETransform()])

# we ignore object detection annotations by setting target_transform to return 0
dataset = torchvision.datasets.MNIST(
    "data",
    download=True,
    transform=transform,
)
# or create a dataset from a folder containing images or videos:
# dataset = LightlyDataset("path/to/folder")

dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=256,
    shuffle=True,
    drop_last=True,
    num_workers=8,
)

accelerator = "gpu" if torch.cuda.is_available() else "mps"

trainer = pl.Trainer(max_epochs=10, devices=1, accelerator=accelerator)
trainer.fit(model=model, train_dataloaders=dataloader)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name      | Type        | Params
------------------------------------------
0 | backbone  | MAEBackbone | 88.2 M
1 | decoder   | MAEDecoder  | 5.1 M 
2 | criterion | MSELoss     | 0     
------------------------------------------
93.4 M    Trainable params
0         Non-trainable params
93.4 M    Total params
373.494   Total estimated model params size (MB)


Training: 0it [00:00, ?it/s]

tensor([[[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        ...,

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0., 