In [14]:
import os
import torch
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader, random_split
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
from sklearn.metrics import classification_report, confusion_matrix
import cv2
# from evaluator import ModelEvaluator
from tqdm import tqdm
from torchsummary import summary
from fvcore.nn import FlopCountAnalysis, parameter_count
from ptflops import get_model_complexity_info
import time

In [None]:
# Device 설정
device = torch.device("cpu")
print(f"Using device: {device}")

Using device: cpu


In [17]:
# Teacher Model 정의
class MaskClassifier(nn.Module):
    def __init__(self):
        super(MaskClassifier, self).__init__()
        
        self.features = nn.Sequential(
            # First Block
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),
            nn.Dropout2d(0.2),
            
            # Second Block
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),
            nn.Dropout2d(0.2),
            
            # Third Block
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),
            nn.Dropout2d(0.2),
        )
        
        self.classifier = nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),
            nn.Linear(128, 2)
        )

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x

In [18]:
# Student Model 정의 (파라미터 95% 감소)
class SmallMaskClassifier(nn.Module):
    def __init__(self):
        super(SmallMaskClassifier, self).__init__()
        
        self.features = nn.Sequential(
            # First Block
            nn.Conv2d(3, 8, kernel_size=3, padding=1),
            nn.BatchNorm2d(8),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),
            
            # Second Block
            nn.Conv2d(8, 16, kernel_size=3, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),
            
            # Third Block
            nn.Conv2d(16, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),
        )
        
        self.classifier = nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),
            nn.Linear(32, 2)
        )

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x


In [19]:
# Knowledge Distillation Loss
class DistillationLoss:
    def __init__(self, temperature=4.0, alpha=0.5):
        self.temperature = temperature
        self.alpha = alpha
        self.criterion = nn.CrossEntropyLoss()

    def __call__(self, student_outputs, teacher_outputs, labels):
        hard_loss = self.criterion(student_outputs, labels)
        soft_student = F.log_softmax(student_outputs / self.temperature, dim=1)
        soft_teacher = F.softmax(teacher_outputs / self.temperature, dim=1)
        soft_loss = F.kl_div(soft_student, soft_teacher, reduction='batchmean') * (self.temperature ** 2)
        return self.alpha * hard_loss + (1 - self.alpha) * soft_loss

In [20]:
# 데이터셋 경로
data_dir = "data"

In [21]:
# 전처리
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # ImageNet stats
])

In [22]:
# 데이터셋 로드
dataset = ImageFolder(root=data_dir, transform=transform)
print(f"Classes: {dataset.classes}")

Classes: ['with_mask', 'without_mask']


In [23]:
# Train:Val:Test = 70:15:15 분할
train_size = int(0.7 * len(dataset))
val_size = int(0.15 * len(dataset))
test_size = len(dataset) - train_size - val_size
train_set, val_set, test_set = random_split(dataset, [train_size, val_size, test_size])

In [24]:
# DataLoader 생성
batch_size = 16
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False)

In [25]:
# 테스트 함수
def test_model(model, test_loader):
    """
    Args:
        model (torch.nn.Module): 평가할 모델
        test_loader (DataLoader): 테스트 데이터 로더
    """
    model.eval()
    all_labels = []
    all_preds = []

    with torch.no_grad():
        progress_bar = tqdm(test_loader, desc="Testing", unit="batch")
        for images, labels in progress_bar:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, preds = torch.max(outputs, 1)
            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(preds.cpu().numpy())

            progress_bar.set_postfix(batch_accuracy=(preds == labels).float().mean().item())

    print("\nTest Classification Report:")
    print(classification_report(all_labels, all_preds, target_names=dataset.classes))

In [26]:
# 모델 저장 함수
def save_model(model, path="mask_classifier.pth"):
    torch.save(model.state_dict(), path)
    print(f"Model saved to {path}")

In [27]:
# Student 모델 학습 함수
def train_student_model(teacher_model, student_model, train_loader, val_loader, epochs=20):
    optimizer = optim.Adam(student_model.parameters(), lr=0.001, weight_decay=0.0001)
    distill_loss = DistillationLoss(temperature=4.0, alpha=0.5)
    
    for epoch in range(epochs):
        student_model.train()
        correct = 0
        total = 0
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch + 1}/{epochs}", unit="batch")
        
        for images, labels in progress_bar:
            images, labels = images.to(device), labels.to(device)
            
            with torch.no_grad():
                teacher_outputs = teacher_model(images)
            
            optimizer.zero_grad()
            student_outputs = student_model(images)
            
            loss = distill_loss(student_outputs, teacher_outputs, labels)
            loss.backward()
            optimizer.step()
            
            _, predicted = student_outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
            
            accuracy = 100. * correct / total
            progress_bar.set_postfix({
                'loss': f'{loss.item():.4f}',
                'acc': f'{accuracy:.2f}%'
            })
    
    # Validation
    student_model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = student_model(images)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
    
    final_acc = 100. * correct / total
    print("\n=== Final Training Results ===")
    print(f"Final Validation Accuracy: {final_acc:.2f}%")

In [28]:
# 저장된 Teacher 모델 불러오기
teacher_model = MaskClassifier().to(device)
teacher_model.load_state_dict(torch.load("mask_classifier.pth", map_location=device))
teacher_model.eval()

MaskClassifier(
  (features): Sequential(
    (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (4): Dropout2d(p=0.2, inplace=False)
    (5): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): ReLU(inplace=True)
    (8): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (9): Dropout2d(p=0.2, inplace=False)
    (10): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (12): ReLU(inplace=True)
    (13): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (14): Dropout2d(p=0.2, inplace=False)
 

In [29]:
# Student 모델 초기화
student_model = SmallMaskClassifier().to(device)

In [30]:
# 파라미터 수 비교
teacher_params = sum(p.numel() for p in teacher_model.parameters())
student_params = sum(p.numel() for p in student_model.parameters())
reduction = 100 * (1 - student_params / teacher_params)

print("\nModel Parameters Comparison:")
print(f"Teacher Model: {teacher_params:,} parameters")
print(f"Student Model: {student_params:,} parameters")
print(f"Parameter Reduction: {reduction:.2f}%")


Model Parameters Comparison:
Teacher Model: 93,954 parameters
Student Model: 6,210 parameters
Parameter Reduction: 93.39%


In [31]:
# 학습 실행
num_epoch = 20
print("\nTraining Student Model with Knowledge Distillation...")
train_student_model(teacher_model, student_model, train_loader, val_loader, num_epoch)


Training Student Model with Knowledge Distillation...


Epoch 1/20: 100%|██████████| 144/144 [00:52<00:00,  2.76batch/s, loss=0.5105, acc=83.78%]
Epoch 2/20: 100%|██████████| 144/144 [00:51<00:00,  2.81batch/s, loss=0.7343, acc=87.66%]
Epoch 3/20: 100%|██████████| 144/144 [00:49<00:00,  2.89batch/s, loss=0.1657, acc=88.88%]
Epoch 4/20: 100%|██████████| 144/144 [00:49<00:00,  2.88batch/s, loss=0.6538, acc=90.19%]
Epoch 5/20: 100%|██████████| 144/144 [00:49<00:00,  2.88batch/s, loss=0.0969, acc=90.36%]
Epoch 6/20: 100%|██████████| 144/144 [00:52<00:00,  2.76batch/s, loss=0.0872, acc=91.67%]
Epoch 7/20: 100%|██████████| 144/144 [00:50<00:00,  2.83batch/s, loss=0.4927, acc=91.67%]
Epoch 8/20: 100%|██████████| 144/144 [00:49<00:00,  2.90batch/s, loss=0.3837, acc=91.98%]
Epoch 9/20: 100%|██████████| 144/144 [00:50<00:00,  2.86batch/s, loss=0.1875, acc=92.11%]
Epoch 10/20: 100%|██████████| 144/144 [00:50<00:00,  2.85batch/s, loss=0.1031, acc=92.67%]
Epoch 11/20: 100%|██████████| 144/144 [00:53<00:00,  2.71batch/s, loss=0.5745, acc=91.89%]
Epoch 12


=== Final Training Results ===
Final Validation Accuracy: 93.48%


In [32]:
# 테스트 실행
print("\nEvaluating Student Model...")
test_model(student_model, test_loader)


Evaluating Student Model...


Testing: 100%|██████████| 31/31 [00:02<00:00, 11.18batch/s, batch_accuracy=0.917]



Test Classification Report:
              precision    recall  f1-score   support

   with_mask       0.98      0.95      0.96       249
without_mask       0.95      0.98      0.96       243

    accuracy                           0.96       492
   macro avg       0.96      0.96      0.96       492
weighted avg       0.96      0.96      0.96       492



In [34]:
# Student 모델 저장
save_model(student_model, "student_model.pth")

Model saved to student_model.pth
