In [1]:
#IMPORTS

#File IO
import os
import glob

#Data manipulation
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

#Pytorch
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader, random_split
import torch.nn.functional as F
from torch.amp import GradScaler, autocast
from segmentation_models_pytorch import Unet

#Scikit learn
from sklearn.metrics import confusion_matrix, classification_report
from sklearn.metrics import mean_squared_error, r2_score

#Misc
from tqdm import tqdm


In [2]:
#HYPERPARMETERS

train_proportion = .8
val_proportion = .1

batch_size = 128
learning_rate = .00002
num_epochs = 50
loss_weights = (1.0, 1.0, .01, .01)

In [3]:
#LOAD DATASET

dataset = torch.load("dataset.pt", weights_only=False)

In [4]:
#CREATE DATALOADERS

train_size = int(train_proportion * len(dataset))
val_size = int(val_proportion * len(dataset))
test_size = len(dataset) - train_size - val_size

generator = torch.Generator().manual_seed(1)
train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [5]:
#CREATE MODEL

class MultiTaskCNN(nn.Module):
    def __init__(self, in_channels=16):
        super().__init__()

        # Shared feature extractor
        self.shared = nn.Sequential(
            nn.Conv2d(in_channels, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),

            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),

            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU()
        )

        # Heads
        self.mask_head = nn.Conv2d(256, 1, kernel_size=1)      # Binary classification
        self.phase_head = nn.Conv2d(256, 5, kernel_size=1)     # Multi-class classification
        self.cod_head = nn.Conv2d(256, 1, kernel_size=1)       # Regression
        self.cps_head = nn.Conv2d(256, 1, kernel_size=1)       # Regression

    def forward(self, x):
        shared_feats = self.shared(x)

        cloud_mask_logits = self.mask_head(shared_feats)          # [B, 1, H, W]
        cloud_mask = torch.sigmoid(cloud_mask_logits)             # Convert to probabilities

        # Create binary mask: 1 where mask > 0.5, else 0
        binary_mask = (cloud_mask > 0.5).float()                  # [B, 1, H, W]

        # Other heads
        cloud_phase_logits = self.phase_head(shared_feats)        # [B, 5, H, W]
        cod_pred = self.cod_head(shared_feats)                    # [B, 1, H, W]
        cps_pred = self.cps_head(shared_feats)                    # [B, 1, H, W]

        # Apply mask logic: zero out other outputs where mask == 0
        cloud_phase_logits = cloud_phase_logits * binary_mask     # broadcast on [B, 5, H, W]
        cod_pred = cod_pred * binary_mask
        cps_pred = cps_pred * binary_mask

        return cloud_mask_logits, cloud_phase_logits, cod_pred, cps_pred




In [6]:
#FINAL SETUP

dev_str = "cuda" if torch.cuda.is_available() else "cpu"
device = torch.device(dev_str)
model = MultiTaskCNN().to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3)
scaler = GradScaler(device = device)

def unpack_labels(labels):
    return (
        labels[:, 0:1, :, :],             # cloud_mask → [B, H, W]
        labels[:, 1, :, :].long(),        # cloud_phase → [B, H, W]
        labels[:, 2:3, :, :],             # cod → [B, 1, H, W]
        labels[:, 3:4, :, :]              # cps → [B, 1, H, W]
    )


train_mask_losses, train_phase_losses, train_cod_losses, train_cps_losses, train_all_losses = [], [], [], [], []
val_mask_losses, val_phase_losses, val_cod_losses, val_cps_losses, val_all_losses = [], [], [], [], []

train_mask_acc, train_phase_acc, train_cod_r2, train_cps_r2 = [], [], [], []
val_mask_acc, val_phase_acc, val_cod_r2, val_cps_r2 = [], [], [], []

In [7]:
#TRAIN and EVALUATE FUNCTIONS

def train(model, train_loader, loss_weights=(1,1,1,1)):
    model.train()

    total_instances = 0 # Count number of instances in the epoch
    total_loss = total_loss_mask = total_loss_phase = total_loss_cod = total_loss_cps = 0 # Total loss and sublosses
    mask_correct = phase_correct = 0 # Number of correct guesses for cloud_mask and cloud_phase
    cod_preds, cod_labels = [], [] # Cod labels and predictions for calculating r2 
    cps_preds, cps_labels = [], [] # Cps labels and predictions for calculating r2

    for inputs, labels in tqdm(train_loader):
        inputs = inputs.to(device)
        labels = labels.to(device)
        cloud_mask_target, cloud_phase_target, cod_target, cps_target = unpack_labels(labels) # Get individual targets

        optimizer.zero_grad()
        B, _, H, W = labels.shape
        total_instances += (B * H * W) # Count instances in batch
        
        with autocast(device_type=dev_str):
            preds = model(inputs) # Model predictions

            # Loss in batch
            loss_mask = nn.BCEWithLogitsLoss()(preds[0], cloud_mask_target)
            loss_phase = nn.CrossEntropyLoss()(preds[1], cloud_phase_target)
            loss_cod = nn.MSELoss()(preds[2], cod_target)
            loss_cps = nn.MSELoss()(preds[3], cps_target)
            total_batch_loss = (
                loss_weights[0] * loss_mask +
                loss_weights[1] * loss_phase +
                loss_weights[2] * loss_cod +
                loss_weights[3] * loss_cps
            )

        # Get correct guesses for mask and phase
        mask_preds = (torch.sigmoid(preds[0]) > 0.5).long()
        mask_correct += (mask_preds == cloud_mask_target).sum().item()
        phase_preds = torch.argmax(preds[1], dim=1)
        phase_correct += (phase_preds == cloud_phase_target).sum().item()

        # Get predicted and actual cod and cps
        cod_preds.append(preds[2].cpu().detach().numpy())
        cod_labels.append(cod_target.cpu().numpy())
        cps_preds.append(preds[3].cpu().detach().numpy())
        cps_labels.append(cps_target.cpu().numpy())
            

        # Update model
        scaler.scale(total_batch_loss).backward()
        scaler.step(optimizer)
        scaler.update()

        # Loss in epoch
        total_loss += total_batch_loss.item()
        total_loss_mask += loss_mask.item()
        total_loss_phase += loss_phase.item()
        total_loss_cod += loss_cod.item()
        total_loss_cps += loss_cps.item()

    # ------------------------------------ BATCH LOOP END -------------------------------------------------------------

    #Accuracy for mask and phase
    mask_accuracy = mask_correct/total_instances
    phase_accuracy = phase_correct/total_instances

    #Cod R2
    cod_preds = np.concatenate(cod_preds).ravel()
    cod_labels = np.concatenate(cod_labels).ravel()
    cod_r2 = r2_score(cod_labels, cod_preds)

    #Cps R2
    cps_preds = np.concatenate(cps_preds).ravel()
    cps_labels = np.concatenate(cps_labels).ravel()
    cps_r2 = r2_score(cps_labels, cps_preds)

    return {
        'loss_total': total_loss / len(train_loader),
        'loss_mask': total_loss_mask / len(train_loader),
        'loss_phase': total_loss_phase / len(train_loader),
        'loss_cod': total_loss_cod / len(train_loader),
        'loss_cps': total_loss_cps / len(train_loader),
        'acc_mask': mask_accuracy,
        'acc_phase': phase_accuracy,
        'r2_cod': cod_r2,
        'r2_cps': cps_r2
    }



def eval(model, val_loader, loss_weights=(1,1,1,1)):
    model.eval()

    total_instances = 0 # Count number of instances in the epoch
    total_loss = total_loss_mask = total_loss_phase = total_loss_cod = total_loss_cps = 0 # Total loss and sublosses
    mask_correct = phase_correct = 0 # Number of correct guesses for cloud_mask and cloud_phase
    cod_preds, cod_labels = [], [] # Cod labels and predictions for calculating r2 
    cps_preds, cps_labels = [], [] # Cps labels and predictions for calculating r2

    for inputs, labels in (val_loader):
        inputs = inputs.to(device)
        labels = labels.to(device)
        cloud_mask_target, cloud_phase_target, cod_target, cps_target = unpack_labels(labels) # Get individual targets

        B, _, H, W = labels.shape
        total_instances += (B * H * W) # Count instances in batch
        
        
        with torch.no_grad(), autocast(device_type=dev_str):
            preds = model(inputs) # Model predictions

            # Loss in batch
            loss_mask = nn.BCEWithLogitsLoss()(preds[0], cloud_mask_target)
            loss_phase = nn.CrossEntropyLoss()(preds[1], cloud_phase_target)
            loss_cod = nn.MSELoss()(preds[2], cod_target)
            loss_cps = nn.MSELoss()(preds[3], cps_target)
            total_batch_loss = (
                loss_weights[0] * loss_mask +
                loss_weights[1] * loss_phase +
                loss_weights[2] * loss_cod +
                loss_weights[3] * loss_cps
            )

        # Get correct guesses for mask and phase
        mask_preds = (torch.sigmoid(preds[0]) > 0.5).long()
        mask_correct += (mask_preds == cloud_mask_target).sum().item()
        phase_preds = torch.argmax(preds[1], dim=1)
        phase_correct += (phase_preds == cloud_phase_target).sum().item()

        # Get predicted and actual cod and cps
        cod_preds.append(preds[2].cpu().detach().numpy())
        cod_labels.append(cod_target.cpu().numpy())
        cps_preds.append(preds[3].cpu().detach().numpy())
        cps_labels.append(cps_target.cpu().numpy())
            
        # Loss in epoch
        total_loss += total_batch_loss.item()
        total_loss_mask += loss_mask.item()
        total_loss_phase += loss_phase.item()
        total_loss_cod += loss_cod.item()
        total_loss_cps += loss_cps.item()

    # ------------------------------------ BATCH LOOP END -------------------------------------------------------------

    #Accuracy for mask and phase
    mask_accuracy = mask_correct/total_instances
    phase_accuracy = phase_correct/total_instances

    #Cod R2
    cod_preds = np.concatenate(cod_preds).ravel()
    cod_labels = np.concatenate(cod_labels).ravel()
    cod_r2 = r2_score(cod_labels, cod_preds)

    #Cps R2
    cps_preds = np.concatenate(cps_preds).ravel()
    cps_labels = np.concatenate(cps_labels).ravel()
    cps_r2 = r2_score(cps_labels, cps_preds)

    return {
        'loss_total': total_loss / len(val_loader),
        'loss_mask': total_loss_mask / len(val_loader),
        'loss_phase': total_loss_phase / len(val_loader),
        'loss_cod': total_loss_cod / len(val_loader),
        'loss_cps': total_loss_cps / len(val_loader),
        'acc_mask': mask_accuracy,
        'acc_phase': phase_accuracy,
        'r2_cod': cod_r2,
        'r2_cps': cps_r2
    }

In [None]:
#TRAIN MODEL

for e in range(1, num_epochs+1):
    train_results = train(model, train_loader,) #loss_weights = loss_weights)
    val_results = eval(model, val_loader,) #loss_weights = loss_weights)
    
    train_mask_losses.append(train_results['loss_mask'])
    train_phase_losses.append(train_results['loss_phase'])
    train_cod_losses.append(train_results['loss_cod'])
    train_cps_losses.append(train_results['loss_cps'])
    train_all_losses.append(train_results['loss_total'])

    train_mask_acc.append(train_results['acc_mask'])
    train_phase_acc.append(train_results['acc_phase'])
    train_cod_r2.append(train_results['r2_cod'])
    train_cps_r2.append(train_results['r2_cps'])

    val_mask_losses.append(val_results['loss_mask'])
    val_phase_losses.append(val_results['loss_phase'])
    val_cod_losses.append(val_results['loss_cod'])
    val_cps_losses.append(val_results['loss_cps'])
    val_all_losses.append(val_results['loss_total'])

    val_mask_acc.append(val_results['acc_mask'])
    val_phase_acc.append(val_results['acc_phase'])
    val_cod_r2.append(val_results['r2_cod'])
    val_cps_r2.append(val_results['r2_cps'])

    print(f"Epoch: {e} | Train Loss: {train_results['loss_total']:.4f} | Val Loss: {val_results['loss_total']:.4f}")

100%|██████████| 94/94 [01:30<00:00,  1.04it/s]


Epoch: 1 | Train Loss: 271.1431 | Val Loss: 254.7181


100%|██████████| 94/94 [01:39<00:00,  1.05s/it]


Epoch: 2 | Train Loss: 235.0202 | Val Loss: 205.6650


100%|██████████| 94/94 [01:35<00:00,  1.02s/it]


Epoch: 3 | Train Loss: 220.3540 | Val Loss: 187.7062


100%|██████████| 94/94 [01:36<00:00,  1.02s/it]


Epoch: 4 | Train Loss: 210.2730 | Val Loss: 206.2323


100%|██████████| 94/94 [01:31<00:00,  1.03it/s]


Epoch: 5 | Train Loss: 202.9861 | Val Loss: 178.1313


100%|██████████| 94/94 [01:24<00:00,  1.11it/s]


Epoch: 6 | Train Loss: 196.8585 | Val Loss: 194.2941


100%|██████████| 94/94 [01:29<00:00,  1.05it/s]


Epoch: 7 | Train Loss: 191.4619 | Val Loss: 173.4021


100%|██████████| 94/94 [01:33<00:00,  1.00it/s]


Epoch: 8 | Train Loss: 186.6996 | Val Loss: 185.8151


100%|██████████| 94/94 [01:17<00:00,  1.21it/s]


Epoch: 9 | Train Loss: 182.3122 | Val Loss: 176.3728


100%|██████████| 94/94 [01:16<00:00,  1.23it/s]


Epoch: 10 | Train Loss: 177.7419 | Val Loss: 154.1308


100%|██████████| 94/94 [01:18<00:00,  1.20it/s]


Epoch: 11 | Train Loss: 173.3193 | Val Loss: 156.9480


100%|██████████| 94/94 [01:16<00:00,  1.23it/s]


Epoch: 12 | Train Loss: 168.9732 | Val Loss: 132.8507


100%|██████████| 94/94 [01:19<00:00,  1.18it/s]


Epoch: 13 | Train Loss: 164.7519 | Val Loss: 164.3162


100%|██████████| 94/94 [01:18<00:00,  1.20it/s]


Epoch: 14 | Train Loss: 161.0266 | Val Loss: 158.5956


100%|██████████| 94/94 [01:20<00:00,  1.17it/s]


Epoch: 15 | Train Loss: 156.8339 | Val Loss: 140.5637


100%|██████████| 94/94 [01:21<00:00,  1.15it/s]


Epoch: 16 | Train Loss: 152.9026 | Val Loss: 130.6007


100%|██████████| 94/94 [01:13<00:00,  1.29it/s]


Epoch: 17 | Train Loss: 149.1331 | Val Loss: 138.6572


100%|██████████| 94/94 [01:17<00:00,  1.21it/s]


Epoch: 18 | Train Loss: 145.9188 | Val Loss: 122.7628


100%|██████████| 94/94 [01:21<00:00,  1.15it/s]


Epoch: 19 | Train Loss: 142.7246 | Val Loss: 118.0048


100%|██████████| 94/94 [01:15<00:00,  1.24it/s]


Epoch: 20 | Train Loss: 139.3440 | Val Loss: 136.5994


100%|██████████| 94/94 [01:32<00:00,  1.01it/s]


Epoch: 21 | Train Loss: 136.1711 | Val Loss: 100.6048


100%|██████████| 94/94 [01:26<00:00,  1.09it/s]


Epoch: 22 | Train Loss: 133.2170 | Val Loss: 121.0267


100%|██████████| 94/94 [01:28<00:00,  1.06it/s]


Epoch: 23 | Train Loss: 130.2481 | Val Loss: 150.5054


100%|██████████| 94/94 [01:30<00:00,  1.04it/s]


Epoch: 24 | Train Loss: 127.6388 | Val Loss: 165.2654


100%|██████████| 94/94 [01:08<00:00,  1.37it/s]


Epoch: 25 | Train Loss: 125.2474 | Val Loss: 124.6097


100%|██████████| 94/94 [00:55<00:00,  1.70it/s]


Epoch: 26 | Train Loss: 122.4383 | Val Loss: 144.2065


100%|██████████| 94/94 [01:28<00:00,  1.06it/s]


Epoch: 27 | Train Loss: 120.1821 | Val Loss: 96.8946


100%|██████████| 94/94 [01:27<00:00,  1.07it/s]


Epoch: 28 | Train Loss: 117.7636 | Val Loss: 115.2979


100%|██████████| 94/94 [00:59<00:00,  1.59it/s]


Epoch: 29 | Train Loss: 115.6592 | Val Loss: 105.7356


100%|██████████| 94/94 [01:38<00:00,  1.05s/it]


Epoch: 30 | Train Loss: 113.7465 | Val Loss: 130.0305


100%|██████████| 94/94 [01:20<00:00,  1.17it/s]


Epoch: 31 | Train Loss: 111.4708 | Val Loss: 101.5216


100%|██████████| 94/94 [01:21<00:00,  1.16it/s]


Epoch: 32 | Train Loss: 109.5046 | Val Loss: 95.2156


100%|██████████| 94/94 [01:35<00:00,  1.01s/it]


Epoch: 33 | Train Loss: 107.5189 | Val Loss: 128.8489


100%|██████████| 94/94 [01:44<00:00,  1.11s/it]


Epoch: 34 | Train Loss: 106.2419 | Val Loss: 106.7997


100%|██████████| 94/94 [01:36<00:00,  1.03s/it]


Epoch: 35 | Train Loss: 104.2420 | Val Loss: 101.1791


100%|██████████| 94/94 [01:41<00:00,  1.08s/it]


Epoch: 36 | Train Loss: 102.6823 | Val Loss: 127.2844


100%|██████████| 94/94 [01:27<00:00,  1.08it/s]


Epoch: 37 | Train Loss: 101.2586 | Val Loss: 94.7942


100%|██████████| 94/94 [01:31<00:00,  1.02it/s]


Epoch: 38 | Train Loss: 99.5064 | Val Loss: 102.3397


100%|██████████| 94/94 [01:36<00:00,  1.03s/it]


Epoch: 39 | Train Loss: 98.2907 | Val Loss: 160.3808


100%|██████████| 94/94 [01:33<00:00,  1.01it/s]


Epoch: 40 | Train Loss: 97.0845 | Val Loss: 115.2522


100%|██████████| 94/94 [01:19<00:00,  1.18it/s]


Epoch: 41 | Train Loss: 95.8348 | Val Loss: 102.8815


100%|██████████| 94/94 [01:35<00:00,  1.02s/it]


Epoch: 42 | Train Loss: 94.5568 | Val Loss: 98.0194


100%|██████████| 94/94 [01:36<00:00,  1.02s/it]


Epoch: 43 | Train Loss: 93.7314 | Val Loss: 97.7114


100%|██████████| 94/94 [01:39<00:00,  1.06s/it]


Epoch: 44 | Train Loss: 92.4806 | Val Loss: 99.9631


100%|██████████| 94/94 [00:57<00:00,  1.63it/s]


Epoch: 45 | Train Loss: 91.6275 | Val Loss: 92.6852


100%|██████████| 94/94 [01:14<00:00,  1.26it/s]


Epoch: 46 | Train Loss: 90.8914 | Val Loss: 100.5348


100%|██████████| 94/94 [01:28<00:00,  1.06it/s]


Epoch: 47 | Train Loss: 90.1255 | Val Loss: 93.8911


100%|██████████| 94/94 [01:18<00:00,  1.20it/s]


Epoch: 48 | Train Loss: 89.6174 | Val Loss: 86.1848


100%|██████████| 94/94 [01:42<00:00,  1.09s/it]


Epoch: 49 | Train Loss: 89.2823 | Val Loss: 116.7973


 21%|██▏       | 20/94 [00:24<01:47,  1.45s/it]

In [None]:
#PLOT LOSS

plt.figure(figsize=(20,8))

plt.subplot(2, 3, 1)
plt.plot(train_mask_losses, label='Train Loss')
plt.plot(val_mask_losses, label='Val Loss')
plt.title('Cloud Mask Loss')
plt.legend()

plt.subplot(2, 3, 2)
plt.plot(train_phase_losses, label='Train Loss')
plt.plot(val_phase_losses, label='Val Loss')
plt.title('Cloud Phase Loss')
plt.legend()

plt.subplot(2, 3, 3)
plt.plot(train_cod_losses, label='Train Loss')
plt.plot(val_cod_losses, label='Val Loss')
plt.title('Cod Loss')
plt.legend()

plt.subplot(2, 3, 4)
plt.plot(train_cps_losses, label='Train Loss')
plt.plot(val_cps_losses, label='Val Loss')
plt.title('Cps Loss')
plt.legend()

plt.subplot(2, 3, 5)
plt.plot(train_all_losses, label='Train Loss')
plt.plot(val_all_losses, label='Val Loss')
plt.title('All Loss')
plt.legend()

#plt.savefig("./graphs/cloud_mask_unet_loss_and_acc.png")  
#plt.savefig("./graphs/cloud_mask_unet_loss_and_acc.png")  


plt.show()

In [None]:
#PLOT ACCURACY

plt.figure(figsize=(16,8))

plt.subplot(2, 2, 1)
plt.plot(train_mask_acc, label='Train Acc')
plt.plot(val_mask_acc, label='Val Acc')
plt.title('Cloud Mask Acc')
plt.legend()

plt.subplot(2, 2, 2)
plt.plot(train_phase_acc, label='Train Acc')
plt.plot(val_phase_acc, label='Val Acc')
plt.title('Cloud Phase Acc')
plt.legend()

plt.subplot(2, 2, 3)
plt.plot(train_cod_r2, label='Train R2')
plt.plot(val_cod_r2, label='Val R2')
plt.title('Cod R2')
plt.legend()

plt.subplot(2, 2, 4)
plt.plot(train_cps_r2, label='Train R2')
plt.plot(val_cps_r2, label='Val R2')
plt.title('Cps R2')
plt.legend()

plt.show()

In [None]:
# MODEL EVALUATION — CLOUD MASK

all_preds = []
all_labels = []

model.eval()
with torch.no_grad():
    for images, labels in tqdm(test_loader):
        images = images.to(device)
        labels = labels.to(device)

        cloud_mask_target, _, _, _ = unpack_labels(labels)
        cloud_mask_pred, _, _, _ = model(images)

        probs = torch.sigmoid(cloud_mask_pred)
        preds = (probs > 0.5).long()

        all_preds.append(preds.cpu().numpy())
        all_labels.append(cloud_mask_target.cpu().numpy())

# Flatten predictions and labels
all_preds = np.concatenate([p.flatten() for p in all_preds])
all_labels = np.concatenate([l.flatten() for l in all_labels])

# Classification report and IoU
report = classification_report(
    all_labels, all_preds,
    labels=[0, 1],
    digits=3,
    output_dict=True,
    zero_division=0
)


# Output
print("CLOUD MASK REPORT:\n", classification_report(all_labels, all_preds, labels=[0, 1], digits=3))
print("CONFUSION MATRIX:\n", confusion_matrix(all_labels, all_preds, labels=[0, 1]))

In [None]:
# MODEL EVALUATION — CLOUD PHASE

num_classes = 5

all_preds = []
all_labels = []

model.eval()
with torch.no_grad():
    for images, labels in tqdm(test_loader):
        images = images.to(device)
        labels = labels.to(device)

        _, cloud_phase_target, _, _ = unpack_labels(labels)
        _, cloud_phase_pred, _, _ = model(images)
        preds = torch.argmax(cloud_phase_pred, dim=1)

        all_preds.append(preds.cpu().numpy())
        all_labels.append(cloud_phase_target.cpu().numpy())
        
# Flatten predictions and labels
all_preds = np.concatenate([p.flatten() for p in all_preds])
all_labels = np.concatenate([l.flatten() for l in all_labels])

# Classification report and IoU
report = classification_report(all_labels, all_preds, digits=3, output_dict=True)
f1_scores = np.array([report[str(i)]['f1-score'] for i in range(num_classes)])
supports = np.array([report[str(i)]['support'] for i in range(num_classes)])
iou = f1_scores / (2 - f1_scores)

# Output
print("CLOUD PHASE REPORT:\n", classification_report(all_labels, all_preds, digits=3))
print("CONFUSION MATRIX:\n", confusion_matrix(all_labels, all_preds))
print("\nIOU:", iou)
print("Unweighted IoU:", np.mean(iou))
print("Weighted IoU:", np.average(iou, weights=supports))


In [None]:
# MODEL EVALUATION — CLOUD OPTICAL DISTANCE

all_preds = []
all_labels = []

model.eval()
with torch.no_grad():
    for images, labels in tqdm(test_loader):
        images = images.to(device).float()
        labels = labels.to(device)

        _, _, cod_target, _ = unpack_labels(labels)
        _, _, cod_pred, _ = model(images)
        
        all_preds.append(cod_pred.cpu().numpy().reshape(-1))
        all_labels.append(cod_target.cpu().numpy().reshape(-1))

all_preds = np.concatenate(all_preds)
all_labels = np.concatenate(all_labels)

print("r2:", r2_score(all_labels, all_preds))
print("MSE:", mean_squared_error(all_labels, all_preds))

In [None]:
# MODEL EVALUATION — CLOUD PARTICLE SIZE

all_preds = []
all_labels = []

model.eval()
with torch.no_grad():
    for images, labels in tqdm(test_loader):
        images = images.to(device).float()
        labels = labels.to(device)

        _, _, _, cps_target = unpack_labels(labels)
        _, _, _, cps_pred = model(images)
        
        all_preds.append(cps_pred.cpu().numpy().reshape(-1))
        all_labels.append(cps_target.cpu().numpy().reshape(-1))

all_preds = np.concatenate(all_preds)
all_labels = np.concatenate(all_labels)

print("r2:", r2_score(all_labels, all_preds))
print("MSE:", mean_squared_error(all_labels, all_preds))

In [None]:
# data_iter = iter(test_loader)
# images, labels = next(data_iter)

# images = images.to(device)
# labels = labels.to(device)
# _, cloud_phase_target, _, _ = unpack_labels(labels)
# _, cloud_phase_pred, _, _ = model(images)
# preds = torch.argmax(cloud_phase_pred, dim=1)


In [None]:
# my_image = images[0].cpu().numpy()
# my_image = np.transpose(my_image, (1, 2, 0))
# phase_pred = preds[0].cpu().numpy()
# phase_target = cloud_phase_target[0].cpu().numpy()



# from netCDF4 import Dataset
# import numpy as np


# with Dataset('image1.nc', 'w', format='NETCDF4') as ds:
#     ds.createDimension('x', my_image.shape[0])
#     ds.createDimension('y', my_image.shape[1])
#     ds.createDimension('band', my_image.shape[2])

#     var = ds.createVariable('radiance', 'f4', ('x', 'y', 'band'))
#     var[:] = my_image

#     var.units = 'unknown'  # optional metadata

# with Dataset('image2.nc', 'w', format='NETCDF4') as ds:
#     ds.createDimension('x', phase_pred.shape[0])
#     ds.createDimension('y', phase_pred.shape[1])

#     var = ds.createVariable('prediction', 'f4', ('x', 'y'))
#     var[:] = phase_pred

#     var.units = 'unknown'  # optional metadata


# with Dataset('image3.nc', 'w', format='NETCDF4') as ds:
#     ds.createDimension('x', phase_target.shape[0])
#     ds.createDimension('y', phase_target.shape[1])

#     var = ds.createVariable('target', 'f4', ('x', 'y'))
#     var[:] = phase_target

#     var.units = 'unknown'  # optional metadata

