In [17]:
# video Style Transfer Implementation
import torch
import utils
import transformer
import os
import time
import cv2
import numpy as np
import re
from PIL import Image

In [18]:
# set configuration
STYLE_TRANSFORM_PATH = "transforms/starry.pth"  # style model
PRESERVE_COLOR = False
CONTENT_FRAMES_DIR = "frames/content_folder"
STYLE_FRAMES_DIR = "style_frames"
OUTPUT_FRAMES_DIR = "frames/output_folder"
VIDEOS_INPUT_DIR = "videos/input"
VIDEOS_OUTPUT_DIR = "videos/output"
TEMPORAL_WEIGHT = 0.5  # Weight for temporal consistency

In [19]:
def natural_sort_key(s):
    """
    Natural sort key function for sorting file names with numbers
    """
    return [int(text) if text.isdigit() else text.lower() for text in re.split(r'(\d+)', s)]

def ensure_dir(directory):
    """Create directory if it doesn't exist"""
    if not os.path.exists(directory):
        os.makedirs(directory)

def load_image_from_array(img_array):
    """Load PIL image from numpy array"""
    return Image.fromarray(img_array)

In [24]:
class VideoStyleTransfer:
    def __init__(self, style_model_path, device=None, preserve_color=False):
        """Initialize the video style transfer model"""
        self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu")
        print(f"Using device: {self.device}")
        
        # Load transformer network
        print("Loading transformer network...")
        self.net = transformer.TransformerNetwork()
        self.net.load_state_dict(torch.load(style_model_path))
        self.net = self.net.to(self.device)
        self.net.eval()  # Set to evaluation mode
        
        self.preserve_color = preserve_color
        self.prev_stylized = None  # Store previous stylized frame for temporal consistency
    
    def stylize_frame(self, frame, apply_temporal=True):
        """
        Stylize a single video frame
        
        Args:
            frame: OpenCV frame in BGR format
            apply_temporal: Whether to apply temporal consistency
            
        Returns:
            Stylized frame in BGR format
        """
        # Convert BGR to RGB
        frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        # Convert to PIL Image
        content_image = load_image_from_array(frame_rgb)
        
        # Apply style transfer
        content_tensor = utils.itot(content_image).to(self.device)
        
        with torch.no_grad():
            stylized_tensor = self.net(content_tensor)
        
        # Apply temporal consistency if needed
        if apply_temporal and self.prev_stylized is not None:
            stylized_tensor = (1 - TEMPORAL_WEIGHT) * stylized_tensor + TEMPORAL_WEIGHT * self.prev_stylized
        
        # Store current output for next frame
        self.prev_stylized = stylized_tensor.clone()
        
        # Convert back to image
        stylized_image = utils.ttoi(stylized_tensor.detach())
        
        if self.preserve_color:
            stylized_image = utils.transfer_color(content_image, stylized_image)
        
        # Convert back to BGR for OpenCV
        stylized_frame = cv2.cvtColor(np.array(stylized_image), cv2.COLOR_RGB2BGR)
        
        return stylized_frame
    
    def process_video(self, input_path, output_path):
        """Process an entire video with style transfer"""
        # Open video file
        video = cv2.VideoCapture(input_path)
        if not video.isOpened():
            print(f"Error: Could not open video file {input_path}")
            return
        
        # Get video properties
        width = int(video.get(cv2.CAP_PROP_FRAME_WIDTH))
        height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT))
        fps = video.get(cv2.CAP_PROP_FPS)
        total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
        
        print(f"Video properties: {width}x{height}, {fps} fps, {total_frames} frames")
        
        # Create output directory if it doesn't exist
        ensure_dir(os.path.dirname(output_path))
        
        # Create video writer
        fourcc = cv2.VideoWriter_fourcc(*'mp4v')
        out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
        
        if not out.isOpened():
            print(f"Error: Could not create output video file {output_path}")
            video.release()
            return
        
        frame_count = 0
        processing_times = []
        
        print("Processing video...")
        start_time = time.time()
        
        while True:
            ret, frame = video.read()
            if not ret:
                break
            
            frame_start = time.time()
            
            # Apply style transfer
            stylized_frame = self.stylize_frame(frame, apply_temporal=(frame_count > 0))
            
            # Ensure frame is 8-bit unsigned integer format (CV_8U)
            if stylized_frame.dtype != np.uint8:
                stylized_frame = np.clip(stylized_frame, 0, 255).astype(np.uint8)
            
            # Write frame to output video
            out.write(stylized_frame)
            
            # Calculate processing time
            frame_time = time.time() - frame_start
            processing_times.append(frame_time)
            
            # Show progress
            frame_count += 1
            if frame_count % 10 == 0 or frame_count == total_frames:
                elapsed = time.time() - start_time
                estimated_total = elapsed / frame_count * total_frames
                remaining = estimated_total - elapsed
                print(f"Processed frame {frame_count}/{total_frames} - "
                      f"Avg. time per frame: {np.mean(processing_times):.3f}s - "
                      f"Remaining time: {remaining/60:.1f} min")
        
        # Release resources
        video.release()
        out.release()
        
        # Print stats
        total_time = time.time() - start_time
        print(f"Video processing completed in {total_time:.2f} seconds")
        print(f"Average processing time per frame: {np.mean(processing_times):.3f} seconds")
        print(f"Stylized video saved to: {output_path}")
        
        return output_path
    
    def process_frame_directory(self, input_dir, output_dir):
        """Process all frames in a directory"""
        # Get all frame files
        frame_files = glob.glob(os.path.join(input_dir, "*.jpg")) + \
                     glob.glob(os.path.join(input_dir, "*.png"))
        frame_files = sorted(frame_files, key=natural_sort_key)
        
        if not frame_files:
            print(f"Error: No frames found in {input_dir}")
            return
        
        ensure_dir(output_dir)
        
        total_frames = len(frame_files)
        processing_times = []
        
        print(f"Processing {total_frames} frames...")
        start_time = time.time()
        
        for i, frame_path in enumerate(frame_files):
            frame_start = time.time()
            
            # Read frame
            frame = cv2.imread(frame_path)
            
            # Apply style transfer
            stylized_frame = self.stylize_frame(frame, apply_temporal=(i > 0))
            
            # Ensure frame is 8-bit unsigned integer format (CV_8U)
            if stylized_frame.dtype != np.uint8:
                stylized_frame = np.clip(stylized_frame, 0, 255).astype(np.uint8)
            
            # Save stylized frame
            output_path = os.path.join(output_dir, os.path.basename(frame_path))
            cv2.imwrite(output_path, stylized_frame)
            
            # Calculate processing time
            frame_time = time.time() - frame_start
            processing_times.append(frame_time)
            
            # Show progress
            if (i + 1) % 10 == 0 or (i + 1) == len(frame_files):
                elapsed = time.time() - start_time
                estimated_total = elapsed / (i + 1) * len(frame_files)
                remaining = estimated_total - elapsed
                print(f"Processed frame {i+1}/{len(frame_files)} - "
                      f"Avg. time per frame: {np.mean(processing_times):.3f}s - "
                      f"Remaining time: {remaining/60:.1f} min")
        
        # Print stats
        total_time = time.time() - start_time
        print(f"Frame processing completed in {total_time:.2f} seconds")
        print(f"Average processing time per frame: {np.mean(processing_times):.3f} seconds")
        print(f"Stylized frames saved to: {output_dir}")
        
        return output_dir
    
    def extract_process_combine(self, input_path, output_path, temp_dir=None):
        """
        Extract frames, process them, and recombine into a video
        This approach can be more memory-efficient for very large videos
        """
        # Create temp directory if needed
        if temp_dir is None:
            temp_dir = "temp_frames"
        ensure_dir(temp_dir)
        
        # Extract frames from video
        video = cv2.VideoCapture(input_path)
        if not video.isOpened():
            print(f"Error: Could not open video file {input_path}")
            return
        
        # Get video properties
        width = int(video.get(cv2.CAP_PROP_FRAME_WIDTH))
        height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT))
        fps = video.get(cv2.CAP_PROP_FPS)
        total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
        
        print(f"Video properties: {width}x{height}, {fps} fps, {total_frames} frames")
        print(f"Extracting frames to {temp_dir}...")
        
        # Extract frames
        frame_count = 0
        while True:
            ret, frame = video.read()
            if not ret:
                break
            
            frame_path = os.path.join(temp_dir, f"frame_{frame_count:06d}.jpg")
            cv2.imwrite(frame_path, frame)
            
            frame_count += 1
            if frame_count % 100 == 0:
                print(f"Extracted {frame_count}/{total_frames} frames")
        
        video.release()
        print(f"Extracted {frame_count} frames")
        
        # Process frames
        print("Processing frames...")
        output_dir = os.path.join(temp_dir, "stylized")
        ensure_dir(output_dir)
        
        frame_paths = [os.path.join(temp_dir, f) for f in os.listdir(temp_dir) 
                      if f.startswith("frame_") and f.endswith(".jpg")]
        frame_paths = sorted(frame_paths, key=natural_sort_key)
        
        start_time = time.time()
        processing_times = []
        
        for i, frame_path in enumerate(frame_paths):
            frame_start = time.time()
            
            # Read frame
            frame = cv2.imread(frame_path)
            
            # Apply style transfer
            stylized_frame = self.stylize_frame(frame, apply_temporal=(i > 0))
            
            # Ensure frame is 8-bit unsigned integer format (CV_8U)
            if stylized_frame.dtype != np.uint8:
                stylized_frame = np.clip(stylized_frame, 0, 255).astype(np.uint8)
            
            # Save stylized frame
            output_path_frame = os.path.join(output_dir, os.path.basename(frame_path))
            cv2.imwrite(output_path_frame, stylized_frame)
            
            # Calculate processing time
            frame_time = time.time() - frame_start
            processing_times.append(frame_time)
            
            # Show progress
            if (i + 1) % 10 == 0 or (i + 1) == len(frame_paths):
                elapsed = time.time() - start_time
                estimated_total = elapsed / (i + 1) * len(frame_paths)
                remaining = estimated_total - elapsed
                print(f"Processed frame {i+1}/{len(frame_paths)} - "
                      f"Avg. time per frame: {np.mean(processing_times):.3f}s - "
                      f"Remaining time: {remaining/60:.1f} min")
        
        # Recombine frames into video
        print("Recombining frames into video...")
        
        # Get all stylized frames
        stylized_paths = [os.path.join(output_dir, f) for f in os.listdir(output_dir) 
                         if f.startswith("frame_") and f.endswith(".jpg")]
        stylized_paths = sorted(stylized_paths, key=natural_sort_key)
        
        # Create video writer
        ensure_dir(os.path.dirname(output_path))
        fourcc = cv2.VideoWriter_fourcc(*'mp4v')
        out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
        
        if not out.isOpened():
            print(f"Error: Could not create output video file {output_path}")
            return
        
        for frame_path in stylized_paths:
            frame = cv2.imread(frame_path)
            out.write(frame)
        
        out.release()
        
        print(f"Video saved to {output_path}")
        
        # Ask if user wants to delete temporary frames
        response = input("Delete temporary frames? (y/n): ")
        if response.lower() == 'y':
            import shutil
            shutil.rmtree(temp_dir)
            print(f"Deleted temporary frames directory: {temp_dir}")
        
        return output_path

In [25]:
def stylize_video(input_path=None, output_path=None, style_path=None, preserve_color=False, 
                 temporal_weight=0.5, method="direct"):
    """
    Top-level function to stylize a video
    
    Args:
        input_path: Path to input video
        output_path: Path to save stylized video
        style_path: Path to style model
        preserve_color: Whether to preserve original colors
        temporal_weight: Weight for temporal consistency (0-1)
        method: "direct" or "extract" for processing method
    """
    # Set default paths if not provided
    if input_path is None:
        input_path = os.path.join(VIDEOS_INPUT_DIR, "boat.mp4")
    
    if output_path is None:
        # Create output path based on input
        video_name = os.path.basename(input_path)
        output_path = os.path.join(VIDEOS_OUTPUT_DIR, f"stylized_{video_name}")
    
    if style_path is None:
        style_path = STYLE_TRANSFORM_PATH
    
    # Update global variables
    global PRESERVE_COLOR, TEMPORAL_WEIGHT
    PRESERVE_COLOR = preserve_color
    TEMPORAL_WEIGHT = temporal_weight
    
    # Initialize model
    model = VideoStyleTransfer(style_path, preserve_color=preserve_color)
    
    # Process video
    if method == "direct":
        return model.process_video(input_path, output_path)
    else:
        return model.extract_process_combine(input_path, output_path)

def stylize_frames(input_dir=None, output_dir=None, style_path=None, preserve_color=False,
                  temporal_weight=0.5):
    """
    Top-level function to stylize frames in a directory
    
    Args:
        input_dir: Directory containing input frames
        output_dir: Directory to save stylized frames
        style_path: Path to style model
        preserve_color: Whether to preserve original colors
        temporal_weight: Weight for temporal consistency (0-1)
    """
    # Set default paths if not provided
    if input_dir is None:
        input_dir = CONTENT_FRAMES_DIR
    
    if output_dir is None:
        output_dir = OUTPUT_FRAMES_DIR
    
    if style_path is None:
        style_path = STYLE_TRANSFORM_PATH
    
    # Update global variables
    global PRESERVE_COLOR, TEMPORAL_WEIGHT
    PRESERVE_COLOR = preserve_color
    TEMPORAL_WEIGHT = temporal_weight
    
    # Initialize model
    model = VideoStyleTransfer(style_path, preserve_color=preserve_color)
    
    # Process frames
    return model.process_frame_directory(input_dir, output_dir)

def frames_to_video(frames_dir, output_path, fps=30):
    """
    Convert a directory of frames to a video
    
    Args:
        frames_dir: Directory containing frame images
        output_path: Path to save output video
        fps: Frames per second for output video
    """
    # Get all frame files
    frame_files = glob.glob(os.path.join(frames_dir, "*.jpg")) + \
                 glob.glob(os.path.join(frames_dir, "*.png"))
    frame_files = sorted(frame_files, key=natural_sort_key)
    
    if not frame_files:
        print(f"Error: No frames found in {frames_dir}")
        return
    
    # Get frame dimensions from the first frame
    frame = cv2.imread(frame_files[0])
    height, width, _ = frame.shape
    
    # Create video writer
    ensure_dir(os.path.dirname(output_path))
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
    
    if not out.isOpened():
        print(f"Error: Could not create output video file {output_path}")
        return
    
    print(f"Converting {len(frame_files)} frames to video...")
    
    for frame_path in frame_files:
        frame = cv2.imread(frame_path)
        out.write(frame)
    
    out.release()
    print(f"Video saved to: {output_path}")
    return output_path

In [26]:
# Ensure directories exist
ensure_dir(VIDEOS_INPUT_DIR)
ensure_dir(VIDEOS_OUTPUT_DIR)

# Example 1: Stylize a video directly
print("EXAMPLE 1: Direct video stylization")
input_video = os.path.join(VIDEOS_INPUT_DIR, "coffee_shop.mp4")
output_video = os.path.join(VIDEOS_OUTPUT_DIR, "coffee_shop_starry.mp4")

stylized_path = stylize_video(
    input_path=input_video,
    output_path=output_video,
    style_path=STYLE_TRANSFORM_PATH,
    preserve_color=False,
    temporal_weight=0.5,
    method="direct"
)
print(f"Stylized video saved to: {stylized_path}")

EXAMPLE 1: Direct video stylization
Using device: cuda
Loading transformer network...
Video properties: 1920x1080, 29.97002997002997 fps, 301 frames
Processing video...
Processed frame 10/301 - Avg. time per frame: 0.154s - Remaining time: 0.8 min
Processed frame 20/301 - Avg. time per frame: 0.155s - Remaining time: 0.8 min
Processed frame 30/301 - Avg. time per frame: 0.156s - Remaining time: 0.7 min
Processed frame 40/301 - Avg. time per frame: 0.157s - Remaining time: 0.7 min
Processed frame 50/301 - Avg. time per frame: 0.157s - Remaining time: 0.7 min
Processed frame 60/301 - Avg. time per frame: 0.156s - Remaining time: 0.6 min
Processed frame 70/301 - Avg. time per frame: 0.158s - Remaining time: 0.6 min
Processed frame 80/301 - Avg. time per frame: 0.158s - Remaining time: 0.6 min
Processed frame 90/301 - Avg. time per frame: 0.160s - Remaining time: 0.6 min
Processed frame 100/301 - Avg. time per frame: 0.160s - Remaining time: 0.5 min
Processed frame 110/301 - Avg. time per 