In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm
from datasets import load_dataset
from sklearn.model_selection import train_test_split

  from .autonotebook import tqdm as notebook_tqdm


In [2]:


# There are two available configurations, "post-fire" and "pre-post-fire."
dataset = load_dataset("DarthReca/california_burned_areas", name="pre-post-fire")


You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.


In [3]:
dataset

DatasetDict({
    0: Dataset({
        features: ['post_fire', 'pre_fire', 'mask'],
        num_rows: 78
    })
    1: Dataset({
        features: ['post_fire', 'pre_fire', 'mask'],
        num_rows: 55
    })
    2: Dataset({
        features: ['post_fire', 'pre_fire', 'mask'],
        num_rows: 69
    })
    3: Dataset({
        features: ['post_fire', 'pre_fire', 'mask'],
        num_rows: 85
    })
    4: Dataset({
        features: ['post_fire', 'pre_fire', 'mask'],
        num_rows: 69
    })
    chabud: Dataset({
        features: ['post_fire', 'pre_fire', 'mask'],
        num_rows: 68
    })
})

In [4]:
X_0 = zip(dataset['0']['post_fire'], dataset['0']['pre_fire'])
print('done')
X_1 = zip(dataset['1']['post_fire'], dataset['1']['pre_fire'])
print('done')
X_2 =zip(dataset['2']['post_fire'], dataset['2']['pre_fire'])
print('done')
X_3 = zip(dataset['3']['post_fire'], dataset['3']['pre_fire'])
print('done')
X_4 = zip(dataset['4']['post_fire'], dataset['4']['pre_fire'])
print('done')
X_c = zip(dataset['chabud']['post_fire'], dataset['chabud']['pre_fire'])

done
done
done
done
done


In [5]:
m_0 = dataset['0']['mask']
print('done')
m_1 = dataset['1']['mask']
print('done')
m_2 = dataset['2']['mask']
print('done')
m_3 = dataset['3']['mask']
print('done')
m_4 = dataset['4']['mask']
print('done')
m_c = dataset['chabud']['mask']

done
done
done
done
done


In [6]:
X = []
for x in [X_0, X_1, X_2, X_3, X_4, X_c]: 
    X.extend(list(x))  # Combine pre_fire and post_fire as features


train_X, test_X, train_Y, test_Y = train_test_split(X, m_0+m_1+m_2+m_3+m_4+m_c, test_size=0.3, random_state=42)


In [7]:
len(train_X), len(train_Y)

(296, 296)

In [8]:
class conv_block(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()
 
        self.conv = nn.Sequential(
            nn.Conv2d(in_c, out_c, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_c),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_c, out_c, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_c),
            nn.ReLU(inplace=True)
        )
 
    def forward(self, x):
        return self.conv(x)
    
class encoder_block(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()
 
        self.conv = conv_block(in_c, out_c)
        self.pool = nn.MaxPool2d((2, 2))
 
    def forward(self, x):
        s = self.conv(x)
        p = self.pool(s)
        return s, p
class attention_gate(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()
 
        self.Wg = nn.Sequential(
            nn.Conv2d(in_c[0], out_c, kernel_size=1, padding=0),
            nn.BatchNorm2d(out_c)
        )
        self.Ws = nn.Sequential(
            nn.Conv2d(in_c[1], out_c, kernel_size=1, padding=0),
            nn.BatchNorm2d(out_c)
        )
        self.relu = nn.ReLU(inplace=True)
        self.output = nn.Sequential(
            nn.Conv2d(out_c, out_c, kernel_size=1, padding=0),
            nn.Sigmoid()
        )
 
    def forward(self, g, s):
        Wg = self.Wg(g)
        Ws = self.Ws(s)
        out = self.relu(Wg + Ws)
        out = self.output(out)
        return out 


class decoder_block(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()
 
        self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
        self.ag = attention_gate(in_c, out_c)
        self.c1 = conv_block(in_c[0]+out_c, out_c)
 
    def forward(self, x, s):
        x = self.up(x)
        s = self.ag(x, s)
        x = torch.cat([x, s], axis=1)
        x = self.c1(x)
        return x

class attention_unet(nn.Module):
    def __init__(self):
        super().__init__()
 
        self.e1 = encoder_block(12, 64)
        self.e2 = encoder_block(64, 128)
        self.e3 = encoder_block(128, 256)
 
        self.b1 = conv_block(256, 512)
 
        self.d1 = decoder_block([512, 256], 256)
        self.d2 = decoder_block([256, 128], 128)
        self.d3 = decoder_block([128, 64], 64)
 
        self.output = nn.Conv2d(64, 1, kernel_size=1, padding=0)
        self.w1 = nn.Parameter(torch.tensor(0.5))
        self.w2 = nn.Parameter(torch.tensor(0.5))
 
    def forward(self, pre_fire, post_fire):
        weights = torch.softmax(torch.stack([self.w1, self.w2]), dim=0)
        x = weights[0] * pre_fire + weights[1] * post_fire
        s1, p1 = self.e1(x)
        s2, p2 = self.e2(p1)
        s3, p3 = self.e3(p2)
 
        b1 = self.b1(p3)
 
        d1 = self.d1(b1, s3)
        d2 = self.d2(d1, s2)
        d3 = self.d3(d2, s1)
 
        output = self.output(d3)
        return torch.sigmoid(output)

In [11]:
model = attention_unet()

In [17]:
from torch.utils.data import DataLoader, Dataset

class ImageData(Dataset):
    def __init__(self, images, masks):
        self.images = images
        self.masks = masks
    
    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        pre_fire_image = image[1]
        post_fire_image = image[0]
        mask = self.masks[idx]
        tensor_image_pre = torch.tensor(pre_fire_image).float().permute(2, 0, 1) #-> make changes to what channels you want to include
        tensor_image_post = torch.tensor(post_fire_image).float().permute(2, 0, 1)
        # ndvi_val = self.ndvi(tensor_image)
        # abai_val = self.abai(tensor_image)
        # nbr_val = self.nbr(tensor_image)
        # image_with_indices = torch.cat((tensor_image, ndvi_val.unsqueeze(0), abai_val.unsqueeze(0), nbr_val.unsqueeze(0)), dim=0)
        tensor_mask = torch.tensor(mask).float().permute(2, 0, 1)
        return tensor_image_pre, tensor_image_post, tensor_mask
    # def ndvi(self, image):
    #     b4 = image[3, :, :]
    #     b8 = image[7, :, :]
    #     return (b8 - b4) / (b8 + b4 + 1e-6)  # Added small value to avoid division by zero

    # def abai(self, image):
    #     b3 = image[2, :, :]
    #     b11 = image[10, :, :]
    #     b12 = image[11, :, :]
    #     return (3 * b12 - 2 * b11 - 3 * b3) / (3 * b12 + 2 * b11 + 3 * b3 + 1e-6)  # Avoid division by zero

    # def nbr(self, image):
    #     b2 = image[1, :, :]
    #     b3 = image[2, :, :]
    #     b8a = image[9, :, :]
    #     b12 = image[11, :, :]
    #     return (b12 - b8a - b3 - b2) / (b12 + b8a + b3 + b2 + 1e-6)
    
train_dataset = ImageData(images=train_X, masks=train_Y)
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)

# val_dataset = ImageData(images=val_X, masks=val_Y)
# val_loader = DataLoader(val_dataset, batch_size=8, shuffle=True)

test_dataset = ImageData(images=test_X, masks=test_Y)
test_loader = DataLoader(test_dataset, batch_size=8, shuffle=True)
        

In [18]:
def precision_score_(groundtruth_mask, pred_mask):
    intersect = np.sum(pred_mask*groundtruth_mask)
    total_pixel_pred = np.sum(pred_mask)
    precision = np.mean(intersect/total_pixel_pred)
    return round(precision, 3)

def recall_score_(groundtruth_mask, pred_mask):
    intersect = np.sum(pred_mask*groundtruth_mask)
    total_pixel_truth = np.sum(groundtruth_mask)
    recall = np.mean(intersect/total_pixel_truth)
    return round(recall, 3)
def iou_(groundtruth_mask, pred_mask):
    intersect = np.sum(pred_mask*groundtruth_mask)
    union = np.sum(pred_mask) + np.sum(groundtruth_mask) - intersect
    return round(np.mean(intersect/union), 3)
def dice_loss(groundtruth_mask, pred_mask):
    intersect = torch.sum(pred_mask * groundtruth_mask)
    total_sum = torch.sum(pred_mask) + torch.sum(groundtruth_mask)
    dice = 1 - (2 * intersect / (total_sum + 1e-6))  # Avoid division by zero
    return dice


In [19]:
iou_(np.array(m_0[0:5]), np.array(m_0[0:5]))

1.0

In [20]:
def dice_loss(groundtruth_mask, pred_mask):
    intersect = torch.sum(pred_mask * groundtruth_mask)
    total_sum = torch.sum(pred_mask) + torch.sum(groundtruth_mask)
    dice = 1 - (2 * intersect / (total_sum + 1e-6))  # Avoid division by zero
    return dice

In [21]:
import torch.optim as optim
criterion = nn.BCELoss()  # Use BCEWithLogitsLoss for binary segmentation
optimizer = optim.Adam(model.parameters(), lr=1e-4)

# Training Loop
num_epochs = 10
device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")
# model.to(device)

for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0
    all_preds = []
    all_masks = []

    for prefires, postfires, masks in tqdm(train_loader):
        prefires, postfires, masks = prefires.to(device), postfires.to(device), masks.to(device)
        
        # Forward pass
        outputs = model(prefires, postfires)
        loss = criterion(outputs, masks)  # Cast masks to long if needed
        dice = dice_loss(masks, outputs)
        total_loss = loss + dice
        epoch_loss += total_loss.item()


        # Backpropagation and optimization
        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()
        preds = (outputs > 0.5).float()
        all_preds.extend(preds.squeeze(1).cpu().numpy())
        all_masks.extend(masks.squeeze(1).cpu().numpy())

    recall = recall_score_(np.array(all_preds), np.array(all_masks))
    precision = precision_score_(np.array(all_preds), np.array(all_masks))
    iou = iou_(np.array(all_preds), np.array(all_masks))
    
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss/len(train_loader)}")
    print(f"Precision: {recall:.4f}, Recall: {precision:.4f}")
    if recall != 0. and precision != 0.:
        print(f'F1 - score : {2/((1/recall) + (1/precision))}')
    print(f"IOU : {iou}")


100%|██████████| 37/37 [05:44<00:00,  9.31s/it]


Epoch [1/10], Loss: 1.199951245978072
Precision: 0.3280, Recall: 0.4410
F1 - score : 0.3761976705485518
IOU : 0.23199999332427979


100%|██████████| 37/37 [05:42<00:00,  9.25s/it]


Epoch [2/10], Loss: 1.1331337819228302
Precision: 0.3810, Recall: 0.5720
F1 - score : 0.457359933351378
IOU : 0.296999990940094


100%|██████████| 37/37 [05:44<00:00,  9.30s/it]


Epoch [3/10], Loss: 1.0830409269075136
Precision: 0.3900, Recall: 0.6650
F1 - score : 0.49165876226972843
IOU : 0.32600000500679016


100%|██████████| 37/37 [05:43<00:00,  9.29s/it]


Epoch [4/10], Loss: 1.0458001397751473
Precision: 0.4740, Recall: 0.6400
F1 - score : 0.5446319562964186
IOU : 0.37400001287460327


100%|██████████| 37/37 [05:44<00:00,  9.30s/it]


Epoch [5/10], Loss: 1.0346989148371928
Precision: 0.4180, Recall: 0.6970
F1 - score : 0.5225937393543508
IOU : 0.3540000021457672


100%|██████████| 37/37 [05:43<00:00,  9.28s/it]


Epoch [6/10], Loss: 1.0389837815954879
Precision: 0.4250, Recall: 0.6220
F1 - score : 0.5049665726560411
IOU : 0.33799999952316284


100%|██████████| 37/37 [05:45<00:00,  9.35s/it]


Epoch [7/10], Loss: 0.9736422622526014
Precision: 0.4880, Recall: 0.6650
F1 - score : 0.5629141483697085
IOU : 0.39100000262260437


100%|██████████| 37/37 [05:41<00:00,  9.22s/it]


Epoch [8/10], Loss: 0.9939295939497046
Precision: 0.4370, Recall: 0.5900
F1 - score : 0.5021032079945903
IOU : 0.33500000834465027


100%|██████████| 37/37 [05:44<00:00,  9.31s/it]


Epoch [9/10], Loss: 0.9297552527608098
Precision: 0.4790, Recall: 0.7200
F1 - score : 0.5752794101794938
IOU : 0.40400001406669617


100%|██████████| 37/37 [05:48<00:00,  9.41s/it]


Epoch [10/10], Loss: 0.9346185503779231
Precision: 0.4630, Recall: 0.6890
F1 - score : 0.5538316002701457
IOU : 0.382999986410141


In [22]:
model.eval()
epoch_loss = 0
all_preds = []
all_masks = []

for prefires, postfires, masks in tqdm(test_loader):
    prefires, postfires, masks = prefires.to(device), postfires.to(device), masks.to(device)
    
    # Forward pass
    outputs = model(prefires, postfires)
    # loss = criterion(outputs, masks)  # Cast masks to long if needed
    # dice = dice_loss(masks, outputs)
    # total_loss = loss + dice
    # epoch_loss += total_loss.item()


    # # Backpropagation and optimization
    # optimizer.zero_grad()
    # total_loss.backward()
    # optimizer.step()
    preds = (outputs > 0.5).float()
    all_preds.extend(preds.squeeze(1).cpu().numpy())
    all_masks.extend(masks.squeeze(1).cpu().numpy())

recall = recall_score_(np.array(all_preds), np.array(all_masks))
precision = precision_score_(np.array(all_preds), np.array(all_masks))
iou = iou_(np.array(all_preds), np.array(all_masks))
print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss/len(train_loader)}")
print(f"Precision: {recall:.4f}, Recall: {precision:.4f}")
if recall != 0. and precision != 0.:
    print(f'F1 - score : {2/((1/recall) + (1/precision))}')
print(f"IOU : {iou}")

100%|██████████| 16/16 [02:25<00:00,  9.09s/it]


Epoch [10/10], Loss: 0.0
Precision: 0.6740, Recall: 0.7620
F1 - score : 0.7153036458484732
IOU : 0.5569999814033508


In [23]:
from sklearn.metrics import classification_report
print(classification_report(np.array(all_preds).flatten(), np.array(all_masks).flatten()))

              precision    recall  f1-score   support

         0.0       0.95      0.96      0.95  28691371
         1.0       0.76      0.67      0.72   4863061

    accuracy                           0.92  33554432
   macro avg       0.85      0.82      0.84  33554432
weighted avg       0.92      0.92      0.92  33554432

