In [None]:
import torch
import os
from IPython.display import clear_output
import torch.nn.functional as F
import PIL.Image
from io import BytesIO
import IPython.display
import numpy as np
import skimage.io

In [None]:
img_dir = "/home/shehzeen/Datasets/sample_target_images"
out_dir = "/home/shehzeen/Datasets/FaceSignsRough"
encoder_model_path = "encoder_model.pth"
decoder_model_path = "decoder_model.pth"
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

### Initialize Models

In [None]:
encoder_model = torch.jit.load("encoder_model.pth")
decoder_model = torch.jit.load("decoder_model.pth")
encoder_model.eval().to(device)
decoder_model.eval().to(device)
clear_output()

### Utility functions

In [None]:
def showarray(a, fmt='png'):
    """
    takes a numpy array (0 to 1) of size h, w, 3
    """
    a = np.uint8(a*255.)
    f = BytesIO()
    PIL.Image.fromarray(a).save(f, fmt)
    IPython.display.display(IPython.display.Image(data=f.getvalue()))
    
def text_to_bits(text, encoding='utf-8', errors='surrogatepass'):
    bits = bin(int.from_bytes(text.encode(encoding, errors), 'big'))[2:]
    return bits.zfill(8 * ((len(bits) + 7) // 8))

def text_from_bits(bits, encoding='utf-8', errors='surrogatepass'):
    n = int(bits, 2)
    return n.to_bytes((n.bit_length() + 7) // 8, 'big').decode(encoding, errors) or '\0'

def load_images(image_filepaths, img_size=256):
    image_batch_np = []
    for file_path in image_filepaths:
        image_from_file = skimage.io.imread(file_path)/255.0
        image_batch_np.append(image_from_file)
    image_batch_np = np.stack(image_batch_np, axis=0)
    image_batch = torch.from_numpy(image_batch_np).float()
    image_batch = image_batch.permute(0, 3, 1, 2)

    h, w = image_batch.shape[2:]
    if h > w:
        image_batch = image_batch[:, :, int((h-w)/2):int((h+w)/2), :]
    elif w > h:
        image_batch = image_batch[:, :, :, int((w-h)/2):int((w+h)/2)]
    image_batch = F.interpolate(image_batch, size=(img_size, img_size), mode='bilinear', align_corners=True)

    return image_batch

def save_images(image_batch, out_dir, prefix=""):
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)
    
    image_paths = []
    for img_idx in range(image_batch.shape[0]):
        image_np = image_batch[img_idx].permute(1, 2, 0).cpu().numpy()
        image_np = np.uint8(image_np*255.)
        file_path = os.path.join(out_dir, "{}_{}.png".format(prefix, img_idx))
        PIL.Image.fromarray(image_np).save(file_path)
        image_paths.append(file_path)
    
    return image_paths

def find_image_paths(image_dir):
    image_paths = []
    for img_file in os.listdir(image_dir):
        if img_file.endswith(".png") or img_file.endswith(".jpg"):
            image_paths.append(os.path.join(image_dir, img_file))
    return image_paths

def decode_images(image_paths, secret_numpy, decoder_model):
    image_batch = load_images(image_paths)
    with torch.no_grad():
        image_batch = image_batch.to(device)
        decoded_secrets, _ = decoder_model(image_batch)

    predicted_secrets = (F.sigmoid(decoded_secrets) > 0.5).long()
    secrets = torch.from_numpy(secret_numpy).repeat(predicted_secrets.shape[0], 1).to(device) 
    secret_accuracy = (predicted_secrets == secrets).float().mean().item()

    decoding_results = []
    for img_idx, image_path in enumerate(image_paths):
        image_predicted_secret = predicted_secrets[img_idx].cpu().numpy().tolist()
        image_predicted_secret_bits = "".join([str(b) for b in image_predicted_secret][:secrete_num_bits])
        try:
            image_predicted_secret_text = text_from_bits(image_predicted_secret_bits)
        except:
            image_predicted_secret_text = "could not decode"
        
        decoding_results.append({
            "image_path": image_path,
            "image_predicted_secret_text": image_predicted_secret_text,
            "bit_accuracy": (predicted_secrets[img_idx] == secrets[img_idx]).float().mean().item()
        })
    
    return secret_accuracy, decoding_results

## Sign images with a secret and visualie the encoded images

In [None]:
secret_text = "sample"
secret_size = 128
secret_bits = text_to_bits(secret_text)
secrete_num_bits = len(secret_bits)

assert secrete_num_bits <= secret_size

secret_bits = secret_bits + "".join(["0"]*(secret_size-secrete_num_bits))
secret_numpy = np.array([[ int(c) for c in  secret_bits ]])




original_image_paths = find_image_paths(img_dir)
original_image_paths = original_image_paths[:3]
images = load_images(original_image_paths)

images = images.to(device)
secrets = torch.from_numpy(secret_numpy).repeat(images.shape[0], 1).to(device)

with torch.no_grad():
    encoded_images, secret_images = encoder_model(images, secrets)
    signed_image_dir = os.path.join(out_dir, "signed_images")
    encoded_image_paths = save_images(encoded_images, signed_image_dir)


for sidx in range(len(encoded_image_paths)):
    original_image_numpy = images[sidx].permute(1, 2, 0).cpu().numpy() 
    encoded_image_numpy = encoded_images[sidx].permute(1, 2, 0).cpu().numpy()
    residual = (encoded_image_numpy - original_image_numpy)
    rmin, rmax = np.min(residual), np.max(residual)
    residual_scaled = (residual-rmin)/(rmax - rmin)
    original_encoded_image = np.concatenate( (original_image_numpy, encoded_image_numpy, residual_scaled), axis=1)
    showarray(original_encoded_image)

## Apply benign transformations on signed images and decode

In [None]:
import pilgram

def apply_benign_transforms(image_filepaths, out_dir):
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)
        
    filters = [
        ('aden', pilgram.aden),
        ('brooklyn', pilgram.brooklyn),
        ('clarendon', pilgram.clarendon),
        ('toaster', pilgram.toaster),
        ('nashville', pilgram.nashville),
    ]

    transformed_image_filepaths = {}
    transformed_image_filepaths["None"] = image_filepaths
    transform_list = ["None"]
    for row in filters:
        transformed_image_filepaths[row[0]] = []
        transform_list.append(row[0])

    jpeg_qualities = [75, 50]
    for jpeg_quality in jpeg_qualities:
        transformed_image_filepaths["JPEG-{}".format(jpeg_quality)] = []
        transform_list.append("JPEG-{}".format(jpeg_quality))

    for fp in image_filepaths:
        original_filename = os.path.basename(fp)
        image_np = skimage.io.imread(fp)
        img = PIL.Image.fromarray(image_np)
        for image_filter in filters:
            filtered_image = image_filter[1](img)
            filtered_filename = "{}_{}.png".format(original_filename, image_filter[0])
            filtered_image.save(os.path.join(out_dir, filtered_filename))
            transformed_image_filepaths[image_filter[0]].append(os.path.join(out_dir, filtered_filename))
        
        for quality in [50, 75]:
            jpeg_key = "JPEG-{}".format(quality)
            image_np = skimage.io.imread(fp)
            image_np = np.uint8(image_np)
            jpeg_path = os.path.join(out_dir, "{}_{}.jpeg".format(original_filename, jpeg_key))
            PIL.Image.fromarray(image_np).save(jpeg_path,"JPEG", quality=quality)
            transformed_image_filepaths[jpeg_key].append(jpeg_path)
    
    return transform_list, transformed_image_filepaths

In [None]:
benign_tranform_list, benign_transformed_image_filepaths = apply_benign_transforms(encoded_image_paths, os.path.join(out_dir, "benign_transformed_images"))

for key in benign_tranform_list:
    secret_accuracy, decoding_results = decode_images(benign_transformed_image_filepaths[key], secret_numpy, decoder_model)
    for row in decoding_results[:1]:
        IPython.display.display(IPython.display.Image(row['image_path']))
        print("Transform : {}".format(key))
        print("Predicted secret: {}".format(row['image_predicted_secret_text']))
        print("Bit accuracy: {}".format(secret_accuracy))
        print("Image path: {}".format(row['image_path']))
        
    print ("-----------------------------------------------------\n")


### Apply malicious (face-swap) transform on signed images and decode

In [None]:
import face_swap

def apply_malicious_transforms(signed_image_paths, target_image_paths, out_dir):
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)
        
    transformed_image_filepaths = {
        "face_swap" : []
    }
    for target_image_path in target_image_paths:
        for signed_image_path in signed_image_paths:
            
            pic_a_name = target_image_path.split("/")[-1].split(".")[0]
            pic_b_name = signed_image_path.split("/")[-1].split(".")[0]
            
            output_file_name_fs = "faseswap_{}_{}.jpg".format(pic_a_name, pic_b_name)
            output_file_name_fs = os.path.join(out_dir, output_file_name_fs)
            try:
                face_swap.swap_faces(target_image_path, signed_image_path, output_file_name_fs)
            except:
                print("Error in shallowfakes")

            if os.path.exists(output_file_name_fs):
                transformed_image_filepaths["face_swap"].append(output_file_name_fs)
                print ("face swap success")
    
    assert len(transformed_image_filepaths['face_swap']) >= 1
    
    return ["face_swap"], transformed_image_filepaths


In [None]:
mal_tranform_list, mal_transformed_image_filepaths = apply_malicious_transforms(encoded_image_paths, original_image_paths, os.path.join(out_dir, "mal_transformed_images"))

for key in mal_tranform_list:
    secret_accuracy, decoding_results = decode_images(mal_transformed_image_filepaths[key], secret_numpy, decoder_model)
    for row in decoding_results:
        print("Transform : {}".format(key))
        print("Predicted secret: {}".format(row['image_predicted_secret_text']))
        print("Bit accuracy: {}".format(secret_accuracy))
        IPython.display.display(IPython.display.Image(row['image_path']))
        print ("-"*100)
