In [1]:
import torch
import numpy as np
import cv2
from flownet2.models import FlowNet2
from flownet2.utils.frame_utils import read_gen

# Load pre-trained FlowNet2 model
device = "cuda" if torch.cuda.is_available() else "cpu"
flownet = FlowNet2().to(device)
flownet.load_state_dict(torch.load("FlowNet2_checkpoint.pth"))
flownet.eval()

def estimate_drift(frame1, frame2):
    # Preprocess frames
    frame1 = read_gen(frame1)
    frame2 = read_gen(frame2)
    
    frame1 = torch.tensor(frame1).permute(2, 0, 1).unsqueeze(0).float().to(device)
    frame2 = torch.tensor(frame2).permute(2, 0, 1).unsqueeze(0).float().to(device)
    
    # Estimate optical flow
    with torch.no_grad():
        flow = flownet(torch.cat([frame1, frame2], dim=1))

    # Compute average displacement
    flow_np = flow.squeeze().cpu().numpy()
    dx, dy = np.mean(flow_np[0]), np.mean(flow_np[1])
    
    return dx, dy

# Example usage
dx, dy = estimate_drift("frame1.png", "frame2.png")
print(f"Drone Drift: Δx = {dx} px, Δy = {dy} px")


ModuleNotFoundError: No module named 'flownet2'