In [1]:
from mmedit.apis import init_model
print("BasicVSR++ is ready!")


  from .autonotebook import tqdm as notebook_tqdm
  from pkg_resources import packaging  # type: ignore[attr-defined]


BasicVSR++ is ready!


In [2]:
import os
import cv2
import torch
import numpy as np
from mmedit.apis import init_model, restoration_video_inference
from IPython.display import Video, display


In [24]:
# Paths
config_path = "/home/patrick/BasicVSR_PlusPlus/configs/basicvsr_plusplus_reds4.py"
checkpoint_path = "checkpoints/basicvsr_plusplus_c64n7_8x1_600k_reds4_20210217-db622b2f.pth"

print("Config and checkpoint ready!")

Config and checkpoint ready!


In [38]:
input_video = "input.mp4"       # <-- your video filename
output_video = "output_vsr.mp4" # final upscaled result

##############################################
# PREP
##############################################

device = "cuda" if torch.cuda.is_available() else "cpu"
frames_dir = "lq/video1"
output_frames_dir = "vsr_frames"

os.makedirs(frames_dir, exist_ok=True)
os.makedirs(output_frames_dir, exist_ok=True)

##############################################
# STEP 1 — Extract frames from input video
##############################################

cap = cv2.VideoCapture(input_video)
idx = 0

print("Extracting frames...")
while True:
    ret, frame = cap.read()
    if not ret:
        break
    cv2.imwrite(f"{frames_dir}/{idx:08d}.png", frame)
    idx += 1

cap.release()
print(f"Extracted {idx} frames into {frames_dir}")

##############################################
# STEP 2 — Load BasicVSR++ model
##############################################

print("Loading BasicVSR++ model...")
model = init_model(config_path, checkpoint_path, device=device)
print("Model loaded on:", device)

##############################################
# STEP 3 — Run super-resolution inference
##############################################

print("Running VSR inference...")

# Force the entire sequence to be processed at once
result = restoration_video_inference(
    model=model,
    img_dir=frames_dir,
    window_size=0,                # no sliding window
    start_idx=0,
    filename_tmpl="{:08d}.png",
    max_seq_len=None              # <<< IMPORTANT FIX
)

print("Inference complete. Saving output frames...")
print(type(result), getattr(result, 'shape', None))


# result shape = [1, T, 3, H, W]
if isinstance(result, torch.Tensor):
    frames = result.squeeze(0)            # -> [T, 3, H, W]
elif isinstance(result, list):
    frames = torch.stack(result).squeeze(0)
else:
    raise RuntimeError(f"Unexpected model output type: {type(result)}")

# Convert to [T, H, W, 3]
frames = frames.permute(0, 2, 3, 1)       # -> [T, H, W, 3]
frames = frames.cpu().numpy()

print("Model returned", frames.shape[0], "frames")

# Save frames to output directory
for i in range(frames.shape[0]):
    img = frames[i]

    # Normalize + convert RGB -> BGR for OpenCV
    img = img.clip(0, 1)
    img_bgr = (img * 255).astype("uint8")[:, :, ::-1]

    cv2.imwrite(f"{output_frames_dir}/{i:08d}.png", img_bgr)

print(f"Saved {frames.shape[0]} upscaled frames into {output_frames_dir}")

##############################################
# STEP 4 — Reassemble output video
##############################################

print("Reassembling upscaled video...")

frame_files = sorted(os.listdir(output_frames_dir))
print("Found frames:", len(frame_files))

if len(frame_files) == 0:
    raise RuntimeError("ERROR: No output frames were generated!")

first = cv2.imread(os.path.join(output_frames_dir, frame_files[0]))
h, w, _ = first.shape

fourcc = cv2.VideoWriter_fourcc(*"mp4v")
out = cv2.VideoWriter(output_video, fourcc, 30, (w, h))

for f in frame_files:
    frame = cv2.imread(os.path.join(output_frames_dir, f))
    out.write(frame)

out.release()

print("Saved upscaled video to:", output_video)

##############################################
# STEP 5 — Display result inline (optional)
##############################################

print("Showing result:")
display(Video(output_video, width=512))


2025-11-13 17:59:51,537 - mmedit - INFO - load checkpoint from http path: https://download.openmmlab.com/mmediting/restorers/basicvsr/spynet_20210409-c6c1bd09.pth


Extracting frames...
Extracted 91 frames into lq/video1
Loading BasicVSR++ model...
load checkpoint from local path: checkpoints/basicvsr_plusplus_c64n7_8x1_600k_reds4_20210217-db622b2f.pth
Model loaded on: cuda
Running VSR inference...
Inference complete. Saving output frames...
<class 'torch.Tensor'> torch.Size([1, 91, 3, 576, 704])
Model returned 91 frames
Saved 91 upscaled frames into vsr_frames
Reassembling upscaled video...
Found frames: 91
Saved upscaled video to: output_vsr.mp4
Showing result:
