In [1]:
import glob
import cv2
import torch
import torch.backends.cudnn as cudnn
from l2cs import select_device, Pipeline, render
import time, pickle
import pathlib
from tqdm import tqdm
import numpy as np

CWD = pathlib.Path.cwd()
path = 'C:/Users/wnsdh/Downloads/snuwet_test'
webcam_records = glob.glob(path + '/webcam_record_*.mp4')

# 시선 추적을 위한 파이프라인 초기화
gaze_pipeline = Pipeline(
    weights=CWD / 'models' / 'L2CSNet_gaze360.pkl',
    arch='ResNet50',
    device=torch.device('cuda')
)

BATCH_SIZE = 32  # 배치 크기 설정

for webcam_record in webcam_records:
    print(webcam_record)
    
    gazeData = []  # 시선 데이터를 저장할 리스트
    frame_buffer = []  # 프레임을 모아둘 버퍼
    timestamp_buffer = []  # 타임스탬프를 모아둘 버퍼
    
    cap = cv2.VideoCapture(webcam_record)
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    
    with torch.no_grad():
        pbar = tqdm(range(total_frames))
        for _ in pbar:
            success, frame = cap.read()
            if not success:
                break
                
            timestamp = cap.get(cv2.CAP_PROP_POS_MSEC) / 1000
            frame_buffer.append(frame)
            timestamp_buffer.append(timestamp)
            
            # 버퍼가 배치 크기만큼 찼거나 마지막 프레임일 때 처리
            if len(frame_buffer) >= BATCH_SIZE or _ == total_frames - 1:
                # 배치 처리
                try:
                    frame_batch = np.stack(frame_buffer)
                    results_batch = gaze_pipeline.step_batch(frame_batch)
                    for timestamp, results in zip(timestamp_buffer, results_batch):
                        gazeData.append((timestamp, results))
                except Exception as e:
                    print(e)
                
                # 버퍼 초기화
                frame_buffer = []
                timestamp_buffer = []
            
            # 검출률 표시 업데이트
            detection_rate = len([g for g in gazeData if g[1].pitch.size > 0]) / (len(gazeData)+1) * 100
            pbar.set_description(f"검출률: {detection_rate:.1f}%")
    
    cap.release()
    pickle.dump(gazeData, open(webcam_record.replace('.mp4', '.pkl').replace('webcam_record_', 'gaze_data_'), 'wb'))

  self.model.load_state_dict(torch.load(self.weights, map_location=device))


C:/Users/wnsdh/Downloads/snuwet_test\webcam_record_250417-2235.mp4


검출률: 99.8%:   1%|▏         | 543/38333 [00:07<08:20, 75.46it/s]


KeyboardInterrupt: 