In [13]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
from matplotlib import pyplot as plt
from torchsummary import summary
import glob
import cv2
import numpy as np

In [14]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# # CNN Encoder
# class CNNEncoder(nn.Module):
#     def __init__(self):
#         super(CNNEncoder, self).__init__()
#         self.encoder = nn.Sequential(
#             nn.Conv2d(1, 32, kernel_size=3, padding=1),
#             nn.ReLU(),
#             nn.MaxPool2d(2),
#             nn.Conv2d(32, 64, kernel_size=3, padding=1),
#             nn.ReLU(),
#             nn.MaxPool2d(2),
#             nn.Conv2d(64, 128, kernel_size=3, padding=1),
#             nn.ReLU(),
#             nn.Conv2d(128, 256, kernel_size=3, padding=1),
#             nn.ReLU(),
#             nn.MaxPool2d(2),
#         )

#     def forward(self, x):
#         return self.encoder(x)

# # summary(CNNEncoder().to(device), (1, 256, 256))

In [15]:
# # Dilated Residual Network (DRN)
# class DRN(nn.Module):
#     def __init__(self, dilation):
#         super(DRN, self).__init__()
#         self.conv = nn.Conv2d(256, 256, kernel_size=3, padding=dilation, dilation=dilation)
#         self.relu = nn.ReLU()

#     def forward(self, x):
#         l1 = self.relu(self.conv(x))
#         l2 = self.relu(self.conv(l1))
#         return self.relu(self.conv(l2) + x)  # Residual connection

# # summary(DRN(1).to(device), (256, 32, 32))

In [16]:
# # Efficient Channel Attention (ECA)
# class ECA(nn.Module):
#     def __init__(self):
#         super(ECA, self).__init__()
#         self.conv = nn.Conv2d(256, 256, kernel_size=1)
#         self.sigmoid = nn.Sigmoid()

#     def forward(self, x):
#         return x * self.sigmoid(self.conv(x))

# summary(ECA().to(device), (256, 32, 32))

In [17]:
# # Simple UNet
# class UNet(nn.Module):
#     def __init__(self):
#         super(UNet, self).__init__()

#         self.encoder = nn.Sequential(
#             nn.Conv2d(256, 128, kernel_size=3, padding=1),
#             nn.ReLU(),
#             nn.MaxPool2d(2),
#             nn.Conv2d(128, 128, kernel_size=3, padding=1),
#             nn.ReLU(),
#             nn.MaxPool2d(2),
#             nn.Conv2d(128, 64, kernel_size=3, padding=1),
#             nn.ReLU(),
#             nn.MaxPool2d(2),
#         )

#         self.decoder = nn.Sequential(
#             nn.ConvTranspose2d(64, 32, kernel_size=3, padding=1),
#             nn.ReLU(),
#             nn.Upsample(scale_factor=2),
#             nn.ConvTranspose2d(32, 32, kernel_size=3, padding=1),
#             nn.ReLU(),
#             nn.Upsample(scale_factor=2),
#             nn.ConvTranspose2d(32, 16, kernel_size=3, padding=1),
#             nn.ReLU(),
#             nn.Upsample(scale_factor=2),
#             nn.ConvTranspose2d(16, 16, kernel_size=3, padding=1),
#             nn.ReLU(),
#             nn.Upsample(scale_factor=2),
#             nn.Upsample(scale_factor=2),
#             nn.ConvTranspose2d(16, 1, kernel_size=3, padding=1),
#             nn.ReLU(),
#             nn.Upsample(scale_factor=2),
#         )

#     def forward(self, x):
#         x = self.encoder(x)
#         return self.decoder(x)

# summary(UNet().to(device), (256, 32, 32))

# # Complete Model
# class SARDespecklingNet(nn.Module):
#     def __init__(self):
#         super(SARDespecklingNet, self).__init__()
#         self.encoder = CNNEncoder()
#         self.drn1 = DRN(dilation=1)
#         self.drn3 = DRN(dilation=3)
#         self.drn5 = DRN(dilation=5)
#         self.eca = ECA()
#         self.unet = UNet()

#     def forward(self, x):
#         x = self.encoder(x)
#         x1 = self.drn1(x)
#         x3 = self.drn3(x)
#         x5 = self.drn5(x)
#         x = self.eca(x1 + x3 + x5)
#         return self.unet(x)

# summary(SARDespecklingNet().to(device), (1, 256, 256))

In [18]:
# Dataset Placeholder (Assuming we have SAR Images & Clean Images)
def prepare_data(n):
    infiles = sorted(glob.glob('./s1/*.png'))[:n]
    outfiles = sorted(glob.glob('./s2/*.png'))[:n]
    in_data = []
    out_data = []

    for fi in infiles:
        img = cv2.imread(fi)
        img = cv2.resize(img, (256,256))
        img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
        in_data.append(img)

    for fo in outfiles:
        img = cv2.imread(fo)
        img = cv2.resize(img, (256,256))
        img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
        out_data.append(img)
    return in_data, out_data
    
class SARDataset(Dataset):
    def __init__(self, sar_images, clean_images, transform=None):
        self.sar_images = sar_images
        self.clean_images = clean_images
        self.transform = transform

    def __len__(self):
        return len(self.sar_images)

    def __getitem__(self, idx):
        sar = self.sar_images[idx]
        clean = self.clean_images[idx]
        # Remove or modify the transform application
        # if self.transform:
        #     sar = self.transform(sar)
        #     clean = self.transform(clean)
        return sar, clean



# Sample Data Preparation (Dummy Tensors)
dummy_sar, clean_sar = prepare_data(1000) # Corresponding clean images
dummy_sar = [np.array(img, dtype=np.float32) / 255 for img in dummy_sar]
clean_sar = [np.array(img, dtype=np.float32) / 255 for img in clean_sar]
dummy_sar = np.array(dummy_sar).reshape(-1, 1, 256, 256)
clean_sar = np.array(clean_sar).reshape(-1, 1, 256, 256)

# dummy_sar = dummy_sar.to(device)
# clean_sar = clean_sar.to(device)

# Remove the ToTensor transform or replace with a different transform if needed
# transform = transforms.ToTensor()

dataset = SARDataset(dummy_sar, clean_sar) #, transform=transform)
train_loader = DataLoader(dataset, batch_size=10, shuffle=True)
val_loader = DataLoader(dataset, batch_size=10, shuffle=False)

In [19]:
np.array(dummy_sar).shape


UnpicklingError: Weights only load failed. This file can still be loaded, to do so you have two options, [1mdo those steps only if you trust the source of the checkpoint[0m. 
	(1) In PyTorch 2.6, we changed the default value of the `weights_only` argument in `torch.load` from `False` to `True`. Re-running `torch.load` with `weights_only` set to `False` will likely succeed, but it can result in arbitrary code execution. Do it only if you got the file from a trusted source.
	(2) Alternatively, to load with `weights_only=True` please check the recommended steps in the following error message.
	WeightsUnpickler error: Unsupported global: GLOBAL __main__.SARDespecklingNet was not an allowed global by default. Please use `torch.serialization.add_safe_globals([SARDespecklingNet])` or the `torch.serialization.safe_globals([SARDespecklingNet])` context manager to allowlist this global if you trust this class/function.

Check the documentation of torch.load to learn more about types accepted by default with weights_only https://pytorch.org/docs/stable/generated/torch.load.html.

In [None]:
# Training and Evaluation
def train_and_evaluate(model, train_loader, val_loader, epochs=10, lr=0.001):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    criterion = nn.MSELoss().to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)


    for epoch in range(epochs):
        model.train()
        train_loss = 0

        # print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
        for sar, clean in train_loader:
            sar, clean = sar.float().to(device), clean.float().to(device)
            optimizer.zero_grad()
            output = model(sar)
            loss = criterion(output, clean)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

        print(f"Epoch {epoch+1}/{epochs}, Training Loss: {train_loss/len(train_loader)}")
        
        if (epoch%2 == 0):
            torch.save(model, "./sar_model.pth")
            
        # Evaluation
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for sar, clean in val_loader:
                sar, clean = sar.to(device, non_blocking=True), clean.to(device, non_blocking=True)
                output = model(sar)
                loss = criterion(output, clean)
                val_loss += loss.item()

        print(f"Validation Loss: {val_loss/len(val_loader)}")


# Train the model
# model = SARDespecklingNet()
model = torch.load("./sar_model.pth", weights_only=False)
train_and_evaluate(model, train_loader, val_loader, epochs=10, lr=0.0001)


Epoch 1/10, Training Loss: 0.03426116056740284


In [None]:
x = torch.from_numpy(dummy_sar[:10]).float().to(device)
# x = x.reshape(1, -1, 256, 256)
x.shape

In [None]:
# model = torch.load("./sar_model.pth", weights_only=False)
model.eval()
with torch.no_grad():
    p = model(x).to(device)

print(p.shape)

In [None]:
plt.imshow(p[2].reshape(256, 256))