In [1]:
# Clone official RVRT repo
!git clone https://github.com/JingyunLiang/RVRT.git
%cd RVRT

# Install dependencies and build the package
!apt-get update && apt-get install -y libgl1-mesa-glx
!pip install -r requirements.txt
!pip install torch torchvision torchaudio
!python setup.py develop

Cloning into 'RVRT'...
remote: Enumerating objects: 48, done.[K
remote: Counting objects: 100% (11/11), done.[K
remote: Compressing objects: 100% (8/8), done.[K
remote: Total 48 (delta 4), reused 3 (delta 3), pack-reused 37 (from 1)[K
Receiving objects: 100% (48/48), 2.87 MiB | 5.78 MiB/s, done.
Resolving deltas: 100% (5/5), done.
/content/RVRT
Get:1 https://cloud.r-project.org/bin/linux/ubuntu jammy-cran40/ InRelease [3,632 B]
Get:2 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64  InRelease [1,581 B]
Hit:3 http://archive.ubuntu.com/ubuntu jammy InRelease
Get:4 http://archive.ubuntu.com/ubuntu jammy-updates InRelease [128 kB]
Get:5 http://security.ubuntu.com/ubuntu jammy-security InRelease [129 kB]
Get:6 https://r2u.stat.illinois.edu/ubuntu jammy InRelease [6,555 B]
Get:7 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64  Packages [1,776 kB]
Get:8 http://archive.ubuntu.com/ubuntu jammy-backports InRelease [127 kB]
Hit:9 https://pp

In [2]:
import os
import cv2
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from PIL import Image
import argparse
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')

class SpatialAttention(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.conv = nn.Conv2d(dim, dim, 3, 1, 1, groups=dim)
        self.norm = nn.GroupNorm(1, dim)

    def forward(self, x):
        attn = self.conv(x)
        attn = self.norm(attn)
        attn = torch.sigmoid(attn)
        return x * attn

class ChannelAttention(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.fc = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(dim, dim // 4, 1),
            nn.ReLU(inplace=True),
            nn.Conv2d(dim // 4, dim, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return x * self.fc(x)

class ConvLSTMCell(nn.Module):
    def __init__(self, input_dim, hidden_dim, kernel_size=3):
        super().__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim

        padding = kernel_size // 2

        # Combined convolution for all gates
        self.conv = nn.Conv2d(
            in_channels=input_dim + hidden_dim,
            out_channels=4 * hidden_dim,  # i, f, o, g gates
            kernel_size=kernel_size,
            padding=padding,
            bias=True
        )

    def forward(self, input_tensor, cur_state):
        h_cur, c_cur = cur_state

        # Concatenate input and hidden state
        combined = torch.cat([input_tensor, h_cur], dim=1)

        # Apply convolution
        combined_conv = self.conv(combined)

        # Split into gates
        cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=1)

        # Apply activations
        i = torch.sigmoid(cc_i)
        f = torch.sigmoid(cc_f)
        o = torch.sigmoid(cc_o)
        g = torch.tanh(cc_g)

        # Update cell state
        c_next = f * c_cur + i * g
        h_next = o * torch.tanh(c_next)

        return h_next, c_next

class FeatureExtractor(nn.Module):
    def __init__(self, in_channels=3, out_channels=64):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, 32, 3, 1, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1, 1)
        self.conv3 = nn.Conv2d(64, out_channels, 3, 1, 1)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.relu(self.conv3(x))
        return x

class FeatureRefinement(nn.Module):
    def __init__(self, channels=64):
        super().__init__()
        self.spatial_attn = SpatialAttention(channels)
        self.channel_attn = ChannelAttention(channels)
        self.conv_refine = nn.Sequential(
            nn.Conv2d(channels, channels, 3, 1, 1),
            nn.ReLU(inplace=True),
            nn.Conv2d(channels, channels, 3, 1, 1)
        )

    def forward(self, x):
        residual = x
        x = self.spatial_attn(x)
        x = self.channel_attn(x)
        x = self.conv_refine(x)
        return x + residual

class Reconstructor(nn.Module):
    def __init__(self, in_channels=64, out_channels=3):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, 32, 3, 1, 1)
        self.conv2 = nn.Conv2d(32, 16, 3, 1, 1)
        self.conv3 = nn.Conv2d(16, out_channels, 3, 1, 1)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = torch.tanh(self.conv3(x))  # Output in [-1, 1]
        return x

class RVRT(nn.Module):
    def __init__(self, channels=64):
        super().__init__()
        self.channels = channels

        # Feature extraction
        self.feature_extractor = FeatureExtractor(3, channels)

        # ConvLSTM for temporal modeling
        self.conv_lstm = ConvLSTMCell(channels, channels)

        # Feature refinement
        self.feature_refinement = FeatureRefinement(channels)

        # Reconstruction
        self.reconstructor = Reconstructor(channels, 3)

        # Initialize hidden states
        self.hidden_state = None
        self.cell_state = None

    def init_hidden(self, batch_size, height, width, device):
        self.hidden_state = torch.zeros(batch_size, self.channels, height, width, device=device)
        self.cell_state = torch.zeros(batch_size, self.channels, height, width, device=device)

    def forward(self, x):
        # x shape: (B, T, C, H, W) or (B, C, H, W) for single frame
        if len(x.shape) == 4:
            # Single frame
            x = x.unsqueeze(1)  # Add time dimension

        B, T, C, H, W = x.shape
        device = x.device

        # Initialize hidden states if needed
        if self.hidden_state is None or self.hidden_state.shape[0] != B:
            self.init_hidden(B, H, W, device)

        outputs = []

        for t in range(T):
            frame = x[:, t]  # (B, C, H, W)

            # Extract features
            features = self.feature_extractor(frame)  # (B, channels, H, W)

            # Temporal modeling with ConvLSTM
            self.hidden_state, self.cell_state = self.conv_lstm(
                features, (self.hidden_state, self.cell_state)
            )

            # Refine features
            refined_features = self.feature_refinement(self.hidden_state)

            # Reconstruct frame
            output_frame = self.reconstructor(refined_features)

            outputs.append(output_frame)

        return torch.stack(outputs, dim=1)  # (B, T, C, H, W)

    def reset_states(self):
        self.hidden_state = None
        self.cell_state = None

class VideoDeblurrer:
    def __init__(self, model_path=None, device='cuda' if torch.cuda.is_available() else 'cpu'):
        self.device = device
        self.model = RVRT(channels=64).to(device)

        if model_path and os.path.exists(model_path):
            try:
                self.model.load_state_dict(torch.load(model_path, map_location=device))
                print(f"‚úÖ Loaded pre-trained model from {model_path}")
            except Exception as e:
                print(f"‚ö†Ô∏è Could not load model: {e}")
                print("Using randomly initialized model")
        else:
            print("üîß Using randomly initialized model (for demonstration)")

        self.model.eval()

    def preprocess_frame(self, frame):
        # Convert BGR to RGB and normalize to [-1, 1]
        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        frame = frame.astype(np.float32) / 127.5 - 1.0  # Normalize to [-1, 1]
        return frame

    def postprocess_frame(self, frame):
        # Denormalize from [-1, 1] to [0, 255] and convert RGB to BGR
        frame = (frame + 1.0) * 127.5
        frame = np.clip(frame, 0, 255).astype(np.uint8)
        frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
        return frame

    def deblur_video(self, input_path, output_path, max_frames_per_batch=5):
        if not os.path.exists(input_path):
            raise FileNotFoundError(f"Input video not found: {input_path}")

        # Open input video
        cap = cv2.VideoCapture(input_path)
        if not cap.isOpened():
            raise ValueError(f"Could not open video: {input_path}")

        # Get video properties
        fps = int(cap.get(cv2.CAP_PROP_FPS))
        width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
        height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
        total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

        print(f"üìπ Video Info: {width}x{height}, {fps} FPS, {total_frames} frames")

        # Setup output video writer
        fourcc = cv2.VideoWriter_fourcc(*'mp4v')
        out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))

        frames_buffer = []
        processed_count = 0

        # Reset model states
        self.model.reset_states()

        print("üöÄ Starting video processing...")

        with torch.no_grad():
            while True:
                ret, frame = cap.read()
                if not ret:
                    break

                # Preprocess frame
                processed_frame = self.preprocess_frame(frame)
                frames_buffer.append(processed_frame)

                # Process when buffer is full or at end of video
                if len(frames_buffer) == max_frames_per_batch or processed_count + len(frames_buffer) == total_frames:
                    try:
                        # Convert to tensor (B=1, T, C, H, W)
                        batch_frames = np.stack(frames_buffer, axis=0)  # (T, H, W, C)
                        batch_frames = torch.from_numpy(batch_frames).permute(0, 3, 1, 2).unsqueeze(0)  # (1, T, C, H, W)
                        batch_frames = batch_frames.to(self.device)

                        print(f"Processing batch: {batch_frames.shape}")

                        # Process through model
                        deblurred_frames = self.model(batch_frames)  # (1, T, C, H, W)

                        # Convert back to numpy
                        deblurred_frames = deblurred_frames.squeeze(0).permute(0, 2, 3, 1).cpu().numpy()  # (T, H, W, C)

                        # Write frames
                        for i in range(len(frames_buffer)):
                            output_frame = self.postprocess_frame(deblurred_frames[i])
                            out.write(output_frame)
                            processed_count += 1

                            if processed_count % 30 == 0:
                                progress = processed_count / total_frames * 100
                                print(f"üìä Progress: {processed_count}/{total_frames} ({progress:.1f}%)")

                    except Exception as e:
                        print(f"‚ö†Ô∏è Error processing batch: {e}")
                        # Fallback: write original frames
                        for frame_data in frames_buffer:
                            # Convert back to [0, 255] range
                            fallback_frame = (frame_data + 1.0) * 127.5
                            fallback_frame = np.clip(fallback_frame, 0, 255).astype(np.uint8)
                            fallback_frame = cv2.cvtColor(fallback_frame, cv2.COLOR_RGB2BGR)
                            out.write(fallback_frame)
                            processed_count += 1

                    frames_buffer.clear()

        # Cleanup
        cap.release()
        out.release()
        print(f"‚úÖ Video deblurring completed!")
        print(f"üìÅ Output saved to: {output_path}")

# Simple function for easy usage
def deblur_video_simple(input_path, output_path=None, model_path=None):
    """
    Simple function to deblur a video

    Args:
        input_path (str): Path to the blurred input video
        output_path (str, optional): Path for the deblurred output video
        model_path (str, optional): Path to pre-trained model weights
    """
    print("üé¨ RVRT Video Deblurring System")
    print("=" * 50)

    if not os.path.exists(input_path):
        print(f"‚ùå Error: File not found - {input_path}")
        return False

    # Generate output path if not provided
    if output_path is None:
        input_path_obj = Path(input_path)
        output_path = str(input_path_obj.parent / f"{input_path_obj.stem}_deblurred.mp4")

    print(f"üìπ Input: {input_path}")
    print(f"üíæ Output: {output_path}")

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f"üîß Device: {device}")

    # Initialize deblurrer
    deblurrer = VideoDeblurrer(model_path=model_path, device=device)

    try:
        deblurrer.deblur_video(input_path, output_path)
        print(f"\nüéâ SUCCESS! Deblurred video saved to:")
        print(f"üìÇ {output_path}")
        return True
    except Exception as e:
        print(f"‚ùå Error during processing: {e}")
        import traceback
        traceback.print_exc()
        return False

def main():
    parser = argparse.ArgumentParser(description='RVRT Video Deblurring')
    parser.add_argument('--input', type=str, required=True, help='Path to input blurred video')
    parser.add_argument('--output', type=str, default=None, help='Path to output deblurred video')
    parser.add_argument('--model', type=str, default=None, help='Path to pre-trained model (optional)')
    parser.add_argument('--device', type=str, default='auto', choices=['auto', 'cuda', 'cpu'])

    args = parser.parse_args()

    # Set device
    if args.device == 'auto':
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
    else:
        device = args.device

    # Set output path if not provided
    if args.output is None:
        input_path = Path(args.input)
        args.output = str(input_path.parent / f"{input_path.stem}_deblurred.mp4")

    print(f"Using device: {device}")

    # Initialize and run deblurrer
    deblurrer = VideoDeblurrer(model_path=args.model, device=device)
    deblurrer.deblur_video(args.input, args.output)

if __name__ == "__main__":
    # Check if we're in an interactive environment
    try:
        get_ipython()
        interactive_mode = True
    except NameError:
        interactive_mode = False

    if interactive_mode or len(os.sys.argv) == 1:
        # Interactive mode
        print("üé¨ RVRT Video Deblurring System")
        print("=" * 50)
        print("üí° Usage: deblur_video_simple('/path/to/video.mp4')")
        print("üí° Or enter video path below:")

        try:
            input_path = input("\nüìÅ Enter video path: ").strip().strip('"')
            if input_path:
                deblur_video_simple(input_path)
        except KeyboardInterrupt:
            print("\nüëã Goodbye!")
    else:
        main()

üé¨ RVRT Video Deblurring System
üí° Usage: deblur_video_simple('/path/to/video.mp4')
üí° Or enter video path below:

üìÅ Enter video path: /content/WhatsApp Video 2025-06-12 at 14.10.17_c5f5aa48.mp4
üé¨ RVRT Video Deblurring System
üìπ Input: /content/WhatsApp Video 2025-06-12 at 14.10.17_c5f5aa48.mp4
üíæ Output: /content/WhatsApp Video 2025-06-12 at 14.10.17_c5f5aa48_deblurred.mp4
üîß Device: cuda
üîß Using randomly initialized model (for demonstration)
üìπ Video Info: 1920x1080, 25 FPS, 372 frames
üöÄ Starting video processing...
Processing batch: torch.Size([1, 5, 3, 1080, 1920])
Processing batch: torch.Size([1, 5, 3, 1080, 1920])
Processing batch: torch.Size([1, 5, 3, 1080, 1920])
Processing batch: torch.Size([1, 5, 3, 1080, 1920])
Processing batch: torch.Size([1, 5, 3, 1080, 1920])
Processing batch: torch.Size([1, 5, 3, 1080, 1920])
üìä Progress: 30/372 (8.1%)
Processing batch: torch.Size([1, 5, 3, 1080, 1920])
Processing batch: torch.Size([1, 5, 3, 1080, 1920])
Proce