In [None]:
from PIL import Image
from PIL import ExifTags

img = Image.open("sample.jpg")
depth = Image.open("sample-depth.jpg")

# focal length
exif = {ExifTags.TAGS[k]: v for k, v in img._getexif().items() if k in ExifTags.TAGS}
old_focal_mm = exif["FocalLength"]
new_focal_mm_multiplier = 3.
if new_focal_mm_multiplier <= 1.:
    raise ValueError('new_focal_mm_multiplier cannot be <= 1')
new_focal_mm = float(new_focal_mm_multiplier * old_focal_mm)
print(f'{old_focal_mm} -> {new_focal_mm}')

# resize
img = img.resize(size=depth.size)

In [None]:
img

In [None]:
depth

In [None]:
import numpy as np

img_arr = np.asarray(img.convert("RGBA"))
depth_arr = np.asarray(depth)[:, :, 0] # any channel is the same since it's black and white
img_arr.shape, depth_arr.shape

In [None]:
import matplotlib.pyplot as plt

plt.figure(figsize=(4, 3))
plt.hist(depth_arr.reshape(-1), density=True)
plt.xlabel('depth')
plt.title('Distribution of depths')

# depth is in range(0, 256)
assert 0 == depth_arr.min() and depth_arr.max() == 255 

In [None]:
plt.hist(img_arr[:, :, 3], density=True)
plt.xlabel('alpha')
plt.title('Distribution of transparency')
assert 255 == img_arr[:, :, 3].min()

In [None]:
# create depth for effect
# d > old focal length and d > new focal length
# decrease depth_min_multiplier to exaggerate close-up effect
depth_min_multiplier, depth_max_multiplier = 2., 10.
depth_min = depth_min_multiplier * max(new_focal_mm, old_focal_mm)
depth_max = depth_max_multiplier * max(new_focal_mm, old_focal_mm)
depth_min, depth_max

In [None]:
depth_arr_transformed = depth_min + (depth_arr / 256) * (depth_max - depth_min)
depth_arr_transformed.min(), depth_arr_transformed.max()

In [None]:
def get_scaling(old_f_in_mm, new_f_in_mm, depth_in_mm) -> float:
    if depth_in_mm < min(new_f_in_mm, old_f_in_mm):
        raise ValueError('Depth cannot be smaller than new_f_in_mm or old_f_in_mm')
    res = (new_f_in_mm / (depth_in_mm - new_f_in_mm))
    res = res / (old_f_in_mm / (depth_in_mm - old_f_in_mm))
    return res


# focal length is proportional to image size
assert get_scaling(2, 3, 5) > 1
assert get_scaling(3, 2, 5) < 1
# for long distances, image size change is focal length change (i.e. zooming)
assert np.abs(get_scaling(3, 2, 1e9) - 2 / 3) < 1e-3


In [None]:
from tqdm import tqdm

size_x, size_y = img.size

# in RGBA format
new_img_arr = np.zeros(shape=img_arr.shape, dtype=np.uint8)

# start writing the furthest transformed pixels
for d in tqdm(sorted(np.unique(depth_arr_transformed), reverse=True)):

    # this modified scaling factor preserves unity at infinity
    scaling = get_scaling(old_f_in_mm=old_focal_mm, new_f_in_mm=new_focal_mm, depth_in_mm=d)
    scaling = scaling / (new_focal_mm / old_focal_mm)
    new_size_x, new_size_y = int(size_x * scaling), int(size_y * scaling)

    mask_arr = depth_arr_transformed == d
    img_arr_d = np.zeros(shape=img_arr.shape, dtype=np.uint8)
    img_arr_d[mask_arr] = img_arr[mask_arr]

    img_d = (
        Image
        .fromarray(img_arr_d)
        .resize((new_size_x, new_size_y))
        .crop(box=(
            (new_size_x - size_x) // 2,
            (new_size_y - size_y) // 2,
            size_x,
            size_y
        ))
        .resize((size_x, size_y))
    )
    img_arr_d = np.asarray(img_d)
    new_img_arr[img_arr_d > 0] = img_arr_d[img_arr_d > 0]


In [None]:
num_transparent_pixels = len(list(zip(*np.where(new_img_arr[:, :, 3] == 0))))

f'Need to fill {num_transparent_pixels} ({100*num_transparent_pixels/depth_arr.size:.2f}%) pixels'

In [None]:
Image.fromarray(new_img_arr)

In [None]:
new_img = Image.fromarray(new_img_arr).convert("RGB")
new_img

In [None]:
new_img_mask = np.zeros(shape=img_arr.shape, dtype=np.uint8)
new_img_mask[new_img_arr[:, :, 3] == 0] = 255
new_img_mask = Image.fromarray(new_img_mask).convert("RGB")
new_img_mask

In [None]:
%pip install --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cpu

In [16]:
# make sure you're logged in with `huggingface-cli login`
from diffusers import StableDiffusionInpaintPipeline
import torch

pipe = StableDiffusionInpaintPipeline.from_pretrained(
    "CompVis/stable-diffusion-v1-4",
    revision="fp16",
    # torch_dtype=torch.float16,
    use_auth_token=True
).to('mps')

image = pipe(prompt='a photograph', init_image=new_img, mask_image=new_img_mask).images[0]
image