# Satvision-TOA Reconstruction Notebook

Version: 04.30.24

Env: `Python [conda env:ilab-pytorch]`

In [None]:
!pip3 install yacs timm segmentation-models-pytorch termcolor webdataset==0.2.86

In [None]:
import os
import sys
import time
import random
import datetime
from tqdm import tqdm
import numpy as np
import logging

import torch
import torch.cuda.amp as amp

import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages

import netCDF4 as nc

import warnings

warnings.filterwarnings('ignore') 

In [None]:
sys.path.append('./pytorch-caney')

from pytorch_caney.config import get_config

from pytorch_caney.models.build import build_model

from pytorch_caney.ptc_logging import create_logger

from pytorch_caney.data.datasets.mim_modis_22m_dataset import MODIS22MDataset

from pytorch_caney.data.transforms import SimmimTransform, SimmimMaskGenerator

from pytorch_caney.config import _C, _update_config_from_file

## 1. Configuration

### Clone model ckpt from huggingface

```bash
# On prism/explore
module load git-lfs

git lfs install

git clone git clone git@hf.co:nasa-cisto-data-science-group/satvision-toa-huge-patch8-window12-192
```

Note: If using git w/ ssh, make sure you have ssh keys enabled to clone using ssh auth.
https://huggingface.co/docs/hub/security-git-ssh

```bash
eval $(ssh-agent)

# If this outputs as anon, follow the next steps.
ssh -T git@hf.co

# Check if ssh-agent is using the proper key
ssh-add -l

# If not
ssh-add ~/.ssh/your-key

# Or if you want to use the default id_* key, just do
ssh-add

```

In [None]:
MODEL_PATH: str = '/explore/nobackup/people/szhang16/satvision-toa-huge-patch8-window8-128/mp_rank_00_model_states.pt'
CONFIG_PATH: str = '/explore/nobackup/people/szhang16/satvision-toa-huge-patch8-window8-128/mim_pretrain_swinv2_satvision_huge_128_window8_patch8_100ep.yaml'

OUTPUT: str = '.'
TAG: str = 'satvision-huge-toa-reconstruction'
DATA_PATH: str = '/home/szhang16/modis-toa-samples_03_31/selected_satvision_toa_2m_128_chips.npy'
DATA_PATHS: list = [DATA_PATH]

In [None]:
# Update config given configurations

config = _C.clone()
_update_config_from_file(config, CONFIG_PATH)

config.defrost()
config.MODEL.RESUME = MODEL_PATH
config.DATA.DATA_PATHS = DATA_PATHS
config.OUTPUT = OUTPUT
config.TAG = TAG
config.freeze()

In [None]:
# Configure logging
logging.basicConfig(
    filename='app.log',  # Specify the log file name
    level=logging.INFO,  # Set logging level to DEBUG
    format='%(asctime)s [%(levelname)s] %(message)s',  # Specify log message format
    datefmt='%Y-%m-%d %H:%M:%S'  # Specify date format
)

# Add logging to standard output
console = logging.StreamHandler()  # Create a handler for standard output
console.setLevel(logging.INFO)  # Set logging level for standard output
console.setFormatter(logging.Formatter('%(asctime)s [%(levelname)s] %(message)s'))  # Set log message format for standard output
logger = logging.getLogger('')
logger.addHandler(console)

## 2. Load model weights from checkpoint

In [None]:
checkpoint = torch.load(MODEL_PATH)
model = build_model(config, pretrain=True)
model.load_state_dict(checkpoint['module']) # If 'module' not working, try 'model'
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
logger.info(f"number of params: {n_parameters}")
model.cuda()
model.eval()

## 3. Load evaluation set (from numpy file)

In [None]:
def find_files(directory, prefix):
    """
    Find files in a directory that start with a specified prefix.

    Args:
        directory (str): The directory path to search for files.
        prefix (str): The prefix that the filenames should start with.

    Returns:
        list: A list of file paths matching the specified prefix.
    """
    matching_files = []
    for filename in os.listdir(directory):
        if filename.startswith(prefix):
            matching_files.append(os.path.join(directory, filename))
    return matching_files

In [None]:
def gather_files(YYYY, DDD, HH):
    """
    Args:
        YYYY (str): year in YYYY format
        DDD (str): day in DDD format
        HH (str): hour in HH format

    Returns:
        dict: A dictionary containing —
                ROOT_PATH: the base abi path

                YYYY: year
                DDD: day of year
                HH: hour

                00: files for minute 00 in hour
                ...
                50: filea for minute 50 in hour

                L200: cloud top height file (full path) for minute 00 in hour
                ...
                L250: cloud top height file (full path) for minute 50 in hour
    """

    ABI_ = {
        "ROOT_PATH": None,

        "YYYY": None,
        "DDD": None,
        "HH": None,

        "00": [],
        "15": [],
        "30": [],
        "45": [],

        "L200": None,
        "L210": None,
        "L220": None,
        "L230": None,
        "L240": None,
        "L250": None
    }

    _ABI_PATH_ = "/css/geostationary/BackStage/GOES-16-ABI-L1B-FULLD/" + YYYY + "/" + DDD + "/" + HH

    for filename in os.listdir(_ABI_PATH_):
        if ABI_["ROOT_PATH"] == None:
            ABI_["ROOT_PATH"] = _ABI_PATH_
            ABI_["YYYY"] = filename[27:31]
            ABI_["DDD"] = filename[31:34]
            ABI_["HH"] = filename[34:36]
        MM = filename[36:38]
        ABI_[f"{MM}"].append(filename)

    return ABI_

In [None]:
def get_L1B_L2(abipaths, l2path, YYYY, DDD, HH):

    # ASSERT ALL ABI CHANNELS PRESENT !!!!!!!!!!!!!!!!!!!!
    # assert len(abipaths) == 8, "NOT ALL ABI FILES LOCATED"
    # !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!


    # LOAD EACH ABI CHANNEL IMAGE
    CHANNELS = []
    ROOT = "/css/geostationary/BackStage/GOES-16-ABI-L1B-FULLD/"

    for file in abipaths:
        L1B = np.array(nc.Dataset(ROOT + "/" + YYYY + "/" + DDD + "/" + HH + "/" + file, 'r')["Rad"])
        CHANNEL = int(file[19:21])
        CHANNELS.append((L1B, CHANNEL))

    # SORT CHANNELS
    CHANNELS.sort(key=lambda x: x[1])
    CHANNELS = [C[0] for C in CHANNELS]

    T = []
    #RESIZE ALL CHANNELS TO SAME SIZE
    for C in CHANNELS:
        S = C.shape[0] // 5424
        if S != 1: C = C[::S, ::S]
        T.append(C)

    CHANNELS = T

    # STACK ABI CHANNELS INTO SINGLE IMAGE
    ABI = np.stack(CHANNELS, axis=2)

    return ABI

In [None]:
DATA = gather_files("2017", "218", "18")

In [None]:
ABI = get_L1B_L2(DATA["00"], DATA["L200"], DATA["YYYY"], DATA["DDD"], DATA["HH"])

In [None]:
plt.imshow(ABI[..., 0])

In [None]:
# Generate 2 * N random locations within the Bounds

BOUNDS = (1000, 5424-1000)
N = 100

# GENERATE 2*N COORDINATES AND ENSURE BALANCED DISTRIBUTION
Xs = np.random.uniform(BOUNDS[0], BOUNDS[1], int(2 * N)).astype(int)
Ys = np.random.uniform(BOUNDS[0], BOUNDS[1], int(2 * N)).astype(int)
coordinate_pairs = list(zip(Xs, Ys))

In [None]:
def _vis_calibrate(data):
    """Calibrate visible channels to reflectance."""
    solar_irradiance = np.array(2017)
    esd = np.array(0.99)
    factor = np.pi * esd * esd / solar_irradiance

    return data * np.float32(factor) * 100
 
def _ir_calibrate(data):
    """Calibrate IR channels to BT."""
    fk1 = np.array(13432.1),
    fk2 = np.array(1497.61),
    bc1 = np.array(0.09102),
    bc2 = np.array(0.99971),

    # if self.clip_negative_radiances:
    #     min_rad = self._get_minimum_radiance(data)
    #     data = data.clip(min=data.dtype.type(min_rad))

    res = (fk2 / np.log(fk1 / data + 1) - bc1) / bc2
    return res

In [None]:
abiData = []

translation = [1, 2, 0, 4, 5, 6, 3, 8, 9, 10, 11, 13, 14, 15]
mask = np.array([True, True, True, True, True, False, True, False, False, False, False, False, False, False])

# Maximum amount of chips
MAXCHIPS = 100

for (x, y) in coordinate_pairs:
    chip = ABI[x-64:x+64, y-64:y+64, :]
    
    if MAXCHIPS < 1:
        break

    skip_chip = False

    # If there's any non values, which are always the next highest power of 2 - 1, skip the chip
    
    for exp in range(10, 16):
        if np.isin(2**exp-1, chip):
            skip_chip = True

    if skip_chip:
        continue

    MAXCHIPS -= 1

    # Reorder the channels
    
    chip = chip[..., translation]

    # Convert from irradiance to respective units
    
    chip[..., mask] = _vis_calibrate(chip[..., mask])
    chip[..., ~mask] = _ir_calibrate(chip[..., ~mask])

    abiData.append(chip)

abiData = np.array(abiData)

In [None]:
abiData.shape

In [None]:
with open('/explore/nobackup/people/szhang16/2017-218-18-abichips-100.npy', 'wb') as f:
    np.save(f, abiData)

In [None]:
# Use the Masked-Image-Modeling transform
transform = SimmimTransform(config)

# The reconstruction evaluation set is a single numpy file
# validation_dataset_path = config.DATA.DATA_PATHS[0]
# validation_dataset = np.load(validation_dataset_path)
# print(validation_dataset.shape)

# Load in the abiData
validation_dataset = abiData
# validation_dataset = origChips

# validation_dataset = validation_dataset[:5]
len_batch = range(validation_dataset.shape[0])

# Apply transform to each image in the batch
# A mask is auto-generated in the transform
imgMasks = [transform(validation_dataset[idx]) for idx \
    in len_batch]

# Seperate img and masks, cast masks to torch tensor
img = torch.stack([imgMask[0] for imgMask in imgMasks])
mask = torch.stack([torch.from_numpy(imgMask[1]) for \
    imgMask in imgMasks])

## 4. Prediction helper functions

In [None]:
def predict(model, dataloader, num_batches=5):

    inputs = []
    outputs = []
    masks = []
    losses = []
    with tqdm(total=num_batches) as pbar:

        for idx, img_mask in enumerate(dataloader):
            
            pbar.update(1)

            if idx > num_batches:
                return inputs, outputs, masks, losses

            img_mask = img_mask[0]

            img = torch.stack([pair[0] for pair in img_mask])
            mask = torch.stack([pair[1] for pair in img_mask])

            img = img.cuda(non_blocking=True)
            mask = mask.cuda(non_blocking=True)

            with torch.no_grad():
                with amp.autocast(enabled=config.ENABLE_AMP):
                    z = model.encoder(img, mask)
                    img_recon = model.decoder(z)
                    loss = model(img, mask)

            inputs.extend(img.cpu())
            masks.extend(mask.cpu())
            outputs.extend(img_recon.cpu())
            losses.append(loss.cpu())
    
    return inputs, outputs, masks, losses


def minmax_norm(img_arr):
    arr_min = img_arr.min()
    arr_max = img_arr.max()
    img_arr_scaled = (img_arr - arr_min) / (arr_max - arr_min)
    img_arr_scaled = img_arr_scaled * 255
    img_arr_scaled = img_arr_scaled.astype(np.uint8)
    return img_arr_scaled


def process_mask(mask):
    mask_img = mask.unsqueeze(0)
    mask_img = mask_img.repeat_interleave(4, 1).repeat_interleave(4, 2).unsqueeze(1).contiguous()
    mask_img = mask_img[0, 0, :, :]
    mask_img = np.stack([mask_img, mask_img, mask_img], axis=-1)
    return mask_img


def process_prediction(image, img_recon, mask, rgb_index):

    mask = process_mask(mask)
    
    red_idx = rgb_index[0]
    blue_idx = rgb_index[1]
    green_idx = rgb_index[2]

    image = image.numpy()
    rgb_image = np.stack((image[red_idx, :, :],
                          image[blue_idx, :, :],
                          image[green_idx, :, :]),
                         axis=-1)
    rgb_image = minmax_norm(rgb_image)

    img_recon = img_recon.numpy()
    
#     idx = 1
    
#     for channel in img_recon:
#         print(f'Channel #{idx}')
#         idx += 1
#         print("MIN: ", end="")
#         print(channel.min())
#         print("MAX: ", end="")
#         print(channel.max())
    
    rgb_image_recon = np.stack((img_recon[red_idx, :, :],
                                img_recon[blue_idx, :, :],
                                img_recon[green_idx, :, :]),
                                axis=-1)
    rgb_image_recon = minmax_norm(rgb_image_recon)

    rgb_masked = np.where(mask == 0, rgb_image, rgb_image_recon)
    rgb_image_masked = np.where(mask == 1, 0, rgb_image)
    rgb_recon_masked = rgb_masked
    
    return rgb_image, rgb_image_masked, rgb_recon_masked, mask

olosses = []

for i in range(14):
    olosses.append([])

lossMSE = torch.nn.MSELoss()

def plot_export_pdf(path, inputs, outputs, masks, rgb_index):
    pdf_plot_obj = PdfPages(path)

    # Loop through each chip
    for idx in range(len(inputs)):
        # prediction processing
        image = inputs[idx]
        img_recon = outputs[idx]
        mask = masks[idx]

        # Get the MSE loss per channel

        for i in range(14):
            oloss = lossMSE(image[i, ...], img_recon[i, ...])
            olosses[i].append(oloss.item())

    # CODE FOR PLOTTING ALL 14 CHANNELS AND THEIR RECONSTRUCTIONS
    
    #     rgb_image, rgb_image_masked, rgb_recon_masked, mask = \
    #         process_prediction(image, img_recon, mask, rgb_index)

    #     # matplotlib code
    #     # Loop through each channel
    #     fig, axs = plt.subplots(14, 2, figsize=(15, 105))
    #     for i in range(14):
    #         im0 = axs[i][0].imshow(img_recon[i, :, :])
    #         im1 = axs[i][1].imshow(image[i, :, :])
    #         plt.colorbar(im0)
    #         plt.colorbar(im1)
    #         axs[i][0].set_title(f"Image {idx} Channel Index {i} Model Reconstruction")
    #         axs[i][1].set_title(f"Image {idx} Channel Index {i} Original")
    #     plt.show()
    #     pdf_plot_obj.savefig(fig)

    # pdf_plot_obj.close()

## 5. Predict

In [None]:
inputs = []
outputs = []
masks = []
losses = []

# We could do this in a single batch however we
# want to report the loss per-image, in place of
# loss per-batch.
for i in tqdm(range(img.shape[0])):
    single_img = img[i].unsqueeze(0)
    single_mask = mask[i].unsqueeze(0)
    single_img = single_img.cuda(non_blocking=True)
    single_mask = single_mask.cuda(non_blocking=True)

    with torch.no_grad():
        z = model.encoder(single_img, single_mask)
        img_recon = model.decoder(z)
        loss = model(single_img, single_mask)

    inputs.extend(single_img.cpu())
    masks.extend(single_mask.cpu())
    outputs.extend(img_recon.cpu())
    losses.append(loss.cpu()) 

## 6. Plot and write to PDF

Writes out all of the predictions to a PDF file

In [None]:
pdf_path = './selected_satvision-toa-reconstruction-14-channels.pdf'
rgb_index = [0, 2, 1] # Indices of [Red band, Blue band, Green band]

plot_export_pdf(pdf_path, inputs, outputs, masks, rgb_index)

In [None]:
print("Model Losses")
print(len(losses))
print(np.mean(losses))
print(np.std(losses))
print(np.min(losses))
print(np.max(losses))

olosses = np.array(olosses)
print("MSE")
for i in range(14):
    print(f"Channel Idx: {i}")
    print(f"{np.mean(olosses[i])}")
    print(f"{np.std(olosses[i])}")
    print(f"{np.min(olosses[i])}")
    print(f"{np.max(olosses[i])}")