In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.utils.data import Dataset, random_split
from torchvision.datasets import ImageFolder
import numpy as np
import matplotlib.pyplot as plt
import os
import ijson

In [None]:
# Display the original image

from torchvision.transforms.functional import resize

# Load the .pt file
image_path = 'trainA_nonorm/0be3a23f-e329a79b.pt' # the name to the corresponding pt file
image_tensor = torch.load(image_path)  # Shape: [C, H, W]

image_tensor = F.resize(image_tensor, (700, 900))  # Shape: [C, 600, 800]

image_tensor = image_tensor/256

# Ensure values are in [0, 1] range
image_tensor = image_tensor.clamp(0, 1)

image_tensor = image_tensor.to(torch.float32).cpu()

# Ensure the shape is (H, W, C) for RGB images
if image_tensor.dim() == 3 and image_tensor.shape[0] in [1, 3]:  
    image_np = image_tensor.permute(1, 2, 0).numpy()
elif image_tensor.dim() == 2:  
    image_np = image_tensor.numpy()
else:
    raise ValueError("Unexpected tensor shape:", image_tensor.shape)

# Normalize if values are outside [0,1] (assume max 255 for images)
if image_np.max() > 1.0:
    image_np = image_np / 255.0  # Normalize to range [0,1]

# Plot the original image
fig, ax = plt.subplots(1)
ax.imshow(image_tensor.permute(1, 2, 0))

In [None]:
# Function that applys the color and brightness enhancement

def contrast_stretch(image, low_percentile=0, high_percentile=85):
    """
    Perform contrast stretching while printing debug info to avoid full black images.
    
    :param image: PyTorch tensor of shape (C, H, W)
    :param low_percentile: Lower percentile for clipping
    :param high_percentile: Upper percentile for clipping
    :return: Contrast-stretched tensor
    """
    image_np = image.cpu().numpy()
    
    # Compute percentiles
    min_val = np.percentile(image_np, low_percentile)
    max_val = np.percentile(image_np, high_percentile)
    
    print(f"Debug: Min percentile value = {min_val}, Max percentile value = {max_val}")
    
    if max_val - min_val < 1e-6:
        print("Warning: Min and max values are too close! Returning original image.")
        return image  # Return original image to avoid black output
    
    # Apply contrast stretching
    stretched = (image_np - min_val) / (max_val - min_val + 1e-8)
    
    # Clip values to avoid over-brightening
    stretched = np.clip(stretched, 0, 1)

    return torch.tensor(stretched, dtype=torch.float32)


In [None]:
# Sample Usage

image_path = 'trainA_700nonorm/0be3a23f-e329a79b.pt' # the name to the corresponding pt file
image_tensor = torch.load(image_path)  # Shape: [C, H, W]

# Ensure the tensor is float32
image_tensor = image_tensor.to(torch.float32)

image_tensor = contrast_stretch(image_tensor)

if image_np.max() > 1.0:
    image_np = image_np / 255.0  # Normalize to range [0,1]

    
# Ensure values are in [0, 1] range
image_tensor = image_tensor.clamp(0, 1)
image_tensor = F.resize(image_tensor, (700, 900))  # Shape: [C, 600, 800]

# Move tensor to CPU (optional)
image_tensor = image_tensor.to(torch.float32).cpu()

# # Ensure the shape is (H, W, C) for RGB images
# if image_tensor.dim() == 3 and image_tensor.shape[0] in [1, 3]:  
#     image_np = image_tensor.permute(1, 2, 0).numpy()
# elif image_tensor.dim() == 2:  
#     image_np = image_tensor.numpy()
# else:
#     raise ValueError("Unexpected tensor shape:", image_tensor.shape)

# Normalize if values are outside [0,1] (assume max 255 for images)

# show the enhanced image
fig, ax = plt.subplots(1)
ax.imshow(image_tensor.permute(1, 2, 0))