In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm
from datasets import load_dataset
from sklearn.model_selection import train_test_split

  from .autonotebook import tqdm as notebook_tqdm


In [2]:


# There are two available configurations, "post-fire" and "pre-post-fire."
dataset = load_dataset("DarthReca/california_burned_areas", name="pre-post-fire")


You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.


In [3]:
dataset

DatasetDict({
    0: Dataset({
        features: ['post_fire', 'pre_fire', 'mask'],
        num_rows: 78
    })
    1: Dataset({
        features: ['post_fire', 'pre_fire', 'mask'],
        num_rows: 55
    })
    2: Dataset({
        features: ['post_fire', 'pre_fire', 'mask'],
        num_rows: 69
    })
    3: Dataset({
        features: ['post_fire', 'pre_fire', 'mask'],
        num_rows: 85
    })
    4: Dataset({
        features: ['post_fire', 'pre_fire', 'mask'],
        num_rows: 69
    })
    chabud: Dataset({
        features: ['post_fire', 'pre_fire', 'mask'],
        num_rows: 68
    })
})

In [4]:
X_0 = zip(dataset['0']['post_fire'], dataset['0']['pre_fire'])
print('done')
X_1 = zip(dataset['1']['post_fire'], dataset['1']['pre_fire'])
print('done')
X_2 =zip(dataset['2']['post_fire'], dataset['2']['pre_fire'])
print('done')
X_3 = zip(dataset['3']['post_fire'], dataset['3']['pre_fire'])
print('done')
X_4 = zip(dataset['4']['post_fire'], dataset['4']['pre_fire'])
print('done')
X_c = zip(dataset['chabud']['post_fire'], dataset['chabud']['pre_fire'])

done
done
done
done
done


In [5]:
m_0 = dataset['0']['mask']
print('done')
m_1 = dataset['1']['mask']
print('done')
m_2 = dataset['2']['mask']
print('done')
m_3 = dataset['3']['mask']
print('done')
m_4 = dataset['4']['mask']
print('done')
m_c = dataset['chabud']['mask']

done
done
done
done
done


In [6]:
X = []
for x in [X_0, X_1, X_2, X_3, X_4, X_c]: 
    X.extend(list(x))  # Combine pre_fire and post_fire as features


train_X, test_X, train_Y, test_Y = train_test_split(X, m_0+m_1+m_2+m_3+m_4+m_c, test_size=0.3, random_state=42)


In [7]:
len(train_X), len(train_Y)

(296, 296)

In [8]:
class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()
        # UNET repicated of https://arxiv.org/pdf/2211.12979 page 5.
        self.w1 = nn.Parameter(torch.tensor(0.5))
        self.w2 = nn.Parameter(torch.tensor(0.5))
        # self.b = nn.Parameter(torch.tensor(0.5))

        ## START OF ENCODING BLOCK ##
        self.orange = nn.Conv2d(in_channels=12, out_channels=64, kernel_size=7, padding='same') #-> change this to how many channels you will use
        self.red1 = nn.MaxPool2d(2)
        self.blue1 = nn.Sequential(nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding='same'), 
                                   nn.BatchNorm2d(64),
                                   nn.ReLU(inplace=True),
                                   nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding='same')
        )
        self.red2 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=2, padding=1)
        self.blue2 = nn.Sequential(nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding='same'), 
                                   nn.BatchNorm2d(128),
                                   nn.ReLU(inplace=True),
                                   nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding='same')
        )
        self.red3 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=2, padding=1)
        self.blue3 = nn.Sequential(nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding='same'), 
                                   nn.BatchNorm2d(256),
                                   nn.ReLU(inplace=True),
                                   nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding='same')
        )
        self.red4 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=2, padding=1)
        self.blue4 = nn.Sequential(nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding='same'), 
                                   nn.BatchNorm2d(512),
                                   nn.ReLU(inplace=True),
                                   nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding='same')
        )
        ## END OF ENCODER BLOCK

        ## START of ENCODER FOR Pre FIRE
        # self.orangeps = nn.Conv2d(in_channels=12, out_channels=64, kernel_size=7, padding='same') #-> change this to how many channels you will use
        # self.red1ps = nn.MaxPool2d(2)
        # self.blue1ps = nn.Sequential(nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding='same'), 
        #                            nn.BatchNorm2d(64),
        #                            nn.ReLU(inplace=True),
        #                            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding='same')
        # )
        # self.red2ps = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=2, padding=1)
        # self.blue2ps = nn.Sequential(nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding='same'), 
        #                            nn.BatchNorm2d(128),
        #                            nn.ReLU(inplace=True),
        #                            nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding='same')
        # )
        # self.red3ps = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=2, padding=1)
        # self.blue3ps = nn.Sequential(nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding='same'), 
        #                            nn.BatchNorm2d(256),
        #                            nn.ReLU(inplace=True),
        #                            nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding='same')
        # )
        # self.red4ps = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=2, padding=1)
        # self.blue4ps = nn.Sequential(nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding='same'), 
        #                            nn.BatchNorm2d(512),
        #                            nn.ReLU(inplace=True),
        #                            nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding='same')
        # )

        ## END pf Pre Fire ENCODER

        ## START OF DECODER BLOCK ##

        self.green1 = nn.ConvTranspose2d(in_channels=512, out_channels=512, kernel_size=2, stride=2)
        self.upblue1 = nn.Sequential(nn.Conv2d(in_channels=768, out_channels=256, kernel_size=3, padding='same'), 
                                   nn.BatchNorm2d(256),
                                   nn.ReLU(inplace=True),
                                   nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding='same'))
        self.green2 = nn.ConvTranspose2d(in_channels=256, out_channels=256, kernel_size=2, stride=2)
        self.upblue2 = nn.Sequential(nn.Conv2d(in_channels=384, out_channels=128, kernel_size=3, padding='same'), 
                                   nn.BatchNorm2d(128),
                                   nn.ReLU(inplace=True),
                                   nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding='same'))
        self.green3 = nn.ConvTranspose2d(in_channels=128, out_channels=128, kernel_size=2, stride=2)
        self.upblue3 = nn.Sequential(nn.Conv2d(in_channels=192, out_channels=64, kernel_size=3, padding='same'), 
                                   nn.BatchNorm2d(64),
                                   nn.ReLU(inplace=True),
                                   nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding='same'))
        self.green4 = nn.ConvTranspose2d(in_channels=64, out_channels=64, kernel_size=2, stride=2)
        self.upblue4 = nn.Sequential(nn.Conv2d(in_channels=128, out_channels=32, kernel_size=3, padding='same'), 
                                   nn.BatchNorm2d(32),
                                   nn.ReLU(inplace=True),
                                   nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, padding='same'))
        self.upblue5 = nn.Sequential(nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, padding='same'), 
                                   nn.BatchNorm2d(32),
                                   nn.ReLU(inplace=True),
                                   nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, padding='same'))
        self.upblue6 = nn.Sequential(nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, padding='same'), 
                                   nn.BatchNorm2d(32),
                                   nn.ReLU(inplace=True),
                                   nn.Conv2d(in_channels=32, out_channels=16, kernel_size=3, padding='same'))
        
        ## END OF DECODER BLOCK
        # self.meta_1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=7,stride=2, padding=1)
        # self.meta_2 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=2, padding=1)
        # self.meta_3 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=2, padding=1)
        # self.meta_4 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=2, padding=1)

        ## Final conv layer
        self.final_conv = nn.Conv2d(16, 1, kernel_size=1)


        ###
    def forward(self, pre_fire, post_fire):
        #shape of x -> 12, 512, 512
        weights = torch.softmax(torch.stack([self.w1, self.w2]), dim=0)
        fused_feature = torch.cat((weights[0] * pre_fire, weights[1] * post_fire), dim=1)
        orange_op = self.orange(fused_feature) 
        #print('orange_op : ',orange_op.shape) -> torch.Size([1, 64, 512, 512])
        red1_op = self.red1(orange_op)
        # print('red1_op : ',red1_op.shape) -> torch.Size([1, 64, 256, 256])
        blue1_op = self.blue1(self.blue1(red1_op)) + red1_op
        # print('blue1_op : ',blue1_op.shape) -> torch.Size([1, 64, 256, 256])
        red2_op = self.red2(blue1_op)
        # print('red2_op : ',red2_op.shape) -> torch.Size([1, 128, 128, 128])
        blue2_op = self.blue2(self.blue2(self.blue2(red2_op))) + red2_op
        # print('blue2_op : ',blue2_op.shape) -> torch.Size([1, 128, 128, 128])
        red3_op = self.red3(blue2_op)
        # print('red3_op : ',red3_op.shape) -> torch.Size([1, 256, 64, 64])
        blue3_op = self.blue3(self.blue3(self.blue3(self.blue3(self.blue3(red3_op))))) + red3_op
        # print('blue3_op : ',blue3_op.shape) -> torch.Size([1, 256, 64, 64])
        red4_op = self.red4(blue3_op)
        # print('red4_op : ',red4_op.shape) -> torch.Size([1, 512, 32, 32])
        blue4_op = self.blue4(self.blue4(red4_op)) + red4_op
        # print('blue4_op : ',blue4_op.shape) -> torch.Size([1, 512, 32, 32])

        # orange_op1 = self.orangeps(post_fire) 
        # #print('orange_op : ',orange_op.shape) -> torch.Size([1, 64, 512, 512])
        # red1_op1 = self.red1ps(orange_op1)
        # # print('red1_op : ',red1_op.shape) -> torch.Size([1, 64, 256, 256])
        # blue1_op1 = self.blue1ps(self.blue1ps(red1_op1)) + red1_op1
        # # print('blue1_op : ',blue1_op.shape) -> torch.Size([1, 64, 256, 256])
        # red2_op1 = self.red2ps(blue1_op1)
        # # print('red2_op : ',red2_op.shape) -> torch.Size([1, 128, 128, 128])
        # blue2_op1 = self.blue2ps(self.blue2ps(self.blue2ps(red2_op1))) + red2_op1
        # # print('blue2_op : ',blue2_op.shape) -> torch.Size([1, 128, 128, 128])
        # red3_op1 = self.red3ps(blue2_op1)
        # # print('red3_op : ',red3_op.shape) -> torch.Size([1, 256, 64, 64])
        # blue3_op1 = self.blue3ps(self.blue3ps(self.blue3ps(self.blue3ps(self.blue3ps(red3_op1))))) + red3_op1
        # # print('blue3_op : ',blue3_op.shape) -> torch.Size([1, 256, 64, 64])
        # red4_op1 = self.red4ps(blue3_op1)
        # # print('red4_op : ',red4_op.shape) -> torch.Size([1, 512, 32, 32])
        # blue4_op1 = self.blue4ps(self.blue4ps(red4_op1)) + red4_op1
        # # print('blue4_op : ',blue4_op.shape) -> torch.Size([1, 512, 32, 32])
        # weights = torch.softmax(torch.stack([self.w1, self.w2]), dim=0)
        # fused_feature = weights[0] * blue4_op1 + weights[1] * blue4_op

        up1_op = self.upblue1(torch.cat((self.green1(blue4_op), blue3_op), dim=1))
        # print('up1_op : ',up1_op.shape) -> torch.Size([1, 256, 64, 64])
        up2_op = self.upblue2(torch.cat((self.green2(up1_op), blue2_op), dim=1))
        # print('up2_op : ',up2_op.shape) -> torch.Size([1, 128, 128, 128])
        up3_op = self.upblue3(torch.cat((self.green3(up2_op), blue1_op), dim=1))
        # print('up3_op : ',up3_op.shape) -> torch.Size([1, 64, 256, 256])
        up4_op = self.upblue6(self.upblue5(self.upblue4(torch.cat((self.green4(up3_op), orange_op), dim=1))))
        # print('up4_op : ',up4_op.shape) -> torch.Size([1, 16, 512, 512])
        return torch.sigmoid(self.final_conv(up4_op))


model = UNet()

In [9]:
from torch.utils.data import DataLoader, Dataset


import torch

class ImageData(Dataset):
    def __init__(self, images, masks):
        self.images = images
        self.masks = masks
        self.means = [
            0.033349706741586264, 0.05701185520536176, 0.05889748132001316,
            0.2323245113436119, 0.1972854853760658, 0.11944914225186566
        ]
        self.stds = [
            0.02269135568823774, 0.026807560223070237, 0.04004109844362779,
            0.07791732423672691, 0.08708738838140137, 0.07241979477437814
        ]
    
    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        pre_fire_image = image[1]
        post_fire_image = image[0]
        mask = self.masks[idx]

        # Convert to tensors and permute the dimensions to (C, H, W)
        tensor_image_pre = torch.tensor(pre_fire_image).float().permute(2, 0, 1)[:6]
        tensor_image_post = torch.tensor(post_fire_image).float().permute(2, 0, 1)[:6]
        tensor_mask = torch.tensor(mask).float().permute(2, 0, 1)

        # Normalize each channel
        for i in range(6):
            tensor_image_pre[i] = (tensor_image_pre[i] - self.means[i]) / self.stds[i]
            tensor_image_post[i] = (tensor_image_post[i] - self.means[i]) / self.stds[i]

        return tensor_image_pre, tensor_image_post, tensor_mask

    
train_dataset = ImageData(images=train_X, masks=train_Y)
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)

# val_dataset = ImageData(images=val_X, masks=val_Y)
# val_loader = DataLoader(val_dataset, batch_size=8, shuffle=True)

test_dataset = ImageData(images=test_X, masks=test_Y)
test_loader = DataLoader(test_dataset, batch_size=8, shuffle=True)
        

In [10]:
def precision_score_(groundtruth_mask, pred_mask):
    intersect = np.sum(pred_mask*groundtruth_mask)
    total_pixel_pred = np.sum(pred_mask)
    precision = np.mean(intersect/total_pixel_pred)
    return round(precision, 3)

def recall_score_(groundtruth_mask, pred_mask):
    intersect = np.sum(pred_mask*groundtruth_mask)
    total_pixel_truth = np.sum(groundtruth_mask)
    recall = np.mean(intersect/total_pixel_truth)
    return round(recall, 3)


In [11]:
recall_score_(np.array(m_0[0:5]), np.array(m_0[0:5]))

1.0

In [12]:
def dice_loss(groundtruth_mask, pred_mask):
    intersect = torch.sum(pred_mask * groundtruth_mask)
    total_sum = torch.sum(pred_mask) + torch.sum(groundtruth_mask)
    dice = 1 - (2 * intersect / (total_sum + 1e-6))  # Avoid division by zero
    return dice

In [13]:
import torch.optim as optim
criterion = nn.BCELoss()  # Use BCEWithLogitsLoss for binary segmentation
optimizer = optim.Adam(model.parameters(), lr=1e-4)

# Training Loop
num_epochs = 10
device = torch.device("cuda:3" if torch.cuda.is_available() else "cpu")
model.to(device)

for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0
    all_preds = []
    all_masks = []

    for prefires, postfires, masks in tqdm(train_loader):
        prefires, postfires, masks = prefires.to(device), postfires.to(device), masks.to(device)
        
        # Forward pass
        outputs = model(prefires, postfires)
        loss = criterion(outputs, masks)  # Cast masks to long if needed
        dice = dice_loss(masks, outputs)
        total_loss = loss + dice
        epoch_loss += total_loss.item()


        # Backpropagation and optimization
        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()
        preds = (outputs > 0.5).float()
        all_preds.extend(preds.squeeze(1).cpu().numpy())
        all_masks.extend(masks.squeeze(1).cpu().numpy())

    recall = recall_score_(np.array(all_preds), np.array(all_masks))
    precision = precision_score_(np.array(all_preds), np.array(all_masks))
    
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss/len(train_loader)}")
    print(f"Precision: {recall:.4f}, Recall: {precision:.4f}")
    if recall != 0. and precision != 0.:
        print(f'F1 - score : {2/((1/recall) + (1/precision))}')


100%|██████████| 37/37 [05:45<00:00,  9.34s/it]


Epoch [1/10], Loss: 1.3783846964707245
Precision: 0.2390, Recall: 0.7780
F1 - score : 0.3656676409841659


100%|██████████| 37/37 [05:54<00:00,  9.58s/it]


Epoch [2/10], Loss: 1.1601060677219082
Precision: 0.3970, Recall: 0.6510
F1 - score : 0.49321948362564416


100%|██████████| 37/37 [05:55<00:00,  9.62s/it]


Epoch [3/10], Loss: 1.0325032053767025
Precision: 0.4140, Recall: 0.6230
F1 - score : 0.4974387771481549


100%|██████████| 37/37 [06:11<00:00, 10.05s/it]


Epoch [4/10], Loss: 0.933425606908025
Precision: 0.4720, Recall: 0.6260
F1 - score : 0.5382003614004442


100%|██████████| 37/37 [06:02<00:00,  9.80s/it]


Epoch [5/10], Loss: 0.8648145021619024
Precision: 0.5480, Recall: 0.5600
F1 - score : 0.5539350080108536


100%|██████████| 37/37 [06:03<00:00,  9.83s/it]


Epoch [6/10], Loss: 0.8211970957549842
Precision: 0.5270, Recall: 0.6060
F1 - score : 0.5637458160020952


100%|██████████| 37/37 [06:08<00:00,  9.97s/it]


Epoch [7/10], Loss: 0.8202610321947046
Precision: 0.5970, Recall: 0.5660
F1 - score : 0.581086837681739


100%|██████████| 37/37 [06:00<00:00,  9.73s/it]


Epoch [8/10], Loss: 0.7494347538496997
Precision: 0.6380, Recall: 0.5930
F1 - score : 0.6146775004626259


100%|██████████| 37/37 [06:20<00:00, 10.27s/it]


Epoch [9/10], Loss: 0.7779548748119457
Precision: 0.5890, Recall: 0.6020
F1 - score : 0.5954290434622462


100%|██████████| 37/37 [06:15<00:00, 10.14s/it]


Epoch [10/10], Loss: 0.7882125804553161
Precision: 0.5720, Recall: 0.5660
F1 - score : 0.5689841882777665


In [14]:
model.eval()
epoch_loss = 0
all_preds = []
all_masks = []

for prefires, postfires, masks in tqdm(test_loader):
    prefires, postfires, masks = prefires.to(device), postfires.to(device), masks.to(device)
    
    # Forward pass
    outputs = model(prefires, postfires)
    loss = criterion(outputs, masks)  # Cast masks to long if needed
    dice = dice_loss(masks, outputs)
    total_loss = loss + dice
    epoch_loss += total_loss.item()


    # Backpropagation and optimization
    optimizer.zero_grad()
    total_loss.backward()
    optimizer.step()
    preds = (outputs > 0.5).float()
    all_preds.extend(preds.squeeze(1).cpu().numpy())
    all_masks.extend(masks.squeeze(1).cpu().numpy())

recall = recall_score_(np.array(all_preds), np.array(all_masks))
precision = precision_score_(np.array(all_preds), np.array(all_masks))

print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss/len(train_loader)}")
print(f"Precision: {recall:.4f}, Recall: {precision:.4f}")
if recall != 0. and precision != 0.:
    print(f'F1 - score : {2/((1/recall) + (1/precision))}')

100%|██████████| 16/16 [03:02<00:00, 11.41s/it]


Epoch [10/10], Loss: 0.3269296038795162
Precision: 0.5650, Recall: 0.6670
F1 - score : 0.6117775941997713


In [15]:
np.array(all_preds).shape

(128, 512, 512)

In [16]:
from sklearn.metrics import classification_report
print(classification_report(np.array(all_preds).flatten(), np.array(all_masks).flatten()))

              precision    recall  f1-score   support

         0.0       0.92      0.95      0.94  28468997
         1.0       0.67      0.56      0.61   5085435

    accuracy                           0.89  33554432
   macro avg       0.80      0.76      0.77  33554432
weighted avg       0.89      0.89      0.89  33554432

