<a href="https://colab.research.google.com/github/woojung02/SSAC_AI/blob/main/U_net.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [79]:
#1 임포트
from google.colab import drive
drive.mount('/content/drive')

import os
import time
import numpy as np
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchvision.models as models
from torchvision import transforms, utils

#2.RLE  이미지로 복원하는 과정
def rle2mask(rle, shape):
    """
    RLE 문자열 → 2D 이진 마스크
    shape = (height, width)
    """
    h, w = shape
    mask = np.zeros(h * w, dtype=np.uint8)
    if isinstance(rle, str):
        vals = np.array(rle.split(), dtype=int)
        starts, lengths = vals[0::2] - 1, vals[1::2]
        for s, l in zip(starts, lengths):
            mask[s : s + l] = 1
    return mask.reshape((h, w), order='F')

#3. 데이터 로드및 전처리
CSV_PATH   = "/content/train.csv"
IMG_FOLDER = "/content/drive/MyDrive/train_images"

df = pd.read_csv(CSV_PATH)
# Mask 없는(EncodedPixels NaN) 행 제거
df = df[df['EncodedPixels'].notnull()].reset_index(drop=True)

# 존재하지 않는 파일 필터링
exists = df['ImageId'].apply(lambda fn: os.path.exists(os.path.join(IMG_FOLDER, fn)))
df = df[exists].reset_index(drop=True)
print(f"총 샘플 수: {len(df)}")


# 4) Dataset 정의

class SteelDataset(Dataset):
    def __init__(self, df, img_dir, transform=None):
        self.df = df
        self.img_dir = img_dir
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        fn  = self.df.loc[idx, "ImageId"]
        rle = self.df.loc[idx, "EncodedPixels"]
        # 이미지 로드
        img = Image.open(os.path.join(self.img_dir, fn)).convert("RGB")
        # 마스크 생성
        mask = rle2mask(rle, (256, 1600))
        mask = Image.fromarray((mask * 255).astype(np.uint8))
        # transform 적용
        if self.transform:
            img  = self.transform(img)
            mask = self.transform(mask)
        return img, mask


# 5) Transform & DataLoader

data_transforms = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
])

train_ds = SteelDataset(df, IMG_FOLDER, transform=data_transforms)
train_loader = DataLoader(train_ds, batch_size=4, shuffle=True, num_workers=2, pin_memory=True)

# 샘플 확인
imgs, masks = next(iter(train_loader))
print("배치 이미지 크기:", imgs.shape, "마스크 크기:", masks.shape)
utils.make_grid(imgs, nrow=4).permute(1,2,0)


# 6) U-Net with ResNet-18 백본 정의
def convrelu(in_ch, out_ch, kernel, padding):
    return nn.Sequential(
        nn.Conv2d(in_ch, out_ch, kernel_size=kernel, padding=padding),
        nn.ReLU(inplace=True),
    )

class UNetResNet18(nn.Module):
    def __init__(self, n_class=1):
        super().__init__()
        # pretrained ResNet-18
        backbone = models.resnet18(pretrained=True)
        layers   = list(backbone.children())
        # encoder
        self.layer0 = nn.Sequential(*layers[:3])   # conv1, bn1, relu
        self.layer1 = nn.Sequential(layers[3], layers[4])  # maxpool, layer1
        self.layer2 = layers[5]  # layer2
        self.layer3 = layers[6]  # layer3
        self.layer4 = layers[7]  # layer4
        # 1×1 conv for skip
        self.l4_1x1 = convrelu(512, 512, 1, 0)
        self.l3_1x1 = convrelu(256, 256, 1, 0)
        self.l2_1x1 = convrelu(128, 128, 1, 0)
        self.l1_1x1 = convrelu(64,  64,  1, 0)
        self.l0_1x1 = convrelu(64,  64,  1, 0)
        # upsample + conv
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.up3 = convrelu(512+256, 256, 3, 1)
        self.up2 = convrelu(256+128, 128, 3, 1)
        self.up1 = convrelu(128+64,  64,  3, 1)
        self.up0 = convrelu(64+64,   64,  3, 1)
        # original skip
        self.orig0 = convrelu(3,   64, 3, 1)
        self.orig1 = convrelu(64, 64, 3, 1)
        self.orig2 = convrelu(64+64, 64, 3, 1)
        # final
        self.conv_last = nn.Conv2d(64, n_class, 1)

    def forward(self, x):
        x0 = self.orig0(x); x0 = self.orig1(x0)
        l0 = self.layer0(x)
        l1 = self.layer1(l0)
        l2 = self.layer2(l1)
        l3 = self.layer3(l2)
        l4 = self.layer4(l3)
        # decoder
        d4 = self.l4_1x1(l4)
        d4 = self.upsample(d4)
        l3s = self.l3_1x1(l3)
        d3 = self.up3(torch.cat([d4, l3s], dim=1))
        d3 = self.upsample(d3)
        l2s = self.l2_1x1(l2)
        d2 = self.up2(torch.cat([d3, l2s], dim=1))
        d2 = self.upsample(d2)
        l1s = self.l1_1x1(l1)
        d1 = self.up1(torch.cat([d2, l1s], dim=1))
        d1 = self.upsample(d1)
        l0s = self.l0_1x1(l0)
        d0 = self.up0(torch.cat([d1, l0s], dim=1))
        d0 = self.upsample(d0)
        cat = torch.cat([d0, x0], dim=1)
        cat = self.orig2(cat)
        return self.conv_last(cat)

# 7) 모델/손실/옵티마이저/스케줄러/얼리스토핑 설정

import torch.optim.lr_scheduler as sched

device    = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model     = UNetResNet18(n_class=1).to(device)
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3,
                            momentum=0.9, weight_decay=1e-4)

# ReduceLROnPlateau: 검증(or학습) 손실이 개선되지 않으면 LR *= factor
scheduler    = sched.ReduceLROnPlateau(optimizer,
                                      mode='min',
                                      factor=0.5,
                                      patience=3,
                                      verbose=True)

best_loss    = float('inf')
patience_cnt = 0
max_patience = 5
max_epochs   = 50   # 최대 에포치 수


# 8) 학습 루프 (스케줄러, 얼리스토핑 포함)

start_time = time.time()
for epoch in range(1, max_epochs+1):
    model.train()
    train_loss = 0.0

    for imgs, masks in train_loader:
        imgs, masks = imgs.to(device), masks.to(device)
        optimizer.zero_grad()
        preds = model(imgs)
        loss  = criterion(preds, masks)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()

    avg_loss = train_loss / len(train_loader)
    print(f"[Epoch {epoch:02d}/{max_epochs}] Train Loss: {avg_loss:.6f}")

    # 1) 스케줄러 단계 (손실 기준)
    scheduler.step(avg_loss)

    # 2) EarlyStopping 체크
    if avg_loss < best_loss:
        best_loss    = avg_loss
        patience_cnt = 0
        torch.save(model.state_dict(), "best_unet.pth")

    else:
        patience_cnt += 1
        print(f"  → 개선 없음. Patience {patience_cnt}/{max_patience}")
        if patience_cnt >= max_patience:
            print(" Early stopping")
            break

total_time = time.time() - start_time
print(f"\n총 학습 시간: {total_time:.2f}초")


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
총 샘플 수: 2103
배치 이미지 크기: torch.Size([4, 3, 256, 256]) 마스크 크기: torch.Size([4, 1, 256, 256])




KeyboardInterrupt: 