In [None]:
# @title üé¨ **Video Interpolator (Supports 2x, 4x, 8x)**
# @markdown Run this to interpolate video while **preserving original duration**.

import os
import shutil
import time
import sys
import glob
from IPython.display import Video, display, clear_output

# ================= CONFIGURATION =================
INPUT_VIDEO = ""  # @param {type:"string"}
# @markdown Leave blank to auto-detect the newest video.

FRAME_MULTIPLIER = 2   # @param {type:"number"}
# @markdown * **2** = 32fps (Standard Smooth)
# @markdown * **4** = 64fps (Ultra Smooth)
# @markdown * **8** = 128fps (Extreme)

OUTPUT_NAME = "interpolated_4x.mp4" # @param {type:"string"}
CRF_QUALITY = 17 # @param {type:"slider", min:0, max:51, step:1}

# ================= SETUP =================
print("üöÄ Initializing Environment...")
rife_dir = "/content/Practical-RIFE"

if os.path.exists(rife_dir):
    shutil.rmtree(rife_dir)
os.makedirs(rife_dir, exist_ok=True)

# 1. Download Model Files
!wget -q https://huggingface.co/Isi99999/Frame_Interpolation_Models/resolve/main/4.25/train_log/IFNet_HDv3.py -O {rife_dir}/IFNet_HDv3.py
!wget -q https://huggingface.co/Isi99999/Frame_Interpolation_Models/resolve/main/4.25/train_log/RIFE_HDv3.py -O {rife_dir}/RIFE_HDv3.py
!wget -q https://huggingface.co/Isi99999/Frame_Interpolation_Models/resolve/main/4.25/train_log/refine.py -O {rife_dir}/refine.py
!wget -q https://huggingface.co/Isi99999/Frame_Interpolation_Models/resolve/main/4.25/train_log/flownet.pkl -O {rife_dir}/flownet.pkl

# 2. Dependencies
!pip install -q git+https://github.com/rk-exxec/scikit-video.git@numpy_deprecation
!apt -y install -qq aria2 ffmpeg > /dev/null

# ================= SHIMS & PATCHES =================
print("üîß Applying Multi-Frame Logic...")

with open(f"{rife_dir}/loss.py", "w") as f:
    f.write("import torch.nn as nn\nclass EPE(nn.Module):\n def __init__(self, a=None): super(EPE, self).__init__()\n def forward(self, f, g): return 0\nclass Sobel(nn.Module):\n def __init__(self): super(Sobel, self).__init__()\n def forward(self, i, g): return 0\nclass SOBEL(Sobel): pass")

with open(f"{rife_dir}/warplayer.py", "w") as f:
    f.write("import torch\nimport torch.nn.functional as F\ndevice=torch.device('cuda')\ndef warp(tenInput, tenFlow):\n backwarp_tenGrid = {}\n if str(tenFlow.shape) not in backwarp_tenGrid:\n  tenHorizontal = torch.linspace(-1.0, 1.0, tenFlow.shape[3]).view(1, 1, 1, tenFlow.shape[3]).expand(tenFlow.shape[0], -1, tenFlow.shape[2], -1)\n  tenVertical = torch.linspace(-1.0, 1.0, tenFlow.shape[2]).view(1, 1, tenFlow.shape[2], 1).expand(tenFlow.shape[0], -1, -1, tenFlow.shape[3])\n  backwarp_tenGrid[str(tenFlow.shape)] = torch.cat([tenHorizontal, tenVertical], 1).to(device)\n tenFlow = torch.cat([tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0), tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0)], 1)\n g = (backwarp_tenGrid[str(tenFlow.shape)] + tenFlow).permute(0, 2, 3, 1)\n return torch.nn.functional.grid_sample(input=tenInput, grid=g, mode='bilinear', padding_mode='border', align_corners=True)")

def patch_file(filename, old, new):
    path = f"{rife_dir}/{filename}"
    if os.path.exists(path):
        with open(path, 'r') as f: c = f.read()
        if old in c:
            with open(path, 'w') as f: f.write(c.replace(old, new))

patch_file("RIFE_HDv3.py", "from train_log.IFNet_HDv3", "from IFNet_HDv3")
patch_file("RIFE_HDv3.py", "from model.loss", "from loss")
patch_file("RIFE_HDv3.py", "from model.warplayer", "from warplayer")
patch_file("IFNet_HDv3.py", "from model.warplayer", "from warplayer")
patch_file("IFNet_HDv3.py", "from model.loss", "from loss")

# ================= RECURSIVE INFERENCE SCRIPT =================
# This logic supports 4x, 8x, etc. by recursively filling gaps
inference_script = """
import sys
sys.path.append('.')
import cv2
import torch
import argparse
import numpy as np
import warnings
import time
from torch.nn import functional as F
warnings.filterwarnings("ignore")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def make_inference(model, I0, I1, n):
    # This recursive function handles 4x, 8x interpolation
    if n == 1:
        return []

    # Generate middle frame (0.5)
    if model.version >= 3.9:
        mid = model.inference(I0, I1, 0.5, 1.0)
    else:
        mid = model.inference(I0, I1)

    # Recursively fill left gap (0 to 0.5) and right gap (0.5 to 1.0)
    return make_inference(model, I0, mid, n//2) + [mid] + make_inference(model, mid, I1, n//2)

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--video', type=str, required=True)
    parser.add_argument('--output', type=str, required=True)
    parser.add_argument('--multi', type=int, default=2)
    parser.add_argument('--scale', type=float, default=1.0)
    args = parser.parse_args()

    # Load Model
    try:
        from RIFE_HDv3 import Model
        model = Model()
        model.load_model('.', -1)
        model.eval()
        model.device()
        print(f"‚úÖ RIFE Model Loaded")
    except Exception as e:
        print(f"‚ùå Error: {e}")
        sys.exit(1)

    # Input Setup
    cap = cv2.VideoCapture(args.video)
    fps = cap.get(cv2.CAP_PROP_FPS)
    tot = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))

    if fps == 0: fps = 24.0

    # 64-Pixel Padding
    s_h, s_w = int(h * args.scale), int(w * args.scale)
    ph = ((s_h - 1) // 64 + 1) * 64
    pw = ((s_w - 1) // 64 + 1) * 64
    padding = (0, pw - s_w, 0, ph - s_h)

    # Calculate Target FPS (Maintains original duration)
    target_fps = fps * args.multi

    print(f"‚ú® Input: {w}x{h} @ {fps} FPS")
    print(f"‚ú® Output: {args.multi}x Frames @ {target_fps} FPS (Duration Locked)")

    writer = cv2.VideoWriter(args.output, cv2.VideoWriter_fourcc(*'mp4v'), target_fps, (s_w, s_h))

    last = None
    cnt = 0
    start_time = time.time()

    with torch.inference_mode():
        while True:
            ret, frame = cap.read()
            if not ret: break

            frame = cv2.resize(frame, (s_w, s_h))
            I1 = torch.from_numpy(np.transpose(frame, (2,0,1))).to(device, non_blocking=True).unsqueeze(0).float() / 255.
            I1 = F.pad(I1, padding)

            if last is not None:
                # RECURSIVE CALL for 4x support
                mid_frames = make_inference(model, last, I1, args.multi)

                # Write intermediate frames
                for mid in mid_frames:
                    mid = mid[0].cpu().numpy()
                    mid = np.transpose(mid, (1, 2, 0))
                    mid = (mid * 255).astype(np.uint8)
                    mid = mid[:s_h, :s_w] # Crop
                    writer.write(mid)

            # Write Real Frame
            writer.write(frame)
            last = I1
            cnt += 1
            if cnt % 5 == 0:
                print(f"   Frame {cnt}/{tot}...", end='\\r')

    cap.release()
    writer.release()
    print(f"\\n‚úÖ Processed {cnt} input frames in {time.time() - start_time:.2f}s")
"""
with open(f"{rife_dir}/inference_video.py", "w") as f:
    f.write(inference_script)

# ================= EXECUTION =================
if not INPUT_VIDEO:
    mp4s = sorted(glob.glob("/content/*.mp4") + glob.glob("/content/ComfyUI/output/*.mp4"), key=os.path.getctime)
    mp4s = [f for f in mp4s if "interp" not in f and "final" not in f]
    INPUT_VIDEO = mp4s[-1] if mp4s else None

if not INPUT_VIDEO:
    print("‚ùå No input video found.")
    sys.exit()

print(f"üé¨ Source: {INPUT_VIDEO}")
temp_out = f"{rife_dir}/temp_out.mp4"
final_out = f"/content/{OUTPUT_NAME}"

os.environ["XDG_RUNTIME_DIR"] = "/tmp"
os.environ["SDL_AUDIODRIVER"] = "dummy"

%cd {rife_dir}
!python3 inference_video.py --multi={FRAME_MULTIPLIER} --video="{INPUT_VIDEO}" --scale=1.0 --output="{temp_out}"
%cd /content

if os.path.exists(temp_out):
    print("‚öôÔ∏è Finalizing (H.264)...")
    !ffmpeg -i "{temp_out}" -c:v libx264 -crf {CRF_QUALITY} -preset fast -y "{final_out}" -loglevel error
    print(f"üíæ Saved: {final_out}")
    display(Video(final_out, embed=True))
else:
    print("‚ùå Failed.")