In [9]:
import cv2
import time
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.widgets import RectangleSelector
from scipy.ndimage import gaussian_filter
from scipy.optimize import curve_fit
from skimage.feature import peak_local_max
from scipy.spatial import distance
from scipy.optimize import minimize
import matplotlib.cm as cm
from concurrent.futures import ProcessPoolExecutor
from scipy.spatial import KDTree

In [10]:
# 定义简化的二维高斯函数
def gaussian2d(xy, A, x0, y0, sigma_x, sigma_y, B):
    x, y = xy
    return A * np.exp(-((x - x0) ** 2 / (2 * sigma_x ** 2) + (y - y0) ** 2 / (2 * sigma_y ** 2))) + B

# 处理单帧图像
def process_frame(frame, roi=None, sigma=2):
    gray_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
    
    # 对整张图进行高斯滤波
    filtered_img = gaussian_filter(gray_frame, sigma=sigma)
    coordinates = peak_local_max(filtered_img, min_distance=5, threshold_abs=100)
    
    fit_results = []
    for (y, x) in coordinates:
        cut_size = 10
        x0, x1 = x - cut_size, x + cut_size + 1
        y0, y1 = y - cut_size, y + cut_size + 1
        if x0 < 0 or x1 > filtered_img.shape[1] or y0 < 0 or y1 > filtered_img.shape[0]:
            continue
        sub_img = filtered_img[y0:y1, x0:x1]
        x_grid, y_grid = np.meshgrid(np.arange(x0, x1), np.arange(y0, y1))
        xdata = np.vstack((x_grid.ravel(), y_grid.ravel()))
        initial_guess = (sub_img.max(), x, y, 3, 3, sub_img.min())
        try:
            popt, _ = curve_fit(gaussian2d, xdata, sub_img.ravel(), p0=initial_guess)
            fit_results.append(popt)
        except RuntimeError:
            continue

    positions = np.array([[fit[1], fit[2]] for fit in fit_results], dtype=np.float32)
    
    # 如果有设置ROI，筛选出位置在ROI内的原子
    if roi is not None:
        x, y, w, h = roi
        positions = np.array([pos for pos in positions if x <= pos[0] <= x + w and y <= pos[1] <= y + h], dtype=np.float32)
    
    return positions

def average_nearest_neighbor(positions):
    if len(positions) < 2:
        return 0
    dists = distance.cdist(positions, positions, 'euclidean')
    np.fill_diagonal(dists, np.inf)
    nearest_neighbor_dists = np.min(dists, axis=1)
    return np.mean(nearest_neighbor_dists)

def find_matches(prev_positions, curr_positions, radius):
    matches_prev = []
    matches_curr = []
    used = np.zeros(len(curr_positions), dtype=bool)

    tree = KDTree(curr_positions)
    for i, pos in enumerate(prev_positions):
        dists, indices = tree.query(pos, k=1)
        if dists <= radius and not used[indices]:
            matches_prev.append(pos)
            matches_curr.append(curr_positions[indices])
            used[indices] = True

    return np.array(matches_prev), np.array(matches_curr)


def track_atoms(frames, initial_positions, sigma, search_radius):
    tracked_positions = [initial_positions]
    lost_atoms_info = {i: {'lost_frames': 0, 'total_lost': 0} for i in range(len(initial_positions))}

    for i in range(1, len(frames)):
        current_frame = frames[i]
        previous_positions = tracked_positions[-1]
        current_positions = process_frame(current_frame, sigma=sigma)

        if len(current_positions) == 0:
            tracked_positions.append(previous_positions)
            for j in range(len(previous_positions)):
                lost_atoms_info[j]['lost_frames'] += 1
            continue

        matches_prev, matches_curr = find_matches(previous_positions, current_positions, search_radius)
        matched_indices = [np.where(np.all(previous_positions == match, axis=1))[0][0] for match in matches_prev]
        
        updated_positions = np.array(previous_positions, copy=True)
        for idx, match in zip(matched_indices, matches_curr):
            updated_positions[idx] = match
            lost_atoms_info[idx]['total_lost'] += lost_atoms_info[idx]['lost_frames']
            lost_atoms_info[idx]['lost_frames'] = 0
        
        for idx in range(len(updated_positions)):
            if idx not in matched_indices:
                lost_atoms_info[idx]['lost_frames'] += 1

        tracked_positions.append(updated_positions)

    # 使每一帧的原子位置数量相同
    max_length = max(len(positions) for positions in tracked_positions)
    for i, positions in enumerate(tracked_positions):
        if len(positions) < max_length:
            padding = np.array([positions[-1]] * (max_length - len(positions)))
            tracked_positions[i] = np.vstack([positions, padding])

    return np.array(tracked_positions), lost_atoms_info

def plot_atom_trajectories(frame, atoms, title, save_path=None):
    plt.figure(figsize=(5, 5))
    plt.imshow(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
    plt.axis('off')  # 隐藏坐标轴
    
    colors = cm.rainbow(np.linspace(0, 1, len(atoms)))  # 使用颜色映射
    
    for idx, atom in enumerate(atoms):
        initial_position = atom[0]
        plt.plot(initial_position[0], initial_position[1], 'ro', markersize=3)  # 初始位置标记为红色
        color = colors[idx]
        for i in range(1, len(atom)):
            prev_position = atom[i - 1]
            current_position = atom[i]
            if np.array_equal(prev_position, current_position):
                plt.plot([prev_position[0], current_position[0]], [prev_position[1], current_position[1]], 
                        color=color, linewidth=1, linestyle='dotted')
            else:
                plt.plot([prev_position[0], current_position[0]], [prev_position[1], current_position[1]],
                        color=color, linewidth=1)
    
    plt.title(title)
    if save_path:
        plt.savefig(save_path, bbox_inches='tight', pad_inches=0)
    plt.show()

# 加载视频文件
def load_video(file_path):
    cap = cv2.VideoCapture(file_path)
    frame_rate = cap.get(cv2.CAP_PROP_FPS)
    frames = []
    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break
        frames.append(frame)
    cap.release()
    return frames, frame_rate

# 用户选择ROI区域
def onselect(eclick, erelease):
    global roi
    x1, y1 = int(eclick.xdata), int(eclick.ydata)
    x2, y2 = int(erelease.xdata), int(erelease.ydata)
    roi = (x1, y1, x2 - x1, y2 - y1)
    print(f"ROI selected: {roi}")

# 显示固定帧并选择ROI
def select_roi_jupyter(frame):
    fig, ax = plt.subplots(1, figsize=(5, 5))
    ax.imshow(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
    rs = RectangleSelector(ax, onselect, interactive=True)
    plt.show()
    return rs

def display_atom_positions(frame, positions, roi=None, save_path=None):
    plt.figure(figsize=(5, 5))
    plt.imshow(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
    plt.axis('off')  # 隐藏坐标轴
    for (x, y) in positions:
        plt.plot(x, y, 'ro', markersize=5)
    print(roi)
    if roi is not None:
        x, y, w, h = roi
        plt.gca().add_patch(plt.Rectangle((x, y), w, h, fill=False, edgecolor='red', linewidth=2))
    if save_path:
        plt.savefig(save_path, bbox_inches='tight', pad_inches=0)
    plt.show()
    
def calculate_frame_cost(i, total_shifts, tracked_positions):
    prev_positions = tracked_positions[i - 1] + total_shifts[i - 1]
    current_positions = tracked_positions[i]
    cost = np.sum(np.sqrt((prev_positions[:, 0] - current_positions[:, 0])**2 +
                          (prev_positions[:, 1] - current_positions[:, 1])**2))
    return cost

def calculate_shifts(frames, tracked_positions, cost_output_path=None):
    cost_history = []

    def objective_function(shifts):
        total_shifts = np.reshape(shifts, (-1, 2))
        costs = [
            calculate_frame_cost(i, total_shifts, tracked_positions) for i in range(1, len(frames))
        ]
        total_cost = np.sum(costs)
        cost_history.append(total_cost)
        return total_cost

    initial_shifts = np.zeros((len(frames) - 1, 2))
    start_time = time.time()  # 开始计时
    result = minimize(objective_function, initial_shifts.flatten(), method='L-BFGS-B')
    end_time = time.time()  # 结束计时
    optimized_shifts = np.reshape(result.x, (-1, 2))

    # 打印计算时间
    print(f"Calculating optimal shifts took {end_time - start_time:.2f} seconds.")

    # 保存 total_cost 变化为 Excel 文件
    cost_df = pd.DataFrame(cost_history, columns=['Total Cost'])
    cost_df.to_excel(cost_output_path, index=False)

    return optimized_shifts

def stabilize_video(frames, shifts):
    stabilized_frames = []
    transform = np.array([0.0, 0.0])
    for i, frame in enumerate(frames):
        if i < len(shifts):
            transform += shifts[i]
        translation_matrix = np.array([[1, 0, -transform[0]], [0, 1, -transform[1]]], dtype=np.float32)
        stabilized_frame = cv2.warpAffine(frame, translation_matrix, (frame.shape[1], frame.shape[0]))
        stabilized_frames.append(stabilized_frame)
    return stabilized_frames

# 保存视频
def save_video(frames, output_path, frame_rate):
    height, width, layers = frames[0].shape
    video = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'mp4v'), frame_rate, (width, height))
    for frame in frames:
        video.write(frame)
    video.release()

## 主程序 

In [11]:
os.cpu_count()

24

In [12]:
# 处理和显示视频
file_path = 'input.mp4'  # 替换为你的视频文件路径
save_path = 'stabilized.mp4'  # 替换为你想保存的文件路径

roi_atom_path = 'roi_atom_positions.png'
atom_path = 'atom_positions.png'

trajectory_path = 'origin_atoms_trajectory.png'
stablized_trajectory_path = 'stablized_atoms_trajectory.png'

shift_record = 'shift.xlsx'

frames, frame_rate = load_video(file_path)

# %matplotlib notebook
%matplotlib nbagg

# 显示固定帧并选择ROI
frame_index = 0  # 替换为你希望显示的帧的索引
frame = frames[frame_index]
roi = select_roi_jupyter(frame)  # 用户选择的ROI区域

<IPython.core.display.Javascript object>

In [13]:
# 获取初始帧中的原子位置
initial_positions = process_frame(frame, roi, sigma=2)
print(f"Initial positions: {initial_positions}")

# 显示并保存roi区域原子位置图像
display_atom_positions(frame, initial_positions, roi, roi_atom_path)


# 显示并保存所有原子位置图像
all_position = process_frame(frame, sigma=2)
display_atom_positions(frame, all_position, save_path = atom_path)

# test_pos = process_frame(frame)
# display_atom_positions(frame,test_pos)

Initial positions: [[ 94.68695  174.35635 ]
 [ 72.299255 175.22832 ]]


<IPython.core.display.Javascript object>

(64, 163, 39, 21)


<IPython.core.display.Javascript object>

None


In [14]:
# 计算search_radius
search_radius = average_nearest_neighbor(initial_positions)
print(f"Search radius: {search_radius}")

# 跟踪原子位置
tracked_positions, lost_atoms_info = track_atoms(frames, initial_positions, sigma=2, 
                                                         search_radius=search_radius/2)
# print(f"Tracked positions: {tracked_positions}")
# print(f"Shifts: {shifts}")

# 区分并导出超出边界的原子信息
lost_atoms = {k: v for k, v in lost_atoms_info.items() if v['total_lost'] > 0}
print("Lost Atoms Information:")
for atom_idx, info in lost_atoms.items():
    print(f"Atom {atom_idx} was lost for {info['total_lost']} frames.")
    
# 保存有和没有超出边界的原子位置
inside_atoms = [tracked_positions[:, i] for i in range(tracked_positions.shape[1]) 
                if lost_atoms_info[i]['total_lost'] == 0]
outside_atoms = [tracked_positions[:, i] for i in range(tracked_positions.shape[1]) 
                 if lost_atoms_info[i]['total_lost'] > 0]

# 显示没有超出边界的原子的运动轨迹
plot_atom_trajectories(frame, inside_atoms, "Inside Bounds Atoms Trajectories", 
                      save_path = trajectory_path)

# 显示超出边界的原子的运动轨迹
plot_atom_trajectories(frame, outside_atoms, "Outside Bounds Atoms Trajectories")

Search radius: 22.40466964167001
Lost Atoms Information:


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [7]:
# 计算最佳位移
optimized_shifts = calculate_shifts(frames, tracked_positions,shift_record)
# print(f"Optimized shifts: {optimized_shifts}")

# 稳定视频帧
stabilized_frames = stabilize_video(frames, optimized_shifts)

# 保存校正后的视频
save_video(stabilized_frames, save_path, frame_rate)

print(f"Stabilized video saved to {save_path}")

Calculating optimal shifts took 53.91 seconds.
Stabilized video saved to stabilized.mp4


In [8]:
# 使用校正后的视频帧重新识别原子位置
stabilized_tracked_positions, _ = track_atoms(stabilized_frames, initial_positions, 
                                                                 sigma=2, search_radius=search_radius/2)

# 保存有和没有超出边界的原子位置
inside_stabilized_atoms = [stabilized_tracked_positions[:, i] 
                           for i in range(stabilized_tracked_positions.shape[1]) 
                           if lost_atoms_info[i]['total_lost'] == 0]
outside_stabilized_atoms = [stabilized_tracked_positions[:, i] 
                            for i in range(stabilized_tracked_positions.shape[1]) 
                            if lost_atoms_info[i]['total_lost'] > 0]

# 显示校正后的视频上未超出边界原子的运动轨迹
plot_atom_trajectories(stabilized_frames[0], inside_stabilized_atoms, 
                       "Inside Bounds Atoms Trajectories on Stabilized Video", stablized_trajectory_path)

# 显示校正后的视频上超出边界原子的运动轨迹
plot_atom_trajectories(stabilized_frames[0], outside_stabilized_atoms,
                       "Outside Bounds Atoms Trajectories on Stabilized Video")

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>