# Generate optical flow from an image

In [2]:
import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

import datetime
from diffusers import UNet2DModel, DDPMScheduler
from models.motion_synthesis import VQModel_, generate_spectrum
from utils import *

if torch.cuda.is_available():
    DEVICE = torch.device("cuda:0")
elif torch.backends.mps.is_available():
    DEVICE = torch.device("mps")
else:
    DEVICE = torch.device("cpu")

transforms = FrameSpectrumProcessing(num_freq=16)
noise_scheduler = DDPMScheduler(num_train_timesteps=1000, beta_start=0.0015, beta_end=0.0195, beta_schedule="scaled_linear")
vae = VQModel_.from_pretrained("CompVis/ldm-celebahq-256", subfolder="vqvae").to(DEVICE).eval()
unet = UNet2DModel.from_pretrained("data/models/unet").to(DEVICE).eval()

out_dir = "data/unet_samples"
if not os.path.exists(out_dir):
    os.makedirs(out_dir)

print(torch.cuda.get_device_name(0))
print("CUDA available:", torch.cuda.is_available())
print("Device name:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "N/A")
print("Allocated:", round(torch.cuda.memory_allocated(0)/1024**2, 2), "MB")
print("Reserved:", round(torch.cuda.memory_reserved(0)/1024**2, 2), "MB")

  from .autonotebook import tqdm as notebook_tqdm
Cannot initialize model with low cpu memory usage because `accelerate` was not found in the environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install `accelerate` for faster and less memory-intense model loading. You can do so with: 
```
pip install accelerate
```
.
Cannot initialize model with low cpu memory usage because `accelerate` was not found in the environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install `accelerate` for faster and less memory-intense model loading. You can do so with: 
```
pip install accelerate
```
.


Quadro RTX 6000
CUDA available: True
Device name: Quadro RTX 6000
Allocated: 1280.12 MB
Reserved: 1312.0 MB


# Web server on port 5000

In [9]:
import os
import datetime
import numpy as np
from flask import Flask, request, render_template, send_from_directory, redirect, url_for
from werkzeug.utils import secure_filename

app = Flask(__name__)
UPLOAD_FOLDER = 'static/uploads'
OUT_FOLDER = 'static/results'
os.makedirs(UPLOAD_FOLDER, exist_ok=True)
os.makedirs(OUT_FOLDER, exist_ok=True)

app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER

@app.route('/', methods=['GET', 'POST'])
def index():
    if request.method == 'POST':
        file = request.files['image']
        filename = secure_filename(file.filename)
        path = os.path.join(app.config['UPLOAD_FOLDER'], filename)
        file.save(path)

        frame_np = get_image(path, width=256, height=160, crop=True)
        frame = transforms.process_frame(frame_np).unsqueeze(0).to(DEVICE)

        num_steps = 100
        sample_name = "_".join([
            os.path.splitext(filename)[0],
            "ddpm" + str(num_steps),
            datetime.datetime.now().isoformat().replace(":", "_")
        ])

        spec_np = generate_spectrum(vae, unet, noise_scheduler, frame, num_steps=num_steps, batch_size=1)

        npy_path = os.path.join(OUT_FOLDER, sample_name + ".npy")
        save_npy(spec_np, npy_path)

        # Prepare frame_np for visualization
        if isinstance(frame_np, torch.Tensor):
            frame_np = frame_np.detach().cpu().numpy()
        if frame_np.ndim == 3 and frame_np.shape[0] == 3:
            frame_np = np.transpose(frame_np, (1, 2, 0))
        if frame_np.ndim == 2:
            frame_np = np.stack([frame_np] * 3, axis=-1)
        if frame_np.dtype != np.uint8:
            frame_np = (frame_np * 255).astype(np.uint8) if frame_np.max() <= 1.0 else frame_np.astype(np.uint8)

        spec_image, video = visualize_sample(frame_np, spec_np, transforms, magnification=5.0, include_flow=True)

        image_path = os.path.join(OUT_FOLDER, sample_name + ".png")
        video_path = os.path.join(OUT_FOLDER, sample_name + ".mp4")

        spec_image.save(image_path)
        video.write_videofile(video_path, logger=None)

        return redirect(url_for('index', image=image_path, video=video_path))

    # GET request — fetch image/video from query params if available
    image_file = request.args.get('image')
    video_file = request.args.get('video')

    return render_template("index.html", image_file=image_file, video_file=video_file)

if __name__ == '__main__':
    app.run(host='0.0.0.0', port=5001)


 * Serving Flask app '__main__'
 * Debug mode: off


 * Running on all addresses (0.0.0.0)
 * Running on http://127.0.0.1:5001
 * Running on http://137.112.104.31:5001
Press CTRL+C to quit
137.112.211.153 - - [23/May/2025 15:34:49] "GET / HTTP/1.1" 200 -
100%|██████████| 16/16 [01:14<00:00,  4.65s/it]
137.112.211.153 - - [23/May/2025 15:36:40] "POST / HTTP/1.1" 302 -
137.112.211.153 - - [23/May/2025 15:36:40] "GET /?image=static/results/Domestic_Goose_ddpm100_2025-05-23T15_35_22.062839.png&video=static/results/Domestic_Goose_ddpm100_2025-05-23T15_35_22.062839.mp4 HTTP/1.1" 200 -
137.112.211.153 - - [23/May/2025 15:36:40] "GET /static/results/Domestic_Goose_ddpm100_2025-05-23T15_35_22.062839.png HTTP/1.1" 200 -
137.112.211.153 - - [23/May/2025 15:36:40] "GET /static/results/Domestic_Goose_ddpm100_2025-05-23T15_35_22.062839.mp4 HTTP/1.1" 206 -
