- eval:
    - test-time and train-time feed-forward
    - channel mask and recon
    - patch_embedding

In [1]:
from GeospatialFM.models.channel_vit import ChannelViTEncoder
from GeospatialFM.models.vision_transformer import ViTEncoder, ViTDecoder
import torch
from GeospatialFM.models.mae import CrossModalMAEViT
from GeospatialFM.loss.mae_loss import SpectralInterpolationLoss

%load_ext autoreload
%autoreload 2

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
optical_sample = torch.randn(2, 13, 224, 224)
radar_sample = torch.randn(2, 2, 224, 224)

In [23]:
optical_encoder = ChannelViTEncoder(in_chans=13)
decoder = ViTDecoder(out_chans=15)
radar_encoder = ViTEncoder(in_chans=2)

In [19]:
optical_encoder.num_patches

196

In [28]:
mae = CrossModalMAEViT(optical_encoder, radar_encoder, decoder, decoder).train()


In [30]:
out = mae(optical_sample, radar_sample, 0.75, 11/12)

tensor([[[[ 1.5124e+00, -4.4097e-01, -3.9417e-01,  ...,  1.2068e-01,
            7.9087e-01, -7.2115e-01],
          [-5.2588e-02, -3.8726e-01,  1.8179e-01,  ...,  3.6736e-01,
            9.6093e-01, -1.0706e-01],
          [ 4.0375e-01,  3.8697e-01, -1.4966e+00,  ...,  1.4193e-01,
           -9.0827e-01,  5.1122e-01],
          ...,
          [-1.3595e+00, -9.9193e-01, -1.1967e+00,  ...,  7.8559e-01,
            3.7458e-01,  4.3930e-01],
          [ 2.4402e-01,  2.1408e-01, -9.8310e-01,  ...,  5.6128e-01,
            6.3088e-01,  8.9061e-01],
          [ 2.0977e-01, -8.5673e-01, -9.0036e-01,  ...,  1.3263e-01,
           -1.1322e+00,  1.3727e+00]]],


        [[[ 1.2815e+00, -4.1533e-01, -2.7620e-01,  ...,  2.7691e+00,
            1.2333e+00, -1.4539e-01],
          [-1.0320e+00, -1.4013e+00,  1.3554e+00,  ...,  8.9287e-01,
           -8.3992e-02, -6.8253e-01],
          [ 6.3759e-01, -9.6324e-01,  8.3419e-01,  ...,  2.0086e+00,
            5.8619e-01,  4.2173e-01],
          ...,
   

In [40]:
optical_sample[:, 1]

tensor([[[-1.2231e+00,  2.6847e-01,  1.5599e+00,  ..., -4.2973e-01,
          -4.7082e-02, -5.9798e-01],
         [ 2.2316e-01, -1.9303e+00, -1.0594e+00,  ..., -2.9924e-02,
          -8.9992e-02, -7.4097e-01],
         [-3.6767e-01,  9.3067e-01,  8.7845e-01,  ...,  2.3672e-01,
          -5.8664e-01, -1.8526e-02],
         ...,
         [ 1.0902e-01,  5.7744e-01, -1.5360e+00,  ...,  1.7027e+00,
          -6.1190e-01, -1.3782e+00],
         [ 1.7373e+00,  1.2053e-01, -1.4495e-01,  ..., -1.9491e+00,
           3.8866e-01,  9.9199e-01],
         [-1.1468e+00,  1.2461e+00,  1.0754e+00,  ...,  1.4846e+00,
          -3.6569e-01, -7.6072e-01]],

        [[ 1.2815e+00, -4.1533e-01, -2.7620e-01,  ...,  2.7691e+00,
           1.2333e+00, -1.4539e-01],
         [-1.0320e+00, -1.4013e+00,  1.3554e+00,  ...,  8.9287e-01,
          -8.3992e-02, -6.8253e-01],
         [ 6.3759e-01, -9.6324e-01,  8.3419e-01,  ...,  2.0086e+00,
           5.8619e-01,  4.2173e-01],
         ...,
         [ 1.0543e+00,  3

In [39]:
out['optical_channel_mask'][1].reshape(256, 13)[0]

tensor([1., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])

In [171]:
loss = SpectralInterpolationLoss()
loss(**out, output_dict=True)

{'optical_channel_mse': tensor(0.8883, grad_fn=<MulBackward0>),
 'radar_channel_mse': tensor(0.8874, grad_fn=<MulBackward0>)}

In [160]:
for key in out.keys():
    print(key, out[key].shape)

optical_mask torch.Size([2, 196])
radar_mask torch.Size([2, 196])
optical_recon torch.Size([2, 196, 3840])
radar_recon torch.Size([2, 196, 3840])
optical_target torch.Size([2, 196, 3328])
radar_target torch.Size([2, 196, 512])
optical_cls_token torch.Size([2, 768])
radar_cls_token torch.Size([2, 768])
logit_scale torch.Size([])
optical_channel_mask torch.Size([2, 3328])


In [86]:
3840 / 15

256.0

In [112]:
optical_target = out['optical_target']
optical_recon = out['optical_recon']
radar_recon = out['radar_recon']
B, L, D = optical_target.shape

In [165]:
def _forward_channel_mse_one_modal(recon, target, mask):
    B, L, D = target.shape
    recon = recon[:, :, :D]
    loss = (recon - target).abs()
    loss = loss.mean(dim=1)
    if mask.sum() == 0:
        return loss.mean() # if no mask, mean loss on all channels
    loss = (loss * mask).sum() / mask.sum()  # mean loss on removed patches
    return loss

In [162]:
optical_recon.shape, optical_target.shape

(torch.Size([2, 196, 3840]), torch.Size([2, 196, 3328]))

In [163]:
mask = out['optical_channel_mask']

In [166]:
_unimodal_channel_loss(optical_recon, optical_target, mask)

torch.Size([2, 3328]) tensor(0.8925, grad_fn=<DivBackward0>)


In [None]:
optical_loss = optical_recon[:, :, :D]
radar_recon_o = radar_recon[:, :, :D]

In [None]:
expanded_mask = mask.unsqueeze(1).unsqueeze(1).expand(B, L, 256, -1).reshape(B, L, D)


In [None]:
recon_channels = (expanded_mask * radar_recon_o)
recon_channels.shape

torch.Size([2, 196, 3328])