In [2]:
import torch
import rasterio
import numpy as np
import os
from Model import mlp

In [3]:
model_name = 'MLP'
Model_MLP = mlp.MLP(input_size=8, num_class=5)
Model_MLP.load_state_dict(torch.load('MLP_best.pth'))


<All keys matched successfully>

In [4]:
raster_path = "./Data/crop_planet.tif"

with rasterio.open(raster_path) as src:
    raster_data = src.read()
    raster_profile = src.profile

num_bands, height, width = raster_data.shape

In [6]:
flat_data = raster_data.reshape(num_bands, -1).T
print(flat_data.shape)
flat_data_tensor = torch.tensor(flat_data, dtype=torch.float32).unsqueeze(1) 

(1469904, 8)


In [7]:
Model_MLP.eval()

with torch.no_grad():
    batch_size = 1024  
    num_samples = flat_data_tensor.shape[0]
    num_batches = (num_samples + batch_size - 1) // batch_size

    predictions = []
    for i in range(num_batches):
        start_index = i * batch_size
        end_index = min((i + 1) * batch_size, num_samples)
        batch_data = flat_data_tensor[start_index:end_index]
                
        batch_data = batch_data.squeeze(-2)  # Shape to (batch_size, 36)

        outputs = Model_MLP(batch_data)
        _, predicted_classes = torch.max(outputs, 1)
        predictions.append(predicted_classes.numpy())

        predicted_classes = np.concatenate(predictions)

print(f"1D Model: {model_name}, Predictions shape: {predicted_classes.shape}")

1D Model: MLP, Predictions shape: (1469904,)


In [8]:
predicted_classes = predicted_classes.reshape(height, width)

output_file = os.path.join("./classified_{}.tif".format(model_name))
with rasterio.open(
    output_file,
    'w',
    driver='GTiff',
    height=height,
    width=width,
    count=1, 
    dtype=rasterio.uint8,  
    crs=src.crs,
    transform=src.transform,
) as dst:
    dst.write(predicted_classes, 1)

print(f"1D Predictions saved to {output_file}")


1D Predictions saved to ./classified_MLP.tif
