# Masked Auto Encoder E2E

This notebook is used to explore the step by step process for the masked auto encoder, in addition to this it includes setup for lightning training, and some cpu profiling.

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
# Add the parent directory to the path
import sys  
sys.path.insert(1, '../')

In [1]:
import os
import numpy as np
import webdataset as wds
from dotenv import load_dotenv

import torch
from torch.profiler import ProfilerActivity, profile, record_function
from torch.utils.data import DataLoader
from torchvision.transforms import functional as F
from pytorch_lightning import Trainer

from model.patch import patch_images
from model.aer_mae import AerMae, MaeLoss
from model.aer_mae_bolt import AerMaeBolt

In [None]:
load_dotenv()

## Load Dataset

In [None]:
root = os.getenv('DATA_ROOT')

In [None]:
train_loc = f'file:{root}' + '/dataset/train-{000000..000015}.tar'

In [2]:
def to_tensor(sample):
    img, meta = sample
    img = F.to_grayscale(img)
    img = F.to_tensor(img)

    loc = meta['features'][0]['geometry']['coordinates'][0][0] # grab the polygon box from the geopandas json
    loc = (np.array(loc) + 180) / 360 # normalize to [0, 1] min-max is (-180, 180) hence loc - (-180) / (-180 - 180) => loc + 180 / 360
    loc = torch.from_numpy(loc).float()

    return img, loc

In [6]:
ds_train = (wds.WebDataset(train_loc, shardshuffle=True)
      .shuffle(1000)
      .decode("pil")
      .to_tuple("jpg", "json")
      .map(to_tensor))

train_loader = DataLoader(ds_train, batch_size=4)

## Sample Data

In [23]:
images, loc = next(iter(train_loader))
images.shape, loc.shape

(torch.Size([4, 1, 256, 256]), torch.Size([4, 2]))

## Create MAE Model

In [16]:
model_dim = 32
geo_dim = 8
patch_size = 16

In [17]:
mae = AerMae(enc_dim=model_dim, dec_dim=model_dim, geo_dim=geo_dim, patch_size=patch_size)
loss_fn = MaeLoss()

## Walk through MAE steps

In [18]:
mae.encoder.pos_encoder.pe.shape, mae.decoder.pos_encoder.pe.shape

(torch.Size([1, 256, 24]), torch.Size([1, 256, 24]))

### 1. Encoding

Apply embedding

In [None]:
src = mae.src_embed(images)
src = mae.pos_encoder(src)
src.shape

torch.Size([4, 256, 32])

Apply masking

In [None]:
masks, pad_masks = mae.rand_mask(src.shape[:2])
masks.shape, pad_masks.shape

(torch.Size([4, 256, 1]), torch.Size([4, 256]))

In [None]:
src = src * masks
src.shape

torch.Size([4, 256, 32])

Encode

In [None]:
mem = mae.encoder(src, src_key_padding_mask=pad_masks)
mem.shape

torch.Size([4, 256, 32])

### 2. Decoding

Apply Embedding

In [None]:
tgt = mae.tgt_embed(mem)
tgt = mae.pos_encoder(tgt, loc)
tgt.shape

torch.Size([4, 256, 32])

Decode

In [None]:
out = mae.decoder(tgt, mem)
out.shape

torch.Size([4, 256, 32])

Project full image

In [None]:
out = mae.fc(out)
out.shape

torch.Size([4, 256, 256])

## Model Forward Pass

In [24]:
out = mae(images, loc)
out[0].shape, out[1].shape

(torch.Size([4, 256, 256]), torch.Size([4, 256, 1]))

Calculate loss

In [25]:
true = patch_images(images, patch_size)
true.shape

torch.Size([4, 256, 256])

In [26]:
loss = loss_fn(y_true=true, y_pred=out[0], mask=out[1])
loss

tensor(175.2708, grad_fn=<DivBackward0>)

## Profiling

In [None]:
with profile(activities=[ProfilerActivity.CPU], record_shapes=True, profile_memory=True) as prof:
    with record_function('model-forward') as rc:
        mae(images, loc)

In [None]:
print(prof.key_averages().table(sort_by='cpu_time_total', row_limit=10))

--------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                        Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg       CPU Mem  Self CPU Mem    # of Calls  
--------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                               model-forward         5.62%     276.015ms       100.00%        4.907s        4.907s           0 b      -6.12 Gb             1  
                               aten::dropout         0.12%       5.885ms        39.66%        1.946s      24.329ms       3.23 Gb           0 b            80  
                            aten::bernoulli_        30.07%        1.475s        30.21%        1.483s      18.534ms           0 b      -1.61 Gb            80  
                                aten::linear  

## Lightning Trainer

In [27]:
l_mae = AerMaeBolt(mae)

Single Epoch Trainer

In [28]:
trainer = Trainer(max_epochs=1)

INFO:pytorch_lightning.utilities.rank_zero:GPU available: False, used: False
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs


In [None]:
trainer.fit(l_mae, train_loader)

## Model Storage

Save and load

In [None]:
torch.save(mae.state_dict(), 'aermae.pth')

In [None]:
mae = torch.load('aermae.pth')

## Cleanup

In [None]:
ds.close()