- eval:
    - test-time and train-time feed-forward
    - channel mask and recon
    - patch_embedding

In [169]:
from GeospatialFM.models.channel_vit import ChannelViTEncoder
from GeospatialFM.models.vision_transformer import ViTEncoder, ViTDecoder
import torch
from GeospatialFM.models.mae import CrossModalMAEViT
from GeospatialFM.loss.mae_loss import SpectralInterpolationLoss

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [61]:
optical_sample = torch.randn(2, 13, 224, 224)
radar_sample = torch.randn(2, 2, 224, 224)

In [155]:
optical_encoder = ChannelViTEncoder(in_chans=13)
decoder = ViTDecoder(out_chans=15)
radar_encoder = ViTEncoder(in_chans=2)

In [156]:
optical_encoder.num_patches

196

In [157]:
mae = CrossModalMAEViT(optical_encoder, radar_encoder, decoder, decoder).train()


In [158]:
out = mae(optical_sample, radar_sample)

In [171]:
loss = SpectralInterpolationLoss()
loss(**out, output_dict=True)

{'optical_channel_mse': tensor(0.8883, grad_fn=<MulBackward0>),
 'radar_channel_mse': tensor(0.8874, grad_fn=<MulBackward0>)}

In [160]:
for key in out.keys():
    print(key, out[key].shape)

optical_mask torch.Size([2, 196])
radar_mask torch.Size([2, 196])
optical_recon torch.Size([2, 196, 3840])
radar_recon torch.Size([2, 196, 3840])
optical_target torch.Size([2, 196, 3328])
radar_target torch.Size([2, 196, 512])
optical_cls_token torch.Size([2, 768])
radar_cls_token torch.Size([2, 768])
logit_scale torch.Size([])
optical_channel_mask torch.Size([2, 3328])


In [86]:
3840 / 15

256.0

In [112]:
optical_target = out['optical_target']
optical_recon = out['optical_recon']
radar_recon = out['radar_recon']
B, L, D = optical_target.shape

In [165]:
def _forward_channel_mse_one_modal(recon, target, mask):
    B, L, D = target.shape
    recon = recon[:, :, :D]
    loss = (recon - target).abs()
    loss = loss.mean(dim=1)
    if mask.sum() == 0:
        return loss.mean() # if no mask, mean loss on all channels
    loss = (loss * mask).sum() / mask.sum()  # mean loss on removed patches
    return loss

In [162]:
optical_recon.shape, optical_target.shape

(torch.Size([2, 196, 3840]), torch.Size([2, 196, 3328]))

In [163]:
mask = out['optical_channel_mask']

In [166]:
_unimodal_channel_loss(optical_recon, optical_target, mask)

torch.Size([2, 3328]) tensor(0.8925, grad_fn=<DivBackward0>)


In [None]:
optical_loss = optical_recon[:, :, :D]
radar_recon_o = radar_recon[:, :, :D]

In [None]:
expanded_mask = mask.unsqueeze(1).unsqueeze(1).expand(B, L, 256, -1).reshape(B, L, D)


In [None]:
recon_channels = (expanded_mask * radar_recon_o)
recon_channels.shape

torch.Size([2, 196, 3328])