In [1]:
import os

model_path = "best_model_8_8.pth"  # or use the absolute path
if not os.path.exists(model_path):
    print(f"Error: The file at {model_path} does not exist.")
else:
    print(f"Model found at {model_path}")


Model found at best_model_8_8.pth


In [None]:
import os
import cv2
import torch
import numpy as np
import time
from torchvision import transforms
import torch.nn as nn

# ------------------- Helper Function -------------------
def center_crop(enc_feat, target_size):
    _, _, h, w = enc_feat.size()
    target_h, target_w = target_size
    start_h = (h - target_h) // 2
    start_w = (w - target_w) // 2
    return enc_feat[:, :, start_h:start_h+target_h, start_w:start_w+target_w]

# ------------------- Model Definition -------------------
class DeepDehazeNet8Layer(nn.Module):
    def __init__(self):
        super(DeepDehazeNet8Layer, self).__init__()

        def conv_block(in_channels, out_channels):
            return nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
                nn.BatchNorm2d(out_channels),
                nn.ReLU(inplace=True),
                nn.Dropout(0.2)
            )
        self.enc1 = conv_block(3, 64)
        self.pool1 = nn.MaxPool2d(2)
        self.enc2 = conv_block(64, 128)
        self.pool2 = nn.MaxPool2d(2)
        self.enc3 = conv_block(128, 256)
        self.pool3 = nn.MaxPool2d(2)

        self.bottleneck = conv_block(256, 512)

        self.up1 = nn.ConvTranspose2d(512, 256, 2, 2)
        self.dec1 = conv_block(512, 256)
        self.up2 = nn.ConvTranspose2d(256, 128, 2, 2)
        self.dec2 = conv_block(256, 128)
        self.up3 = nn.ConvTranspose2d(128, 64, 2, 2)
        self.dec3 = conv_block(128, 64)

        self.final = nn.Conv2d(64, 3, 1)

    def forward(self, x):
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool1(e1))
        e3 = self.enc3(self.pool2(e2))
        b = self.bottleneck(self.pool3(e3))

        u1 = self.up1(b)
        e3 = center_crop(e3, u1.shape[2:])
        d1 = self.dec1(torch.cat([u1, e3], 1))

        u2 = self.up2(d1)
        e2 = center_crop(e2, u2.shape[2:])
        d2 = self.dec2(torch.cat([u2, e2], 1))

        u3 = self.up3(d2)
        e1 = center_crop(e1, u3.shape[2:])
        d3 = self.dec3(torch.cat([u3, e1], 1))

        return torch.sigmoid(self.final(d3))

# ------------------- Setup -------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = DeepDehazeNet8Layer().to(device).half()
model.load_state_dict(torch.load("best_model_8_8.pth", map_location=device))
model.eval()

transform = transforms.Compose([transforms.ToTensor()])

video_path = "input_video.mp4"
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
    raise IOError("Cannot open video")

resize_dim = (512, 512)
os.makedirs("results", exist_ok=True)

frame_count = 0
total_infer_time = 0
start_time = time.time()

video_fps = cap.get(cv2.CAP_PROP_FPS)
frame_width = resize_dim[0] * 2
frame_height = resize_dim[1]
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
output_video_path = os.path.join("results", "dehazed_output.mp4")
out = cv2.VideoWriter(output_video_path, fourcc, video_fps, (frame_width, frame_height))

# For high-quality export
hq_frames = []

# ------------------- Real-time Processing -------------------
while cap.isOpened():
    ret, frame = cap.read()
    if not ret:
        break

    frame_resized = cv2.resize(frame, resize_dim)
    frame_rgb = cv2.cvtColor(frame_resized, cv2.COLOR_BGR2RGB)
    input_tensor = transform(frame_rgb).unsqueeze(0).to(device).half()

    infer_start = time.time()
    with torch.no_grad():
        output = model(input_tensor)
    infer_end = time.time()

    infer_time = infer_end - infer_start
    total_infer_time += infer_time
    frame_count += 1

    output_image = output.squeeze().float().cpu().numpy().transpose(1, 2, 0)
    output_image = np.clip(output_image * 255, 0, 255).astype(np.uint8)
    output_bgr = cv2.cvtColor(output_image, cv2.COLOR_RGB2BGR)

    fps_current = 1 / infer_time if infer_time > 0 else 0
    combined = np.hstack((frame_resized, output_bgr))

    # cv2.putText(combined, f"FPS: {fps_current:.2f}",
                # (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 255, 0), 2)

    out.write(combined)
    hq_frames.append(combined)

    cv2.imshow("Original | Dehazed", combined)
    if cv2.waitKey(1) & 0xFF == ord('q'):
        break

# ------------------- Final Stats -------------------
end_time = time.time()
avg_fps = frame_count / (end_time - start_time)
avg_infer = (total_infer_time / frame_count) * 1000

print(f"\nTotal Frames: {frame_count}")
print(f"Average FPS: {avg_fps:.2f}")
# print(f"Average Inference Time: {avg_infer:.2f} ms/frame")

# with open("results/benchmark_stats.txt", "w") as f:
    # f.write(f"Total Frames: {frame_count}\n")
    # f.write(f"Average FPS: {avg_fps:.2f}\n")
    # f.write(f"Average Inference Time: {avg_infer:.2f} ms/frame\n")

cap.release()
out.release()
cv2.destroyAllWindows()

# ------------------- High-Quality Export -------------------
# Export using OpenCV
hq_output_path = "results/output_video_hq.mp4"
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out_hq = cv2.VideoWriter(hq_output_path, fourcc, video_fps, (frame_width, frame_height))

for frame in hq_frames:
    out_hq.write(frame)

out_hq.release()
print(f"High-quality video saved to: {hq_output_path}")



Total Frames: 1340
Average FPS: 13.62
High-quality video saved to: results/output_video_hq.mp4
