In [2]:
pip install torch

In [None]:
import random
import numpy as np
import matplotlib.pyplot as plt
import torch
import torchvision.transforms as transforms
from torchvision import models
from PIL import Image
import rasterio

In [None]:
# Check if GPU is available and set device accordingly
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load a pre-trained segmentation model (e.g., DeepLabV3)
model = models.segmentation.deeplabv3_resnet101(
    pretrained=True).to(device).eval()

# Preprocess the input image


def preprocess_image(image_path):
    # Open TIF file using rasterio
    with rasterio.open(image_path) as dataset:
        # Read the image bands
        image_bands = dataset.read().transpose(1, 2, 0)

    # Normalize the image bands
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[
                             0.229, 0.224, 0.225]),
    ])
    input_tensor = transform(image_bands).unsqueeze(0).to(device)
    return input_tensor


# Load and preprocess an example TIF image
tif_image_path = "D:\College files\TISS\Data\Sentinal.tif"
input_image = preprocess_image(tif_image_path)

# Perform inference
with torch.no_grad():
    output = model(input_image)['out'][0]
output_predictions = output.argmax(0).cpu().numpy()

# Display the input image and the segmentation mask
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.imshow(Image.open(tif_image_path))
plt.title("Input Image")

plt.subplot(1, 2, 2)
plt.imshow(output_predictions, cmap='jet', alpha=0.6)
plt.title("Segmentation Mask")

plt.show()