- eval:
    - test-time and train-time feed-forward
    - channel mask and recon
    - patch_embedding

In [1]:
from GeospatialFM.models.channel_vit import ChannelViTEncoder
from GeospatialFM.models.vision_transformer import ViTEncoder, ViTDecoder
import torch
from GeospatialFM.models.mae import CrossModalMAEViT
from GeospatialFM.loss.mae_loss import SpectralInterpolationLoss, MAELoss

%load_ext autoreload
%autoreload 2

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
optical_sample = torch.randn(2, 13, 224, 224)
radar_sample = torch.randn(2, 2, 224, 224)

In [3]:
optical_encoder = ChannelViTEncoder(in_chans=13, channel_pool='max')
decoder = ViTDecoder(out_chans=15)
radar_encoder = ViTEncoder(in_chans=2)

In [4]:
output = optical_encoder.forward_encoder(optical_sample, 0.75, 0.75)

In [6]:
output[0].shape

torch.Size([100, 50, 768])

In [5]:
ret = decoder.forward_decoder(output[0], output[2])

In [6]:
optical_encoder.num_patches / 13

15.076923076923077

In [7]:
mae = CrossModalMAEViT(optical_encoder, radar_encoder, decoder, decoder).train()

In [8]:
out = mae(optical_sample, radar_sample, 0.75, 0.5)

In [9]:
out.keys()

dict_keys(['optical_mask', 'radar_mask', 'optical_recon', 'radar_recon', 'optical_target', 'radar_target', 'optical_cls_token', 'radar_cls_token', 'logit_scale', 'optical_channel_mask'])

In [10]:
out['optical_recon'].shape, out['optical_target'].shape

(torch.Size([2, 196, 3840]), torch.Size([2, 196, 3328]))

In [11]:
optical_sample[:, 1]

tensor([[[ 1.1977e+00, -5.5180e-01, -2.4375e-01,  ...,  6.4409e-01,
           7.9837e-01, -5.1097e-01],
         [ 1.7075e+00,  6.0347e-02, -8.4704e-03,  ..., -1.2083e+00,
           4.8129e-02,  3.1424e-01],
         [ 1.8514e-01, -1.2083e+00,  8.6367e-01,  ..., -1.3438e-01,
          -1.0370e-01,  7.4731e-01],
         ...,
         [-4.1498e-01, -2.0958e+00,  1.4300e-04,  ...,  1.3914e+00,
          -6.9633e-01, -7.6774e-01],
         [ 8.8051e-01,  2.4801e-01,  8.6115e-01,  ...,  1.2676e+00,
          -2.8543e-02,  8.5104e-01],
         [-6.4284e-01,  5.7437e-01, -1.4108e+00,  ..., -5.7488e-01,
          -1.0415e+00,  1.1898e+00]],

        [[-2.7266e-01, -5.6392e-01,  5.7947e-01,  ...,  3.4807e-01,
           5.6105e-02,  2.2259e+00],
         [-3.9405e-01,  1.3019e+00, -6.7917e-01,  ...,  1.1674e+00,
          -1.5613e+00, -1.2703e+00],
         [ 3.2976e-01, -1.6983e+00, -2.2353e-01,  ..., -2.8816e-01,
           3.7322e-01,  8.0427e-02],
         ...,
         [ 8.8603e-01, -9

In [12]:
out['optical_channel_mask'].shape, out['optical_mask'].shape

(torch.Size([2, 3328]), torch.Size([2, 196]))

In [13]:
out['optical_channel_mask']

tensor([[1., 0., 0.,  ..., 0., 1., 0.],
        [0., 0., 0.,  ..., 1., 0., 1.]])

In [15]:
out['optical_channel_mask'][0].reshape(256, 13)[0]

tensor([1., 0., 0., 0., 1., 1., 0., 1., 1., 1., 0., 1., 0.])

In [16]:
loss = SpectralInterpolationLoss()
loss(**out, output_dict=True)

torch.Size([2, 196, 3328])
torch.Size([2, 196, 3328])
torch.Size([2, 3328])
tensor(3186.9939, grad_fn=<SumBackward0>) tensor(3584.)
torch.Size([2, 196, 3328])
torch.Size([2, 196, 3328])
torch.Size([2, 3328])
tensor(3181.5608, grad_fn=<SumBackward0>) tensor(3584.)


{'optical_channel_mse': tensor(0.8892, grad_fn=<MulBackward0>),
 'radar_channel_mse': tensor(0.8877, grad_fn=<MulBackward0>)}

In [17]:
loss = MAELoss()
loss(**out, output_dict=True)

torch.Size([2, 196, 3840])
torch.Size([2, 196])
tensor(261.1608, grad_fn=<SumBackward0>) tensor(294.)
torch.Size([2, 196, 3840])
torch.Size([2, 196])
tensor(260.6951, grad_fn=<SumBackward0>) tensor(294.)


{'optical_mse': tensor(0.8883, grad_fn=<MulBackward0>),
 'radar_mse': tensor(0.8867, grad_fn=<MulBackward0>)}

In [160]:
for key in out.keys():
    print(key, out[key].shape)

optical_mask torch.Size([2, 196])
radar_mask torch.Size([2, 196])
optical_recon torch.Size([2, 196, 3840])
radar_recon torch.Size([2, 196, 3840])
optical_target torch.Size([2, 196, 3328])
radar_target torch.Size([2, 196, 512])
optical_cls_token torch.Size([2, 768])
radar_cls_token torch.Size([2, 768])
logit_scale torch.Size([])
optical_channel_mask torch.Size([2, 3328])


In [86]:
3840 / 15

256.0

In [112]:
optical_target = out['optical_target']
optical_recon = out['optical_recon']
radar_recon = out['radar_recon']
B, L, D = optical_target.shape

In [165]:
def _forward_channel_mse_one_modal(recon, target, mask):
    B, L, D = target.shape
    recon = recon[:, :, :D]
    loss = (recon - target).abs()
    loss = loss.mean(dim=1)
    if mask.sum() == 0:
        return loss.mean() # if no mask, mean loss on all channels
    loss = (loss * mask).sum() / mask.sum()  # mean loss on removed patches
    return loss

In [162]:
optical_recon.shape, optical_target.shape

(torch.Size([2, 196, 3840]), torch.Size([2, 196, 3328]))

In [163]:
mask = out['optical_channel_mask']

In [166]:
_unimodal_channel_loss(optical_recon, optical_target, mask)

torch.Size([2, 3328]) tensor(0.8925, grad_fn=<DivBackward0>)


In [None]:
optical_loss = optical_recon[:, :, :D]
radar_recon_o = radar_recon[:, :, :D]

In [None]:
expanded_mask = mask.unsqueeze(1).unsqueeze(1).expand(B, L, 256, -1).reshape(B, L, D)


In [None]:
recon_channels = (expanded_mask * radar_recon_o)
recon_channels.shape

torch.Size([2, 196, 3328])