In [1]:
import os
from PIL import Image
import torch
from torch.utils.data import Dataset
from torchvision import transforms
import torch
import torch.nn as nn
from torchvision.models import efficientnet_b0
from torchvision.models import EfficientNet_B0_Weights
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from tqdm.autonotebook import tqdm

class MvtecADDataset(Dataset):
    def __init__(self, root_dir, split="train", transform=None):
        self.root_dir = root_dir
        self.split = split
        self.transform = transform
        
        self.image_paths = []
        self.mask_paths = []
        self.labels = []
        self.images = []
        
        for object_type in tqdm(os.listdir(root_dir)):
            object_path = os.path.join(root_dir, object_type)
            if not os.path.isdir(object_path):
                continue
            
            if split == "train":
                train_dir = os.path.join(object_path, "train", "good")
                self.image_paths.extend([os.path.join(train_dir, img) for img in os.listdir(train_dir)])
                # self.images.extend([Image.open(os.path.join(train_dir, img)).convert("RGB") for img in os.listdir(train_dir)])
                self.labels.extend([0] * len(self.image_paths))
                self.mask_paths.extend([None] * len(self.image_paths))

            elif split == "test":
                test_dir = os.path.join(object_path, "test")
                ground_truth_dir = os.path.join(object_path, "ground_truth")
                
                for defect_type in os.listdir(test_dir):
                    defect_dir = os.path.join(test_dir, defect_type)
                    for img_name in os.listdir(defect_dir):
                        img_path = os.path.join(defect_dir, img_name)
                        self.image_paths.append(img_path)
                        # self.images.append(Image.open(img_path).convert("RGB"))
                        
                        if defect_type == "good":
                            self.labels.append(0)
                            self.mask_paths.append(None)
                        else:
                            self.labels.append(1)
                            mask_path = os.path.join(ground_truth_dir, defect_type, img_name)
                            self.mask_paths.append(mask_path if os.path.exists(mask_path) else None)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        label = self.labels[idx]

        image = Image.open(img_path).convert("RGB")
        original_img_size = image.size
        mask = None
        if self.transform:
            image = self.transform(image)
            
        # 마스크 로드 (None일 경우 이미지와 동일한 크기의 0으로 채워진 텐서 생성)
        if self.mask_paths[idx] is not None:
            mask = Image.open(self.mask_paths[idx]).convert("L")
            if self.transform:
                mask = self.transform(mask)
        else:
            # mask가 None일 경우 0으로 채워진 텐서 생성 (이미지와 동일한 크기)
            _, width, height = image.shape  # PIL 이미지의 크기 가져오기
            mask = torch.zeros((1, height, width), dtype=torch.float32)  # (C, H, W) 형식 유지

        return image, mask, label, original_img_size, img_path


  from tqdm.autonotebook import tqdm


In [43]:
# EfficientNet 기반 Autoencoder 정의
class EfficientNetB0Autoencoder(nn.Module):
    def __init__(self):
        super(EfficientNetB0Autoencoder, self).__init__()
        # EfficientNet-b0을 encoder로 사용
        self.encoder = efficientnet_b0(EfficientNet_B0_Weights.DEFAULT)
        
        # Decoder 정의
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(1280, 512, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 3, kernel_size=4, stride=2, padding=1),
            nn.Sigmoid(),  # pixel 값을 [0, 1] 범위로 맞추기 위해 사용
        )

    def forward(self, x):
        # Encoder를 통해 특징 추출
        x = self.encoder.features(x)
        # Decoder를 통해 재구성
        x = self.decoder(x)
        return x

In [48]:

# 데이터 전처리 및 데이터 로더 설정
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

train_dataset = MvtecADDataset(root_dir="mvtec_anomaly_detection", split="train", transform=transform)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)

test_dataset = MvtecADDataset(root_dir="mvtec_anomaly_detection", split="test", transform=transform)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)

# 모델 학습 설정
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = EfficientNetB0Autoencoder().to(device)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

  0%|          | 0/17 [00:00<?, ?it/s]

  0%|          | 0/17 [00:00<?, ?it/s]

In [49]:
# 학습 루프
num_epochs = 20
for epoch in range(num_epochs):
    model.train()
    train_loss = 0
    for images, _, _, _, _ in tqdm(train_loader):
        images = images.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, images)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {train_loss / len(train_loader):.4f}")


  0%|          | 0/227 [00:00<?, ?it/s]

Epoch [1/20], Loss: 0.0180


  0%|          | 0/227 [00:00<?, ?it/s]

Epoch [2/20], Loss: 0.0081


  0%|          | 0/227 [00:00<?, ?it/s]

Epoch [3/20], Loss: 0.0065


  0%|          | 0/227 [00:00<?, ?it/s]

Epoch [4/20], Loss: 0.0067


  0%|          | 0/227 [00:00<?, ?it/s]

Epoch [5/20], Loss: 0.0058


  0%|          | 0/227 [00:00<?, ?it/s]

Epoch [6/20], Loss: 0.0055


  0%|          | 0/227 [00:00<?, ?it/s]

Epoch [7/20], Loss: 0.0053


  0%|          | 0/227 [00:00<?, ?it/s]

Epoch [8/20], Loss: 0.0051


  0%|          | 0/227 [00:00<?, ?it/s]

Epoch [9/20], Loss: 0.0050


  0%|          | 0/227 [00:00<?, ?it/s]

Epoch [10/20], Loss: 0.0042


  0%|          | 0/227 [00:00<?, ?it/s]

Epoch [11/20], Loss: 0.0040


  0%|          | 0/227 [00:00<?, ?it/s]

Epoch [12/20], Loss: 0.0039


  0%|          | 0/227 [00:00<?, ?it/s]

Epoch [13/20], Loss: 0.0036


  0%|          | 0/227 [00:00<?, ?it/s]

Epoch [14/20], Loss: 0.0035


  0%|          | 0/227 [00:00<?, ?it/s]

Epoch [15/20], Loss: 0.0033


  0%|          | 0/227 [00:00<?, ?it/s]

Epoch [16/20], Loss: 0.0031


  0%|          | 0/227 [00:00<?, ?it/s]

Epoch [17/20], Loss: 0.0030


  0%|          | 0/227 [00:00<?, ?it/s]

Epoch [18/20], Loss: 0.0036


  0%|          | 0/227 [00:00<?, ?it/s]

Epoch [19/20], Loss: 0.0034


  0%|          | 0/227 [00:00<?, ?it/s]

Epoch [20/20], Loss: 0.0030


In [50]:
torch.cuda.empty_cache()

In [51]:
import os
import torch
import torch.nn.functional as F
from torchvision.utils import save_image
from PIL import Image
import tiffile as tiff

# Anomaly Map을 저장할 최상위 폴더 경로
root_anomaly_map_dir =  "anomaly_maps/" + model.__class__.__name__

def save_anomaly_map(anomaly_map, image_path, anomaly_root_dir):
    """
    anomaly_map을 원본 이미지와 동일한 폴더 구조로 anomaly_root_dir에 저장합니다.
    
    Args:
        anomaly_map (Tensor): anomaly map 이미지
        image_path (str): 원본 이미지 경로
        anomaly_root_dir (str): anomaly map의 최상위 폴더 경로
    """
    # 이미지의 파일 경로에서 최상위 디렉토리를 제외한 경로 추출
    relative_path = os.path.relpath(image_path, start='mvtec_anomaly_detection')
    relative_path = os.path.splitext(relative_path)[0] + '.tiff'
    save_path = os.path.join(anomaly_root_dir, relative_path)
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    
    # anomaly map 저장
    tiff.imwrite(save_path, anomaly_map)

# 테스트 시 anomaly map 저장
def test_and_save_anomaly_maps(model, test_loader, device, root_anomaly_map_dir):
    model.eval()
    with torch.no_grad():
        for batch_idx, (images, _, _, original_image_size, image_paths) in enumerate(tqdm(test_loader)):
            images = images.to(device)
            outputs = model(images)

            # 재구성 오차 기반 anomaly map 생성
            anomaly_maps = F.mse_loss(outputs, images, reduction='none').mean(dim=1, keepdim=True)

            # 배치의 각 이미지에 대해 anomaly map 저장 
            for i in range(images.size(0)):
                image_path = image_paths[i]

                # original_image_size should be in the format (height, width)
                height, width = original_image_size[0][i].item(), original_image_size[1][i].item()
                
                # Use 'bilinear' for 2D data
                # anomaly_map = F.interpolate(anomaly_maps[i].unsqueeze(0), size=(height, width), mode='bilinear', align_corners=False)
                
                # Convert anomaly_map to (H, W, C) format for saving as an image
                anomaly_map = anomaly_maps[i].squeeze().cpu().numpy()
                
                save_anomaly_map(anomaly_map, image_path, root_anomaly_map_dir)


In [52]:
test_and_save_anomaly_maps(model, test_loader, device, root_anomaly_map_dir)

  0%|          | 0/108 [00:00<?, ?it/s]

In [53]:
torch.save(model.state_dict(), f'model/{model.__class__.__name__}')

In [3]:
def load(model_class, state_dict_path):
    model = model_class()
    model.load_state_dict(torch.load(state_dict_path, weights_only=True))
    return model

In [7]:
model = load(EfficientNetAutoencoder, 'model/efficientb0')

In [10]:
model = model.to(device)