# üîê Invisible Watermark Training Notebook
### PhotoMaker CAP C6 Group 3 ‚Äî Paresh

This notebook trains the CNN-based invisible watermark encoder/decoder.
- Loads CIFAR10
- Applies robustness attacks
- Trains encoder + decoder jointly
- Saves `encoder_trained.pth` and `decoder_trained.pth`
- Includes verification test


In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import torchvision
from PIL import Image
import numpy as np

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BIT_LENGTH = 64
print("Using device:", DEVICE)

Using device: cuda


## üìå Encoder & Decoder (same as your project)

In [7]:
class WatermarkEncoder(nn.Module):
    def __init__(self, bit_length=64):
        super().__init__()
        self.bit_length = bit_length

        self.embed = nn.Sequential(
            nn.Linear(bit_length, 256),
            nn.ReLU(),
            nn.Linear(256, 64 * 64),
            nn.ReLU()
        )

        self.conv = nn.Sequential(
            nn.Conv2d(4, 32, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 32, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 3, 1)
        )

    def forward(self, image, bits):
        B, C, H, W = image.shape
        wm = self.embed(bits).view(B, 1, 64, 64)
        wm = torch.nn.functional.interpolate(wm, size=(H, W), mode="bilinear")
        x = torch.cat([image, wm], dim=1)
        residual = self.conv(x)
        return torch.clamp(image + 0.01 * residual, 0, 1)


class WatermarkDecoder(nn.Module):
    def __init__(self, bit_length=64):
        super().__init__()
        self.bit_length = bit_length

        self.conv = nn.Sequential(
            nn.Conv2d(3, 32, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 32, 3, padding=1),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d(1)
        )

        self.fc = nn.Sequential(
            nn.Linear(32, 128),
            nn.ReLU(),
            nn.Linear(128, bit_length),
            nn.Sigmoid()
        )

    def forward(self, image):
        x = self.conv(image)
        x = x.view(x.size(0), -1)
        return self.fc(x)

## üìå Robustness Attacks

In [8]:
import io
from PIL import Image
import torchvision.transforms as T

to_pil = T.ToPILImage()
to_tensor = T.ToTensor()

def jpeg_compress_single(img, quality=70):
    """JPEG compress a single image tensor: C√óH√óW."""
    pil = to_pil(img.cpu())
    buffer = io.BytesIO()
    pil.save(buffer, format="JPEG", quality=quality)
    buffer.seek(0)
    return to_tensor(Image.open(buffer)).to(img.device)

def apply_attacks(batch):
    """Apply robustness attacks to a batch of images: B√óC√óH√óW."""
    attacked = []

    for img in batch:  # iterate over each image in the batch
        x = img.clone()

        # Gaussian noise
        x = x + 0.01 * torch.randn_like(x)

        # Random blur
        if torch.rand(1).item() < 0.3:
            x = torchvision.transforms.functional.gaussian_blur(x, kernel_size=5)

        # JPEG compression
        if torch.rand(1).item() < 0.5:
            x = jpeg_compress_single(x, quality=70)

        attacked.append(torch.clamp(x, 0, 1))

    return torch.stack(attacked)


## üìå Load CIFAR10 Dataset

In [9]:
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
])

dataset = datasets.CIFAR10(root="./data", train=True, download=True, transform=transform)
loader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4)

len(dataset)

50000

## üöÄ Training Loop

In [None]:
encoder = WatermarkEncoder(BIT_LENGTH).to(DEVICE)
decoder = WatermarkDecoder(BIT_LENGTH).to(DEVICE)

opt = torch.optim.Adam(list(encoder.parameters()) + list(decoder.parameters()), lr=1e-4)

lambda_img = 1.0
lambda_wm = 5.0

EPOCHS = 100

for epoch in range(EPOCHS):
    for imgs, _ in loader:
        imgs = imgs.to(DEVICE)

        bits = torch.randint(0, 2, (imgs.size(0), BIT_LENGTH), device=DEVICE).float()

        watermarked = encoder(imgs, bits)
        attacked = apply_attacks(watermarked)
        pred_bits = decoder(attacked)

        img_loss = F.mse_loss(watermarked, imgs)
        wm_loss = F.binary_cross_entropy(pred_bits, bits)

        loss = lambda_img * img_loss + lambda_wm * wm_loss

        opt.zero_grad()
        loss.backward()
        opt.step()

    print(f"Epoch {epoch+1}/{EPOCHS} ‚Äî img_loss={img_loss.item():.4f}, wm_loss={wm_loss.item():.4f}")

Epoch 1/30 ‚Äî img_loss=0.0000, wm_loss=0.6932
Epoch 2/30 ‚Äî img_loss=0.0000, wm_loss=0.6934
Epoch 3/30 ‚Äî img_loss=0.0000, wm_loss=0.6931
Epoch 4/30 ‚Äî img_loss=0.0000, wm_loss=0.6934


## üíæ Save Trained Models

In [6]:
torch.save(encoder.state_dict(), "encoder_trained.pth")
torch.save(decoder.state_dict(), "decoder_trained.pth")
print("Saved encoder_trained.pth and decoder_trained.pth")

Saved encoder_trained.pth and decoder_trained.pth


## üîç Verification Test

In [7]:
encoder.eval()
decoder.eval()

sample_img, _ = dataset[0]
sample_img = sample_img.unsqueeze(0).to(DEVICE)

test_bits = torch.randint(0, 2, (1, BIT_LENGTH), device=DEVICE).float()

wm_img = encoder(sample_img, test_bits)
decoded = decoder(wm_img)

decoded_bits = (decoded > 0.5).int().cpu().numpy().tolist()[0]
confidence = sum(decoded_bits) / BIT_LENGTH

print("Decoded bits:", decoded_bits)
print("Confidence:", confidence)

Decoded bits: [0, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 0, 1, 1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 0, 0, 1, 0, 1, 0, 1, 1, 1, 1, 0, 1, 0, 0, 1, 0, 1]
Confidence: 0.65625
