In [1]:
pip install rasterio

Note: you may need to restart the kernel to use updated packages.


In [2]:
import os
from sklearn.model_selection import train_test_split

# Paths
img_dir = "/kaggle/input/images/images/images"
label_dir = "/kaggle/input/images/labels/labels"

# List images and labels
img_list = sorted([f for f in os.listdir(img_dir) if f.endswith(".tif")])
label_list = sorted([f for f in os.listdir(label_dir) if f.endswith(".png")])

# Keep only matching pairs
img_paths = [os.path.join(img_dir, f) for f in img_list if f.replace(".tif", ".png") in label_list]
label_paths = [os.path.join(label_dir, f.replace(".tif", ".png")) for f in img_list if f.replace(".tif", ".png") in label_list]

# Train-test split
train_img_paths, test_img_paths, train_label_paths, test_label_paths = train_test_split(
    img_paths, label_paths, test_size=0.2, random_state=42
)


In [3]:
# import torch
# import torch.nn as nn
# import tifffile as tiff
# import numpy as np
# import cv2
# import matplotlib.pyplot as plt
# from PIL import Image  # For loading PNG ground truth

# # **Step 1: Load the 6-Channel Image**
# image_path = "/kaggle/input/images/images/images/3.tif"
# label_path = "/kaggle/input/images/labels/labels/3.png"  # Update this path

# image = tiff.imread(image_path).astype(np.float32)  # Shape: (H, W, C)
# image = np.transpose(image, (2, 0, 1))  # Convert to (C, H, W)
# assert image.shape[0] == 6, f"Expected 6 channels, got {image.shape[0]}"

# # **Step 2: Load the Ground Truth PNG Label**
# label = Image.open(label_path).convert("L")  # Convert to grayscale
# label = np.array(label)  # Convert to NumPy array

# # **Step 3: Normalize Continuous Channels ([0,1])**
# def normalize_channel(channel):
#     return (channel - np.min(channel)) / (np.max(channel) - np.min(channel) + 1e-6)

# # Normalize all channels except LULC (5th channel)
# for i in [0, 1, 2, 3, 5]:  # Skip index 4 (LULC)
#     image[i] = normalize_channel(image[i])

# # **Step 4: Resize Low-Resolution Channels**
# target_size = (image.shape[1], image.shape[2])  # Target resolution (H, W) = 10m
# image[3] = cv2.resize(image[3], target_size, interpolation=cv2.INTER_CUBIC)  # DEM 30m
# image[5] = cv2.resize(image[5], target_size, interpolation=cv2.INTER_CUBIC)  # Water Occurrence

# # **Step 5: One-Hot Encode LULC (Categorical Channel)**
# unique_classes = np.unique(image[4])  # Find unique class labels
# num_classes = len(unique_classes)
# print(f"LULC Unique Classes: {unique_classes}")

# # Create empty one-hot encoded array
# lulc_one_hot = np.zeros((num_classes, image.shape[1], image.shape[2]))

# for i, class_value in enumerate(unique_classes):
#     lulc_one_hot[i] = (image[4] == class_value).astype(np.float32)

# # **Step 6: Combine Channels for CNN**
# # Continuous Channels: [SAR1, SAR2, DEM1, DEM2, Water Occurrence]
# continuous_channels = np.array([image[i] for i in [0, 1, 2, 3, 5]])
# combined_input = np.concatenate((continuous_channels, lulc_one_hot), axis=0)  # Shape: (C, H, W)

# # Convert to PyTorch Tensor
# image_tensor = torch.tensor(combined_input, dtype=torch.float32).unsqueeze(0)  # Shape: (1, C, H, W)
# print("Updated Input Image Shape:", image_tensor.shape)  # Should be (1, C, H, W)

# # **Step 7: Define CNN-Based Fusion Model**
# import torch
# import torch.nn as nn
# import torch.nn.functional as F

# class FusionCNN(nn.Module):
#     def __init__(self, num_input_channels):
#         super(FusionCNN, self).__init__()

#         # First block: Extract low-level spatial features
#         self.conv1 = nn.Conv2d(in_channels=num_input_channels, out_channels=64, kernel_size=3, padding=1)
#         self.bn1 = nn.BatchNorm2d(64)  # Batch normalization
#         self.conv2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1)
#         self.bn2 = nn.BatchNorm2d(64)

#         # Second block: Extract mid-level features
#         self.conv3 = nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3, padding=1)
#         self.bn3 = nn.BatchNorm2d(32)
#         self.conv4 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, padding=1)
#         self.bn4 = nn.BatchNorm2d(32)

#         # Third block: Extract high-level semantic features
#         self.conv5 = nn.Conv2d(in_channels=32, out_channels=16, kernel_size=3, padding=1)
#         self.bn5 = nn.BatchNorm2d(16)
#         self.conv6 = nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, padding=1)
#         self.bn6 = nn.BatchNorm2d(16)

#         # Final layer: Reduce to 1 feature map
#         self.conv_final = nn.Conv2d(in_channels=16, out_channels=1, kernel_size=3, padding=1)
        
#         # Dropout for regularization
#         self.dropout = nn.Dropout(0.3)  

#     def forward(self, x):
#         # First block
#         x = F.relu(self.bn1(self.conv1(x)))
#         x = F.relu(self.bn2(self.conv2(x)))
#         x = self.dropout(x)  # Apply dropout

#         # Second block
#         x = F.relu(self.bn3(self.conv3(x)))
#         x = F.relu(self.bn4(self.conv4(x)))
#         x = self.dropout(x)

#         # Third block
#         x = F.relu(self.bn5(self.conv5(x)))
#         x = F.relu(self.bn6(self.conv6(x)))
#         x = self.dropout(x)

#         # Final output
#         x = self.conv_final(x)
#         return x.squeeze(0)  # Shape: (1, H, W)

# # **Initialize the Model**
# num_input_channels = combined_input.shape[0]  
# print(num_input_channels)
# fusion_model = FusionCNN(num_input_channels)

# # **Perform Feature Fusion**
# fused_feature_map = fusion_model(image_tensor)

# # **Print Model Summary**
# print(fusion_model)

# # **Step 9: Visualize Fused Feature Map & Ground Truth**
# fig, axes = plt.subplots(1, 2, figsize=(12, 5))

# # **Plot CNN-Fused Feature Map**
# axes[0].imshow(fused_feature_map.detach().numpy().squeeze(), cmap="viridis")
# axes[0].set_title("CNN-Based Fused Feature Map")
# axes[0].axis("off")

# # **Plot Ground Truth Label**
# axes[1].imshow(label, cmap="gray")
# axes[1].set_title("Ground Truth (Label)")
# axes[1].axis("off")

# plt.show()

# # **Step 10: Print Shape & Tensor**
# print("Fused Feature Map Shape:", fused_feature_map.shape)  # Should be (H, W)
# print("Fused Feature Map Tensor:", fused_feature_map)
# print("Ground Truth Shape:", label.shape)  # Should be (H, W)


In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
import numpy as np
import tifffile as tiff
import cv2
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from sklearn.model_selection import train_test_split


class FusionCNN(nn.Module):
    def __init__(self, num_input_channels):
        super(FusionCNN, self).__init__()

        self.conv1 = nn.Conv2d(in_channels=num_input_channels, out_channels=64, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(64)  
        self.conv2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(64)

        self.conv3 = nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(32)
        self.conv4 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, padding=1)
        self.bn4 = nn.BatchNorm2d(32)

        self.conv5 = nn.Conv2d(in_channels=32, out_channels=16, kernel_size=3, padding=1)
        self.bn5 = nn.BatchNorm2d(16)
        self.conv6 = nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, padding=1)
        self.bn6 = nn.BatchNorm2d(16)

        self.conv_final = nn.Conv2d(in_channels=16, out_channels=1, kernel_size=3, padding=1)
        self.dropout = nn.Dropout(0.3)  

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = self.dropout(x)  

        x = F.relu(self.bn3(self.conv3(x)))
        x = F.relu(self.bn4(self.conv4(x)))
        x = self.dropout(x)

        x = F.relu(self.bn5(self.conv5(x)))
        x = F.relu(self.bn6(self.conv6(x)))
        x = self.dropout(x)

        x = self.conv_final(x)
        return x.squeeze(0)  


class FusionDataset(Dataset):
    def __init__(self, img_paths, label_paths, fusion_model, num_lulc_classes, device):
        self.img_paths = img_paths
        self.label_paths = label_paths
        self.fusion_model = fusion_model.eval()  
        self.num_lulc_classes = num_lulc_classes
        self.device = device

    def normalize_channel(self, channel):
        return (channel - np.min(channel)) / (np.max(channel) - np.min(channel) + 1e-6)

    def __getitem__(self, idx):
        img = tiff.imread(self.img_paths[idx]).astype(np.float32)  
        img = np.transpose(img, (2, 0, 1))  

        label = Image.open(self.label_paths[idx]).convert("L")  
        label = np.array(label, dtype=np.float32)  

        for i in [0, 1, 2, 3, 5]:  
            img[i] = self.normalize_channel(img[i])

        target_size = (img.shape[1], img.shape[2])  
        img[3] = cv2.resize(img[3], target_size, interpolation=cv2.INTER_CUBIC)  
        img[5] = cv2.resize(img[5], target_size, interpolation=cv2.INTER_CUBIC)  

        lulc_one_hot = np.zeros((self.num_lulc_classes, img.shape[1], img.shape[2]))
        for class_value in range(self.num_lulc_classes):
            lulc_one_hot[class_value] = (img[4] == class_value).astype(np.float32)

        continuous_channels = np.array([img[i] for i in [0, 1, 2, 3, 5]])  
        combined_input = np.concatenate((continuous_channels, lulc_one_hot), axis=0)  

        img_tensor = torch.tensor(combined_input, dtype=torch.float32).unsqueeze(0).to(self.device)  

        with torch.no_grad():
            fused_feature_map = self.fusion_model(img_tensor)  

        label_tensor = torch.tensor(label, dtype=torch.float32).unsqueeze(0).to(self.device)  

        return fused_feature_map, label_tensor

    def __len__(self):
        return len(self.img_paths)


device = torch.device("cuda")
print(f"Using device: {device}")

img_dir = "/kaggle/input/images/images/images"
label_dir = "/kaggle/input/images/labels/labels"

img_list = sorted([f for f in os.listdir(img_dir) if f.endswith(".tif")])
label_list = sorted([f for f in os.listdir(label_dir) if f.endswith(".png")])

img_paths = [os.path.join(img_dir, f) for f in img_list if f.replace(".tif", ".png") in label_list]
label_paths = [os.path.join(label_dir, f.replace(".tif", ".png")) for f in img_list if f.replace(".tif", ".png") in label_list]

train_img_paths, test_img_paths, train_label_paths, test_label_paths = train_test_split(
    img_paths, label_paths, test_size=0.2, random_state=42)

lulc_img = tiff.imread(train_img_paths[0])[4]  
num_lulc_classes = len(np.unique(lulc_img))  

num_input_channels = 5 + num_lulc_classes  
fusion_model = FusionCNN(num_input_channels).to(device)

train_dataset = FusionDataset(train_img_paths, train_label_paths, fusion_model, num_lulc_classes, device)
test_dataset = FusionDataset(test_img_paths, test_label_paths, fusion_model, num_lulc_classes, device)

train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False)

fused_features, labels = next(iter(train_loader))
print("Fused Feature Map Shape:", fused_features.shape)  
print("Label Shape:", labels.shape)  


Using device: cuda
Fused Feature Map Shape: torch.Size([4, 1, 512, 512])
Label Shape: torch.Size([4, 1, 512, 512])


In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class ConvBlock(nn.Module):
    def __init__(self, ch_in, ch_out):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True),
            nn.BatchNorm2d(ch_out),
            nn.ReLU(inplace=True),
            nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1, bias=True),
            nn.BatchNorm2d(ch_out),
            nn.ReLU(inplace=True)
        )
    
    def forward(self, x):
        return self.conv(x)

class UpConvBlock(nn.Module):
    def __init__(self, ch_in, ch_out):
        super().__init__()
        self.up = nn.Sequential(
            nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True),
            nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True),
            nn.BatchNorm2d(ch_out),
            nn.ReLU(inplace=True)
        )
    
    def forward(self, x):
        return self.up(x)

class AttentionBlock(nn.Module):
    def __init__(self, f_g, f_l, f_int):
        super().__init__()
        self.w_g = nn.Sequential(
            nn.Conv2d(f_g, f_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(f_int)
        )
        self.w_x = nn.Sequential(
            nn.Conv2d(f_l, f_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(f_int)
        )
        self.psi = nn.Sequential(
            nn.Conv2d(f_int, 1, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(1),
            nn.Sigmoid()
        )
        self.relu = nn.ReLU(inplace=True)
    
    def forward(self, g, x):
        g1 = self.w_g(g)
        x1 = self.w_x(x)
        psi = self.relu(g1 + x1)
        psi = self.psi(psi)
        return x * psi

class AttentionUNet(nn.Module):
    def __init__(self, in_channels=1, out_channels=1):
        super().__init__() 
        
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
        
        self.conv1 = ConvBlock(ch_in=in_channels, ch_out=64)  # Updated to 11 channels
        self.conv2 = ConvBlock(ch_in=64, ch_out=128)
        self.conv3 = ConvBlock(ch_in=128, ch_out=256)
        self.conv4 = ConvBlock(ch_in=256, ch_out=512)
        self.conv5 = ConvBlock(ch_in=512, ch_out=1024)
        
        self.up5 = UpConvBlock(ch_in=1024, ch_out=512)
        self.att5 = AttentionBlock(f_g=512, f_l=512, f_int=256)
        self.upconv5 = ConvBlock(ch_in=1024, ch_out=512)
        
        self.up4 = UpConvBlock(ch_in=512, ch_out=256)
        self.att4 = AttentionBlock(f_g=256, f_l=256, f_int=128)
        self.upconv4 = ConvBlock(ch_in=512, ch_out=256)
        
        self.up3 = UpConvBlock(ch_in=256, ch_out=128)
        self.att3 = AttentionBlock(f_g=128, f_l=128, f_int=64)
        self.upconv3 = ConvBlock(ch_in=256, ch_out=128)
        
        self.up2 = UpConvBlock(ch_in=128, ch_out=64)
        self.att2 = AttentionBlock(f_g=64, f_l=64, f_int=32)
        self.upconv2 = ConvBlock(ch_in=128, ch_out=64)
        
        self.conv_1x1 = nn.Conv2d(64, out_channels, kernel_size=1, stride=1, padding=0)
        
    def forward(self, x):
        # Encoder
        x1 = self.conv1(x)
        x2 = self.conv2(self.maxpool(x1))
        x3 = self.conv3(self.maxpool(x2))
        x4 = self.conv4(self.maxpool(x3))
        x5 = self.conv5(self.maxpool(x4))
        
        # Decoder + Attention
        d5 = self.upconv5(torch.cat((self.att5(self.up5(x5), x4), self.up5(x5)), dim=1))
        d4 = self.upconv4(torch.cat((self.att4(self.up4(d5), x3), self.up4(d5)), dim=1))
        d3 = self.upconv3(torch.cat((self.att3(self.up3(d4), x2), self.up3(d4)), dim=1))
        d2 = self.upconv2(torch.cat((self.att2(self.up2(d3), x1), self.up2(d3)), dim=1))
        
        d1 = self.conv_1x1(d2)
        
        return torch.sigmoid(d1)  # Output probability map

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from tqdm import tqdm
import numpy as np

device = torch.device("cuda")

attention_unet = AttentionUNet(in_channels=1, out_channels=1)  # Adjust input channels for SAR, DEM, LULC, etc.
attention_unet.to(device)


def dice_coef_metric(inputs, target):
    intersection = 2.0 * (target * inputs).sum()
    union = target.sum() + inputs.sum()
    if target.sum() == 0 and inputs.sum() == 0:
        return 1.0  # If both are empty, perfect match
    return intersection / union

class DiceLoss(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(DiceLoss, self).__init__()

    def forward(self, inputs, targets, smooth=1.0):
        # Apply sigmoid activation
        inputs = torch.sigmoid(inputs)       
        
        # Flatten tensors
        inputs = inputs.view(-1)
        targets = targets.view(-1)
        
        intersection = (inputs * targets).sum()
        dice = (2. * intersection + smooth) / (inputs.sum() + targets.sum() + smooth)
        
        return 1 - dice  # Dice loss is (1 - dice coefficient)

def train_model(model_name, model, train_loader, val_loader, train_loss, optimizer, lr_scheduler=None, num_epochs=50):
    print(f"[INFO] Training Model: {model_name}")
    
    loss_history = []
    train_history = []
    val_history = []
    
    bce_loss_fn = nn.BCEWithLogitsLoss()  # Binary Cross-Entropy loss
    
    for epoch in range(num_epochs):
        model.train()
        
        epoch_losses = []
        epoch_dice = []
        
        for i_step, (data, target) in enumerate(tqdm(train_loader)):
            data, target = data.to(device), target.to(device)
            
            outputs = model(data)
            
            # Convert outputs to binary mask using threshold
            out_cut = torch.sigmoid(outputs)  # Apply sigmoid activation
            out_cut = torch.where(out_cut < 0.5, torch.tensor(0.0, device=device), torch.tensor(1.0, device=device))
            
            # Compute Dice coefficient
            train_dice = dice_coef_metric(out_cut, target)
            
            # Compute BCE + Dice Loss
            bce_loss = bce_loss_fn(outputs, target)
            dice_loss = train_loss(outputs, target)
            loss = bce_loss + dice_loss  # Combined loss
            
            epoch_losses.append(loss.item())
            epoch_dice.append(train_dice)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        
        val_dice = compute_iou(model, val_loader)
        
        # Learning rate scheduling (if applicable)
        if lr_scheduler:
            lr_scheduler.step()
        
        loss_history.append(np.mean(epoch_losses))
        train_history.append(np.mean([x.cpu().numpy() for x in epoch_dice]))
        val_history.append(val_dice)
        
        print(f"Epoch [{epoch+1}/{num_epochs}]")
        print(f"Train Loss: {np.mean(epoch_losses):.4f}, Train Dice: {np.mean([x.cpu().numpy() for x in epoch_dice]):.4f}, Val Dice: {val_dice:.4f}")

    
    return loss_history, train_history, val_history

def compute_iou(model, loader, threshold=0.5):
    total_iou = 0
    count = 0
    
    with torch.no_grad():
        model.eval()

        for i_step, (data, target) in enumerate(loader):
            data, target = data.to(device), target.to(device)
            
            outputs = model(data)
            out_cut = torch.sigmoid(outputs)
            out_cut = torch.where(out_cut < threshold, torch.tensor(0.0, device=device), torch.tensor(1.0, device=device))
            
            iou = dice_coef_metric(out_cut, target)
            total_iou += iou
            count += 1

    return total_iou / count if count > 0 else 0

# Define optimizer
opt = torch.optim.Adamax(attention_unet.parameters(), lr=1e-3)
# Train Model
num_epochs = 1
aun_lh, aun_th, aun_vh = train_model("Attepresntion UNet", attention_unet, train_loader, test_loader, DiceLoss(), opt, None, num_epochs)

[INFO] Training Model: Attepresntion UNet


100%|██████████| 326/326 [1:04:25<00:00, 11.86s/it]


In [None]:
torch.save(attention_unet.state_dict(), "flood_segmentation_model.pth")
print("Model saved successfully!")

In [None]:
model = AttentionUNet(1, 1)  # Use the correct model class
model.load_state_dict(torch.load("flood_segmentation_model.pth"))
model.to(device)
model.eval()  # Set to evaluation mode

In [None]:
import tifffile as tiff
import torch
import numpy as np
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import cv2

# Define device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load FusionCNN and AttentionUNet models
fusion_model = FusionCNN(num_input_channels=6 + num_lulc_classes).to(device)
segmentation_model = AttentionUNet(in_channels=1).to(device)

# Set models to evaluation mode
fusion_model.eval()
segmentation_model.eval()

# Define image path
image_path = "/kaggle/input/images/images/images/4.tif"

# Load and preprocess the image using tifffile
image = tiff.imread(image_path).astype(np.float32)  # Shape: (H, W, C)
image = np.transpose(image, (2, 0, 1))  # Convert to (C, H, W)

# Normalize specific bands
def normalize_channel(channel):
    return (channel - np.min(channel)) / (np.max(channel) - np.min(channel) + 1e-6)

for i in [0, 1, 2, 3, 5]:  # Normalize specific bands
    image[i] = normalize_channel(image[i])

# Resize LULC and DEM bands
target_size = (image.shape[1], image.shape[2])  
image[3] = cv2.resize(image[3], target_size, interpolation=cv2.INTER_CUBIC)  # DEM
image[5] = cv2.resize(image[5], target_size, interpolation=cv2.INTER_CUBIC)  # LULC

# One-hot encode LULC
lulc_one_hot = np.zeros((num_lulc_classes, image.shape[1], image.shape[2]))
for class_value in range(num_lulc_classes):
    lulc_one_hot[class_value] = (image[4] == class_value).astype(np.float32)

# Combine continuous bands and one-hot LULC
combined_input = np.concatenate((image[[0, 1, 2, 3, 5]], lulc_one_hot), axis=0)

# Convert to tensor and move to GPU
image_tensor = torch.tensor(combined_input, dtype=torch.float32).unsqueeze(0).to(device)

# Run FusionCNN
with torch.no_grad():
    fused_feature_map = fusion_model(image_tensor)  # Shape: (1, H, W)

# Prepare for AttentionUNet
fused_feature_map = fused_feature_map.unsqueeze(0)  # Add channel dimension -> (1, 1, H, W)

# Run segmentation model
with torch.no_grad():
    output = segmentation_model(fused_feature_map)  # Shape: (1, 1, H, W)

# Convert output to binary mask
pred_mask = output.sigmoid().cpu().numpy().squeeze()
pred_mask = (pred_mask > 0.6).astype(np.uint8)  # Binary (0 or 1)

# Load ground truth (assuming it's a PNG image)
label_path = "/kaggle/input/images/labels/labels/4.png"
gt_mask = plt.imread(label_path)

# Normalize GT Mask to Binary (0 or 1)
gt_mask = (gt_mask > 0.001).astype(np.uint8)

# Visualize prediction vs ground truth
fig, ax = plt.subplots(1, 2, figsize=(10, 5))

ax[0].imshow(gt_mask, cmap="gray")
ax[0].set_title("Ground Truth")
ax[0].axis("off")

ax[1].imshow(pred_mask, cmap="gray")
ax[1].set_title("Predicted Mask")
ax[1].axis("off")

plt.show()