In [1]:
import torch
import utils
import transformer
import os
import time
import cv2
import numpy as np
from PIL import Image
import torchvision.models as models
import torchvision.transforms as transforms

In [2]:
# set configuration
STYLE_MODELS_DIR = "transforms/"
VIDEOS_INPUT_DIR = "videos/input"
VIDEOS_OUTPUT_DIR = "videos/output"
TEMPORAL_WEIGHT = 0.5  # Weight for temporal consistency

In [3]:
def ensure_dir(directory):
    """Create directory if it doesn't exist"""
    if not os.path.exists(directory):
        os.makedirs(directory)

def load_image_from_array(img_array):
    """Load PIL image from numpy array"""
    return Image.fromarray(img_array)

class ObjectDetector:
    def __init__(self, device=None):
        """Initialize object detector model"""
        self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu")
        print(f"Using device for object detection: {self.device}")
        
        # Load pre-trained model
        print("Loading object detection model...")
        self.model = models.detection.maskrcnn_resnet50_fpn(pretrained=True)
        self.model = self.model.to(self.device)
        self.model.eval()
        
        # Define transformation
        self.transform = transforms.Compose([
            transforms.ToTensor()
        ])
        
        # COCO class names
        self.class_names = [
            '__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
            'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign',
            'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
            'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A',
            'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
            'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
            'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
            'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
            'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table',
            'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
            'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book',
            'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
        ]
    
    def detect_objects(self, frame):
        """
        Detect objects in a frame
        
        Args:
            frame: OpenCV BGR frame
            
        Returns:
            List of detected objects with boxes, masks, and class info
        """
        # Convert BGR to RGB
        rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        
        # Convert to tensor
        img_tensor = self.transform(rgb_frame).to(self.device)
        
        # Perform inference
        with torch.no_grad():
            prediction = self.model([img_tensor])
        
        # Process results
        objects = []
        for i in range(len(prediction[0]['boxes'])):
            score = prediction[0]['scores'][i].cpu().item()
            
            # Filter by confidence
            if score > 0.5:
                box = prediction[0]['boxes'][i].cpu().numpy().astype(np.int32)
                label_id = prediction[0]['labels'][i].cpu().item()
                label = self.class_names[label_id]
                
                # Get mask if available
                mask = None
                if 'masks' in prediction[0]:
                    if len(prediction[0]['masks']) > i:
                        mask = prediction[0]['masks'][i, 0].cpu().numpy() > 0.5
                
                objects.append({
                    'box': box,
                    'label': label,
                    'score': score,
                    'mask': mask
                })
        
        return objects

class RegionBasedStyleTransfer:
    def __init__(self, style_foreground_path, style_background_path, device=None, 
                preserve_color_fg=False, preserve_color_bg=False):
        """Initialize the region-based style transfer model"""
        self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu")
        print(f"Using device for style transfer: {self.device}")
        
        # Load transformer networks
        if style_foreground_path is not None:
            print("Loading foreground style model...")
            self.net_fg = transformer.TransformerNetwork()
            self.net_fg.load_state_dict(torch.load(style_foreground_path))
            self.net_fg = self.net_fg.to(self.device)
            self.net_fg.eval()
        else:
            print("Foreground will use original pixels (no style)")
            self.net_fg = None
        
        if style_background_path is not None:
            print("Loading background style model...")
            self.net_bg = transformer.TransformerNetwork()
            self.net_bg.load_state_dict(torch.load(style_background_path))
            self.net_bg = self.net_bg.to(self.device)
            self.net_bg.eval()
        else:
            print("Background will use original pixels (no style)")
            self.net_bg = None
        
        # Initialize object detector
        self.detector = ObjectDetector(device)
        
        # Style settings
        self.preserve_color_fg = preserve_color_fg
        self.preserve_color_bg = preserve_color_bg
        
        # Temporal consistency
        self.prev_stylized_fg = None
        self.prev_stylized_bg = None
    
    def stylize_region(self, frame, is_foreground=True, apply_temporal=True):
        """
        Apply style transfer to a region of a frame
        
        Args:
            frame: OpenCV frame in BGR format
            is_foreground: Whether to use foreground or background style
            apply_temporal: Whether to apply temporal consistency
            
        Returns:
            Stylized frame in BGR format
        """
        # Select appropriate network and settings
        if is_foreground:
            net = self.net_fg
            prev_stylized = self.prev_stylized_fg
            preserve_color = self.preserve_color_fg
        else:
            net = self.net_bg
            prev_stylized = self.prev_stylized_bg
            preserve_color = self.preserve_color_bg
            
        # Check if the corresponding network exists
        if net is None:
            return frame  # Return original frame if no style network
        
        # Convert BGR to RGB
        frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        
        # Convert to PIL Image
        content_image = load_image_from_array(frame_rgb)
        
        # Apply style transfer
        content_tensor = utils.itot(content_image).to(self.device)
        
        with torch.no_grad():
            stylized_tensor = net(content_tensor)
            
            # Apply temporal consistency if needed
            if apply_temporal and prev_stylized is not None:
                stylized_tensor = (1 - TEMPORAL_WEIGHT) * stylized_tensor + TEMPORAL_WEIGHT * prev_stylized
            
            # Store current output for next frame
            if is_foreground:
                self.prev_stylized_fg = stylized_tensor.clone()
            else:
                self.prev_stylized_bg = stylized_tensor.clone()
        
        # Convert back to image
        stylized_image = utils.ttoi(stylized_tensor.detach())
        
        if preserve_color:
            stylized_image = utils.transfer_color(content_image, stylized_image)
        
        # Convert back to BGR for OpenCV
        stylized_frame = cv2.cvtColor(np.array(stylized_image), cv2.COLOR_RGB2BGR)
        
        return stylized_frame
    
    def process_frame(self, frame, apply_temporal=True, target_classes=None):
        """
        Process a frame with region-based style transfer
        
        Args:
            frame: OpenCV frame in BGR format
            apply_temporal: Whether to apply temporal consistency
            target_classes: List of classes to style as foreground (None for all)
            
        Returns:
            Stylized frame in BGR format
        """
        # Get original frame dimensions
        height, width = frame.shape[:2]
        
        # Detect objects
        objects = self.detector.detect_objects(frame)
        
        # Create foreground mask (initialize with zeros)
        fg_mask = np.zeros((height, width), dtype=np.uint8)
        
        # Add detected objects to foreground mask if they match target classes
        for obj in objects:
            if target_classes is None or obj['label'] in target_classes:
                if obj['mask'] is not None:
                    # Use object mask
                    fg_mask = np.logical_or(fg_mask, obj['mask']).astype(np.uint8) * 255
                else:
                    # Use bounding box
                    x1, y1, x2, y2 = obj['box']
                    cv2.rectangle(fg_mask, (x1, y1), (x2, y2), 255, -1)
        
        # Dilate the mask slightly to avoid hard edges
        kernel = np.ones((5, 5), np.uint8)
        fg_mask = cv2.dilate(fg_mask, kernel, iterations=1)
        
        # Create background mask (inverse of foreground)
        bg_mask = 255 - fg_mask
        
        # Convert masks to 3-channel and proper scale for blending
        fg_mask_3c = cv2.cvtColor(fg_mask, cv2.COLOR_GRAY2BGR) / 255.0
        bg_mask_3c = cv2.cvtColor(bg_mask, cv2.COLOR_GRAY2BGR) / 255.0
        
        # Start with the original frame
        result = frame.copy()
        
        # Apply background style if provided
        if self.net_bg is not None:
            # Style the entire frame with background style
            bg_stylized = self.stylize_region(frame, is_foreground=False, apply_temporal=apply_temporal)
            # Only apply to background areas (where bg_mask is non-zero)
            result = result * (1.0 - bg_mask_3c) + bg_stylized * bg_mask_3c
        
        # Apply foreground style if provided
        if self.net_fg is not None:
            # Style the entire frame with foreground style
            fg_stylized = self.stylize_region(frame, is_foreground=True, apply_temporal=apply_temporal)
            # Only apply to foreground areas (where fg_mask is non-zero)
            result = result * (1.0 - fg_mask_3c) + fg_stylized * fg_mask_3c
        
        return result.astype(np.uint8)
    
    def process_video_headless(self, input_path, output_path, target_classes=None):
        """Process an entire video with region-based style transfer (no GUI)"""
        # Disable OpenCV GUI
        cv2.setUseOptimized(True)
        os.environ["OPENCV_VIDEOIO_PRIORITY_BACKEND"] = "0"
        os.environ["OPENCV_VIDEOIO_DEBUG"] = "0"
        
        # Open video file
        video = cv2.VideoCapture(input_path)
        if not video.isOpened():
            print(f"Error: Could not open video file {input_path}")
            return
        
        # Get video properties
        width = int(video.get(cv2.CAP_PROP_FRAME_WIDTH))
        height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT))
        fps = video.get(cv2.CAP_PROP_FPS)
        total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
        
        print(f"Video properties: {width}x{height}, {fps} fps, {total_frames} frames")
        
        # Create video writer
        ensure_dir(os.path.dirname(output_path))
        fourcc = cv2.VideoWriter_fourcc(*'mp4v')
        out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
        
        if not out.isOpened():
            print(f"Error: Could not create output video file {output_path}")
            video.release()
            return
        
        frame_count = 0
        processing_times = []
        
        print("Processing video...")
        start_time = time.time()
        
        while True:
            ret, frame = video.read()
            if not ret:
                break
            
            frame_start = time.time()
            
            # Process frame with region-based style transfer
            stylized_frame = self.process_frame(frame, apply_temporal=(frame_count > 0), 
                                             target_classes=target_classes)
            
            # Ensure frame is 8-bit unsigned integer format (CV_8U)
            if stylized_frame.dtype != np.uint8:
                stylized_frame = np.clip(stylized_frame, 0, 255).astype(np.uint8)
            
            # Write frame to output video
            out.write(stylized_frame)
            
            # Calculate processing time
            frame_time = time.time() - frame_start
            processing_times.append(frame_time)
            
            # Show progress
            frame_count += 1
            if frame_count % 10 == 0 or frame_count == total_frames:
                elapsed = time.time() - start_time
                avg_fps = frame_count / elapsed
                estimated_total = elapsed / frame_count * total_frames
                remaining = estimated_total - elapsed
                
                print(f"Processed frame {frame_count}/{total_frames} - "
                      f"Avg. FPS: {avg_fps:.2f} - "
                      f"Avg. time per frame: {np.mean(processing_times[-10:]):.3f}s - "
                      f"Remaining time: {remaining/60:.1f} min")
                
                # Option to save intermediate frames (useful for debugging)
                if frame_count % 100 == 0:
                    # Save debug frames - uncommenting these would save debug images
                    debug_dir = os.path.join(os.path.dirname(output_path), "debug_frames")
                    ensure_dir(debug_dir)
                    
                    # Save current frame
                    cv2.imwrite(os.path.join(debug_dir, f"frame_{frame_count:06d}.jpg"), frame)
                    cv2.imwrite(os.path.join(debug_dir, f"stylized_{frame_count:06d}.jpg"), stylized_frame)
        
        # Release resources
        video.release()
        out.release()
        
        # Print stats
        total_time = time.time() - start_time
        print(f"Video processing completed in {total_time:.2f} seconds")
        print(f"Average processing time per frame: {np.mean(processing_times):.3f} seconds")
        print(f"Total average FPS: {frame_count / total_time:.2f}")
        print(f"Stylized video saved to: {output_path}")
        
        return output_path

In [4]:
def region_based_style_transfer(input_path=None, output_path=None, 
                              fg_style_path=None, bg_style_path=None,
                              target_classes=None, preserve_color_fg=False, 
                              preserve_color_bg=False):
    """
    Apply region-based style transfer to a video (headless version for servers)
    
    Args:
        input_path: Path to input video
        output_path: Path to save stylized video
        fg_style_path: Path to foreground style model (None to preserve original objects)
        bg_style_path: Path to background style model (None to preserve original background)
        target_classes: List of object classes to treat as foreground
        preserve_color_fg: Whether to preserve original colors for foreground
        preserve_color_bg: Whether to preserve original colors for background
    """
    # Set default paths if not provided
    if input_path is None:
        input_path = os.path.join(VIDEOS_INPUT_DIR, "sample.mp4")
    
    if output_path is None:
        # Create output path based on input
        video_name = os.path.basename(input_path)
        output_path = os.path.join(VIDEOS_OUTPUT_DIR, f"stylized_regions_{video_name}")
    
    # Set default target classes if not provided
    if target_classes is None:
        # Default to people and pets as foreground
        target_classes = ['person', 'dog', 'cat', 'horse']
    
    # Initialize model
    model = RegionBasedStyleTransfer(
        style_foreground_path=fg_style_path, 
        style_background_path=bg_style_path, 
        preserve_color_fg=preserve_color_fg,
        preserve_color_bg=preserve_color_bg
    )
    
    # Process video without display
    return model.process_video_headless(input_path, output_path, target_classes)

In [5]:
# Example usage with style only for background
result = region_based_style_transfer(
    input_path="videos/input/ship.mp4",  # Change to your video file
    output_path="videos/output/bg_only_stylized_ship2.mp4",
    fg_style_path=None,  # No style for foreground (keep original)
    bg_style_path="transforms/wave.pth",  # Style for background
    target_classes=['boat'],  # Objects to detect as foreground
    preserve_color_fg=False,
    preserve_color_bg=False
)

print(f"Stylized video saved to: {result}")

Using device for style transfer: cuda
Foreground will use original pixels (no style)
Loading background style model...
Using device for object detection: cuda
Loading object detection model...




Video properties: 1920x1080, 29.97002997002997 fps, 528 frames
Processing video...
Processed frame 10/528 - Avg. FPS: 2.13 - Avg. time per frame: 0.459s - Remaining time: 4.1 min
Processed frame 20/528 - Avg. FPS: 2.40 - Avg. time per frame: 0.360s - Remaining time: 3.5 min
Processed frame 30/528 - Avg. FPS: 2.52 - Avg. time per frame: 0.356s - Remaining time: 3.3 min
Processed frame 40/528 - Avg. FPS: 2.59 - Avg. time per frame: 0.355s - Remaining time: 3.1 min
Processed frame 50/528 - Avg. FPS: 2.64 - Avg. time per frame: 0.345s - Remaining time: 3.0 min
Processed frame 60/528 - Avg. FPS: 2.68 - Avg. time per frame: 0.344s - Remaining time: 2.9 min
Processed frame 70/528 - Avg. FPS: 2.70 - Avg. time per frame: 0.349s - Remaining time: 2.8 min
Processed frame 80/528 - Avg. FPS: 2.73 - Avg. time per frame: 0.343s - Remaining time: 2.7 min
Processed frame 90/528 - Avg. FPS: 2.73 - Avg. time per frame: 0.357s - Remaining time: 2.7 min
Processed frame 100/528 - Avg. FPS: 2.73 - Avg. time 