In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm
from datasets import load_dataset
from sklearn.model_selection import train_test_split

  from .autonotebook import tqdm as notebook_tqdm


In [2]:


# There are two available configurations, "post-fire" and "pre-post-fire."
dataset = load_dataset("DarthReca/california_burned_areas", name="pre-post-fire")


You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.


In [3]:
dataset

DatasetDict({
    0: Dataset({
        features: ['post_fire', 'pre_fire', 'mask'],
        num_rows: 78
    })
    1: Dataset({
        features: ['post_fire', 'pre_fire', 'mask'],
        num_rows: 55
    })
    2: Dataset({
        features: ['post_fire', 'pre_fire', 'mask'],
        num_rows: 69
    })
    3: Dataset({
        features: ['post_fire', 'pre_fire', 'mask'],
        num_rows: 85
    })
    4: Dataset({
        features: ['post_fire', 'pre_fire', 'mask'],
        num_rows: 69
    })
    chabud: Dataset({
        features: ['post_fire', 'pre_fire', 'mask'],
        num_rows: 68
    })
})

In [4]:
X_0 = zip(dataset['0']['post_fire'], dataset['0']['pre_fire'])
print('done')
X_1 = zip(dataset['1']['post_fire'], dataset['1']['pre_fire'])
print('done')
X_2 =zip(dataset['2']['post_fire'], dataset['2']['pre_fire'])
print('done')
X_3 = zip(dataset['3']['post_fire'], dataset['3']['pre_fire'])
print('done')
X_4 = zip(dataset['4']['post_fire'], dataset['4']['pre_fire'])
print('done')
X_c = zip(dataset['chabud']['post_fire'], dataset['chabud']['pre_fire'])

done
done
done
done
done


In [5]:
m_0 = dataset['0']['mask']
print('done')
m_1 = dataset['1']['mask']
print('done')
m_2 = dataset['2']['mask']
print('done')
m_3 = dataset['3']['mask']
print('done')
m_4 = dataset['4']['mask']
print('done')
m_c = dataset['chabud']['mask']

done
done
done
done
done


In [6]:
X = []
for x in [X_0, X_1, X_2, X_3, X_4, X_c]: 
    X.extend(list(x))  # Combine pre_fire and post_fire as features


train_X, test_X, train_Y, test_Y = train_test_split(X, m_0+m_1+m_2+m_3+m_4+m_c, test_size=0.3, random_state=42)


In [7]:
len(train_X), len(train_Y)

(296, 296)

In [8]:
class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()
        # UNET repicated of https://arxiv.org/pdf/2211.12979 page 5.
        self.w1 = nn.Parameter(torch.tensor(0.33))
        self.w2 = nn.Parameter(torch.tensor(0.33))
        self.w3 = nn.Parameter(torch.tensor(0.33))

        ## START OF ENCODING BLOCK ##
        self.orange = nn.Conv2d(in_channels=27, out_channels=64, kernel_size=7, padding='same') #-> change this to how many channels you will use
        self.red1 = nn.MaxPool2d(2)
        self.blue1 = nn.Sequential(nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding='same'), 
                                   nn.BatchNorm2d(64),
                                   nn.ReLU(inplace=True),
                                   nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding='same')
        )
        self.red2 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=2, padding=1)
        self.blue2 = nn.Sequential(nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding='same'), 
                                   nn.BatchNorm2d(128),
                                   nn.ReLU(inplace=True),
                                   nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding='same')
        )
        self.red3 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=2, padding=1)
        self.blue3 = nn.Sequential(nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding='same'), 
                                   nn.BatchNorm2d(256),
                                   nn.ReLU(inplace=True),
                                   nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding='same')
        )
        self.red4 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=2, padding=1)
        self.blue4 = nn.Sequential(nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding='same'), 
                                   nn.BatchNorm2d(512),
                                   nn.ReLU(inplace=True),
                                   nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding='same')
        )
        ## END OF ENCODER BLOCK

        ## START OF DECODER BLOCK ##

        self.green1 = nn.ConvTranspose2d(in_channels=512, out_channels=512, kernel_size=2, stride=2)
        self.upblue1 = nn.Sequential(nn.Conv2d(in_channels=768, out_channels=256, kernel_size=3, padding='same'), 
                                   nn.BatchNorm2d(256),
                                   nn.ReLU(inplace=True),
                                   nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding='same'))
        self.green2 = nn.ConvTranspose2d(in_channels=256, out_channels=256, kernel_size=2, stride=2)
        self.upblue2 = nn.Sequential(nn.Conv2d(in_channels=384, out_channels=128, kernel_size=3, padding='same'), 
                                   nn.BatchNorm2d(128),
                                   nn.ReLU(inplace=True),
                                   nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding='same'))
        self.green3 = nn.ConvTranspose2d(in_channels=128, out_channels=128, kernel_size=2, stride=2)
        self.upblue3 = nn.Sequential(nn.Conv2d(in_channels=192, out_channels=64, kernel_size=3, padding='same'), 
                                   nn.BatchNorm2d(64),
                                   nn.ReLU(inplace=True),
                                   nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding='same'))
        self.green4 = nn.ConvTranspose2d(in_channels=64, out_channels=64, kernel_size=2, stride=2)
        self.upblue4 = nn.Sequential(nn.Conv2d(in_channels=128, out_channels=32, kernel_size=3, padding='same'), 
                                   nn.BatchNorm2d(32),
                                   nn.ReLU(inplace=True),
                                   nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, padding='same'))
        self.upblue5 = nn.Sequential(nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, padding='same'), 
                                   nn.BatchNorm2d(32),
                                   nn.ReLU(inplace=True),
                                   nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, padding='same'))
        self.upblue6 = nn.Sequential(nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, padding='same'), 
                                   nn.BatchNorm2d(32),
                                   nn.ReLU(inplace=True),
                                   nn.Conv2d(in_channels=32, out_channels=16, kernel_size=3, padding='same'))
        
        ## END OF DECODER BLOCK

        ## Final conv layer
        self.final_conv = nn.Conv2d(16, 1, kernel_size=1)


        ###
    def forward(self, pre_fire, post_fire, diff_indices):
        #shape of x -> 12, 512, 512
        weights = torch.softmax(torch.stack([self.w1, self.w2, self.w3]), dim=0)
        fused_feature = torch.cat((weights[0] * pre_fire, weights[1] * post_fire, weights[2] * diff_indices), dim=1)
        orange_op = self.orange(fused_feature) 
        #print('orange_op : ',orange_op.shape) -> torch.Size([1, 64, 512, 512])
        red1_op = self.red1(orange_op)
        # print('red1_op : ',red1_op.shape) -> torch.Size([1, 64, 256, 256])
        blue1_op = self.blue1(self.blue1(red1_op)) + red1_op
        # print('blue1_op : ',blue1_op.shape) -> torch.Size([1, 64, 256, 256])
        red2_op = self.red2(blue1_op)
        # print('red2_op : ',red2_op.shape) -> torch.Size([1, 128, 128, 128])
        blue2_op = self.blue2(self.blue2(self.blue2(red2_op))) + red2_op
        # print('blue2_op : ',blue2_op.shape) -> torch.Size([1, 128, 128, 128])
        red3_op = self.red3(blue2_op)
        # print('red3_op : ',red3_op.shape) -> torch.Size([1, 256, 64, 64])
        blue3_op = self.blue3(self.blue3(self.blue3(self.blue3(self.blue3(red3_op))))) + red3_op
        # print('blue3_op : ',blue3_op.shape) -> torch.Size([1, 256, 64, 64])
        red4_op = self.red4(blue3_op)
        # print('red4_op : ',red4_op.shape) -> torch.Size([1, 512, 32, 32])
        blue4_op = self.blue4(self.blue4(red4_op)) + red4_op
        # print('blue4_op : ',blue4_op.shape) -> torch.Size([1, 512, 32, 32])


        up1_op = self.upblue1(torch.cat((self.green1(blue4_op), blue3_op), dim=1))
        # print('up1_op : ',up1_op.shape) -> torch.Size([1, 256, 64, 64])
        up2_op = self.upblue2(torch.cat((self.green2(up1_op), blue2_op), dim=1))
        # print('up2_op : ',up2_op.shape) -> torch.Size([1, 128, 128, 128])
        up3_op = self.upblue3(torch.cat((self.green3(up2_op), blue1_op), dim=1))
        # print('up3_op : ',up3_op.shape) -> torch.Size([1, 64, 256, 256])
        up4_op = self.upblue6(self.upblue5(self.upblue4(torch.cat((self.green4(up3_op), orange_op), dim=1))))
        # print('up4_op : ',up4_op.shape) -> torch.Size([1, 16, 512, 512])
        return torch.sigmoid(self.final_conv(up4_op))


model = UNet()

In [9]:
from torch.utils.data import DataLoader, Dataset

class ImageData(Dataset):
    def __init__(self, images, masks):
        self.images = images
        self.masks = masks
    
    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        pre_fire_image = image[1]
        post_fire_image = image[0]
        mask = self.masks[idx]
        tensor_image_pre = torch.tensor(pre_fire_image).float().permute(2, 0, 1) #-> make changes to what channels you want to include
        tensor_image_post = torch.tensor(post_fire_image).float().permute(2, 0, 1)
        ndvi_val_pre = self.ndvi(tensor_image_pre)
        abai_val_pre = self.abai(tensor_image_pre)
        nbr_val_pre = self.nbr(tensor_image_pre)

        ndvi_val_post = self.ndvi(tensor_image_post)
        abai_val_post = self.abai(tensor_image_post)
        nbr_val_post = self.nbr(tensor_image_post)

        ndvi_val = ndvi_val_post - ndvi_val_pre
        abai_val = abai_val_post - abai_val_pre
        nbr_val = nbr_val_post - nbr_val_pre
        diff_indices =  torch.cat((ndvi_val.unsqueeze(0), abai_val.unsqueeze(0), nbr_val.unsqueeze(0)), dim=0)
        # image_with_indices = torch.cat((tensor_image, ndvi_val.unsqueeze(0), abai_val.unsqueeze(0), nbr_val.unsqueeze(0)), dim=0)
        tensor_mask = torch.tensor(mask).float().permute(2, 0, 1)
        return tensor_image_pre, tensor_image_post, diff_indices, tensor_mask
    def ndvi(self, image):
        b4 = image[3, :, :]
        b8 = image[7, :, :]
        return (b8 - b4) / (b8 + b4 + 1e-6)  # Added small value to avoid division by zero

    def abai(self, image):
        b3 = image[2, :, :]
        b11 = image[10, :, :]
        b12 = image[11, :, :]
        return (3 * b12 - 2 * b11 - 3 * b3) / (3 * b12 + 2 * b11 + 3 * b3 + 1e-6)  # Avoid division by zero

    def nbr(self, image):
        b2 = image[1, :, :]
        b3 = image[2, :, :]
        b8a = image[9, :, :]
        b12 = image[11, :, :]
        return (b12 - b8a - b3 - b2) / (b12 + b8a + b3 + b2 + 1e-6)
    
train_dataset = ImageData(images=train_X, masks=train_Y)
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)

# val_dataset = ImageData(images=val_X, masks=val_Y)
# val_loader = DataLoader(val_dataset, batch_size=8, shuffle=True)

test_dataset = ImageData(images=test_X, masks=test_Y)
test_loader = DataLoader(test_dataset, batch_size=8, shuffle=True)
        

In [10]:
def precision_score_(groundtruth_mask, pred_mask):
    intersect = np.sum(pred_mask*groundtruth_mask)
    total_pixel_pred = np.sum(pred_mask)
    precision = np.mean(intersect/total_pixel_pred)
    return round(precision, 3)

def recall_score_(groundtruth_mask, pred_mask):
    intersect = np.sum(pred_mask*groundtruth_mask)
    total_pixel_truth = np.sum(groundtruth_mask)
    recall = np.mean(intersect/total_pixel_truth)
    return round(recall, 3)
def iou_(groundtruth_mask, pred_mask):
    intersect = np.sum(pred_mask*groundtruth_mask)
    union = np.sum(pred_mask) + np.sum(groundtruth_mask) - intersect
    return round(np.mean(intersect/union), 3)
def dice_loss(groundtruth_mask, pred_mask):
    intersect = torch.sum(pred_mask * groundtruth_mask)
    total_sum = torch.sum(pred_mask) + torch.sum(groundtruth_mask)
    dice = 1 - (2 * intersect / (total_sum + 1e-6))  # Avoid division by zero
    return dice


In [11]:
iou_(np.array(m_0[0:5]), np.array(m_0[0:5]))

1.0

In [12]:
def dice_loss(groundtruth_mask, pred_mask):
    intersect = torch.sum(pred_mask * groundtruth_mask)
    total_sum = torch.sum(pred_mask) + torch.sum(groundtruth_mask)
    dice = 1 - (2 * intersect / (total_sum + 1e-6))  # Avoid division by zero
    return dice

In [13]:
import torch.optim as optim
criterion = nn.BCELoss()  # Use BCEWithLogitsLoss for binary segmentation
optimizer = optim.Adam(model.parameters(), lr=1e-4)
losses = []
# Training Loop
num_epochs = 10
device = torch.device("cuda:3" if torch.cuda.is_available() else "cpu")
model.to(device)

for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0
    all_preds = []
    all_masks = []

    for prefires, postfires, diff_indices, masks in tqdm(train_loader):
        prefires, postfires, diff_indices, masks = prefires.to(device), postfires.to(device), diff_indices.to(device), masks.to(device)
        
        # Forward pass
        outputs = model(prefires, postfires, diff_indices)
        loss = criterion(outputs, masks)  # Cast masks to long if needed
        dice = dice_loss(masks, outputs)
        total_loss = loss + dice
        epoch_loss += total_loss.item()
        losses.append(total_loss.item())


        # Backpropagation and optimization
        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()
        preds = (outputs > 0.5).float()
        all_preds.extend(preds.squeeze(1).cpu().numpy())
        all_masks.extend(masks.squeeze(1).cpu().numpy())

    recall = recall_score_(np.array(all_preds), np.array(all_masks))
    precision = precision_score_(np.array(all_preds), np.array(all_masks))
    iou = iou_(np.array(all_preds), np.array(all_masks))
    
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss/len(train_loader)}")
    print(f"Precision: {recall:.4f}, Recall: {precision:.4f}")
    if recall != 0. and precision != 0.:
        print(f'F1 - score : {2/((1/recall) + (1/precision))}')
    print(f"IOU : {iou}")


 14%|█▎        | 5/37 [00:46<04:58,  9.32s/it]

100%|██████████| 37/37 [05:42<00:00,  9.25s/it]


Epoch [1/10], Loss: 1.2329383573016606
Precision: 0.4890, Recall: 0.5850
F1 - score : 0.5327094839244912
IOU : 0.3630000054836273


100%|██████████| 37/37 [05:39<00:00,  9.18s/it]


Epoch [2/10], Loss: 0.9869396606007138
Precision: 0.5610, Recall: 0.6630
F1 - score : 0.607749988635381
IOU : 0.4359999895095825


100%|██████████| 37/37 [06:00<00:00,  9.74s/it]


Epoch [3/10], Loss: 0.8160587533100231
Precision: 0.5920, Recall: 0.7030
F1 - score : 0.6427428654261997
IOU : 0.4740000069141388


100%|██████████| 37/37 [05:53<00:00,  9.55s/it]


Epoch [4/10], Loss: 0.7507019212117066
Precision: 0.6330, Recall: 0.6510
F1 - score : 0.6418738512362817
IOU : 0.4729999899864197


100%|██████████| 37/37 [05:57<00:00,  9.67s/it]


Epoch [5/10], Loss: 0.7191718172382664
Precision: 0.6320, Recall: 0.6650
F1 - score : 0.6480802105252526
IOU : 0.4790000021457672


100%|██████████| 37/37 [05:45<00:00,  9.34s/it]


Epoch [6/10], Loss: 0.6794819042489335
Precision: 0.7010, Recall: 0.6290
F1 - score : 0.6630511212315913
IOU : 0.4959999918937683


100%|██████████| 37/37 [05:48<00:00,  9.41s/it]


Epoch [7/10], Loss: 0.6533555605927029
Precision: 0.7050, Recall: 0.6620
F1 - score : 0.6828236939619317
IOU : 0.5189999938011169


100%|██████████| 37/37 [06:09<00:00,  9.98s/it]


Epoch [8/10], Loss: 0.6469286011682974
Precision: 0.7090, Recall: 0.6720
F1 - score : 0.6900043358371746
IOU : 0.5270000100135803


100%|██████████| 37/37 [06:01<00:00,  9.78s/it]


Epoch [9/10], Loss: 0.6402338719045794
Precision: 0.7050, Recall: 0.6310
F1 - score : 0.6659505815307691
IOU : 0.49900001287460327


100%|██████████| 37/37 [05:43<00:00,  9.27s/it]


Epoch [10/10], Loss: 0.6140320510477633
Precision: 0.7270, Recall: 0.6320
F1 - score : 0.676179559605104
IOU : 0.5109999775886536


In [17]:
model.eval()
epoch_loss = 0
all_preds = []
all_masks = []

for prefires, postfires, diff_indices, masks in tqdm(test_loader):
    prefires, postfires, diff_indices, masks = prefires.to(device), postfires.to(device), diff_indices.to(device),  masks.to(device)
    
    # Forward pass
    outputs = model(prefires, postfires, diff_indices)
    # loss = criterion(outputs, masks)  # Cast masks to long if needed
    # dice = dice_loss(masks, outputs)
    # total_loss = loss + dice
    # epoch_loss += total_loss.item()


    # # Backpropagation and optimization
    # optimizer.zero_grad()
    # total_loss.backward()
    # optimizer.step()
    preds = (outputs > 0.5).float()
    all_preds.extend(preds.squeeze(1).cpu().numpy())
    all_masks.extend(masks.squeeze(1).cpu().numpy())

recall = recall_score_(np.array(all_preds), np.array(all_masks))
precision = precision_score_(np.array(all_preds), np.array(all_masks))
iou = iou_(np.array(all_preds), np.array(all_masks))
print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss/len(train_loader)}")
print(f"Precision: {recall:.4f}, Recall: {precision:.4f}")
if recall != 0. and precision != 0.:
    print(f'F1 - score : {2/((1/recall) + (1/precision))}')
print(f"IOU : {iou}")

100%|██████████| 16/16 [02:23<00:00,  8.98s/it]


Epoch [10/10], Loss: 0.0
Precision: 0.7910, Recall: 0.8470
F1 - score : 0.8180427409670945
IOU : 0.6919999718666077


In [15]:
np.array(all_preds).shape

(128, 512, 512)

In [18]:
from sklearn.metrics import classification_report
print(classification_report(np.array(all_preds).flatten(), np.array(all_masks).flatten()))

              precision    recall  f1-score   support

         0.0       0.97      0.98      0.97  28946031
         1.0       0.85      0.79      0.82   4608401

    accuracy                           0.95  33554432
   macro avg       0.91      0.88      0.89  33554432
weighted avg       0.95      0.95      0.95  33554432

