In [None]:
import torch
import os, pickle
import numpy as np
import evaluate
import matplotlib.pyplot as plt

from torch.utils.data import Dataset, DataLoader, random_split
from transformers import SegformerForSemanticSegmentation, SegformerConfig, SegformerImageProcessor
from torchinfo import summary
from torch.utils.data.dataset import Dataset
from torchvision import transforms
from torch import nn
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
from tqdm.notebook import tqdm
from matplotlib.colors import ListedColormap, BoundaryNorm

In [None]:
# Load the pre-trained model
base_model = SegformerForSemanticSegmentation.from_pretrained("nvidia/mit-b0")

# Update the configuration
config = base_model.config
config.num_channels = 18
config.num_labels = 2

# Initialize the new model with the updated configuration
model = SegformerForSemanticSegmentation(config)

# Load the state dictionary from the pre-trained model into the new model
pretrained_dict = base_model.state_dict()
model_dict = model.state_dict()

# Filter out unnecessary keys
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict and v.size() == model_dict[k].size()}

# Overwrite entries in the existing state dict
model_dict.update(pretrained_dict)

# Load the new state dict
model.load_state_dict(model_dict)

# Added batch size to input_size
summary(model, input_size=(1, 18, 128, 128))

In [None]:
print(model.config)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

test_input  = torch.randn(1, 18, 128, 128).to(device)

output = model(test_input)

print(output.logits.shape)

In [None]:
import os
import pickle
import numpy as np
import torch
from torch.utils.data import Dataset

def get_indices(arr):
    if arr.ndim != 3 or arr.shape[2] < 7:
        raise ValueError("Input array must be 3-dimensional with at least 7 channels.")
    
    bands = {
        "ndvi": (arr[:, :, 4] - arr[:, :, 3]) / (arr[:, :, 4] + arr[:, :, 3] + 1e-7),
        "evi": 2.5 * (arr[:, :, 4] - arr[:, :, 3]) / (arr[:, :, 4] + 6 * arr[:, :, 3] - 7.5 * arr[:, :, 1] + 1),
        "savi": 1.5 * (arr[:, :, 4] - arr[:, :, 3]) / (arr[:, :, 4] + arr[:, :, 3] + 0.5),
        "msavi": 0.5 * (2 * arr[:, :, 4] + 1 - np.sqrt((2 * arr[:, :, 4] + 1) ** 2 - 8 * (arr[:, :, 4] - arr[:, :, 3]))),
        "ndmi": (arr[:, :, 4] - arr[:, :, 5]) / (arr[:, :, 4] + arr[:, :, 5] + 1e-7),
        "nbr": (arr[:, :, 4] - arr[:, :, 6]) / (arr[:, :, 4] + arr[:, :, 6] + 1e-7),
        "nbr2": (arr[:, :, 5] - arr[:, :, 6]) / (arr[:, :, 5] + arr[:, :, 6] + 1e-7),
    }
    for name in bands:
        value = np.nan_to_num(bands[name])
        arr = np.dstack((arr, value))
    return arr

def normalize_image(image):
    min_val = np.min(image)
    max_val = np.max(image)
    return (image - min_val) / (max_val - min_val + 1e-7)

class SlidingWindowDataset(Dataset):
    def __init__(self, pickle_dir, window_size=128, stride=64, reduce_indices=False):
        self.pickle_dir = pickle_dir
        self.window_size = window_size
        self.stride = stride
        self.reduce_indices = reduce_indices
        self.processed_images, self.processed_masks = self._process_data()

    def _process_data(self):
        processed_images = []
        processed_masks = []
        
        for file_name in os.listdir(self.pickle_dir):
            if file_name.endswith('.pkl'):
                with open(os.path.join(self.pickle_dir, file_name), 'rb') as f:
                    img, mask = pickle.load(f, encoding='latin1')
                
                if img.ndim == 3 and img.shape[2] >= 7:
                    img = get_indices(img)
                    img = normalize_image(img)
                    h, w, _ = img.shape
                    for i in range(0, h - self.window_size + 1, self.stride):
                        for j in range(0, w - self.window_size + 1, self.stride):
                            window_img = img[i:i + self.window_size, j:j + self.window_size]
                            window_mask = mask[i:i + self.window_size, j:j + self.window_size]
                            class_0_ratio = np.sum(window_mask == 0) / window_mask.size
                            class_1_ratio = np.sum(window_mask == 1) / window_mask.size
                            class_2_ratio = np.sum(window_mask == 2) / window_mask.size
                            if class_0_ratio < 0.5:
                                if class_2_ratio > 0.4:
                                    # Augment the image by rotating it 90 degrees 3 times
                                    for _ in range(3):
                                        window_img = np.rot90(window_img).copy()
                                        window_mask = np.rot90(window_mask).copy()
                                        processed_images.append(window_img)
                                        processed_masks.append(window_mask)
                                else:
                                    processed_images.append(window_img)
                                    processed_masks.append(window_mask)
                else:
                    print(f"Skipping image with shape {img.shape} in file {file_name}")
        
        return processed_images, processed_masks

    def __len__(self):
        return len(self.processed_images)

    def __getitem__(self, idx):
        image = self.processed_images[idx]
        mask = self.processed_masks[idx]
        
        if mask.dtype == np.uint16:
            mask = mask.astype(np.int64)
            
        image = torch.tensor(image, dtype=torch.float32).permute(2, 0, 1)  # Convert to CxHxW
        mask = torch.tensor(mask, dtype=torch.long)
        
        if self.reduce_indices:
            mask = mask - 1
            mask[mask == -1] = 255

        encoded_data = {
            'pixel_values': image,
            'labels': mask
        }

        return encoded_data

In [None]:
root_dir = 'data/train/training_2015_pickled_data'

dataset = SlidingWindowDataset(pickle_dir=root_dir, window_size=128, reduce_indices=True)

SEED = 123
LEARNING_RATE = 1e-5
BATCH_SIZE = 16
TRAIN_DEV_TEST_SPLIT = (0.8, 0.1, 0.1)

generator = torch.Generator().manual_seed(SEED)
train_dataset, val_dataset, test_dataset = random_split(dataset, TRAIN_DEV_TEST_SPLIT, generator)

train_dataloader = DataLoader(train_dataset, batch_size = BATCH_SIZE, shuffle = True)
val_dataloader = DataLoader(val_dataset, batch_size = BATCH_SIZE, shuffle = False)
test_dataloader = DataLoader(test_dataset, batch_size = BATCH_SIZE, shuffle = False)

In [None]:
batch = next(iter(train_dataloader))

In [None]:
# Take a mask and check if it has any values other than 0 and 1
mask = batch['labels']
image = batch['pixel_values']

print(torch.unique(mask))
print(mask.shape)
print(torch.max(image), torch.min(image))

In [None]:
# Take a mask and check if it has any values other than 0 and 1
mask = batch['labels']

print(torch.unique(mask))

In [None]:
print(len(dataset))

# 2458

In [None]:
import torch
import torch.nn as nn
from tqdm import tqdm
import numpy as np
from sklearn.metrics import precision_score, recall_score, f1_score

# Define optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=0.00006)

# Move model to GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Define the metric functions
def compute_accuracy_ignore_background(preds, labels, ignore_index=255):
    preds = preds.cpu().numpy().flatten()
    labels = labels.cpu().numpy().flatten()
    mask = labels != ignore_index
    preds = preds[mask]
    labels = labels[mask]
    accuracy = (preds == labels).mean()
    return accuracy

def compute_miou_ignore_background(preds, labels, num_classes, ignore_index=255):
    preds = preds.cpu().numpy().flatten()
    labels = labels.cpu().numpy().flatten()
    mask = labels != ignore_index
    preds = preds[mask]
    labels = labels[mask]
    
    ious = []
    for cls in range(num_classes):
        if cls == ignore_index:
            continue
        pred_inds = preds == cls
        label_inds = labels == cls
        intersection = np.logical_and(pred_inds, label_inds).sum()
        union = np.logical_or(pred_inds, label_inds).sum()
        if union == 0:
            ious.append(float('nan'))
        else:
            ious.append(intersection / union)
    return np.nanmean(ious)

def compute_precision_per_class_ignore_background(preds, labels, num_classes, ignore_index=255):
    preds = preds.cpu().numpy().flatten()
    labels = labels.cpu().numpy().flatten()
    mask = labels != ignore_index
    preds = preds[mask]
    labels = labels[mask]
    
    precision_per_class = precision_score(labels, preds, average=None, labels=range(num_classes), zero_division=0)
    return {f'class_{i}': precision for i, precision in enumerate(precision_per_class) if i != ignore_index}

def compute_recall_per_class_ignore_background(preds, labels, num_classes, ignore_index=255):
    preds = preds.cpu().numpy().flatten()
    labels = labels.cpu().numpy().flatten()
    mask = labels != ignore_index
    preds = preds[mask]
    labels = labels[mask]
    
    recall_per_class = recall_score(labels, preds, average=None, labels=range(num_classes), zero_division=0)
    return {f'class_{i}': recall for i, recall in enumerate(recall_per_class) if i != ignore_index}

def compute_f1_per_class_ignore_background(preds, labels, num_classes, ignore_index=255):
    preds = preds.cpu().numpy().flatten()
    labels = labels.cpu().numpy().flatten()
    mask = labels != ignore_index
    preds = preds[mask]
    labels = labels[mask]
    
    f1_per_class = f1_score(labels, preds, average=None, labels=range(num_classes), zero_division=0)
    return {f'class_{i}': f1 for i, f1 in enumerate(f1_per_class) if i != ignore_index}

# Initialize the exponentially weighted average loss
ewma_loss = None
alpha = 0.1  # Smoothing factor

model.train()
for epoch in range(10):  # loop over the dataset multiple times
    losses = []
    accuracies = []
    mious = []
    precisions = []
    recalls = []
    f1_scores = []
    print("Epoch:", epoch + 1)
    t = tqdm(enumerate(train_dataloader), total=len(train_dataloader))
    for idx, batch in t:
        # Get the inputs
        pixel_values = batch["pixel_values"].to(device)
        labels = batch["labels"].to(device)

        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward + backward + optimize
        outputs = model(pixel_values=pixel_values, labels=labels)
        loss, logits = outputs.loss, outputs.logits

        loss.backward()
        optimizer.step()

        # Update the exponentially weighted average loss
        if ewma_loss is None:
            ewma_loss = loss.item()
        else:
            ewma_loss = alpha * loss.item() + (1 - alpha) * ewma_loss

        # Evaluate
        with torch.no_grad():
            upsampled_logits = nn.functional.interpolate(logits, size=labels.shape[-2:], mode="bilinear", align_corners=False)
            predicted = upsampled_logits.argmax(dim=1)

            # Calculate metrics
            accuracy = compute_accuracy_ignore_background(predicted, labels)
            miou = compute_miou_ignore_background(predicted, labels, num_classes=2)
            precision_per_class = compute_precision_per_class_ignore_background(predicted, labels, num_classes=2)
            recall_per_class = compute_recall_per_class_ignore_background(predicted, labels, num_classes=2)
            f1_per_class = compute_f1_per_class_ignore_background(predicted, labels, num_classes=2)

        losses.append(loss.item())
        accuracies.append(accuracy)
        mious.append(miou)
        precisions.append(precision_per_class)
        recalls.append(recall_per_class)
        f1_scores.append(f1_per_class)

        # Print loss and metrics every batch
        t.set_postfix(Loss=f"{ewma_loss:.4f}", Accuracy=f"{accuracy:.4f}", mIoU=f"{miou:.4f}")
        t.update()

    # Print loss and metrics every epoch
    print(f"Loss: {np.mean(losses):.4f}, Accuracy: {np.mean(accuracies):.4f}, mIoU: {np.nanmean(mious):.4f}")
    print(f"Precision: {np.nanmean([list(p.values()) for p in precisions], axis=0)}")
    print(f"Recall: {np.nanmean([list(r.values()) for r in recalls], axis=0)}")
    print(f"F1 Score: {np.nanmean([list(f.values()) for f in f1_scores], axis=0)}")

In [None]:
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

# Evaluate on the validation set
model.eval()

accuracies = []
mious = []
precisions = []
recalls = []
f1_scores = []


with torch.no_grad():
    for batch in tqdm(test_dataloader):
        pixel_values = batch["pixel_values"].to(device)
        labels = batch["labels"].to(device)

        outputs = model(pixel_values=pixel_values, labels=labels)
        loss, logits = outputs.loss, outputs.logits
        upsampled_logits = nn.functional.interpolate(logits, size=labels.shape[-2:], mode="bilinear", align_corners=False)
        predicted = upsampled_logits.argmax(dim=1)

        upsampled_logits = nn.functional.interpolate(logits, size=labels.shape[-2:], mode="bilinear", align_corners=False)
        predicted = upsampled_logits.argmax(dim=1)

        # Calculate metrics
        accuracy = compute_accuracy_ignore_background(predicted, labels)
        miou = compute_miou_ignore_background(predicted, labels, num_classes=2)
        precision_per_class = compute_precision_per_class_ignore_background(predicted, labels, num_classes=2)
        recall_per_class = compute_recall_per_class_ignore_background(predicted, labels, num_classes=2)
        f1_per_class = compute_f1_per_class_ignore_background(predicted, labels, num_classes=2)

        accuracies.append(accuracy)
        mious.append(miou)
        precisions.append(precision_per_class)
        recalls.append(recall_per_class)
        f1_scores.append(f1_per_class)




print(f"Precision: {np.nanmean([list(p.values()) for p in precisions], axis=0)}")
print(f"Recall: {np.nanmean([list(r.values()) for r in recalls], axis=0)}")
print(f"F1 Score: {np.nanmean([list(f.values()) for f in f1_scores], axis=0)}")

In [None]:
model.eval()

all_preds = []
all_labels = []

with torch.no_grad():
    for batch in tqdm(test_dataloader):
        pixel_values = batch["pixel_values"].to(device)
        labels = batch["labels"].to(device)

        outputs = model(pixel_values=pixel_values, labels=labels)
        loss, logits = outputs.loss, outputs.logits
        upsampled_logits = nn.functional.interpolate(logits, size=labels.shape[-2:], mode="bilinear", align_corners=False)
        predicted = upsampled_logits.argmax(dim=1)

        # Set predictions to 255 wherever the ignore class is present in the labels
        predicted[labels == 255] = 255

        all_preds.append(predicted.detach().cpu())
        all_labels.append(labels.detach().cpu())

all_preds = torch.cat(all_preds)
all_labels = torch.cat(all_labels)

# Flatten the tensors
all_preds_flat = all_preds.flatten()
all_labels_flat = all_labels.flatten()

# Mask to ignore the ignore_index
ignore_index = 255
mask = all_labels_flat != ignore_index

# Filter out the ignore_index
filtered_preds = all_preds_flat[mask]
filtered_labels = all_labels_flat[mask]

# Define the colors for each specific value
colors = ['brown', 'green', 'black']

# Create a colormap
cmap = ListedColormap(colors)

# Define the boundaries for the values
boundaries = [0, 1, 2, 3]  # 0-1 -> black, 1-2 -> brown, 2-3 -> green

# Create a normalization
norm = BoundaryNorm(boundaries, cmap.N, clip=True)

# Plot a test image mask and the model's prediction
idx = 4
rows = 3

fig, ax = plt.subplots(rows, 2, figsize=(10, 10))

ax[0][0].set_title("Test Image Mask")
ax[0][1].set_title("Model Prediction")
for i in range(rows):
    ax[i][0].imshow(all_labels[idx * rows + i], cmap=cmap, norm=norm)
    ax[i][1].imshow(all_preds[idx * rows + i], cmap=cmap, norm=norm)

plt.show()

In [None]:
all_preds[idx]

In [None]:
model_list = os.listdir('Model')

# Get the last number of the model
model_list = sorted(model_list, key=lambda x: int(x.split('_')[-1].split('.')[0]))

# Save the model with the next number

MODEL_SAVE_PATH = f"Model/model_{int(model_list[-1].split('_')[-1].split('.')[0]) + 1}.pth"

model.save_pretrained

In [None]:
# load model

model_num = 1

model = SegformerForSemanticSegmentation.from_pretrained(f"Model/model_{model_num}.pth")

In [None]:
from transformers import SegformerImageProcessor