# ☁️ INSAT-3DR Cloud Motion Forecasting

This notebook demonstrates a multi-channel (TIR1 + WV) short-term cloud motion prediction using a 3D U-Net with SE blocks and LPIPS-SSIM-enhanced loss.

In [None]:
!pip install lpips scikit-image pytorch-msssim
import os, zipfile
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from skimage.metrics import peak_signal_noise_ratio as psnr, structural_similarity as ssim
import lpips
from pytorch_msssim import ssim as ssim_loss_fn


In [None]:
from google.colab import files
uploaded = files.upload()

zip_path = "/content/data.zip"
extract_dir = "/content/data"
if not os.path.exists(extract_dir):
    with zipfile.ZipFile(zip_path, 'r') as zip_ref:
        zip_ref.extractall(extract_dir)


In [None]:
base_path = "/content/data"
channels = ["TIR1", "WV"]
for ch in channels:
    ch_path = os.path.join(base_path, ch)
    files = sorted([f for f in os.listdir(ch_path) if f.endswith(".tif")])
    print(f"{ch}: {len(files)} files")


In [None]:
RESIZE = (128, 128)
CHANNELS = ["TIR1", "WV"]
BASE_PATH = "/content/data"

def normalize_tensor(t):
    return (t - t.mean()) / (t.std() + 1e-5)

class CloudDataset(Dataset):
    def __init__(self, base_path):
        self.X, self.Y = [], []
        self._load(base_path)

    def _load(self, base_path):
        files = {ch: sorted([os.path.join(base_path, ch, f) for f in os.listdir(os.path.join(base_path, ch)) if f.endswith(".tif")]) for ch in CHANNELS}
        for i in [0, 2, 4]:
            x_stack, y_stack = [], []
            for ch in CHANNELS:
                imgs = [Image.open(files[ch][j]).convert("L") for j in range(i, i+5)]
                tensors = [normalize_tensor(transforms.ToTensor()(transforms.Resize(RESIZE)(img))) for img in imgs]
                x_stack.append(torch.stack(tensors[:3]))
                y_stack.append(torch.stack(tensors[3:]))
            self.X.append(torch.stack(x_stack, dim=1))
            self.Y.append(torch.stack(y_stack, dim=1))

    def __len__(self): return len(self.X)
    def __getitem__(self, idx): return self.X[idx], self.Y[idx]


In [None]:
class SEBlock(nn.Module):
    def __init__(self, channels, reduction=16):
        super().__init__()
        self.pool = nn.AdaptiveAvgPool3d(1)
        self.fc = nn.Sequential(
            nn.Linear(channels, channels // reduction),
            nn.ReLU(inplace=True),
            nn.Linear(channels // reduction, channels),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _, _ = x.size()
        y = self.pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1, 1)
        return x * y

class UNet3D(nn.Module):
    def __init__(self, in_channels=2, out_channels=2):
        super().__init__()
        self.down1 = nn.Sequential(nn.Conv3d(in_channels, 32, 3, padding=1), nn.BatchNorm3d(32), nn.ReLU(), SEBlock(32))
        self.down2 = nn.Sequential(nn.Conv3d(32, 64, 3, padding=1), nn.BatchNorm3d(64), nn.ReLU(), SEBlock(64))
        self.pool = nn.MaxPool3d(2)
        self.middle = nn.Sequential(nn.Conv3d(64, 128, 3, padding=1), nn.BatchNorm3d(128), nn.ReLU(), SEBlock(128))
        self.up = nn.Upsample(scale_factor=2, mode="trilinear", align_corners=False)
        self.up1 = nn.Sequential(nn.Conv3d(128, 64, 3, padding=1), nn.BatchNorm3d(64), nn.ReLU())
        self.up2 = nn.Sequential(nn.Conv3d(64, 32, 3, padding=1), nn.BatchNorm3d(32), nn.ReLU())
        self.out = nn.Conv3d(32, out_channels, 1)

    def forward(self, x):
        x1 = self.down1(x)
        x2 = self.down2(self.pool(x1))
        m = self.middle(x2)
        x3 = self.up1(self.up(m))
        x4 = self.up2(x3)
        return torch.tanh(self.out(x4))


In [None]:
lp = lpips.LPIPS(net='alex').to("cuda" if torch.cuda.is_available() else "cpu")

class CombinedLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.mse = nn.MSELoss()
        self.lpips = lpips.LPIPS(net='alex').to("cuda" if torch.cuda.is_available() else "cpu")

    def forward(self, pred, target):
        mse_loss = self.mse(pred, target)
        ssim_total = 0.0
        lpips_total = 0.0
        N, C, T, H, W = pred.shape

        for t in range(T):
            for c in range(C):
                pr = pred[:, c, t]
                gt = target[:, c, t]
                ssim_val = ssim_loss_fn(pr.unsqueeze(1), gt.unsqueeze(1), data_range=1.0)
                ssim_total += (1 - ssim_val)
                pr_rgb = pr.repeat(1, 3, 1, 1)
                gt_rgb = gt.repeat(1, 3, 1, 1)
                lpips_val = self.lpips(pr_rgb, gt_rgb)
                lpips_total += lpips_val.mean()

        return mse_loss + 0.1 * (ssim_total / (C * T)) + 0.1 * (lpips_total / (C * T))


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
EPOCHS, LR, BATCH_SIZE = 200, 1e-4, 1
dataset = CloudDataset(BASE_PATH)
loader = DataLoader(dataset, batch_size=BATCH_SIZE)
model = UNet3D().to(device)
opt = torch.optim.Adam(model.parameters(), lr=LR)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, mode='min', patience=10, factor=0.5, verbose=True)
loss_fn = CombinedLoss()

for epoch in range(EPOCHS):
    model.train()
    total_loss = 0
    for xb, yb in loader:
        xb, yb = xb.to(device), yb.to(device)
        xb = xb.squeeze(3).permute(0, 2, 1, 3, 4)
        yb = yb.squeeze(3).permute(0, 2, 1, 3, 4)
        opt.zero_grad()
        out = model(xb)
        loss = loss_fn(out, yb)
        loss.backward()
        opt.step()
        total_loss += loss.item()
    print(f"Epoch {epoch+1}/{EPOCHS}, Loss: {total_loss/len(loader):.4f}")
    scheduler.step(total_loss/len(loader))
