# Assessing data merging techniques

The CIMR retrieval is supposed to combine observations from different sensors. The underlying network must therefore be able to merge the branches that process the inputs. This notebooks explores different approaches for doing this.

In [None]:
%load_ext autoreload
%autoreload 2
import numpy as np
import matplotlib.pyplot as plt

We use a synthetic data set to test the ability of different NNs to merge information from different sources. The retrieval inputs are random fields whose spatial variability is limited to wavelengths from separate spectral bands. The retrieval output is simply the sum of these inputs. In addition to that, the input for each sensor is corrupted with random noise.

In [None]:
from cimr.data.training_data import SuperpositionDataset
from cimr.data.training_data import StreamData, sparse_collate
from torch.utils.data import DataLoader
training_data = SuperpositionDataset(size=128, n_samples=1000, availability=[0.1, 1.0, 0.1], n_steps=8)
training_loader = DataLoader(
    training_data,
    num_workers=8,
    batch_size=2,
    shuffle=True,
    worker_init_fn=training_data.init_rng,
)

validation_data = SuperpositionDataset(size=128, n_samples=100, availability=1.0, n_steps=8)
validation_loader = DataLoader(
    validation_data,
    num_workers=8,
    batch_size=4,
    shuffle=True,
    worker_init_fn=validation_data.init_rng,
)

In [None]:
x, y = validation_data[0]

In [None]:
from IPython.display import HTML
ani = validation_data.plot_sample(x, y)
HTML(ani.to_jshtml())

In [None]:
from cimr.data.training_data import SuperpositionDataset
from cimr.data.training_data import StreamData, sparse_collate
from cimr.models import CIMRNaive
import pytorch_lightning as pl
from torch.utils.data import DataLoader
from torch import optim
from quantnn.qrnn import QRNN
from quantnn import metrics
    
def run_training(availability, aggregation):
    name = f"Aggregation: {aggregation}, availability: {availability}"
    if availability.lower() == "full":
        availability = [1.0, 1.0, 1.0]
    else:
        availability = [0.1, 1.0, 0.1]
        
    training_data = SuperpositionDataset(
        size=128,
        n_samples=1000,
        availability=availability,
        n_steps=1
    )
    training_loader = DataLoader(
        training_data,
        num_workers=8,
        batch_size=2,
        shuffle=True,
        worker_init_fn=training_data.init_rng,
    )
    validation_data = SuperpositionDataset(
        size=128,
        n_samples=100,
        availability=1.0,
        n_steps=1
    )
    validation_loader = DataLoader(
        validation_data,
        num_workers=8,
        batch_size=4,
        shuffle=True,
        worker_init_fn=validation_data.init_rng,
    )

    aggregation = "average"
    model = CIMRNaive(4, 2, aggregation=aggregation, block_type="convnext")
    quantiles = np.linspace(0, 1, 66)[1:-1]
    qrnn = QRNN(model=model, quantiles=quantiles )


    mtrx = [
        metrics.Bias(),
        metrics.Correlation(),
        metrics.MeanSquaredError(),
        metrics.ScatterPlot(),
        metrics.CalibrationPlot()
    ]

    lm = qrnn.lightning(mask=-100, metrics=mtrx, name=name)
    optimizer = optim.Adam(model.parameters(), lr=0.0005)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=5)
    lm.optimizer = optimizer
    lm.scheduler = scheduler

    trainer = pl.Trainer(
        max_epochs=5,
        accelerator="gpu",
        devices=-1,
        precision=32,
        gradient_clip_val=1,
        logger=lm.tensorboard,
        replace_sampler_ddp=True,    
    )
    trainer.fit(
        model=lm,
        train_dataloaders=training_loader,
        val_dataloaders=validation_loader
    )

In [None]:

run_training("full", "linear")