In [21]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt

In [22]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

In [23]:
#setting up some parameters

# Root directory for dataset


# Number of workers for dataloader
workers = 2

# Batch size during training
batch_size = 128

# Spatial size of training images. All images will be resized to this
#   size using a transformer.
image_size = 64

# Number of channels in the training images. For color images this is 3
nc = 3

# Size of z latent vector (i.e. size of generator input)
nz = 128

# Size of feature maps in generator
ngf = 64

# Size of generator channels
ngc = 4

# Size of feature maps in discriminator
ndf = 64

# Number of training epochs
num_epochs = 5

# Learning rate for optimizers
lr = 0.0002

# Beta1 hyperparameter for Adam optimizers
beta1 = 0.5

In [24]:
class StegEncoder(nn.Module):
    def __init__(self):
        super(StegEncoder, self).__init__()

        #cover image
        self.cover = nn.Sequential(
            # input is (nc) x 64 x 64
            nn.Conv2d(nc, ngc, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),

            # state size. (ngc) x 32 x 32
            nn.Conv2d(ngc, ngc * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngc * 2),
            nn.LeakyReLU(0.2, inplace=True),

            # state size. (ngc*2) x 16 x 16
            nn.Conv2d(ngc * 2, ngc * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngc * 4),
            nn.LeakyReLU(0.2, inplace=True),

            # state size. (ngc*4) x 8 x 8
            nn.Conv2d(ngc * 4, ngc * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngc * 8),
            nn.LeakyReLU(0.2, inplace=True),

            #state size. (ngc*8) x 4 x 4
            nn.Flatten(),
            nn.Linear(ngc*8*4*4, nz),
        )

        #secret image
        self.secret = nn.Sequential(
            # input is (nc) x 64 x 64
            nn.Conv2d(nc, ngc, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),

            # state size. (ngc) x 32 x 32
            nn.Conv2d(ngc, ngc * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngc * 2),
            nn.LeakyReLU(0.2, inplace=True),

            # state size. (ngc*2) x 16 x 16
            nn.Conv2d(ngc * 2, ngc * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngc * 4),
            nn.LeakyReLU(0.2, inplace=True),

            # state size. (ngc*4) x 8 x 8
            nn.Conv2d(ngc * 4, ngc * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngc * 8),
            nn.LeakyReLU(0.2, inplace=True),

            #state size. (ngc*8) x 4 x 4
            nn.Flatten(),
            nn.Linear(ngc*8*4*4, nz),
        )

        #generator for stego image
        self.stego = nn.Sequential(
            # input is 2*nz (latent vectors of cover and secret images concatenated)
            nn.ConvTranspose2d(2*nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),

            # state size. (ngf*8) x 4 x 4
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),

            # state size. (ngf*4) x 8 x 8
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),

            # state size. (ngf*2) x 16 x 16
            nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),

            # state size. (ngf) x 32 x 32
            nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh()
            # state size. (nc) x 64 x 64
        )


    def forward(self, secret, cover):
        cover = self.cover(cover)
        secret = self.secret(secret)
        stego = torch.cat((cover, secret), 1)
        stego = stego.view(-1, 2*nz, 1, 1)
        stego = self.stego(stego)
        return stego

In [1]:
class StegDecoder(nn.Module):
    def __init__(self):
        super(StegDecoder, self).__init__()

        self.reverseStego = nn.Sequential(
            # input is (nc) x 64 x 64
            nn.Conv2d(nc, ngc, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),

            # state size. (ngc) x 32 x 32
            nn.Conv2d(ngc, ngc * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngc * 2),
            nn.LeakyReLU(0.2, inplace=True),

            # state size. (ngc*2) x 16 x 16
            nn.Conv2d(ngc * 2, ngc * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngc * 4),
            nn.LeakyReLU(0.2, inplace=True),

            # state size. (ngc*4) x 8 x 8
            nn.Conv2d(ngc * 4, ngc * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngc * 8),
            nn.LeakyReLU(0.2, inplace=True),

            #state size. (ngc*8) x 4 x 4
            nn.Flatten(),
            nn.Linear(ngc*8*4*4, nz),
        )

        self.regenerate = nn.Sequential(
            # input is (nz) x 1 x 1
            nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),

            # state size. (ngf*8) x 4 x 4
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),

            # state size. (ngf*4) x 8 x 8
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),

            # state size. (ngf*2) x 16 x 16
            nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),

            # state size. (ngf) x 32 x 32
            nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh()
            # state size. (nc) x 64 x 64
        )

    def forward(self, stego):
        stego = self.reverseStego(stego)
        stego = stego.view(-1, nz, 1, 1)
        stego = self.regenerate(stego)
        return stego

In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        
        self.compress = nn.Sequential(
            # input is (nc) x 64 x 64
            nn.Conv2d(nc, ngc, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),

            # state size. (ngc) x 32 x 32
            nn.Conv2d(ngc, ngc * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngc * 2),
            nn.LeakyReLU(0.2, inplace=True),

            # state size. (ngc*2) x 16 x 16
            nn.Conv2d(ngc * 2, ngc * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngc * 4),
            nn.LeakyReLU(0.2, inplace=True),

            # state size. (ngc*4) x 8 x 8
            nn.Conv2d(ngc * 4, ngc * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngc * 8),
            nn.LeakyReLU(0.2, inplace=True),

            #state size. (ngc*8) x 4 x 4
            nn.Flatten(),
            nn.Linear(ngc*8*4*4, nz),
        )

        self.classify = nn.Sequential(
            nn.Linear(nz, nz / 4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(nz / 4, 1),
            nn.Sigmoid()
        )

    def forward(self, image):
        image = self.compress(image)
        image = self.classify(image)
        return image
            

In [34]:
#Loss functions

a = 1
b = 1

class CoverMSELoss(nn.Module):
    def __init__(self):
        super(CoverMSELoss, self).__init__()

    def forward(self, cover, stego):
        return torch.mean((cover - stego) ** 2)
    
class SecretMSELoss(nn.Module):
    def __init__(self):
        super(SecretMSELoss, self).__init__()

    def forward(self, secret, stego):
        return torch.mean((secret - stego) ** 2)
    
class EncDecLoss(nn.Module):
    def __init__(self):
        super(EncDecLoss, self).__init__()

    def forward(self, cover, secret, stego, stego_decoded):
        return a * CoverMSELoss(cover, stego) + b * SecretMSELoss(secret, stego_decoded)
    
encDecCriterion = EncDecLoss()
criterion = nn.BCELoss()

In [35]:
device = "cpu"

In [None]:
def train(img_dataloader, enc, dec, disc, encDecOptim, discOptim, encDecCriterion, criterion, num_epochs, mix_coeff):
    #update discriminator 4 times for every generator training

    #put in training mode
    enc.train()
    dec.train()
    disc.train()

    for epoch in range(num_epochs):
        for i, data in enumerate(img_dataloader, 0):
            if i % 5 == 0:
                #update generator
                encDecOptim.zero_grad()
                cover, secret = data
                cover = cover.to(device)
                secret = secret.to(device)
                stego = enc(secret, cover)
                stego_decoded = dec(stego)
                gen_loss = disc(stego)
                loss = encDecCriterion(cover, secret, stego, stego_decoded) + mix_coeff * gen_loss
                loss.backward()
                encDecOptim.step()
            else:
                #update discriminator
                discOptim.zero_grad()
                cover, secret = data
                cover = cover.to(device)
                secret = secret.to(device)
                stego = enc(secret, cover)
                real = torch.zeros((batch_size, 1), device=device)
                fake = torch.ones((batch_size, 1), device=device)
                real_loss = criterion(disc(cover), real)
                fake_loss = criterion(disc(stego), fake)
                loss = real_loss + fake_loss
                loss.backward()
                discOptim.step()

In [None]:
#TODO
#Dataset,  dataloader (cover, secret ---> the return format of the dataloader)

class Dataset(torch.utils.data.Dataset):
    def __init__(self, images, transform=None):
        self.images = images
        self.pairs = []
        self.createPairs()
        self.transform = transform
    
    def createPairs(self):
        #create pairs of images for cover and secret
        for i in range(len(self.images)):
            for j in range(len(self.images)):
                self.pairs.append((self.images[i], self.images[j]))

    def __len__(self):
        return len(self.pairs)
    
    def __getitem__(self, idx):
        cover = self.images[self.pairs[idx][0]]
        secret = self.images[self.pairs[idx][1]]
        self.transform(cover)
        self.transform(secret)
        return cover, secret

In [None]:
import torchvision.transforms as transforms

transform = transforms.Compose([
    transforms.Resize(image_size),
    # transforms.CenterCrop(image_size),
    transforms.ToTensor(),
    # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

from PIL import Image
import os

img_dir = "path/to/directory"

images = [Image.open(img_dir + img) for img in os.listdir(img_dir)]


train_dataloader = torch.utils.data.DataLoader(Dataset(images), batch_size=batch_size, shuffle=True, num_workers=workers)

enc = StegEncoder().to(device)
dec = StegDecoder().to(device)
disc = Discriminator().to(device)

encDecOptim = torch.optim.Adam(enc.parameters(), lr=lr, betas=(beta1, 0.999))
discOptim = torch.optim.Adam(disc.parameters(), lr=lr, betas=(beta1, 0.999))

enc.apply(weights_init)
dec.apply(weights_init)
disc.apply(weights_init)

train(train_dataloader, enc, dec, disc, encDecOptim, discOptim, encDecCriterion, criterion, num_epochs, 0.5)

#save the model
torch.save(enc.state_dict(), "encoder.pth")
torch.save(dec.state_dict(), "decoder.pth")
torch.save(disc.state_dict(), "discriminator.pth")