In [None]:
import torch
import torchvision.transforms as transforms
import torch.nn as nn
from PIL import Image
import cv2

In [None]:
# 1. 디바이스 설정
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
class MaskClassifier(nn.Module):
    def __init__(self):
        super(MaskClassifier, self).__init__()
        
        # Feature Extraction - 더 얕은 구조로 변경
        self.features = nn.Sequential(
            # First Block
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),
            nn.Dropout2d(0.2),
            
            # Second Block
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),
            nn.Dropout2d(0.2),
            
            # Third Block
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),
            nn.Dropout2d(0.2),
        )
        
        # Classifier - 더 단순한 구조로 변경
        self.classifier = nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),
            nn.Linear(128, 2)
        )

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x

In [None]:
# 모델 로드
model = MaskClassifier().to(device)
model.load_state_dict(torch.load("mask_classifier.pth", map_location=device))
model.eval()

In [None]:
# 3. 이미지 전처리
# 전처리
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # ImageNet stats
])

In [None]:
# 4. Jetson Nano CSI 카메라를 사용한 실시간 추론
def infer_csi_camera(model):
    """
    Jetson Nano의 CSI 카메라를 활용한 실시간 추론 함수
    Args:
        model: 학습된 PyTorch 모델
    """
    # GStreamer 파이프라인 정의 (width=640, height=480)
    gst_pipeline = (
        "nvarguscamerasrc ! "
        "video/x-raw(memory:NVMM), width=640, height=480, format=(string)NV12, framerate=30/1 ! "
        "nvvidconv flip-method=0 ! "
        "video/x-raw, width=640, height=480, format=(string)BGRx ! "
        "videoconvert ! "
        "video/x-raw, format=(string)BGR ! appsink"
    )

    cap = cv2.VideoCapture(gst_pipeline, cv2.CAP_GSTREAMER)
    if not cap.isOpened():
        print("CSI 카메라를 열 수 없습니다.")
        return

    print("Press 'q' to quit.")

    while True:
        ret, frame = cap.read()
        if not ret:
            print("카메라 프레임을 읽을 수 없습니다.")
            break

        # 전체 프레임 전처리
        resized_frame = cv2.resize(frame, (112, 112))  # 모델 입력 크기로 조정
        input_tensor = transform(resized_frame).unsqueeze(0).to(device)

        # 모델 추론
        with torch.no_grad():
            output = model(input_tensor)
            _, pred = torch.max(output, 1)

        # 예측 결과 표시
        label = "With Mask" if pred.item() == 0 else "Without Mask"
        color = (0, 255, 0) if pred.item() == 0 else (0, 0, 255)  # Green: Mask, Red: No Mask

        cv2.putText(frame, label, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, color, 2)
        cv2.imshow("CSI Camera Inference", frame)

        # 종료 조건
        if cv2.waitKey(1) & 0xFF == ord('q'):
            break

    cap.release()
    cv2.destroyAllWindows()

In [None]:
# 5. 실시간 추론 실행
infer_csi_camera(model)