In [None]:
'''
runs stable diffusion via diffusers, creates a series of frames, with simplex noise distortions, and saves them as a video
from original script by Dr47. Apr/May 2024
'''
#import libraries
import os
import time 
import argparse
from PIL import Image, ImageChops

from diffusers import AutoPipelineForText2Image
from diffusers import StableDiffusionImg2ImgPipeline
    
import torch
import imageio
import sys      # provides access to some vars used or maintained by Python interpreter, and to functions that interact with it.

from xformers.ops import MemoryEfficientAttentionFlashAttentionOp # Enable memory efficient attention from xFormers.(not working atm)

import numpy as np
from opensimplex import OpenSimplex # for noise generation

#check cuda availability
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')
print('Using device:', device)


In [None]:
#define model and pipeline

#optimisation trick #1 - Use TensorFloat-32
'''PyTorch enables TF32 mode for convolutions, but NOT for matmul. Enabling it can 
significantly (eg: 24% on RTX3070) speed up computations with minimal loss in numerical accuracy.
'''
torch.backends.cuda.matmul.allow_tf32 = True

args = None

# path to SD1.5 model
SD15_MODEL = "D:/SD_CKPTS_used_by_Auto1111/realisticVisionV51_v51VAE.safetensors"
LOOP_COUNT = 4 # Number of times to loop the final animation (results in loop count +1)

# Define pipeline (half-precision weights for faster execution)
def get_pipeline():
    return (
        StableDiffusionImg2ImgPipeline.from_single_file( SD15_MODEL, torch_dtype=torch.float16 ).to("cuda"),
        
        AutoPipelineForText2Image.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16,
            variant="fp16", use_safetensors=True, safety_checker=None,   ).to("cuda")    )

In [None]:
# function to Generate the image

def fn_generate(input_path, prompt, pipeline):
    init_image = None
    if input_path:
        # Load the input image
        init_image = Image.open(input_path).convert("RGB")
        
    # Generate the output image
    return pipeline( prompt=prompt, image=init_image, strength=args.strength,
        num_inference_steps=args.num_inference_steps, guidance_scale=args.guidance_scale,
        generator=( torch.Generator(device="cuda").manual_seed(args.seed) if args.seed else None ) ).images[0]

In [None]:
def fn_process_image(im):
    '''takes an image, returns a distorted image'''
    
    np_im = np.array(im)
    #get h and w of image
    h, w, _ = np_im.shape # ignore the 3rd dimension (RGB)

    # Create a new array to hold the distorted image data
    distorted_data = np.empty_like(np_im)

    # Distort image with Perlin noise
    distort_intensity = 9
    distort_scale = 60.0 # higher is smoother
    seed = 123 # seed for noise

    # Create arrays to hold the distortions
    x_distort = np.empty((h,w), dtype=np.float64)
    y_distort = np.empty((h,w), dtype=np.float64)

    # Create a noise generator
    gen = OpenSimplex(seed=seed)

    # Generate Perlin noise and store the distortions
    for i in range(h):
        for j in range(w):
            # Generate Perlin noise based on the pixel coordinates and the seed
            n = gen.noise2(i / distort_scale, j / distort_scale)
            
            # Map the noise to [0, 1]
            n = (n + 1) / 2
            
            # Store the distortions
            x_distort[i, j] = n * distort_intensity
            y_distort[i, j] = n * distort_intensity

    # Calculate the mean distortions
    x_mean_distort = np.mean(x_distort)
    y_mean_distort = np.mean(y_distort)

    # Apply distortions and translate image back to its (approx) original position
    for i in range(h):
        for j in range(w):
            # distort pixel co-ords
            x = int((i + x_distort[i, j] - x_mean_distort) % h)
            y = int((j + y_distort[i, j] - y_mean_distort) % w)
            
            # Copy pixel data
            distorted_data[i, j] = np_im[x, y]

    # Create a new image from distorted data
    distorted_img = Image.fromarray(distorted_data)

    return distorted_img


In [None]:
def fn_check_pipeline():
    # Check if the pipeline is working
    img2img_pipeline = get_pipeline()
    #print(f"img2img_pipeline: {img2img_pipeline}")
    print(f"img2img_pipeline: {str(img2img_pipeline)[:100]}") # print the first 100 characters of the pipeline
    
    if img2img_pipeline is None:
        raise ValueError("pipeline for img2img has issues")
    return

fn_check_pipeline()     #test the pipeline

In [None]:
#create argument parser
class Args:
    init_image = None
    output_path = "output-frames"
    count = 50  # number of frames to generate
    prompt = "a box of rocks"  # replace with your value
    strength = 0.5 # was 0.65
    seed = None
    num_inference_steps = 15
    guidance_scale = 7
    skip_frame_generation = False
    gif = True
    no_interpolate = False # disable motion interpolation for final ffmpeg post-processing
    model = "SD15"

args = Args()

In [None]:
def replace_black_with_random_color(pixel): # does this work?
    if pixel.all == (0, 0, 0):
        return (np.random.randint(0, 255), np.random.randint(0, 255), np.random.randint(0, 255))
    return pixel

In [None]:
# main function to generate the images

if not args.skip_frame_generation:
    img2img_pipeline, text_pipeline = get_pipeline()
    print('pipelines have been gotten!')
    if not os.path.exists(args.output_path):
        os.makedirs(args.output_path)

    if args.init_image:
        # get initial image
        src_image = args.init_image

        # copy the initial image to the output path as frame 0
        init_image = Image.open(src_image).convert("RGB")
        init_image = init_image.resize((512,512))
        init_image.save(os.path.join(args.output_path, f'frame_{"0".zfill(4)}.png'))

    else:
        # generate first image from prompt
        #fn_check_pipeline()
        text_pipeline.enable_model_cpu_offload() # this MASSIVELY speeds up inference
        # next line doesn't work :(
        #text_pipeline.enable_xformers_memory_efficient_attention(attention_op=MemoryEfficientAttentionFlashAttentionOp)
        outimage = fn_generate(None, args.prompt, text_pipeline)
        outpath = os.path.join(args.output_path, f'frame_{"0".zfill(4)}.png')
        outimage.save(outpath)
        src_image = outpath

    
    # Generate the images
    print ('starting the clock!...')
    tick = time.time_ns()

    for frame_num in range(args.count + 1):
        frame_id = str(frame_num + 1).zfill(4)
        output_file_path_and_name = os.path.join(args.output_path, f"frame_{frame_id}.png")

        print(f"Generating image for {src_image} to {output_file_path_and_name}...")
        outimage = fn_generate(src_image, args.prompt, img2img_pipeline)
        
        outimage.save(output_file_path_and_name)       # save the ORIGINAL image FIRST

        #APPLY TRANSFORMS HERE
        #=====================
        outimage = fn_process_image(outimage)    

        outimage.save(output_file_path_and_name)       # save the 3D rotated image
        src_image = output_file_path_and_name          # set the source image for the next iteration
    
    tock = time.time_ns()
    baseline = f"{(tock - tick) / 1e9:.1f}" # convert to seconds
    print(f"Execution time -- {baseline} seconds\n")

In [None]:
# post-processing

if True:  # args.gif:
    images = []

    # maybe skip first frame, since it can be overpowering? (all frames used for now)
    for i in range(0, args.count + 1):
        images.append( imageio.imread( os.path.join(args.output_path, f"frame_{str(i).zfill(4)}.png") ) )

    imageio.mimsave("output.gif", images, duration=1.0)
    print("GIF saved to output.gif")
    imageio.mimsave("output.mp4", images, fps=2)
    print("MP4 saved to output.mp4")

    if not args.no_interpolate:
        # check if ffmpeg is available
        if os.system("ffmpeg -version") != 0:
            print("ffmpeg not found, skipping interpolation")
            exit(1)

        # interpolate the gif - CARE on quotes formatting in the filter string...
        print("Interpolating frames to 4fps...")
        os.system( f'ffmpeg -i output.mp4 -vf "minterpolate=\'mi_mode=mci:mc_mode=aobmc:vsbmc=1:fps=4\'" output_4fps.mp4' )
        print("Interpolating frames to 8fps...")
        os.system( f'ffmpeg -i output_4fps.mp4 -vf "minterpolate=\'mi_mode=mci:mc_mode=aobmc:vsbmc=1:fps=8\'" output_8fps.mp4' )
        print(f"Looping final video a total of {LOOP_COUNT+1} times...")
        os.system( f'ffmpeg -stream_loop {LOOP_COUNT} -i output_8fps.mp4 -c copy "{args.prompt}.mp4"' )
        print("Interpolation done. files saved.")