In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import numpy as np
import pandas as pd
from pathlib import Path
import matplotlib.pyplot as plt
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

In [3]:
from dataset import PneumothoraxDataset, get_train_transforms, get_val_transforms
from utils import predict_with_tiles, dice_coefficient, iou_score
from models.unet import UNet
from dataset import build_file_paths_dict

In [4]:
class Config:
    # Paths
    dicom_path = Path('./pneumothorax_data/dicom-images-train')
    rle_path = Path('./pneumothorax_data/train-rle.csv')
    save_dir = Path('./results')
    
    # Model
    model_name = 'unet' #모델명 여기서 수정
    in_channels = 3
    n_classes = 1
    
    # Training
    batch_size = 2 #수정
    num_epochs = 1 #수정
    learning_rate = 1e-4
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    # Augmentation
    use_mirroring = False #수정
    pad_size = 0#92
    
    use_tile_strategy = True
    tile_size = 128 #512
    tile_overlap = 92

config = Config()
config.save_dir.mkdir(exist_ok=True)

In [5]:
file_paths_dict = build_file_paths_dict(config.dicom_path)

전체 DICOM 파일 수: 10712


인덱싱: 100%|██████████| 10712/10712 [00:00<00:00, 1016639.92it/s]


In [None]:
train_rle = pd.read_csv(config.rle_path)

all_image_ids = train_rle['ImageId'].unique()[:100]#수정
np.random.seed(42)
np.random.shuffle(all_image_ids)

split_idx = int(len(all_image_ids) * 0.8)
train_ids = all_image_ids[:split_idx]
val_ids = all_image_ids[split_idx:]

print(f"Train: {len(train_ids)}, Val: {len(val_ids)}")

Train: 80, Val: 20


In [7]:
train_dataset = PneumothoraxDataset(
    image_ids=train_ids,
    train_rle=train_rle,
    #dicom_path=config.dicom_path,
    file_paths_dict=file_paths_dict,
    transform=get_train_transforms(),
    #use_mirroring=config.use_mirroring,
    use_mirroring=False,
    pad_size=config.pad_size if config.use_mirroring else 0
)

val_dataset = PneumothoraxDataset(
    image_ids=val_ids,
    train_rle=train_rle,
    file_paths_dict=file_paths_dict,
    transform=get_val_transforms(),
    use_mirroring=False,  # Validation은 mirroring 사용 안 함
    pad_size=0
)

train_loader = DataLoader(
    train_dataset, 
    batch_size=config.batch_size, 
    shuffle=True, 
    num_workers=0,
    pin_memory=True
)

val_loader = DataLoader(
    val_dataset, 
    batch_size=config.batch_size, 
    shuffle=False, 
    num_workers=0,
    pin_memory=True
)

사용 가능한 이미지: 80개

사용 가능한 이미지: 20개



In [8]:
print(f"모델 명 : {config.model_name}")
model = UNet(in_channels=config.in_channels, n_classes=config.n_classes)
model = model.to(config.device)

criterion = nn.BCEWithLogitsLoss()
#optimizer = optim.Adam(model.parameters(), lr=config.learning_rate)
optimizer = optim.SGD(model.parameters(), momentum=0.99)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=5, verbose=True)


모델 명 : unet


In [9]:
def train_epoch(model, loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    running_dice = 0.0
    
    with tqdm(loader, desc="Training") as pbar:
        for images, masks in pbar:
            images = images.to(device)
            masks = masks.to(device)
            
            # Forward
            outputs = model(images)
            loss = criterion(outputs, masks)
            
            # Backward
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            # Metrics
            with torch.no_grad():
                preds = torch.sigmoid(outputs)
                dice = dice_coefficient(preds > 0.5, masks)
            
            running_loss += loss.item()
            running_dice += dice
            
            pbar.set_postfix({'loss': loss.item(), 'dice': dice})
    
    epoch_loss = running_loss / len(loader)
    epoch_dice = running_dice / len(loader)
    
    return epoch_loss, epoch_dice

In [10]:
def validate_epoch(model, loader, criterion, device):
    model.eval()
    running_loss = 0.0
    running_dice = 0.0
    running_iou = 0.0
    
    with torch.no_grad():
        with tqdm(loader, desc="Validation") as pbar:
            for images, masks in pbar:
                images = images.to(device)
                masks = masks.to(device)
                
                outputs = model(images)
                loss = criterion(outputs, masks)
                
                preds = torch.sigmoid(outputs)
                dice = dice_coefficient(preds > 0.5, masks)
                iou = iou_score(preds > 0.5, masks)
                
                running_loss += loss.item()
                running_dice += dice
                running_iou += iou
                
                pbar.set_postfix({'loss': loss.item(), 'dice': dice, 'iou': iou})
    
    epoch_loss = running_loss / len(loader)
    epoch_dice = running_dice / len(loader)
    epoch_iou = running_iou / len(loader)
    
    return epoch_loss, epoch_dice, epoch_iou

In [None]:
best_dice = 0.0
history = {
    'train_loss': [], 'train_dice': [],
    'val_loss': [], 'val_dice': [], 'val_iou': []
}

for epoch in range(config.num_epochs):
    print(f"\nEpoch {epoch+1}/{config.num_epochs}")
    
    train_loss, train_dice = train_epoch(model, train_loader, criterion, optimizer, config.device)
    val_loss, val_dice, val_iou = validate_epoch(model, val_loader, criterion, config.device)
    
    scheduler.step(val_loss)
    
    history['train_loss'].append(train_loss)
    history['train_dice'].append(train_dice)
    history['val_loss'].append(val_loss)
    history['val_dice'].append(val_dice)
    history['val_iou'].append(val_iou)
    
    print(f"Train Loss: {train_loss:.4f}, Train Dice: {train_dice:.4f}")
    print(f"Val Loss: {val_loss:.4f}, Val Dice: {val_dice:.4f}, Val IoU: {val_iou:.4f}")
    
    if val_dice > best_dice:
        best_dice = val_dice
        torch.save(model.state_dict(), config.save_dir / f'best_{config.model_name}.pth')
        print(f"best model saved. (Dice: {best_dice:.4f})")


Epoch 1/1


Training:   0%|          | 0/40 [00:00<?, ?it/s]

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(15, 5))

axes[0].plot(history['train_loss'], label='Train Loss')
axes[0].plot(history['val_loss'], label='Val Loss')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].legend()
axes[0].set_title('Loss')

axes[1].plot(history['train_dice'], label='Train Dice')
axes[1].plot(history['val_dice'], label='Val Dice')
axes[1].plot(history['val_iou'], label='Val IoU')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Score')
axes[1].legend()
axes[1].set_title('Metrics')

plt.tight_layout()
plt.savefig(config.save_dir / f'{config.model_name}_history.png')
plt.show()

In [None]:
print("\n테스트 추론 중...")
model.load_state_dict(torch.load(config.save_dir / f'best_{config.model_name}.pth'))
model.eval()

# 테스트 샘플 예측 (validation set에서 샘플링)
test_samples = val_ids[:5]  # 5개 샘플만 테스트

for img_id in test_samples:
    # Load image
    file_path = val_dataset.file_paths[img_id]
    import pydicom
    dcm = pydicom.dcmread(file_path)
    image = dcm.pixel_array
    
    if image.max() > 0:
        image = image.astype(np.float32) / image.max()
    else:
        image = image.astype(np.float32)
    
    image = np.stack([image, image, image], axis=-1)
    
    # Predict
    if config.use_tile_strategy:
        prediction = predict_with_tiles(
            model=model,
            image=image,
            tile_size=config.tile_size,
            overlap=config.tile_overlap,
            device=config.device
        )
    else:
        # Simple inference
        image_norm = (image - np.array([0.485, 0.456, 0.406])) / np.array([0.229, 0.224, 0.225])
        image_tensor = torch.from_numpy(image_norm).permute(2, 0, 1).unsqueeze(0).float().to(config.device)
        
        with torch.no_grad():
            output = model(image_tensor)
            prediction = torch.sigmoid(output).cpu().numpy()[0, 0]
    
    # Threshold
    binary_pred = (prediction > 0.5).astype(np.uint8)
    
    # Load ground truth
    rle = train_rle[train_rle['ImageId'] == img_id][' EncodedPixels'].values[0]
    from dataset import rle_decode
    
    if rle == '-1' or pd.isna(rle):
        gt_mask = np.zeros((1024, 1024), dtype=np.uint8)
    else:
        gt_mask = rle_decode(rle, 1024, 1024)
    
    # Visualize
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    
    axes[0].imshow(image[:, :, 0], cmap='gray')
    axes[0].set_title('Original Image')
    axes[0].axis('off')
    
    axes[1].imshow(gt_mask, cmap='gray')
    axes[1].set_title('Ground Truth')
    axes[1].axis('off')
    
    axes[2].imshow(binary_pred, cmap='gray')
    axes[2].set_title('Prediction')
    axes[2].axis('off')
    
    plt.tight_layout()
    plt.savefig(config.save_dir / f'prediction_{img_id}.png')
    plt.close()

print(f"\n완료! 결과는 {config.save_dir}에 저장되었습니다.")