In [11]:
import torch
import torchvision.transforms as transforms
import torch.nn as nn
from PIL import Image
import cv2
import time

In [2]:
class MaskClassifier(nn.Module):
    def __init__(self):
        super(MaskClassifier, self).__init__()
        
        # Feature Extraction - 더 얕은 구조로 변경
        self.features = nn.Sequential(
            # First Block
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),
            nn.Dropout2d(0.2),
            
            # Second Block
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),
            nn.Dropout2d(0.2),
            
            # Third Block
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),
            nn.Dropout2d(0.2),
        )
        
        # Classifier - 더 단순한 구조로 변경
        self.classifier = nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),
            nn.Linear(128, 2)
        )

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x

In [3]:
# 1. 모델 로드
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = MaskClassifier().to(device)

In [4]:
# 학습된 가중치 로드
model.load_state_dict(torch.load("mask_classifier.pth", map_location=device))
model.eval()

MaskClassifier(
  (features): Sequential(
    (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (4): Dropout2d(p=0.2, inplace=False)
    (5): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): ReLU(inplace=True)
    (8): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (9): Dropout2d(p=0.2, inplace=False)
    (10): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (12): ReLU(inplace=True)
    (13): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (14): Dropout2d(p=0.2, inplace=False)
 

In [None]:
# 3. 이미지 전처리
# 전처리
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # ImageNet stats
])

In [6]:
# 4. 단일 이미지 추론 함수
def infer_image(image_path):
    image = Image.open(image_path).convert("RGB")  # 이미지 로드 및 RGB 변환
    input_tensor = transform(image).unsqueeze(0).to(device)  # 전처리 및 배치 차원 추가

    with torch.no_grad():
        output = model(input_tensor)
        _, pred = torch.max(output, 1)  # 클래스 예측

    # 데이터셋 클래스 이름 정의 (임시 값 설정 필요)
    dataset_classes = ["With Mask", "Without Mask"]
    label = dataset_classes[pred.item()]  # 클래스 이름
    print(f"Prediction: {label}")
    return label

In [None]:
# # 테스트용 단일 이미지 추론
# image_path = "data/with_mask/example.jpg"  # 테스트 이미지 경로
# infer_image(image_path)

In [9]:
# 5. 실시간 웹캠 추론 함수
def infer_webcam():
    cap = cv2.VideoCapture(0)  # 웹캠 열기

    print("Press 'q' to quit.")
    while True:
        ret, frame = cap.read()
        if not ret:
            break
        
        # 좌우 반전 적용
        frame = cv2.flip(frame, 1)

        # OpenCV 이미지를 PIL 이미지로 변환
        image = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        image_pil = Image.fromarray(image)

        # 전처리 수행
        input_tensor = transform(image_pil).unsqueeze(0).to(device)

        # 추론
        with torch.no_grad():
            output = model(input_tensor)
            _, pred = torch.max(output, 1)

        # 결과 표시
        dataset_classes = ["With Mask", "Without Mask"]
        label = dataset_classes[pred.item()]
        cv2.putText(frame, label, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
        cv2.imshow("Webcam Inference", frame)

        # 종료 조건
        if cv2.waitKey(1) & 0xFF == ord('q'):
            break

    cap.release()
    cv2.destroyAllWindows()

In [10]:
# 실시간 웹캠 추론 실행
infer_webcam()

Press 'q' to quit.
