In [None]:
import torch

if torch.cuda.is_available():
    device = torch.device("cuda")
    print("✅ CUDA 사용 가능:", torch.cuda.get_device_name(0))
else:
    device = torch.device("cpu")
    print("⚠️ CUDA 불가능, CPU 사용")

In [None]:
from google.colab import files
import zipfile
import os

# part3를 대상으로 할 경우 학습용 dataset은 1800개 정도로 작은 편
uploaded = files.upload()
for fname in uploaded:
    if fname.endswith(".zip"):
        with zipfile.ZipFile(fname, 'r') as zip_ref:
            zip_ref.extractall(fname.replace(".zip", ""))
print("✅ 압축 해제 완료")

In [None]:
from PIL import Image
from torchvision import transforms

def load_image_and_label(image_path, label_path, target_size=(224, 224)):
    image = Image.open(image_path).convert("RGB")
    transform = transforms.Compose([
        transforms.Resize(target_size),
        transforms.ToTensor()
    ])
    image_tensor = transform(image)
    with open(label_path, "r") as f:
        parts = f.readline().strip().split()
        cls, x, y, w, h = map(float, parts)
    return image_tensor, (x, y, w, h), int(cls)

In [None]:
def box_iou(box1, box2, image_size):
    W, H = image_size
    def to_xyxy(box):
        x, y, w, h = box
        return (x - w/2)*W, (y - h/2)*H, (x + w/2)*W, (y + h/2)*H
    x1, y1, x2, y2 = to_xyxy(box1)
    x3, y3, x4, y4 = to_xyxy(box2)
    xi1, yi1 = max(x1,x3), max(y1,y3)
    xi2, yi2 = min(x2,x4), min(y2,y4)
    inter = max(xi2-xi1,0)*max(yi2-yi1,0)
    area1 = (x2-x1)*(y2-y1)
    area2 = (x4-x3)*(y4-y3)
    union = area1 + area2 - inter
    return inter/union if union > 0 else 0

In [None]:
from PIL import ImageDraw
import matplotlib.pyplot as plt

def draw_boxes(image_tensor, true_box, pred_box=None):
    image = transforms.ToPILImage()(image_tensor.squeeze(0).cpu())
    draw = ImageDraw.Draw(image)
    W, H = image.size
    x, y, w, h = true_box
    draw.rectangle([
        (x-w/2)*W, (y-h/2)*H, (x+w/2)*W, (y+h/2)*H
    ], outline="green", width=2)
    if pred_box:
        x, y, w, h = pred_box
        draw.rectangle([
            (x-w/2)*W, (y-h/2)*H, (x+w/2)*W, (y+h/2)*H
        ], outline="red", width=2)
    return image

In [None]:
import torch.nn as nn

class TinyYOLO(nn.Module):
    def __init__(self, S=14, B=1):
        super().__init__()
        self.S, self.B = S, B
        self.output_size = S * S * (B * 5)
        self.model = nn.Sequential(
            nn.Conv2d(3, 16, 3, padding=1),
            nn.LeakyReLU(),
            nn.MaxPool2d(2,2),
            # BatchNorm2d는 역효과 발생

            nn.Conv2d(16, 32, 3, padding=1),
            nn.LeakyReLU(),
            nn.MaxPool2d(2,2),

            nn.Conv2d(32, 64, 3, padding=1),
            nn.LeakyReLU(),
            nn.MaxPool2d(2,2),

            nn.Flatten(),
            nn.Linear(28*28*64, 1024),
            nn.LeakyReLU(),
            nn.Linear(1024, self.output_size)
        )
    def forward(self, x):
        return self.model(x).reshape(-1, self.S, self.S, self.B*5)

In [None]:
import glob, random
from torch.utils.data import Dataset, DataLoader

class FaceDataset(Dataset):
    def __init__(self, image_paths, label_paths, size=(224,224), S=14):
        self.images = image_paths
        self.labels = label_paths
        self.size = size
        self.S = S

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img = self.images[idx]
        label = self.labels[idx]
        image_tensor, box, cls = load_image_and_label(img, label, self.size)
        target = torch.zeros((self.S, self.S, 5))
        grid_x = int(box[0] * self.S)
        grid_y = int(box[1] * self.S)
        target[grid_y, grid_x, 0:4] = torch.tensor(box)
        target[grid_y, grid_x, 4] = 1.0
        return image_tensor, target

img_list = sorted(glob.glob("part3/*.jpg"))
lbl_list = ["part3_labels/" + os.path.basename(f).replace(".jpg", ".txt") for f in img_list]
pairs = list(zip(img_list, lbl_list))
random.shuffle(pairs)
test = pairs[:100]
train = pairs[100:]

train_dataset = FaceDataset([p[0] for p in train], [p[1] for p in train], size=(224,224))
test_dataset = FaceDataset([p[0] for p in test], [p[1] for p in test], size=(224,224))
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)

In [None]:
model = TinyYOLO(S=14).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
loss_fn = nn.MSELoss()

for epoch in range(10):
    model.train()
    total_loss = 0
    for imgs, targets in train_loader:
        imgs, targets = imgs.to(device), targets.to(device)
        preds = model(imgs)
        loss = loss_fn(preds, targets)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch+1}: Loss = {total_loss/len(train_loader):.4f}")

In [None]:
model.eval()
correct = 0
for img_path, lbl_path in test:
    img_tensor, gt_box, _ = load_image_and_label(img_path, lbl_path, target_size=(224,224))
    input_tensor = img_tensor.unsqueeze(0).to(device)
    with torch.no_grad():
        pred = model(input_tensor).cpu().squeeze(0)
    best_conf, best_box = 0, None
    for i in range(14):
        for j in range(14):
            cell = pred[i, j]
            if cell[4] > best_conf:
                best_conf = cell[4]
                best_box = cell[0:4]
    iou = box_iou(gt_box, best_box.tolist(), (224,224))
    if iou >= 0.5:
        correct += 1

print(f"✅ 정확도: {correct} / 100 = {correct}%")

In [None]:
examples = random.sample(test, 12)
plt.figure(figsize=(12, 9))
for idx, (img_path, lbl_path) in enumerate(examples):
    img_tensor, gt_box, _ = load_image_and_label(img_path, lbl_path, target_size=(224,224))
    input_tensor = img_tensor.unsqueeze(0).to(device)
    with torch.no_grad():
        pred = model(input_tensor).cpu().squeeze(0)
    best_conf, best_box = 0, None
    for i in range(14):
        for j in range(14):
            cell = pred[i, j]
            if cell[4] > best_conf:
                best_conf = cell[4]
                best_box = cell[0:4]
    plt.subplot(3, 4, idx+1)
    vis_img = draw_boxes(img_tensor.unsqueeze(0), gt_box, best_box.tolist())
    plt.imshow(vis_img)
    plt.axis('off')
    plt.title(os.path.basename(img_path))
plt.tight_layout()
plt.show()