In [None]:
import os
import torch
import numpy as np
import rasterio
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

class CustomTiffDataset(Dataset):
    def __init__(self, image_dir, label_dir, transform=None):
        self.image_dir = image_dir
        self.label_dir = label_dir
        self.transform = transform
        self.images = [f for f in os.listdir(image_dir) if f.endswith('.tif')]
        
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        img_name = self.images[idx]
        
        # Load TIFF files using rasterio
        with rasterio.open(os.path.join(self.image_dir, img_name)) as src_image:
            image = src_image.read().astype(np.float32)  # Read all bands (assume 4 bands)
        
        # Construct the corresponding label filename
        label_name = img_name.replace('image_patch_', 'label_patch_')
        
        with rasterio.open(os.path.join(self.label_dir, label_name)) as src_label:
            label = src_label.read(1).astype(np.int64)  # Read as single channel, convert to long for class indices
        
        # Convert to torch tensors
        image = torch.tensor(image, dtype=torch.float32)  # Shape: [C, H, W]
        
        # Ensure label values are within the valid range
        label = torch.tensor(label, dtype=torch.int64)  # Shape: [H, W]
        
        if self.transform:
            image = self.transform(image)
        
        return image, label

# Example of how to use the CustomTiffDataset class
image_dir = '/home/yshao/unet/lc/newtrain/images'
label_dir = '/home/yshao/unet/lc/newtrain/labels'

# Define transforms including normalization
mean = [0.485, 0.456, 0.406, 0.5,0.5]  # Update mean for 4 bands
std = [0.229, 0.224, 0.225, 0.25,0.25]  # Update std for 4 bands

transform = transforms.Compose([
    # Scale the input values to [0, 1] if needed; otherwise, comment out the line below
    # transforms.Lambda(lambda x: x / 255.0),  
    transforms.Normalize(mean=mean, std=std)  # Normalize the tensors
])

dataset = CustomTiffDataset(image_dir, label_dir, transform=transform)
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)

# Example to verify data loading
for images, labels in dataloader:
    print("Images batch shape:", images.shape)
    print("Labels batch shape:", labels.shape)
    break


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class UNet(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UNet, self).__init__()
        
        self.encoder1 = self.conv_block(in_channels, 64)
        self.encoder2 = self.conv_block(64, 128)
        self.encoder3 = self.conv_block(128, 256)
        self.encoder4 = self.conv_block(256, 512)
        
        self.bottleneck = self.conv_block(512, 1024)
        
        self.decoder4 = self.conv_block(1024 + 512, 512)
        self.decoder3 = self.conv_block(512 + 256, 256)
        self.decoder2 = self.conv_block(256 + 128, 128)
        self.decoder1 = self.conv_block(128 + 64, 64)
        
        self.final_conv = nn.Conv2d(64, out_channels, kernel_size=1)
        
    def conv_block(self, in_channels, out_channels):
        block = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
        return block

    def forward(self, x):
        enc1 = self.encoder1(x)
        enc2 = self.encoder2(F.max_pool2d(enc1, 2))
        enc3 = self.encoder3(F.max_pool2d(enc2, 2))
        enc4 = self.encoder4(F.max_pool2d(enc3, 2))
        
        bottleneck = self.bottleneck(F.max_pool2d(enc4, 2))
        
        dec4 = self.decoder4(torch.cat([F.interpolate(bottleneck, scale_factor=2, mode='bilinear', align_corners=True), enc4], dim=1))
        dec3 = self.decoder3(torch.cat([F.interpolate(dec4, scale_factor=2, mode='bilinear', align_corners=True), enc3], dim=1))
        dec2 = self.decoder2(torch.cat([F.interpolate(dec3, scale_factor=2, mode='bilinear', align_corners=True), enc2], dim=1))
        dec1 = self.decoder1(torch.cat([F.interpolate(dec2, scale_factor=2, mode='bilinear', align_corners=True), enc1], dim=1))
        
        return self.final_conv(dec1)

In [None]:
import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F

# Define the device to use (GPU if available, otherwise CPU)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Initialize the model and move it to the appropriate device
model = UNet(in_channels=5, out_channels=3).to(device)

# Define the class weights (example weights, you need to calculate based on your dataset)
class_weights = torch.tensor([0.1, 1.0, 1.0], device=device)
criterion = nn.CrossEntropyLoss(weight=class_weights)

# Define the optimizer with a lower learning rate
optimizer = optim.Adam(model.parameters(), lr=0.0001)

num_epochs = 50

# Custom loss function to ignore 0s in the labels and apply class weights
def masked_weighted_cross_entropy_loss(outputs, labels):
    # Flatten the outputs and labels
    outputs = outputs.permute(0, 2, 3, 1).contiguous().view(-1, outputs.size(1))
    labels = labels.view(-1)
    
    # Create a mask to ignore 0s in the labels
    mask = labels != 0
    
    # Apply mask to outputs and labels
    outputs = outputs[mask]
    labels = labels[mask]
    
    return F.cross_entropy(outputs, labels, weight=class_weights)

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for images, labels in dataloader:
        images = images.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        
        outputs = model(images)
        loss = masked_weighted_cross_entropy_loss(outputs, labels)
        
        # Backpropagation
        loss.backward()
        
        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        optimizer.step()
        
        running_loss += loss.item()
    
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(dataloader):.4f}')


In [None]:
checkpoint_dir = '/home/yshao/unet/lc'
checkpoint_path = os.path.join(checkpoint_dir, f'model_epoch_{epoch+1}.pth')
torch.save(model.state_dict(), checkpoint_path)
print(f'Model saved to {checkpoint_path}')

In [None]:
import os
import numpy as np
import torch
from torchvision import transforms
import rasterio

def load_image(image_path, transform=None):
    with rasterio.open(image_path) as src:
        image = src.read().astype(np.float32)
    image = torch.tensor(image, dtype=torch.float32)  # Convert to tensor
    if transform:
        image = transform(image)
    return image

def predict_and_save(image_path, model, device, transform=None, output_dir=None):
    # Load and preprocess the image
    model.eval()
    with torch.no_grad():
        image = load_image(image_path, transform).unsqueeze(0).to(device)
        
        # Make prediction
        output = model(image)
        pred_mask = torch.softmax(output, dim=1).cpu().numpy().squeeze(0)
        
        # Get the class with the highest probability for each pixel
        pred_mask = np.argmax(pred_mask, axis=0).astype(np.uint8)
        
        # Load the original image metadata for saving
        with rasterio.open(image_path) as src:
            meta = src.meta.copy()
            transform = src.transform
            crs = src.crs
        
        # Update metadata to ensure compatibility with uint8
        meta.update({
            'count': 1,
            'dtype': 'uint8',
            'height': pred_mask.shape[0],
            'width': pred_mask.shape[1],
            'transform': transform,
            'crs': crs
        })
        
        # Remove nodata value if it's outside the uint8 range
        if 'nodata' in meta and (meta['nodata'] < 0 or meta['nodata'] > 255):
            del meta['nodata']
        
        # Save the predicted mask as a TIFF file
        if output_dir:
            os.makedirs(output_dir, exist_ok=True)
            output_path = os.path.join(output_dir, os.path.basename(image_path).replace('image_', 'predicted_'))
            with rasterio.open(output_path, 'w', **meta) as dst:
                dst.write(pred_mask, 1)
            print(f'Saved predicted mask to {output_path}')

# Function to process all TIFF files in a directory
def process_folder(input_dir, model, device, transform=None, output_dir=None):
    for filename in os.listdir(input_dir):
        if filename.endswith('.tif'):
            image_path = os.path.join(input_dir, filename)
            predict_and_save(image_path, model, device, transform, output_dir)

# Example usage
input_dir = '/home/yshao/unet/lc/newtrain/images'
output_dir = '/home/yshao/unet/lc/predictions'  # Directory to save the predicted masks

# Define the same transform used for training
mean = [0.485, 0.456, 0.406, 0.5,0.5]  # Update mean for 5 bands
std = [0.229, 0.224, 0.225, 0.25,0.25]  # Update std for 5 bands
transform = transforms.Compose([
    transforms.Normalize(mean=mean, std=std)  # Normalize the tensors
])

# Assuming model and device are already defined and model is loaded with weights
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = UNet(in_channels=5, out_channels=3).to(device)  # Adjust output channels for 3 classes
model.load_state_dict(torch.load('/home/yshao/unet/lc/model_epoch_50.pth'))

# Process all TIFF files in the input directory
process_folder(input_dir, model, device, transform=transform, output_dir=output_dir)


In [None]:
import os
import rasterio
from rasterio.merge import merge
from rasterio.plot import show
import glob

def mosaic_tifs(input_dir, output_file):
    # Get list of all tif files in the directory
    search_criteria = "*.tif"
    q = os.path.join(input_dir, search_criteria)
    tif_files = glob.glob(q)
    
    # List to store opened datasets
    src_files_to_mosaic = []
    
    for fp in tif_files:
        src = rasterio.open(fp)
        src_files_to_mosaic.append(src)
    
    # Merge function returns a single array and the transformation info
    mosaic, out_trans = merge(src_files_to_mosaic)
    
    # Copy the metadata
    out_meta = src.meta.copy()
    
    # Update the metadata with the new dimensions, transform (affine) and CRS
    out_meta.update({
        "driver": "GTiff",
        "height": mosaic.shape[1],
        "width": mosaic.shape[2],
        "transform": out_trans,
        "crs": src.crs
    })
    
    # Write the mosaic raster to disk
    with rasterio.open(output_file, "w", **out_meta) as dest:
        dest.write(mosaic)
    
    # Close all the opened files
    for src in src_files_to_mosaic:
        src.close()
    
    print(f"Mosaic saved to {output_file}")

# Example usage
input_dir = '/home/yshao/unet/lc/predictions'
output_file = '/home/yshao/unet/lc/mosaic.tif'

mosaic_tifs(input_dir, output_file)
