In [None]:
import torch
import torch.nn.functional as F
import numpy as np
import cv2
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
from torchvision.datasets import Flickr8k
from diffusers import StableDiffusionPipeline
import os
from PIL import Image

In [None]:
device = torch.device("cpu")

In [None]:
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5").to(device)

In [None]:
transform = transforms.Compose([
    transforms.Resize((512, 512)),
    transforms.ToTensor(),
])

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
FLICKR8K_IMAGES_PATH = "/content/drive/MyDrive/Flickr8k dataset/Images"

In [None]:
class Flickr8kDataset(Dataset):
    def __init__(self, images_path, transform=None):
        super().__init__()
        self.images_path = images_path
        self.transform = transform
        self.image_filenames = os.listdir(images_path)

    def __len__(self):
        return len(self.image_filenames)

    def __getitem__(self, idx):
        img_name = self.image_filenames[idx]
        img_path = os.path.join(self.images_path, img_name)
        image = Image.open(img_path).convert("RGB")

        if self.transform:
            image = self.transform(image)

        return image, img_name  

In [None]:
dataset = Flickr8kDataset(FLICKR8K_IMAGES_PATH, transform=transform)
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)

In [None]:
cover_image, _ = next(iter(dataloader))
secret_image, _ = next(iter(dataloader))

In [None]:
def encode_stegano(cover, secret, alpha=0.1):
    """Encodes secret into cover image using a weighted blend."""
    return (1 - alpha) * cover + alpha * secret

In [None]:
def decode_stegano(encoded_image, cover_image, alpha=0.1):
    """Extracts the secret image from the encoded image using inverse blending."""
    if encoded_image.shape != cover_image.shape:
        raise ValueError("Encoded and Cover images must have the same shape!")
    secret_recovered = (encoded_image - (1 - alpha) * cover_image) / alpha
    secret_recovered = torch.clamp(secret_recovered, 0, 1)
    return secret_recovered

In [None]:
def show_images(images, titles):
    fig, axes = plt.subplots(1, len(images), figsize=(12, 4))
    for ax, img, title in zip(axes, images, titles):
        if isinstance(img, torch.Tensor):
            img = img.permute(1, 2, 0).numpy()
        ax.imshow(img)
        ax.set_title(title)
        ax.axis("off")
    plt.show()

In [None]:
encoded_image = encode_stegano(cover_image, secret_image, alpha=0.1)
decoded_secret_image = decode_stegano(encoded_image, cover_image, alpha=0.1)
show_images([cover_image[1], secret_image[1], encoded_image[1], decoded_secret_image[1]],
            ["Cover", "Secret", "Encoded", "Decoded Secret"])

In [None]:
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import mean_squared_error as mse

def calculate_metrics(original_image, decoded_image):
    """Calculates SSIM and MSE between original and decoded images."""
    # Convert tensors to numpy arrays and scale to 0-255
    original_np = (original_image.permute(1, 2, 0).numpy() * 255).astype(np.uint8)
    decoded_np = (decoded_image.permute(1, 2, 0).numpy() * 255).astype(np.uint8)

    # Ensure images are in grayscale for SSIM calculation if needed, or calculate SSIM per channel
    # For simplicity, calculate SSIM on grayscale or average channels
    # If images are RGB, SSIM can be calculated per channel and averaged, or convert to grayscale
    # For this, we'll convert to grayscale for SSIM.
    original_gray = cv2.cvtColor(original_np, cv2.COLOR_RGB2GRAY)
    decoded_gray = cv2.cvtColor(decoded_np, cv2.COLOR_RGB2GRAY)


    ssim_index = ssim(original_gray, decoded_gray)
    mean_squared_error = mse(original_np, decoded_np)

    return ssim_index, mean_squared_error

# `secret_image[1]` is the original secret image and `decoded_secret_image[1]` is the decoded one
original_secret = secret_image[1]
decoded_secret = decoded_secret_image[1]

ssim_score, mse_score = calculate_metrics(original_secret, decoded_secret)

print(f"SSIM: {ssim_score:.4f}")
print(f"MSE: {mse_score:.4f}")