In [1]:
!pip install rasterio
!pip install torch
!pip install torchvision
!pip install torchsummary

Collecting rasterio
  Obtaining dependency information for rasterio from https://files.pythonhosted.org/packages/09/b9/169a76e257e527d352da021da6602480a829eac03b0ab3045639c3f80fb6/rasterio-1.4.2-cp311-cp311-macosx_14_0_arm64.whl.metadata
  Downloading rasterio-1.4.2-cp311-cp311-macosx_14_0_arm64.whl.metadata (9.1 kB)
Collecting affine (from rasterio)
  Obtaining dependency information for affine from https://files.pythonhosted.org/packages/0b/f7/85273299ab57117850cc0a936c64151171fac4da49bc6fba0dad984a7c5f/affine-2.4.0-py3-none-any.whl.metadata
  Downloading affine-2.4.0-py3-none-any.whl.metadata (4.0 kB)
Collecting cligj>=0.5 (from rasterio)
  Obtaining dependency information for cligj>=0.5 from https://files.pythonhosted.org/packages/73/86/43fa9f15c5b9fb6e82620428827cd3c284aa933431405d1bcf5231ae3d3e/cligj-0.7.2-py3-none-any.whl.metadata
  Downloading cligj-0.7.2-py3-none-any.whl.metadata (5.0 kB)
Collecting click-plugins (from rasterio)
  Obtaining dependency information for click-plu

In [2]:
import os
import numpy as np
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torch.autograd import Variable
import torch.nn.functional as F
from PIL import Image
from glob import glob
import rasterio


In [10]:


class SEN12MSDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        """
        Args:
            root_dir (str): Path to the dataset folder containing image files.
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        self.root_dir = root_dir
        self.transform = transform
        self.image_files = [
            f for f in os.listdir(root_dir) if f.endswith('.tif') or f.endswith('.TIF')
        ]

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_path = os.path.join(self.root_dir, self.image_files[idx])
        with rasterio.open(img_path) as src:
            data = src.read()  # Load all bands
            rgb = np.stack([data[3], data[2], data[1]], axis=-1)  # Red, Green, Blue (1-based indexing)
            nir = data[7]  # Near-Infrared band

        # Normalize bands to [0, 1] range
        rgb = rgb / 10000.0
        nir = nir / 10000.0

        # Apply separate transformations for RGB and NIR
        if self.transform:
            rgb = self.transform(rgb)
            nir = transforms.ToPILImage()(nir)
            nir = transforms.Resize((256, 256))(nir)
            nir = transforms.ToTensor()(nir)
              # Add channel dimension

        return rgb, nir  # Convert NIR to tensor directly




# Example usage
root_dir = "/Users/pgt/Downloads/Sentinel train data"
transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

dataset = SEN12MSDataset(root_dir=root_dir, transform=transform)
dataloader = DataLoader(dataset, batch_size=16, shuffle=True)



In [11]:
#execute this cell when you have local dataset

import os
import shutil
import numpy as np
from sklearn.model_selection import train_test_split

def split_dataset(source_dir, train_dir, test_dir, test_size=0.2, random_state=42):
    # Create train and test directories if they don't exist
    os.makedirs(train_dir, exist_ok=True)
    os.makedirs(test_dir, exist_ok=True)

    # Get all image files
    image_files = [f for f in os.listdir(source_dir) if f.endswith(('.tif', '.TIF'))]

    # Split files
    train_files, test_files = train_test_split(
        image_files,
        test_size=test_size,
        random_state=random_state
    )

    # Copy train files
    for file in train_files:
        src_path = os.path.join(source_dir, file)
        dst_path = os.path.join(train_dir, file)
        shutil.copy2(src_path, dst_path)

    # Copy test files
    for file in test_files:
        src_path = os.path.join(source_dir, file)
        dst_path = os.path.join(test_dir, file)
        shutil.copy2(src_path, dst_path)

    print(f"Total images: {len(image_files)}")
    print(f"Train images: {len(train_files)}")
    print(f"Test images: {len(test_files)}")

# Example usage
source_dir = "/Users/pgt/Downloads/Sentinel train data"
train_dir = "/Users/pgt/Downloads/Sentinel train data/train"
test_dir = "/Users/pgt/Downloads/Sentinel train data/test"

split_dataset(source_dir, train_dir, test_dir)

train_dataset = SEN12MSDataset(root_dir=train_dir, transform=transform)
train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True)

Total images: 17
Train images: 13
Test images: 4


In [12]:
class UNetGenerator(nn.Module):
    def __init__(self, in_channels=3, out_channels=1):
        super(UNetGenerator, self).__init__()

        def down_block(in_channels, out_channels, normalize=True):
            layers = [nn.Conv2d(in_channels, out_channels, 4, stride=2, padding=1, bias=False)]
            if normalize:
                layers.append(nn.BatchNorm2d(out_channels))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return nn.Sequential(*layers)

        def up_block(in_channels, out_channels, dropout=False):
            layers = [
                nn.ConvTranspose2d(in_channels, out_channels, 4, stride=2, padding=1, bias=False),
                nn.BatchNorm2d(out_channels),
                nn.ReLU(inplace=True)
            ]
            if dropout:
                layers.append(nn.Dropout(0.5))
            return nn.Sequential(*layers)

        self.encoder = nn.ModuleList([
            down_block(3, 64, normalize=False),
            down_block(64, 128),
            down_block(128, 256),
            down_block(256, 512),
            down_block(512, 512),
            down_block(512, 512),
            down_block(512, 512)
        ])

        self.decoder = nn.ModuleList([
            up_block(512, 512, dropout=True),
            up_block(1024, 512, dropout=True),
            up_block(1024, 512, dropout=True),
            up_block(1024, 256),
            up_block(512, 128),
            up_block(256, 64),
            up_block(128, out_channels)
        ])

    def forward(self, x):
        # Ensure input is the correct size
        x = F.interpolate(x, size=(256, 256), mode='bilinear', align_corners=True)

        skips = []
        # Encoder path
        for down in self.encoder:
            x = down(x)
            skips.append(x)

        skips = skips[:-1][::-1]  # Reversing to align with decoder

        # Decoder path
        for idx, up in enumerate(self.decoder):
            x = up(x)

            if idx < len(skips):
                # Resize decoder output to match skip connection dimensions
                x = F.interpolate(x, size=(skips[idx].shape[2], skips[idx].shape[3]), mode='bilinear', align_corners=True)

                # Concatenate with skip connection
                x = torch.cat((x, skips[idx]), dim=1)

        return x  # Remove torch.tanh() to keep original dynamic range

In [13]:
# PatchGAN Discriminator
class PatchDiscriminator(nn.Module):
    def __init__(self, in_channels=4):
        super(PatchDiscriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(in_channels, 64, 4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, 4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, 4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, 4, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(512, 1, 4, padding=1)
        )

    def forward(self, x):
        return self.model(x)

# Loss Functions
adversarial_loss = nn.MSELoss()
l1_loss = nn.L1Loss()

# Hyperparameters
batch_size = 16
epochs = 200
lr = 0.0002
lambda_l1 = 100




device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize models
generator = UNetGenerator().to(device)
discriminator = PatchDiscriminator().to(device)

optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))

# Training Loop
for epoch in range(epochs):
    for i, (rgb, nir) in enumerate(dataloader):
        # Move data to GPU
        rgb, nir = rgb.to(device), nir.to(device)

        # Train Generator
        optimizer_G.zero_grad()
        gen_nir = generator(rgb)
        pred_fake = discriminator(torch.cat((rgb, gen_nir), 1))
        g_loss = adversarial_loss(pred_fake, torch.ones_like(pred_fake).to(device)) + lambda_l1 * l1_loss(gen_nir, nir)
        g_loss.backward()
        optimizer_G.step()

        # Train Discriminator
        optimizer_D.zero_grad()
        pred_real = discriminator(torch.cat((rgb, nir), 1))
        loss_real = adversarial_loss(pred_real, torch.ones_like(pred_real).to(device))
        pred_fake = discriminator(torch.cat((rgb, gen_nir.detach()), 1))
        loss_fake = adversarial_loss(pred_fake, torch.zeros_like(pred_fake).to(device))
        d_loss = 0.5 * (loss_real + loss_fake)
        d_loss.backward()
        optimizer_D.step()

        print(f"[Epoch {epoch}/{epochs}] [Batch {i}/{len(dataloader)}] [D loss: {d_loss.item()}] [G loss: {g_loss.item()}]")


[Epoch 0/200] [Batch 0/2] [D loss: 0.7590886354446411] [G loss: 42.422401428222656]
[Epoch 0/200] [Batch 1/2] [D loss: 0.4605754315853119] [G loss: 49.69865417480469]
[Epoch 1/200] [Batch 0/2] [D loss: 0.8940381407737732] [G loss: 30.385292053222656]
[Epoch 1/200] [Batch 1/2] [D loss: 2.0108044147491455] [G loss: 33.72285461425781]
[Epoch 2/200] [Batch 0/2] [D loss: 1.8384779691696167] [G loss: 22.906776428222656]
[Epoch 2/200] [Batch 1/2] [D loss: 2.559666633605957] [G loss: 30.818586349487305]
[Epoch 3/200] [Batch 0/2] [D loss: 1.7806153297424316] [G loss: 19.144041061401367]
[Epoch 3/200] [Batch 1/2] [D loss: 1.5153573751449585] [G loss: 25.133817672729492]
[Epoch 4/200] [Batch 0/2] [D loss: 1.130022644996643] [G loss: 15.114501953125]
[Epoch 4/200] [Batch 1/2] [D loss: 0.40101975202560425] [G loss: 24.463926315307617]
[Epoch 5/200] [Batch 0/2] [D loss: 0.4148412048816681] [G loss: 13.188583374023438]
[Epoch 5/200] [Batch 1/2] [D loss: 0.30613285303115845] [G loss: 17.83596038818359

[Epoch 48/200] [Batch 1/2] [D loss: 0.23168474435806274] [G loss: 13.307082176208496]
[Epoch 49/200] [Batch 0/2] [D loss: 0.051124073565006256] [G loss: 7.755587577819824]
[Epoch 49/200] [Batch 1/2] [D loss: 0.06720001995563507] [G loss: 14.554376602172852]
[Epoch 50/200] [Batch 0/2] [D loss: 0.05883710831403732] [G loss: 7.334502696990967]
[Epoch 50/200] [Batch 1/2] [D loss: 0.03935374319553375] [G loss: 14.061630249023438]
[Epoch 51/200] [Batch 0/2] [D loss: 0.040714412927627563] [G loss: 8.403129577636719]
[Epoch 51/200] [Batch 1/2] [D loss: 0.06321173906326294] [G loss: 14.143854141235352]
[Epoch 52/200] [Batch 0/2] [D loss: 0.09086836874485016] [G loss: 9.023553848266602]
[Epoch 52/200] [Batch 1/2] [D loss: 0.3188236355781555] [G loss: 36.29304885864258]
[Epoch 53/200] [Batch 0/2] [D loss: 0.16816116869449615] [G loss: 8.03206729888916]
[Epoch 53/200] [Batch 1/2] [D loss: 0.0715629830956459] [G loss: 18.70524787902832]
[Epoch 54/200] [Batch 0/2] [D loss: 0.050013456493616104] [G l

[Epoch 96/200] [Batch 1/2] [D loss: 0.020789634436368942] [G loss: 10.261397361755371]
[Epoch 97/200] [Batch 0/2] [D loss: 0.009716411121189594] [G loss: 6.574942111968994]
[Epoch 97/200] [Batch 1/2] [D loss: 0.011857820674777031] [G loss: 9.08332633972168]
[Epoch 98/200] [Batch 0/2] [D loss: 0.009121004492044449] [G loss: 6.798091411590576]
[Epoch 98/200] [Batch 1/2] [D loss: 0.08070866763591766] [G loss: 12.27329158782959]
[Epoch 99/200] [Batch 0/2] [D loss: 0.06415308266878128] [G loss: 7.053114414215088]
[Epoch 99/200] [Batch 1/2] [D loss: 0.10944517701864243] [G loss: 13.944321632385254]
[Epoch 100/200] [Batch 0/2] [D loss: 0.08620849251747131] [G loss: 6.1417388916015625]
[Epoch 100/200] [Batch 1/2] [D loss: 0.2049182802438736] [G loss: 10.09762191772461]
[Epoch 101/200] [Batch 0/2] [D loss: 0.23692889511585236] [G loss: 7.245144367218018]
[Epoch 101/200] [Batch 1/2] [D loss: 0.050322890281677246] [G loss: 9.921953201293945]
[Epoch 102/200] [Batch 0/2] [D loss: 0.0913575440645217

[Epoch 144/200] [Batch 0/2] [D loss: 0.004576793871819973] [G loss: 7.202338218688965]
[Epoch 144/200] [Batch 1/2] [D loss: 0.004687233362346888] [G loss: 11.968530654907227]
[Epoch 145/200] [Batch 0/2] [D loss: 0.00337788974866271] [G loss: 6.568627834320068]
[Epoch 145/200] [Batch 1/2] [D loss: 0.006350407842546701] [G loss: 9.677936553955078]
[Epoch 146/200] [Batch 0/2] [D loss: 0.005940022878348827] [G loss: 6.705404281616211]
[Epoch 146/200] [Batch 1/2] [D loss: 0.003294383641332388] [G loss: 15.154261589050293]
[Epoch 147/200] [Batch 0/2] [D loss: 0.0036590155214071274] [G loss: 6.344959735870361]
[Epoch 147/200] [Batch 1/2] [D loss: 0.0062318602576851845] [G loss: 12.458914756774902]
[Epoch 148/200] [Batch 0/2] [D loss: 0.004316940903663635] [G loss: 8.857450485229492]
[Epoch 148/200] [Batch 1/2] [D loss: 0.07059997320175171] [G loss: 30.944690704345703]
[Epoch 149/200] [Batch 0/2] [D loss: 0.01243998110294342] [G loss: 8.499937057495117]
[Epoch 149/200] [Batch 1/2] [D loss: 0.0

[Epoch 191/200] [Batch 1/2] [D loss: 0.005612166132777929] [G loss: 8.882675170898438]
[Epoch 192/200] [Batch 0/2] [D loss: 0.002020070794969797] [G loss: 6.806778907775879]
[Epoch 192/200] [Batch 1/2] [D loss: 0.0031106616370379925] [G loss: 9.629294395446777]
[Epoch 193/200] [Batch 0/2] [D loss: 0.002445833059027791] [G loss: 7.338988304138184]
[Epoch 193/200] [Batch 1/2] [D loss: 0.008347188122570515] [G loss: 8.762150764465332]
[Epoch 194/200] [Batch 0/2] [D loss: 0.003834763076156378] [G loss: 6.399412631988525]
[Epoch 194/200] [Batch 1/2] [D loss: 0.006034809164702892] [G loss: 11.856558799743652]
[Epoch 195/200] [Batch 0/2] [D loss: 0.00313965929672122] [G loss: 6.4900312423706055]
[Epoch 195/200] [Batch 1/2] [D loss: 0.003711316268891096] [G loss: 10.692161560058594]
[Epoch 196/200] [Batch 0/2] [D loss: 0.0022548669949173927] [G loss: 7.064061164855957]
[Epoch 196/200] [Batch 1/2] [D loss: 0.006882842630147934] [G loss: 13.37191104888916]
[Epoch 197/200] [Batch 0/2] [D loss: 0.

In [15]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def evaluate_model(model, test_dataloader, device):
    model.eval()
    total_l1_loss = 0.0
    with torch.no_grad():
        for rgb, nir_true in test_dataloader:
            rgb, nir_true = rgb.to(device), nir_true.to(device)
            nir_pred = model(rgb)
            l1_loss = torch.nn.L1Loss()(nir_pred, nir_true)
            total_l1_loss += l1_loss.item()

    avg_l1_loss = total_l1_loss / len(test_dataloader)
    print(f"Test L1 Loss: {avg_l1_loss}")
    return avg_l1_loss

test_dir='/Users/pgt/Downloads/Sentinel train data'
test_dataset = SEN12MSDataset(root_dir=test_dir, transform=transform)
test_dataloader = DataLoader(test_dataset, batch_size=16)
evaluate_model(generator, test_dataloader, device)

Test L1 Loss: 0.09811070933938026


0.09811070933938026

In [16]:
# Saving the generator model
torch.save(generator.state_dict(), '/Users/pgt/Downloads/unet_generator.pth')


In [17]:
# Initialize the model architecture again
generator = UNetGenerator()

# Load the model weights
generator.load_state_dict(torch.load('/Users/pgt/Downloads/unet_generator.pth',map_location=torch.device('cpu')))

# Set the model to evaluation mode (important for inference)
generator.eval()


  generator.load_state_dict(torch.load('/Users/pgt/Downloads/unet_generator.pth',map_location=torch.device('cpu')))


UNetGenerator(
  (encoder): ModuleList(
    (0): Sequential(
      (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): LeakyReLU(negative_slope=0.2, inplace=True)
    )
    (1): Sequential(
      (0): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.2, inplace=True)
    )
    (2): Sequential(
      (0): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.2, inplace=True)
    )
    (3): Sequential(
      (0): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.2, inplace=True)
  

In [21]:
import torch
import numpy as np
from PIL import Image
import torchvision.transforms as transforms

def generate_nir_from_rgb(model, image_path, device='cuda'):
    """
    Generate NIR band from RGB image using trained model

    Args:
        model: Trained generator model
        image_path: Path to input RGB image
        device: Device to run inference on
    """
    # Load and preprocess the RGB image
    rgb_image = Image.open(image_path).convert('RGB')

    # Create transform pipeline
    transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    # Transform the image
    rgb_tensor = transform(rgb_image).unsqueeze(0).to(device)

    # Generate NIR band
    model.eval()
    with torch.no_grad():
        generated_nir = model(rgb_tensor)

    # Post-process the generated NIR
    # Remove normalization and rescale back to original range
    generated_nir = generated_nir.squeeze(0).cpu().numpy()
    generated_nir = (generated_nir + 1) * 0.5 * 10000  # Denormalize and rescale

    # Ensure values are in valid range
    generated_nir = np.clip(generated_nir, 0, 10000)

    # Convert to 16-bit unsigned integer
    generated_nir = generated_nir.astype(np.uint16)

    return generated_nir[0]  # Return single channel

# Usage example
if __name__ == "__main__":
    # Set device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Path to your test image
    rgb_image_path = '/Users/pgt/Downloads/test_nir.jpeg'

    # Generate NIR band
    nir_band = generate_nir_from_rgb(generator, rgb_image_path, device)

    # Save the generated NIR band
    # For visualization, we'll normalize to 8-bit range
    nir_visualization = ((nir_band / 10000) * 255).astype(np.uint8)
    nir_image = Image.fromarray(nir_visualization)
    nir_image.save('generated_nir_1.png')

    # Save the full-range NIR band (16-bit TIFF)
    nir_image_16bit = Image.fromarray(nir_band)
    nir_image_16bit.save('generated_nir_16bit_1.tiff')