In [None]:
# !pip install datasets
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm
from datasets import load_dataset
from sklearn.model_selection import train_test_split

In [None]:
# to store and read data from google drive
from google.colab import drive
from datasets import load_from_disk
import shutil
import os

In [None]:
# dataset = load_dataset("DarthReca/california_burned_areas", name="pre-post-fire", trust_remote_code=True)
dataset = load_dataset("DarthReca/california_burned_areas", name="post-fire", trust_remote_code=True)

In [None]:
X_0 = dataset['0']['post_fire']
print('done')
X_1 = dataset['1']['post_fire']
print('done')
X_2 = dataset['2']['post_fire']
print('done')
X_3 = dataset['3']['post_fire']
print('done')
X_4 = dataset['4']['post_fire']
print('done')
X_c = dataset['chabud']['post_fire']

In [None]:
m_0 = dataset['0']['mask']
print('done')
m_1 = dataset['1']['mask']
print('done')
m_2 = dataset['2']['mask']
print('done')
m_3 = dataset['3']['mask']
print('done')
m_4 = dataset['4']['mask']
print('done')
m_c = dataset['chabud']['mask']

In [None]:
# Add factor before split the data

epsilon = 1e-10

# Functions to calculate the indices with added epsilon
def ndvi(b4, b8):
    return (b8 - b4) / (b8 + b4 + epsilon)

def abai(b3, b11, b12):
    return (3 * b12 - 2 * b11 - 3 * b3) / (3 * b12 + 2 * b11 + 3 * b3 + epsilon)

def nbr(b2, b3, b8a, b12):
    return (b12 - b8a - b3 - b2) / (b12 + b8a + b3 + b2 + epsilon)

def add_indices(image):
    # Convert to a NumPy array if it's not already
    image = np.array(image)

    # Extract the required bands
    b2, b3, b4, b8, b8a, b11, b12 = image[..., 1], image[..., 2], image[..., 3], image[..., 7], image[..., 8], image[..., 10], image[..., 11]
    ndvi_band = ndvi(b4, b8)
    abai_band = abai(b3, b11, b12)
    nbr_band = nbr(b2, b3, b8a, b12)

    # Stack indices as additional bands
    return np.dstack((image, ndvi_band, abai_band, nbr_band))

# Concatenate all images and masks
all_images = [X_0, X_1, X_2, X_3, X_4, X_c]
all_masks = [m_0, m_1, m_2, m_3, m_4, m_c]

# Add indices to each image
processed_images = []
for subset in all_images:
    for img in subset:
        processed_images.append(add_indices(img))

# Concatenate masks into a single list
processed_masks = [mask for subset in all_masks for mask in subset]

# Split into train and test sets
train_X, test_X, train_Y, test_Y = train_test_split(processed_images, processed_masks, test_size=0.3)

# Check the final shape of train_X
#print(np.array(train_X).shape)  # Should be (421, 512, 512, 15)

In [None]:
class AttentionBlock(nn.Module):
    def __init__(self, in_channels, gating_channels, inter_channels):
        super(AttentionBlock, self).__init__()

        # 1x1 convolution for gating signal
        self.phi_g = nn.Conv2d(gating_channels, inter_channels, kernel_size=1, padding='same')

        # 1x1 convolution for input signal
        self.theta_x = nn.Conv2d(in_channels, inter_channels, kernel_size=3, stride=1, padding='same')

        # Combine gating and input signal
        self.psi = nn.Conv2d(inter_channels, 1, kernel_size=1, padding='same')

        # Normalization
        self.bn = nn.BatchNorm2d(in_channels)

    def forward(self, x, g):
        # Get shapes
        shape_x = x.shape
        shape_g = g.shape

        # Convolve the gating signal and the input signal
        phi_g = self.phi_g(g)
        theta_x = self.theta_x(x)

        # Add the convolved features and apply ReLU
        add_xg = F.relu(phi_g + theta_x)

        # Sigmoid activation on the summed features
        psi = torch.sigmoid(self.psi(add_xg))

        # Upsample psi to the size of x
        upsample_sigmoid_xg = F.interpolate(psi, size=(shape_x[2], shape_x[3]), mode='bilinear', align_corners=True)

        # Multiply upsampled attention map with x (element-wise attention)
        attn_coefficients = upsample_sigmoid_xg * x

        # Consolidate to original x channels with a 1x1 convolution
        output = self.bn(attn_coefficients)

        return output

In [None]:
class UNetWithAttention(nn.Module):
    def __init__(self):
        super(UNetWithAttention, self).__init__()

        # Define encoder blocks (as in your original UNet)
        self.orange = nn.Conv2d(in_channels=12, out_channels=64, kernel_size=7, padding='same')
        self.red1 = nn.MaxPool2d(2)
        self.blue1 = nn.Sequential(nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding='same'),
                                   nn.BatchNorm2d(64),
                                   nn.ReLU(inplace=True),
                                   nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding='same'))
        self.red2 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=2, padding=1)
        self.blue2 = nn.Sequential(nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding='same'),
                                   nn.BatchNorm2d(128),
                                   nn.ReLU(inplace=True),
                                   nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding='same'))
        self.red3 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=2, padding=1)
        self.blue3 = nn.Sequential(nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding='same'),
                                   nn.BatchNorm2d(256),
                                   nn.ReLU(inplace=True),
                                   nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding='same'))
        self.red4 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=2, padding=1)
        self.blue4 = nn.Sequential(nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding='same'),
                                   nn.BatchNorm2d(512),
                                   nn.ReLU(inplace=True),
                                   nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding='same'))

        # Define attention blocks for each skip connection
        self.att1 = AttentionBlock(in_channels=256, gating_channels=512, inter_channels=128)
        self.att2 = AttentionBlock(in_channels=128, gating_channels=256, inter_channels=64)
        self.att3 = AttentionBlock(in_channels=64, gating_channels=128, inter_channels=32)
        self.att4 = AttentionBlock(in_channels=64, gating_channels=64, inter_channels=32)

        # Define decoder blocks (upsampling)
        self.green1 = nn.ConvTranspose2d(in_channels=512, out_channels=512, kernel_size=2, stride=2)
        self.upblue1 = nn.Sequential(nn.Conv2d(in_channels=768, out_channels=256, kernel_size=3, padding='same'),
                                     nn.BatchNorm2d(256),
                                     nn.ReLU(inplace=True),
                                     nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding='same'))
        self.green2 = nn.ConvTranspose2d(in_channels=256, out_channels=256, kernel_size=2, stride=2)
        self.upblue2 = nn.Sequential(nn.Conv2d(in_channels=384, out_channels=128, kernel_size=3, padding='same'),
                                     nn.BatchNorm2d(128),
                                     nn.ReLU(inplace=True),
                                     nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding='same'))
        self.green3 = nn.ConvTranspose2d(in_channels=128, out_channels=128, kernel_size=2, stride=2)
        self.upblue3 = nn.Sequential(nn.Conv2d(in_channels=192, out_channels=64, kernel_size=3, padding='same'),
                                     nn.BatchNorm2d(64),
                                     nn.ReLU(inplace=True),
                                     nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding='same'))
        self.green4 = nn.ConvTranspose2d(in_channels=64, out_channels=64, kernel_size=2, stride=2)
        self.upblue4 = nn.Sequential(nn.Conv2d(in_channels=128, out_channels=32, kernel_size=3, padding='same'),
                                     nn.BatchNorm2d(32),
                                     nn.ReLU(inplace=True),
                                     nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, padding='same'))

        self.final_conv = nn.Conv2d(32, 1, kernel_size=1)

    def forward(self, x):
        # Encoder path
        orange_op = self.orange(x)
        red1_op = self.red1(orange_op)
        blue1_op = self.blue1(red1_op) + red1_op
        red2_op = self.red2(blue1_op)
        blue2_op = self.blue2(red2_op) + red2_op
        red3_op = self.red3(blue2_op)
        blue3_op = self.blue3(red3_op) + red3_op
        red4_op = self.red4(blue3_op)
        blue4_op = self.blue4(red4_op) + red4_op

        # Decoder path with attention
        g1 = self.green1(blue4_op)
        att1 = self.att1(blue3_op, g1)
        up1_op = self.upblue1(torch.cat([g1, att1], dim=1))

        g2 = self.green2(up1_op)
        att2 = self.att2(blue2_op, g2)
        up2_op = self.upblue2(torch.cat([g2, att2], dim=1))

        g3 = self.green3(up2_op)
        att3 = self.att3(blue1_op, g3)
        up3_op = self.upblue3(torch.cat([g3, att3], dim=1))

        g4 = self.green4(up3_op)
        att4 = self.att4(orange_op, g4)
        up4_op = self.upblue4(torch.cat([g4, att4], dim=1))

        # Final output
        return torch.sigmoid(self.final_conv(up4_op))


model = UNetWithAttention()

In [None]:
from torch.utils.data import DataLoader, Dataset

class ImageData(Dataset):
    def __init__(self, images, masks):
        self.images = images
        self.masks = masks

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        mask = self.masks[idx]
        tensor_image = torch.tensor(image[:, :, [1, 2, 3, 4, 5, 6, 8, 10, 11, 12, 13,14]]).float().permute(2, 0, 1) #-> make changes to what channels you want to include
        tensor_mask = torch.tensor(mask).float().permute(2, 0, 1)
        return tensor_image, tensor_mask

train_dataset = ImageData(images=train_X, masks=train_Y)
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)

test_dataset = ImageData(images=test_X, masks=test_Y)
test_loader = DataLoader(test_dataset, batch_size=8, shuffle=True)

In [None]:
def precision_score_(groundtruth_mask, pred_mask):
    intersect = np.sum(pred_mask*groundtruth_mask)
    total_pixel_pred = np.sum(pred_mask)
    precision = np.mean(intersect/total_pixel_pred)
    return round(precision, 3)

def recall_score_(groundtruth_mask, pred_mask):
    intersect = np.sum(pred_mask*groundtruth_mask)
    total_pixel_truth = np.sum(groundtruth_mask)
    recall = np.mean(intersect/total_pixel_truth)
    return round(recall, 3)

def dice_loss(groundtruth_mask, pred_mask):
    intersect = torch.sum(pred_mask * groundtruth_mask)
    total_sum = torch.sum(pred_mask) + torch.sum(groundtruth_mask)
    dice = 1 - (2 * intersect / (total_sum + 1e-6))  # Avoid division by zero
    return dice

def iou_(groundtruth_mask, pred_mask):
    intersect = np.sum(pred_mask*groundtruth_mask)
    union = np.sum(pred_mask) + np.sum(groundtruth_mask) - intersect
    return round(np.mean(intersect/union), 3)

In [None]:
import torch.optim as optim
import numpy as np
from tqdm import tqdm
from sklearn.metrics import f1_score

# Define loss and optimizer
criterion = nn.BCEWithLogitsLoss()  # Use BCEWithLogitsLoss for binary segmentation
optimizer = optim.Adam(model.parameters(), lr=1e-4)

# Training Loop
num_epochs = 10
# num_epochs = 1
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
model.to(device)

for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0
    all_preds = []
    all_masks = []

    for images, masks in tqdm(train_loader):
        images, masks = images.to(device), masks.to(device)

        # Forward pass
        outputs = model(images)

        # Calculate loss
        loss = criterion(outputs, masks)

        # Calculate additional metrics (if you have a custom dice_loss function)
        # dice = dice_loss(masks, outputs)
        # total_loss = loss + dice  # Uncomment if using Dice loss
        total_loss = loss  # For now, just using BCE loss
        epoch_loss += total_loss.item()

        # Backpropagation and optimization
        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()

        # Generate predictions
        preds = torch.sigmoid(outputs)  # Apply sigmoid since using BCEWithLogitsLoss
        preds = (preds > 0.5).float()

        # Store predictions and ground truths
        all_preds.extend(preds.squeeze(1).cpu().numpy())
        all_masks.extend(masks.squeeze(1).cpu().numpy())

    # Calculate precision and recall
    recall = recall_score_(np.array(all_preds), np.array(all_masks))
    precision = precision_score_(np.array(all_preds), np.array(all_masks))
    iou = iou_(np.array(all_preds), np.array(all_masks))

    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss/len(train_loader)}")
    print(f"Precision: {recall:.4f}, Recall: {precision:.4f}")
    if recall != 0. and precision != 0.:
        print(f'F1 - score : {2/((1/recall) + (1/precision))}')
    print(f"IOU : {iou}")


In [None]:
model.eval()
epoch_loss = 0
all_preds = []
all_masks = []

for images, masks in tqdm(test_loader):
    images, masks = images.to(device), masks.to(device)

    # Forward pass
    outputs = model(images)
    loss = criterion(outputs, masks)  # Cast masks to long if needed
    dice = dice_loss(masks, outputs)
    total_loss = loss + dice
    epoch_loss += total_loss.item()


    # Backpropagation and optimization
    optimizer.zero_grad()
    total_loss.backward()
    optimizer.step()
    preds = (outputs > 0.5).float()
    all_preds.extend(preds.squeeze(1).cpu().numpy())
    all_masks.extend(masks.squeeze(1).cpu().numpy())

recall = recall_score_(np.array(all_preds), np.array(all_masks))
precision = precision_score_(np.array(all_preds), np.array(all_masks))
iou = iou_(np.array(all_preds), np.array(all_masks))

print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss/len(train_loader)}")
print(f"Precision: {recall:.4f}, Recall: {precision:.4f}")
if recall != 0. and precision != 0.:
    print(f'F1 - score : {2/((1/recall) + (1/precision))}')
print(f"IOU : {iou}")

In [None]:
from sklearn.metrics import classification_report
print(classification_report(np.array(all_preds).flatten(), np.array(all_masks).flatten()))