In [1]:
import os
from PIL import Image
import numpy as np
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import rasterio
from torchvision.io import read_image
from torchvision.transforms import ToTensor, Normalize
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torch.utils.data.dataset import Subset
from sklearn.metrics import classification_report
import os
import matplotlib.pyplot as plt
import random
import pdb
tqdm.pandas()

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
path_to_tiles = "gypsum/eguide/projects/scorreacardo/urbano/data/worldpop_centered/processed/"


In [None]:

# Define the root directory containing the satellite image tiles
root_dir = "/gypsum/eguide/projects/scorreacardo/urbano/data/worldpop_centered/processed/"

# Define the output directory for the chips
output_dir = "/work/scorreacardo_umass_edu/DeepSatGSD/data/processed/inference"

# Define the desired chip size
chip_size = 256

# Create the output directory if it doesn't exist
if not os.path.exists(output_dir):
    os.makedirs(output_dir)

# Initialize a list to store the chip filenames and dates
chip_filenames = []
chip_dates = []

# Walk through the root directory and its subdirectories
for root, dirs, files in os.walk(root_dir):
    # Iterate over the files in the current directory
    for file in tqdm(files):
        # Check if the file is a PNG image
        if file.lower().endswith(".png"):
            # Get the full path of the image file
            image_path = os.path.join(root, file)
            
            # Open the image file
            image = Image.open(image_path)
            
            # Get the image size
            image_width, image_height = image.size
            
            # Iterate over the image in a sliding window fashion to extract chips
            for y in range(0, image_height - chip_size + 1, chip_size):
                for x in range(0, image_width - chip_size + 1, chip_size):
                    # Extract the chip from the image
                    chip = image.crop((x, y, x + chip_size, y + chip_size))
                    
                    # Get the date from the filename
                    year = file.split("_")[0]
                    month = file.split("_")[1]
                    day = file.split("_")[2]
                    tilename = file.split("_")[4]
                    
                    # Create a unique filename for the chip
                    chip_filename = f"{year}_{month}_{year}_{tilename}_{x}_{y}_.tif"
                    
                    # Save the chip as a TIFF image
                    chip_path = os.path.join(output_dir, chip_filename)
                    chip.save(chip_path, format="TIFF")
                    
                    # Append the chip filename and date to the lists
                    chip_filenames.append(chip_filename)
                    chip_dates.append(f"{year}/{month}/{day}")

# Print the list of chip filenames and dates
count=0
for filename, date in zip(chip_filenames, chip_dates):
    count += 1
    if count % 500 == 0:
        print(f"Chip Filename: {filename}, Date: {date}")

In [22]:
#let's create the inference class:
class InferenceDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.class_dir = None
        self.transform = transform
        self.classes = None
        self.filepaths = []
        self.dates = []
        
        class_dir = "/work/scorreacardo_umass_edu/DeepSatGSD/data/processed/random_synthetic"
        self.classes = sorted([filename for filename in os.listdir(class_dir) if filename.startswith("GSD")], 
                              key=lambda x: int(x.split('_')[1][:-2]))
        
        file_list = [f for f in os.listdir(root_dir)]
        sampled_files = random.sample(file_list, k=int(len(file_list) * 0.20))
                
        for filename in sampled_files:
            filepath = os.path.join(root_dir, filename)
            self.filepaths.append(filepath)
            year = filename.split("_")[0]
            month = filename.split("_")[1]
            self.dates.append(f"{year}/{month}")


    def __len__(self):
        return len(self.filepaths)

    def __getitem__(self, index):
        filepath = self.filepaths[index]
        date = self.dates[index]

        with rasterio.open(filepath, 'r') as img:
            image = img.read()
            image = image[:3, :, :]
            image = image.transpose(1, 2, 0)
        if self.transform:
            image = self.transform(image).float()

        return image, date

In [23]:
# Define the image transformations
transform = torchvision.transforms.Compose([
    ToTensor(),
    Normalize([0.46619912981987, 0.4138015806674957, 0.2945951819419861], 
              [0.19115719199180603, 0.1479424238204956, 0.13974712789058685])  # Normalize image tensors
])
batch_size = 16
root_dir = "/work/scorreacardo_umass_edu/DeepSatGSD/data/processed/inference"
inference_dataset = InferenceDataset(root_dir, transform=transform)
inference_loader = DataLoader(inference_dataset, batch_size=batch_size, shuffle=False)

In [10]:
class_mapping = {dataloader_class: dataset_class for dataloader_class, dataset_class in zip(dataloader_classes, dataset.classes)}

In [None]:
# Create an instance of the same model architecture
import torchvision.models as models

model = models.resnet18(pretrained=False)
num_classes = len(inference_dataset.classes)
model.fc = nn.Linear(model.fc.in_features, num_classes)

# Load the saved model state dictionary
model.load_state_dict(torch.load('/work/scorreacardo_umass_edu/DeepSatGSD/models/trained_model_gsd_resnet18_epochs_3_training_40_perc.pt'))
model.eval()  # Set the model to evaluation mode

# Now you can use the model for inference

In [12]:
# Move the model to the appropriate device (e.g., GPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
model.eval()  # Set the model to evaluation mode

dates_results = []
predicted_results = []

# Perform inference
with torch.no_grad():
    for inputs, dates in tqdm(inference_loader):
        inputs = inputs.to(device)  # Move inputs to the appropriate device
        outputs = model(inputs)
        _, predicted = torch.max(outputs, 1)
        dates_results += dates
        predicted_results += predicted.tolist()

100%|██████████| 1505/1505 [15:00<00:00,  1.67it/s]
