In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import numpy as np
import matplotlib.pyplot as plt
from torchvision import transforms
import torch.nn.functional as F
from tqdm import tqdm
import os
import cv2
from PIL import Image

# U-Net model definition
class UNet(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UNet, self).__init__()

        # Encoder
        self.encoder1 = nn.Conv2d(in_channels, 64, kernel_size=3, padding=1)
        self.encoder2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.encoder3 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
        self.encoder4 = nn.Conv2d(256, 512, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2)

        # Decoder
        self.upconv4 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)  # Up-conv for decoder
        self.upconv3 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)  # Up-conv for decoder
        self.upconv2 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)   # Up-conv for decoder
        self.upconv1 = nn.ConvTranspose2d(64, 64, kernel_size=2, stride=2)    # Up-conv for decoder

        # Additional convolutional layers after concatenation to reduce channel sizes
        self.conv4 = nn.Conv2d(768, 256, kernel_size=3, padding=1)  # After upconv4
        self.conv3 = nn.Conv2d(384, 128, kernel_size=3, padding=1)  # After upconv3
        self.conv2 = nn.Conv2d(192, 64, kernel_size=3, padding=1)   # After upconv2
        self.conv1 = nn.Conv2d(128, 64, kernel_size=3, padding=1)   # After upconv1

        # Final layer
        self.final = nn.Conv2d(64, out_channels, kernel_size=1)

    def forward(self, x):
        # Encoder path
        enc1 = self.encoder1(x)
        enc2 = self.encoder2(self.pool(enc1))
        enc3 = self.encoder3(self.pool(enc2))
        enc4 = self.encoder4(self.pool(enc3))
        
        # Bottleneck (no pooling)
        mid = self.pool(enc4)

        # Decoder path (upsampling and concatenation)
        dec4 = self.upconv4(mid)
        dec4_resized = F.interpolate(dec4, size=enc4.shape[2:], mode='bilinear', align_corners=False)
        dec4 = torch.cat([dec4_resized, enc4], dim=1)
        #dec4 = torch.cat([dec4, enc4], dim=1)  # Concatenate encoder and decoder
        dec4 = self.conv4(dec4)  # Reduce channels after concatenation


        dec3 = self.upconv3(dec4)
        dec3_resized = F.interpolate(dec3, size=enc3.shape[2:], mode='bilinear', align_corners=False)
        dec3 = torch.cat([dec3_resized, enc3], dim=1)
        #dec3 = torch.cat([dec3, enc3], dim=1)  # Concatenate encoder and decoder
        dec3 = self.conv3(dec3)  # Reduce channels after concatenation

        dec2 = self.upconv2(dec3)
        dec2_resized = F.interpolate(dec2, size=enc2.shape[2:], mode='bilinear', align_corners=False)
        dec2 = torch.cat([dec2_resized, enc2], dim=1)
        #dec2 = torch.cat([dec2, enc2], dim=1)  # Concatenate encoder and decoder
        dec2 = self.conv2(dec2)  # Reduce channels after concatenation

        dec1 = self.upconv1(dec2)
        dec1 = torch.cat([dec1, enc1], dim=1)  # Concatenate encoder and decoder
        dec1 = self.conv1(dec1)  # Reduce channels after concatenation

        # Final output
        return self.final(dec1)

# Custom dataset class
class GlacierDataset(Dataset):
    def __init__(self, images_dir, masks_dir, transform=None):
        self.images_dir = images_dir
        self.masks_dir = masks_dir
        self.transform = transform
        self.image_paths = sorted(os.listdir(images_dir))
        self.mask_paths = sorted(os.listdir(masks_dir))

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image = Image.open(os.path.join(self.images_dir, self.image_paths[idx]))
        mask = Image.open(os.path.join(self.masks_dir, self.mask_paths[idx]))

        # Convert to grayscale for segmentation (single channel)
        mask = mask.convert('L')  # Convert the mask to grayscale (1 channel)
        
        # Apply transformation if provided
        if self.transform:
            image = self.transform(image)
            mask = self.transform(mask)

        # Ensure the mask has the correct shape [1, height, width]
        mask = mask.unsqueeze(0)  # Add a channel dimension: [1, height, width]

        # Squeeze out any extra dimensions (if they exist)
        mask = mask.squeeze(0)  # This will remove any unnecessary dimensions

        return image, mask


# Helper function to calculate IOU
# def calculate_iou(pred_mask, true_mask):
#     # Resize pred_mask to match true_mask shape
#     pred_mask_resized = np.resize(pred_mask, true_mask.shape)
#     #pred_mask_resized = cv2.resize(pred_mask, (true_mask.shape[2], true_mask.shape[1]))

#     if pred_mask.shape[1] < true_mask.shape[1]:
#         pad_height = true_mask.shape[1] - pred_mask.shape[1]
#         pred_mask = np.pad(pred_mask, ((0, 0), (0, pad_height), (0, 0)), mode='constant', constant_values=0)

#     if pred_mask.shape[2] < true_mask.shape[2]:
#         pad_width = true_mask.shape[2] - pred_mask.shape[2]
#         pred_mask = np.pad(pred_mask, ((0, 0), (0, 0), (0, pad_width)), mode='constant', constant_values=0)

#     intersection = np.logical_and(pred_mask, true_mask)
#     union = np.logical_or(pred_mask, true_mask)
#     return np.sum(intersection) / np.sum(union)





# def calculate_iou(pred_mask, true_mask):
#     # Resize pred_mask to match true_mask shape using OpenCV resize (bilinear interpolation)
#     pred_mask_resized = cv2.resize(pred_mask[0], (true_mask.shape[2], true_mask.shape[1]), interpolation=cv2.INTER_LINEAR)
    
#     # Ensure the pred_mask is a binary mask (0 or 1)
#     pred_mask_resized = (pred_mask_resized > 0.5).astype(np.uint8)
    
#     # Perform logical operations
#     intersection = np.logical_and(pred_mask_resized, true_mask[0])  # Assumes both are 3D (batch, height, width)
#     union = np.logical_or(pred_mask_resized, true_mask[0])
    
#     # Calculate IoU
#     iou = np.sum(intersection) / np.sum(union)
#     return iou


def calculate_iou(pred_mask, true_mask):
    # Ensure both masks are binary
    pred_mask = (pred_mask > 0.5).astype(np.uint8)
    true_mask = (true_mask > 0.5).astype(np.uint8)

    # Resize pred_mask to match true_mask shape (height, width)
    pred_mask_resized = cv2.resize(pred_mask, (true_mask.shape[1], true_mask.shape[0]), interpolation=cv2.INTER_LINEAR)

    # Compute Intersection and Union
    intersection = np.logical_and(pred_mask_resized, true_mask)
    union = np.logical_or(pred_mask_resized, true_mask)

    # Calculate IoU
    iou = np.sum(intersection) / np.sum(union)
    return iou





# Training function
def train_model(model, train_loader, device, num_epochs=2, lr=0.001):
    criterion = nn.BCEWithLogitsLoss()  # Binary cross entropy loss for segmentation
    optimizer = optim.Adam(model.parameters(), lr=lr)
    model.train()

    for epoch in range(num_epochs):
        running_loss = 0.0
        for images, masks in tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}'):
            images, masks = images.to(device), masks.to(device)

            # Ensure masks have the correct shape [batch_size, 1, height, width]
            masks = masks.unsqueeze(1)  # Ensure mask has 1 channel
            if masks.dim() > 4:
                masks = masks.squeeze(1)  # Remove extra dimensions if any

            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, masks.float())  # Ensure mask is float

            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {running_loss/len(train_loader)}")


# Visualization and Calculation Function
# def visualize_and_calculate(before_image_path, after_image_path, model, device):
#     before_image = Image.open(before_image_path)
#     after_image = Image.open(after_image_path)

#     # Convert images to RGB (in case they are in RGBA or another mode)
#     before_image = before_image.convert('RGB')
#     after_image = after_image.convert('RGB')

#     # Preprocessing and transforming images to tensor
#     transform = transforms.Compose([
#         transforms.ToTensor(),
#         transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
#     ])

#     before_image = transform(before_image).unsqueeze(0).to(device)
#     after_image = transform(after_image).unsqueeze(0).to(device)

#     model.eval()
#     with torch.no_grad():
#         # Forward pass for before and after images
#         before_output = model(before_image)
#         after_output = model(after_image)
        
#         before_mask = torch.sigmoid(before_output).squeeze(0).cpu().numpy()
#         after_mask = torch.sigmoid(after_output).squeeze(0).cpu().numpy()

#         before_mask = (before_mask > 0.5).astype(np.uint8)
#         after_mask = (after_mask > 0.5).astype(np.uint8)

#         # Calculate IOU
#         iou_before_after = calculate_iou(before_mask, after_mask)
#         print(f"IoU between Before and After Masks: {iou_before_after:.4f}")

#         # Visualize before and after masks
#         fig, ax = plt.subplots(1, 3, figsize=(12, 6))
#         ax[0].imshow(before_mask, cmap='gray')
#         ax[0].set_title("Before Mask")
#         ax[1].imshow(after_mask, cmap='gray')
#         ax[1].set_title("After Mask")
#         ax[2].imshow(np.abs(before_mask - after_mask), cmap='hot')
#         ax[2].set_title("Change Mask")
#         plt.show()


# def visualize_and_calculate(before_image_path, after_image_path, model, device):
#     before_image = Image.open(before_image_path)
#     after_image = Image.open(after_image_path)

#     # Convert images to RGB (in case they are in RGBA or another mode)
#     before_image = before_image.convert('RGB')
#     after_image = after_image.convert('RGB')

#     # Preprocessing and transforming images to tensor
#     transform = transforms.Compose([
#         transforms.ToTensor(),
#         transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
#     ])

#     before_image = transform(before_image).unsqueeze(0).to(device)
#     after_image = transform(after_image).unsqueeze(0).to(device)

#     model.eval()
#     with torch.no_grad():
#         # Forward pass for before and after images
#         before_output = model(before_image)
#         after_output = model(after_image)
        
#         # Apply sigmoid and squeeze batch dimension
#         before_mask = torch.sigmoid(before_output).squeeze(0).cpu().numpy()
#         after_mask = torch.sigmoid(after_output).squeeze(0).cpu().numpy()

#         # Thresholding to create binary masks
#         before_mask = (before_mask > 0.5).astype(np.uint8)
#         after_mask = (after_mask > 0.5).astype(np.uint8)

#         # Calculate IOU
#         iou_before_after = calculate_iou(before_mask, after_mask)
#         print(f"IoU between Before and After Masks: {iou_before_after:.4f}")

#         # Visualize before and after masks
#         fig, ax = plt.subplots(1, 3, figsize=(12, 6))
#         ax[0].imshow(before_mask, cmap='gray')
#         ax[0].set_title("Before Mask")
#         ax[1].imshow(after_mask, cmap='gray')
#         ax[1].set_title("After Mask")
#         ax[2].imshow(np.abs(before_mask - after_mask), cmap='hot')
#         ax[2].set_title("Change Mask")
#         plt.show()


# def visualize_and_calculate(before_image_path, after_image_path, model, device):
#     before_image = Image.open(before_image_path)
#     after_image = Image.open(after_image_path)

#     # Convert images to RGB (in case they are in RGBA or another mode)
#     before_image = before_image.convert('RGB')
#     after_image = after_image.convert('RGB')

#     # Preprocessing and transforming images to tensor
#     transform = transforms.Compose([
#         transforms.ToTensor(),
#         transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
#     ])

#     before_image = transform(before_image).unsqueeze(0).to(device)
#     after_image = transform(after_image).unsqueeze(0).to(device)

#     model.eval()
#     with torch.no_grad():
#         # Forward pass for before and after images
#         before_output = model(before_image)
#         after_output = model(after_image)
        
#         # Apply sigmoid and squeeze batch dimension
#         before_mask = torch.sigmoid(before_output).squeeze(0).cpu().numpy()
#         after_mask = torch.sigmoid(after_output).squeeze(0).cpu().numpy()

#         # Thresholding to create binary masks
#         before_mask = (before_mask > 0.5).astype(np.uint8)
#         after_mask = (after_mask > 0.5).astype(np.uint8)

#         # Calculate IOU
#         iou_before_after = calculate_iou(before_mask, after_mask)
#         print(f"IoU between Before and After Masks: {iou_before_after:.4f}")

#         # Visualize before and after masks
#         fig, ax = plt.subplots(1, 3, figsize=(12, 6))
#         ax[0].imshow(before_mask, cmap='gray')
#         ax[0].set_title("Before Mask")
#         ax[1].imshow(after_mask, cmap='gray')
#         ax[1].set_title("After Mask")
#         ax[2].imshow(np.abs(before_mask - after_mask), cmap='hot')
#         ax[2].set_title("Change Mask")
#         plt.show()


# def visualize_and_calculate(before_image_path, after_image_path, model, device):
#     before_image = Image.open(before_image_path)
#     after_image = Image.open(after_image_path)

#     # Convert images to RGB (in case they are in RGBA or another mode)
#     before_image = before_image.convert('RGB')
#     after_image = after_image.convert('RGB')

#     # Preprocessing and transforming images to tensor
#     transform = transforms.Compose([
#         transforms.ToTensor(),
#         transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
#     ])

#     before_image = transform(before_image).unsqueeze(0).to(device)
#     after_image = transform(after_image).unsqueeze(0).to(device)

#     model.eval()
#     with torch.no_grad():
#         # Forward pass for before and after images
#         before_output = model(before_image)
#         after_output = model(after_image)
        
#         # Apply sigmoid and squeeze batch dimension
#         before_mask = torch.sigmoid(before_output).squeeze(0).cpu().numpy()
#         after_mask = torch.sigmoid(after_output).squeeze(0).cpu().numpy()

#         # Ensure that before_mask and after_mask are 2D (height, width)
#         if before_mask.ndim == 3:
#             before_mask = before_mask[0]  # If there's an extra channel dimension
#         if after_mask.ndim == 3:
#             after_mask = after_mask[0]  # If there's an extra channel dimension

#         # Thresholding to create binary masks
#         before_mask = (before_mask > 0.5).astype(np.uint8)
#         after_mask = (after_mask > 0.5).astype(np.uint8)

#         # Calculate IOU
#         iou_before_after = calculate_iou(before_mask, after_mask)
#         print(f"IoU between Before and After Masks: {iou_before_after:.4f}")

#         # Visualize before and after masks
#         fig, ax = plt.subplots(1, 3, figsize=(12, 6))
#         ax[0].imshow(before_mask, cmap='gray')
#         ax[0].set_title("Before Mask")
#         ax[1].imshow(after_mask, cmap='gray')
#         ax[1].set_title("After Mask")
#         ax[2].imshow(np.abs(before_mask - after_mask), cmap='hot')
#         ax[2].set_title("Change Mask")
#         plt.show()




# def visualize_and_calculate(before_image_path, after_image_path, model, device):
#     before_image = Image.open(before_image_path)
#     after_image = Image.open(after_image_path)

#     # Convert images to RGB (in case they are in RGBA or another mode)
#     before_image = before_image.convert('RGB')
#     after_image = after_image.convert('RGB')

#     # Preprocessing and transforming images to tensor
#     transform = transforms.Compose([
#         transforms.ToTensor(),
#         transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
#     ])

#     before_image = transform(before_image).unsqueeze(0).to(device)
#     after_image = transform(after_image).unsqueeze(0).to(device)

#     model.eval()
#     with torch.no_grad():
#         # Forward pass for before and after images
#         before_output = model(before_image)
#         after_output = model(after_image)
        
#         # Apply sigmoid and squeeze batch dimension
#         before_mask = torch.sigmoid(before_output).squeeze(0).cpu().numpy()
#         after_mask = torch.sigmoid(after_output).squeeze(0).cpu().numpy()

#         # Ensure both masks are binary
#         before_mask = (before_mask > 0.5).astype(np.uint8)
#         after_mask = (after_mask > 0.5).astype(np.uint8)

#         # Resize after_mask to match before_mask shape
#         if before_mask.shape != after_mask.shape:
#             after_mask = cv2.resize(after_mask, (before_mask.shape[1], before_mask.shape[0]), interpolation=cv2.INTER_LINEAR)

#         # Calculate IOU
#         iou_before_after = calculate_iou(before_mask, after_mask)
#         print(f"IoU between Before and After Masks: {iou_before_after:.4f}")

#         # Visualize before and after masks
#         fig, ax = plt.subplots(1, 3, figsize=(12, 6))
#         ax[0].imshow(before_mask, cmap='gray')
#         ax[0].set_title("Before Mask")
#         ax[1].imshow(after_mask, cmap='gray')
#         ax[1].set_title("After Mask")
#         ax[2].imshow(np.abs(before_mask - after_mask), cmap='hot')
#         ax[2].set_title("Change Mask")
#         plt.show()




# def visualize_and_calculate(before_image_path, after_image_path, model, device):
#     before_image = Image.open(before_image_path)
#     after_image = Image.open(after_image_path)

#     # Convert images to RGB (in case they are in RGBA or another mode)
#     before_image = before_image.convert('RGB')
#     after_image = after_image.convert('RGB')

#     # Preprocessing and transforming images to tensor
#     transform = transforms.Compose([
#         transforms.ToTensor(),
#         transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
#     ])

#     before_image = transform(before_image).unsqueeze(0).to(device)
#     after_image = transform(after_image).unsqueeze(0).to(device)

#     model.eval()
#     with torch.no_grad():
#         # Forward pass for before and after images
#         before_output = model(before_image)
#         after_output = model(after_image)
        
#         # Apply sigmoid and squeeze batch dimension
#         before_mask = torch.sigmoid(before_output).squeeze(0).cpu().numpy()
#         after_mask = torch.sigmoid(after_output).squeeze(0).cpu().numpy()

#         # Ensure both masks are binary
#         before_mask = (before_mask > 0.5).astype(np.uint8)
#         after_mask = (after_mask > 0.5).astype(np.uint8)

#         # Check and print shapes of the masks
#         print("Before Mask Shape:", before_mask.shape)
#         print("After Mask Shape:", after_mask.shape)

#         # Ensure both masks have valid shapes before attempting resize
#         if before_mask.shape != after_mask.shape:
#             target_shape = (before_mask.shape[1], before_mask.shape[0])
#             if target_shape[0] > 0 and target_shape[1] > 0:
#                 print(f"Resizing after_mask to: {target_shape}")
#                 after_mask = cv2.resize(after_mask, target_shape, interpolation=cv2.INTER_LINEAR)
#             else:
#                 print("Invalid target shape. Skipping resize.")
#                 return  # Exit function if invalid target shape

#         # Calculate IOU
#         iou_before_after = calculate_iou(before_mask, after_mask)
#         print(f"IoU between Before and After Masks: {iou_before_after:.4f}")

#         # Visualize before and after masks
#         fig, ax = plt.subplots(1, 3, figsize=(12, 6))
#         ax[0].imshow(before_mask, cmap='gray')
#         ax[0].set_title("Before Mask")
#         ax[1].imshow(after_mask, cmap='gray')
#         ax[1].set_title("After Mask")
#         ax[2].imshow(np.abs(before_mask - after_mask), cmap='hot')
#         ax[2].set_title("Change Mask")
#         plt.show()



# def visualize_and_calculate(before_image_path, after_image_path, model, device):
#     before_image = Image.open(before_image_path)
#     after_image = Image.open(after_image_path)

#     # Convert images to RGB (in case they are in RGBA or another mode)
#     before_image = before_image.convert('RGB')
#     after_image = after_image.convert('RGB')

#     # Preprocessing and transforming images to tensor
#     transform = transforms.Compose([
#         transforms.ToTensor(),
#         transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
#     ])

#     before_image = transform(before_image).unsqueeze(0).to(device)
#     after_image = transform(after_image).unsqueeze(0).to(device)

#     model.eval()
#     with torch.no_grad():
#         # Forward pass for before and after images
#         before_output = model(before_image)
#         after_output = model(after_image)
        
#         # Apply sigmoid and squeeze batch dimension
#         before_mask = torch.sigmoid(before_output).squeeze(0).cpu().numpy()
#         after_mask = torch.sigmoid(after_output).squeeze(0).cpu().numpy()

#         # Thresholding to create binary masks
#         before_mask = (before_mask > 0.5).astype(np.uint8)
#         after_mask = (after_mask > 0.5).astype(np.uint8)

#         # Ensure both masks are the same shape
#         if before_mask.shape != after_mask.shape:
#             print(f"Before Mask Shape: {before_mask.shape}")
#             print(f"After Mask Shape: {after_mask.shape}")
#             # Resize after_mask to match before_mask shape
#             after_mask = cv2.resize(after_mask, (before_mask.shape[1], before_mask.shape[0]), interpolation=cv2.INTER_LINEAR)
#             print(f"Resized After Mask Shape: {after_mask.shape}")

#         # Calculate IOU
#         iou_before_after = calculate_iou(before_mask, after_mask)
#         print(f"IoU between Before and After Masks: {iou_before_after:.4f}")

#         # Visualize before and after masks
#         fig, ax = plt.subplots(1, 3, figsize=(12, 6))
#         ax[0].imshow(before_mask, cmap='gray')
#         ax[0].set_title("Before Mask")
#         ax[1].imshow(after_mask, cmap='gray')
#         ax[1].set_title("After Mask")
#         ax[2].imshow(np.abs(before_mask - after_mask), cmap='hot')
#         ax[2].set_title("Change Mask")
#         plt.show()


# def visualize_and_calculate(before_image_path, after_image_path, model, device):
#     before_image = Image.open(before_image_path)
#     after_image = Image.open(after_image_path)

#     # Convert images to RGB (in case they are in RGBA or another mode)
#     before_image = before_image.convert('RGB')
#     after_image = after_image.convert('RGB')

#     # Preprocessing and transforming images to tensor
#     transform = transforms.Compose([
#         transforms.ToTensor(),
#         transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
#     ])

#     before_image = transform(before_image).unsqueeze(0).to(device)
#     after_image = transform(after_image).unsqueeze(0).to(device)

#     model.eval()
#     with torch.no_grad():
#         # Forward pass for before and after images
#         before_output = model(before_image)
#         after_output = model(after_image)
        
#         # Apply sigmoid and squeeze batch dimension
#         before_mask = torch.sigmoid(before_output).squeeze(0).cpu().numpy()
#         after_mask = torch.sigmoid(after_output).squeeze(0).cpu().numpy()

#         # Thresholding to create binary masks
#         before_mask = (before_mask > 0.5).astype(np.uint8)
#         after_mask = (after_mask > 0.5).astype(np.uint8)

#         # Print out shapes for debugging
#         print(f"Before Mask Shape: {before_mask.shape}")
#         print(f"After Mask Shape: {after_mask.shape}")

#         # Ensure both masks are the same shape
#         if before_mask.shape != after_mask.shape:
#             print(f"Resizing needed, Before Mask: {before_mask.shape}, After Mask: {after_mask.shape}")
            
#             # Check if after_mask is not empty and before resizing
#             if after_mask.shape[0] > 0 and after_mask.shape[1] > 0:
#                 try:
#                     # Resize after_mask to match before_mask shape
#                     after_mask = cv2.resize(after_mask, (before_mask.shape[1], before_mask.shape[0]), interpolation=cv2.INTER_LINEAR)
#                     print(f"Resized After Mask Shape: {after_mask.shape}")
#                 except Exception as e:
#                     print(f"Error resizing after_mask: {e}")
#             else:
#                 print("After mask has an invalid shape. Cannot resize.")
        
#         # Ensure the final masks are not empty
#         if before_mask.shape[0] == 0 or before_mask.shape[1] == 0:
#             print("Error: Before mask has invalid shape.")
#             return
#         if after_mask.shape[0] == 0 or after_mask.shape[1] == 0:
#             print("Error: After mask has invalid shape.")
#             return

#         # Calculate IOU
#         iou_before_after = calculate_iou(before_mask, after_mask)
#         print(f"IoU between Before and After Masks: {iou_before_after:.4f}")

#         # Visualize before and after masks
#         fig, ax = plt.subplots(1, 3, figsize=(12, 6))
#         ax[0].imshow(before_mask, cmap='gray')
#         ax[0].set_title("Before Mask")
#         ax[1].imshow(after_mask, cmap='gray')
#         ax[1].set_title("After Mask")
#         ax[2].imshow(np.abs(before_mask - after_mask), cmap='hot')
#         ax[2].set_title("Change Mask")
#         plt.show()


# def visualize_and_calculate(before_image_path, after_image_path, model, device):
#     before_image = Image.open(before_image_path)
#     after_image = Image.open(after_image_path)

#     # Convert images to RGB (in case they are in RGBA or another mode)
#     before_image = before_image.convert('RGB')
#     after_image = after_image.convert('RGB')

#     # Preprocessing and transforming images to tensor
#     transform = transforms.Compose([
#         transforms.ToTensor(),
#         transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
#     ])

#     before_image = transform(before_image).unsqueeze(0).to(device)
#     after_image = transform(after_image).unsqueeze(0).to(device)

#     model.eval()
#     with torch.no_grad():
#         # Forward pass for before and after images
#         before_output = model(before_image)
#         after_output = model(after_image)
        
#         # Apply sigmoid and squeeze batch dimension
#         before_mask = torch.sigmoid(before_output).squeeze(0).cpu().numpy()
#         after_mask = torch.sigmoid(after_output).squeeze(0).cpu().numpy()

#         # Thresholding to create binary masks
#         before_mask = (before_mask > 0.5).astype(np.uint8)
#         after_mask = (after_mask > 0.5).astype(np.uint8)

#         # Print out shapes for debugging
#         print(f"Before Mask Shape: {before_mask.shape}")
#         print(f"After Mask Shape: {after_mask.shape}")

#         # Resize both masks to the same shape (use before_mask shape as target)
#         if before_mask.shape != after_mask.shape:
#             print(f"Resizing both masks to match the shape of the before_mask: {before_mask.shape}")
            
#             # Resize after_mask to match before_mask shape
#             after_mask = cv2.resize(after_mask, (before_mask.shape[1], before_mask.shape[0]), interpolation=cv2.INTER_LINEAR)
#             print(f"Resized After Mask Shape: {after_mask.shape}")
        
#         # Ensure the final masks are not empty
#         if before_mask.shape[0] == 0 or before_mask.shape[1] == 0:
#             print("Error: Before mask has invalid shape.")
#             return
#         if after_mask.shape[0] == 0 or after_mask.shape[1] == 0:
#             print("Error: After mask has invalid shape.")
#             return

#         # Calculate IOU
#         iou_before_after = calculate_iou(before_mask, after_mask)
#         print(f"IoU between Before and After Masks: {iou_before_after:.4f}")

#         # Visualize before and after masks
#         fig, ax = plt.subplots(1, 3, figsize=(12, 6))
#         ax[0].imshow(before_mask, cmap='gray')
#         ax[0].set_title("Before Mask")
#         ax[1].imshow(after_mask, cmap='gray')
#         ax[1].set_title("After Mask")
#         ax[2].imshow(np.abs(before_mask - after_mask), cmap='hot')
#         ax[2].set_title("Change Mask")
#         plt.show()



# def visualize_and_calculate(before_image_path, after_image_path, model, device):
#     before_image = Image.open(before_image_path)
#     after_image = Image.open(after_image_path)

#     # Convert images to RGB (in case they are in RGBA or another mode)
#     before_image = before_image.convert('RGB')
#     after_image = after_image.convert('RGB')

#     # Preprocessing and transforming images to tensor
#     transform = transforms.Compose([
#         transforms.ToTensor(),
#         transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
#     ])

#     before_image = transform(before_image).unsqueeze(0).to(device)
#     after_image = transform(after_image).unsqueeze(0).to(device)

#     model.eval()
#     with torch.no_grad():
#         # Forward pass for before and after images
#         before_output = model(before_image)
#         after_output = model(after_image)
        
#         # Apply sigmoid and squeeze batch dimension
#         before_mask = torch.sigmoid(before_output).squeeze(0).cpu().numpy()
#         after_mask = torch.sigmoid(after_output).squeeze(0).cpu().numpy()

#         # Thresholding to create binary masks
#         before_mask = (before_mask > 0.5).astype(np.uint8)
#         after_mask = (after_mask > 0.5).astype(np.uint8)

#         # Print out shapes for debugging
#         print(f"Before Mask Shape: {before_mask.shape}")
#         print(f"After Mask Shape: {after_mask.shape}")

#         # Check if both masks are non-zero in shape
#         if before_mask.shape[0] == 0 or before_mask.shape[1] == 0:
#             print("Error: Before mask has invalid shape.")
#             return
#         if after_mask.shape[0] == 0 or after_mask.shape[1] == 0:
#             print("Error: After mask has invalid shape.")
#             return

#         # Resize both masks to the same shape (use before_mask shape as target)
#         if before_mask.shape != after_mask.shape:
#             print(f"Resizing both masks to match the shape of the before_mask: {before_mask.shape}")
            
#             # Resize after_mask to match before_mask shape
#             try:
#                 after_mask = cv2.resize(after_mask, (before_mask.shape[1], before_mask.shape[0]), interpolation=cv2.INTER_LINEAR)
#                 print(f"Resized After Mask Shape: {after_mask.shape}")
#             except Exception as e:
#                 print(f"Error resizing after_mask: {e}")
#                 return
        
#         # Calculate IOU
#         iou_before_after = calculate_iou(before_mask, after_mask)
#         print(f"IoU between Before and After Masks: {iou_before_after:.4f}")

#         # Visualize before and after masks
#         fig, ax = plt.subplots(1, 3, figsize=(12, 6))
#         ax[0].imshow(before_mask, cmap='gray')
#         ax[0].set_title("Before Mask")
#         ax[1].imshow(after_mask, cmap='gray')
#         ax[1].set_title("After Mask")
#         ax[2].imshow(np.abs(before_mask - after_mask), cmap='hot')
#         ax[2].set_title("Change Mask")
#         plt.show()


def visualize_and_calculate(before_image_path, after_image_path, model, device):
    before_image = Image.open(before_image_path)
    after_image = Image.open(after_image_path)

    # Convert images to RGB (in case they are in RGBA or another mode)
    before_image = before_image.convert('RGB')
    after_image = after_image.convert('RGB')

    # Preprocessing and transforming images to tensor
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    before_image = transform(before_image).unsqueeze(0).to(device)
    after_image = transform(after_image).unsqueeze(0).to(device)

    model.eval()
    with torch.no_grad():
        # Forward pass for before and after images
        before_output = model(before_image)
        after_output = model(after_image)
        
        # Apply sigmoid and squeeze batch dimension
        before_mask = torch.sigmoid(before_output).squeeze(0).cpu().numpy()
        after_mask = torch.sigmoid(after_output).squeeze(0).cpu().numpy()

        # Thresholding to create binary masks
        before_mask = (before_mask > 0.5).astype(np.uint8)
        after_mask = (after_mask > 0.5).astype(np.uint8)

        # Print out shapes for debugging
        print(f"Before Mask Shape: {before_mask.shape}")
        print(f"After Mask Shape: {after_mask.shape}")

        # Ensure masks are non-empty and valid before resizing
        if before_mask.shape[0] == 0 or before_mask.shape[1] == 0:
            print("Error: Before mask has invalid shape.")
            return
        if after_mask.shape[0] == 0 or after_mask.shape[1] == 0:
            print("Error: After mask has invalid shape.")
            return

        # If shapes differ, resize them to the same shape
        if before_mask.shape != after_mask.shape:
            print(f"Resizing both masks to match the shape of the before_mask: {before_mask.shape}")
            
            # Ensure after_mask has valid shape for resizing
            if after_mask.shape[0] > 0 and after_mask.shape[1] > 0:
                try:
                    after_mask_resized = cv2.resize(after_mask, (before_mask.shape[1], before_mask.shape[0]), interpolation=cv2.INTER_LINEAR)
                    print(f"Resized After Mask Shape: {after_mask_resized.shape}")
                    after_mask = after_mask_resized  # Use resized mask
                except Exception as e:
                    print(f"Error resizing after_mask: {e}")
                    return
            else:
                print("Error: Invalid after_mask dimensions for resizing.")
                return
        
        # Calculate IOU
        iou_before_after = calculate_iou(before_mask, after_mask)
        print(f"IoU between Before and After Masks: {iou_before_after:.4f}")

        # Visualize before and after masks
        fig, ax = plt.subplots(1, 3, figsize=(12, 6))
        ax[0].imshow(before_mask, cmap='gray')
        ax[0].set_title("Before Mask")
        ax[1].imshow(after_mask, cmap='gray')
        ax[1].set_title("After Mask")
        ax[2].imshow(np.abs(before_mask - after_mask), cmap='hot')
        ax[2].set_title("Change Mask")
        plt.show()




# Main Script
if __name__ == "__main__":
    # Device configuration
    device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
    
    # Paths to the dataset (replace with your actual paths)
    images_dir = 'data/images'  # Your images path
    masks_dir = 'data/masks'  # Your masks path
    
    # Dataset and DataLoader setup
    transform = transforms.Compose([
        transforms.Resize((400, 400)),  # Ensure images are of the same size (adjust as needed)
        transforms.ToTensor(),
    ])
    
    dataset = GlacierDataset(images_dir, masks_dir, transform=transform)
    train_loader = DataLoader(dataset, batch_size=4, shuffle=True)

    # Initialize the model
    model = UNet(in_channels=3, out_channels=1).to(device)

    # Train the model
    train_model(model, train_loader, device, num_epochs=1, lr=0.001)

    # After training, visualize and calculate
    before_image_path = "before.png"  # Example image path
    after_image_path = "after.png"   # Example image path
    visualize_and_calculate(before_image_path, after_image_path, model, device)


Epoch 1/1: 100%|██████████| 103/103 [00:32<00:00,  3.14it/s]


Epoch 1/1, Loss: 0.4883266683414723
Before Mask Shape: (1, 1490, 1016)
After Mask Shape: (1, 1492, 1008)
Resizing both masks to match the shape of the before_mask: (1, 1490, 1016)
Error resizing after_mask: OpenCV(4.10.0) /Users/runner/miniforge3/conda-bld/libopencv_1727648921144/work/modules/imgproc/src/resize.cpp:3789: error: (-215:Assertion failed) !dsize.empty() in function 'resize'

