In [1]:
# 기본 패키지 임포트
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import os, trimesh, pickle
import matplotlib.pyplot as plt
import matplotlib.font_manager as fm
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast, GradScaler

# 디바이스 설정 (GPU 사용 가능 시 GPU로 설정)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Using device:", device)
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.enabled = True


Using device: cuda


In [None]:
def generate_npz_cache(mesh_dir, cache_dir):
    os.makedirs(cache_dir, exist_ok=True)

    mesh_files = [f for f in os.listdir(mesh_dir)
                  if f.endswith('.obj') or f.endswith('.stl') or f.endswith('.ply')]

    for fname in mesh_files:
        mesh_path = os.path.join(mesh_dir, fname)
        cache_path = os.path.join(cache_dir, fname.replace('.obj', '.npz').replace('.stl', '.npz').replace('.ply', '.npz'))

        # 이미 캐시된 경우 생략
        if os.path.exists(cache_path):
            print(f"✅ 존재함: {cache_path} → 건너뜀")
            continue

        try:
            mesh = trimesh.load(mesh_path, process=False)
            vertices = np.array(mesh.vertices)
            faces = np.array(mesh.faces)
            face_normals = np.array(mesh.face_normals)

            np.savez(cache_path, vertices=vertices, faces=faces, face_normals=face_normals)
            print(f"💾 저장됨: {cache_path}")
        except Exception as e:
            print(f"❌ 실패: {fname} → {e}")

# 🔧 경로 지정
source_mesh_dir = r"C:\Users\konyang\Desktop\MeshCNN_TF\data\dataset\simplified_mesh\test"
target_cache_dir = r"C:\Users\konyang\Desktop\MeshCNN_TF\data\dataset\cached_mesh\test"

# 실행
generate_npz_cache(source_mesh_dir, target_cache_dir)


In [8]:
class LungNoduleDataset(Dataset):
    def __init__(self, data_dir, split='train', cache_dir=None):
        self.mesh_files = []
        self.face_labels = []
        self.class_labels = []
        self.cache_dir = cache_dir

        mesh_dir = os.path.join(data_dir, split)
        label_dir = os.path.join(data_dir.replace("simplified_mesh", "simplified_label"), split)

        if self.cache_dir:
            os.makedirs(os.path.join(self.cache_dir, split), exist_ok=True)

        for fname in os.listdir(mesh_dir):
            if fname.endswith('.obj') or fname.endswith('.stl') or fname.endswith('.ply'):
                mesh_path = os.path.join(mesh_dir, fname)
                label_path = os.path.join(label_dir, fname.split('.')[0] + "_label.npy")

                if os.path.exists(label_path):
                    face_label = np.load(label_path)
                    class_label = 1 if face_label.sum() > 0 else 0
                else:
                    face_label = None
                    class_label = 0

                self.mesh_files.append((mesh_path, label_path))
                self.class_labels.append(class_label)

    def __len__(self):
        return len(self.mesh_files)

    def __getitem__(self, idx):
        mesh_path, label_path = self.mesh_files[idx]
        fname = os.path.splitext(os.path.basename(mesh_path))[0]
        
        try:
            # -----------------------------------------
            # ✅ 1. 캐시 불러오기 or 생성
            # -----------------------------------------
            if self.cache_dir:
                cached_path = os.path.join(self.cache_dir, fname + ".npz")
                if os.path.exists(cached_path):
                    data = np.load(cached_path)
                    vertices = data['vertices']
                    faces = data['faces']
                    face_normals = data['face_normals']
                else:
                    mesh = trimesh.load(mesh_path)
                    vertices = np.array(mesh.vertices)
                    faces = np.array(mesh.faces)
                    face_normals = np.array(mesh.face_normals)
                    np.savez(cached_path, vertices=vertices, faces=faces, face_normals=face_normals)
            else:
                mesh = trimesh.load(mesh_path)
                vertices = np.array(mesh.vertices)
                faces = np.array(mesh.faces)
                face_normals = np.array(mesh.face_normals)

            # -----------------------------------------
            # ✅ 2. 레이블 불러오기
            # -----------------------------------------
            if os.path.exists(label_path):
                face_label = np.load(label_path)
            else:
                face_label = None

             # 2. 간선 및 간선->면 매핑 생성
            edge_to_faces = {}
            for f_idx, face in enumerate(faces):
                # 삼각형 face의 세 변 (정렬하여 tuple로 사용)
                for e in [(face[0], face[1]), (face[1], face[2]), (face[2], face[0])]:
                    e_sorted = tuple(sorted(e))
                    if e_sorted not in edge_to_faces:
                        edge_to_faces[e_sorted] = []
                    edge_to_faces[e_sorted].append(f_idx)
            edges = list(edge_to_faces.keys())               # 간선 리스트 (고유 간선)

            # 3. 간선 이웃 정보 계산 (각 간선당 최대 4개 이웃 간선)
            # 이웃 간선 정의: 하나의 꼭짓점을 공유하는 간선 (1-링 이웃)
            vert_to_edges = {v: [] for v in range(len(vertices))}
            for e_idx, (v1, v2) in enumerate(edges):
                vert_to_edges[v1].append(e_idx)
                vert_to_edges[v2].append(e_idx)
            neighbors_list = []
            for e_idx, (v1, v2) in enumerate(edges):
                neighbor_set = set(vert_to_edges[v1] + vert_to_edges[v2])
                neighbor_set.discard(e_idx)  # 자기 자신은 제외
                neighbors = list(neighbor_set)
                # 최대 4개까지 이웃 간선을 선택 (많으면 자르기)
                neighbors = neighbors[:4]
                # 4개 미만이면 자기 자신으로 패딩하여 크기 고정
                while len(neighbors) < 4:
                    neighbors.append(e_idx)
                neighbors_list.append(neighbors)
            neighbor_index = torch.tensor(neighbors_list, dtype=torch.long)

            # 4. 간선 특징 계산 (5차원: dihedral, inner1, inner2, ratio1, ratio2)
            edge_features = []
            for e_idx, (v1, v2) in enumerate(edges):
                p1, p2 = vertices[v1], vertices[v2]
                # 간선 길이
                edge_length = np.linalg.norm(p1 - p2)
                # 간선을 공유하는 면 목록
                faces_indices = edge_to_faces[(v1, v2)]
                # Dihedral angle 계산
                if len(faces_indices) == 2:
                    f1, f2 = faces_indices[0], faces_indices[1]
                    n1, n2 = face_normals[f1], face_normals[f2]
                    cos_theta = np.dot(n1, n2) / (np.linalg.norm(n1)*np.linalg.norm(n2) + 1e-8)
                    cos_theta = np.clip(cos_theta, -1.0, 1.0)
                    dihedral = np.arccos(cos_theta)
                else:
                    # 인접 면이 하나뿐인 경우 (경계 간선) - dihedral을 0으로 처리
                    dihedral = 0.0

                # 첫 번째 면의 대향각 및 비율
                inner1 = 0.0; ratio1 = 0.0
                if len(faces_indices) > 0:
                    f1 = faces_indices[0]
                    # 간선의 반대쪽 꼭짓점
                    face_v = faces[f1]
                    # 간선을 이루는 두 꼭짓점을 제외한 나머지 한 꼭짓점
                    opp_v = [v for v in face_v if v not in (v1, v2)][0]
                    vec1 = vertices[v1] - vertices[opp_v]
                    vec2 = vertices[v2] - vertices[opp_v]
                    cos_inner = np.dot(vec1, vec2) / ((np.linalg.norm(vec1)*np.linalg.norm(vec2)) + 1e-8)
                    cos_inner = np.clip(cos_inner, -1.0, 1.0)
                    inner1 = np.arccos(cos_inner)
                    # 높이 계산 (삼각형 면적 이용)
                    area = np.linalg.norm(np.cross(vec1, vec2)) / 2.0
                    height = (2.0 * area) / (edge_length + 1e-8)
                    ratio1 = edge_length / (height + 1e-8)
                # 두 번째 면의 대향각 및 비율 (없으면 첫 번째 면 값으로 대체)
                inner2 = inner1; ratio2 = ratio1
                if len(faces_indices) == 2:
                    f2 = faces_indices[1]
                    face_v2 = faces[f2]
                    opp_v2 = [v for v in face_v2 if v not in (v1, v2)][0]
                    vec3 = vertices[v1] - vertices[opp_v2]
                    vec4 = vertices[v2] - vertices[opp_v2]
                    cos_inner2 = np.dot(vec3, vec4) / ((np.linalg.norm(vec3)*np.linalg.norm(vec4)) + 1e-8)
                    cos_inner2 = np.clip(cos_inner2, -1.0, 1.0)
                    inner2 = np.arccos(cos_inner2)
                    area2 = np.linalg.norm(np.cross(vec3, vec4)) / 2.0
                    height2 = (2.0 * area2) / (edge_length + 1e-8)
                    ratio2 = edge_length / (height2 + 1e-8)
                # 특징 벡터 구성
                edge_features.append([dihedral, inner1, inner2, ratio1, ratio2])

            edge_features = torch.tensor(edge_features, dtype=torch.float32)

            # 5. 분할 레이블 (간선 단위) 생성
            if face_label is not None:
                face_labels_arr = face_label # numpy array of shape (F,)
                edge_labels = []
                for e_idx, (v1, v2) in enumerate(edges):
                    faces_indices = edge_to_faces[(v1, v2)]
                    # 해당 간선에 인접한 face 중 하나라도 결절(1)인 경우 간선 레이블 1
                    label = 0
                    for f_idx in faces_indices:
                        if face_labels_arr[f_idx] == 1:
                            label = 1
                            break
                    edge_labels.append(label)
                edge_labels = torch.tensor(edge_labels, dtype=torch.long)
            else:
                # face 레이블이 없는 경우 (분할 라벨이 없으면 모두 0으로 처리)
                edge_labels = torch.zeros(len(edges), dtype=torch.long)

            # 6. 클래스 레이블 (결절 유무)
            class_label = torch.tensor(self.class_labels[idx], dtype=torch.long)

            # 결과 반환: 특징, 이웃정보, 클래스 레이블, 간선 레이블, (필요하면 원본 메쉬 정보도 반환 가능)
            sample = {
                'edge_features': edge_features,          # 계산된 간선 특징 (E, 5)
                'neighbor_index': neighbor_index,        # 이웃 간선 인덱스 (E, 4)
                'class_label': class_label,              # 이진 클래스 레이블 (0 or 1)
                'edge_labels': edge_labels,              # 각 간선에 대한 분할 레이블 (E,)
                'vertices': vertices,                    # (V, 3) 시각화용
                'faces': faces                           # (F, 3) 시각화용
            }
            return sample
        except Exception as e:
            print(f"❌ __getitem__() error at idx {idx}: {e}")
            
            return {
                'edge_features': torch.empty(0),
                'neighbor_index': torch.empty(0, dtype=torch.long),
                'class_label': torch.tensor(0, dtype=torch.long),
                'edge_labels': torch.empty(0, dtype=torch.long),
                'vertices': np.zeros((0, 3)),
                'faces': np.zeros((0, 3))
            }

# 데이터셋 및 데이터로더 초기화 예시 (파일 경로는 실제 데이터셋에 맞게 수정해야 함)
# 경로 설정
data_dir = r"C:\Users\konyang\Desktop\MeshCNN_TF\data\dataset\simplified_mesh"
cache_dir = r"C:\Users\konyang\Desktop\MeshCNN_TF\data\dataset\cached_mesh"

# Dataset 객체 생성 (캐시 경로 포함)
train_dataset = LungNoduleDataset(
    data_dir=data_dir,
    split='train',
    cache_dir=cache_dir
)

# DataLoader 설정 (병렬 처리 권장)
train_loader = DataLoader(
    train_dataset,
    batch_size=1,
    shuffle=True,
    num_workers=0,  # ✅ 워커 병렬 처리 끄기
    pin_memory=True
)

# ✅ 검증용 데이터셋 생성
val_dataset = LungNoduleDataset(
    data_dir=data_dir,
    split='val',
    cache_dir=cache_dir
)

val_loader = DataLoader(
    val_dataset,
    batch_size=1,
    shuffle=False,  # 검증은 shuffle하지 않음
    num_workers=0,
    pin_memory=True
)

In [9]:
class MeshConv(nn.Module):
    """메쉬 간선 합성곱 레이어: 각 간선과 이웃한 4개 간선 (총5개)의 특징을 통합해 출력."""
    def __init__(self, in_channels, out_channels):
        super(MeshConv, self).__init__()
        # 5 * in_channels 크기의 입력을 out_channels로 변환
        self.linear = nn.Linear(in_channels * 5, out_channels)
    
    def forward(self, x, neighbor_index):
        # x: (E, in_channels)
        # neighbor_index: (E, 4)

        E, in_channels = x.shape
        device = x.device

        # (E, 1, in_channels): 자기 자신
        x_self = x.unsqueeze(1)

        # (E, 4) → (E, 4, in_channels): neighbor 인덱스를 gather하기 위해 확장
        neighbor_index_expanded = neighbor_index.unsqueeze(-1).expand(-1, -1, in_channels)

        # x → (1, E, in_channels): gather dim=1을 위해 확장
        x_expanded = x.unsqueeze(0)

        # gather를 사용해 neighbor feature 추출: (1, E, C)에서 gather (E, 4, C)
        x_neighbors = torch.gather(x_expanded.expand(E, -1, -1), dim=1, index=neighbor_index_expanded)

        # concat: (E, 5, in_channels)
        x_combined = torch.cat([x_self, x_neighbors], dim=1)

        # Flatten: (E, 5 * in_channels)
        x_flat = x_combined.view(E, -1)

        out = self.linear(x_flat)
        return out  # (E, out_channels)

class MeshPool(nn.Module):
    """메쉬 풀링 레이어: 간선 수를 감소 (2:1 비율로 클러스터)."""
    def __init__(self):
        super(MeshPool, self).__init__()
        self.cluster_map = None  # 나중에 unpool에 사용하기 위한 매핑
    
    def forward(self, x, neighbor_index):
        # x: (N, C) 입력 간선 특징, N = 현재 간선 수
        N = x.size(0)
        # 클러스터 구성: 2개 간선씩 묶음 (홀수개일 경우 마지막은 단독 클러스터)
        if N % 2 == 0:
            new_N = N // 2
        else:
            new_N = N // 2 + 1
        # 출력 텐서 초기화
        device = x.device
        C = x.size(1)
        x_pooled = torch.zeros((new_N, C), dtype=x.dtype, device=device)
        # cluster_map: 길이 N 리스트로, 각 간선이 속한 클러스터 인덱스
        cluster_map = [-1] * N
        # 2개씩 평균
        for i in range(N // 2):
            idx1 = 2 * i
            idx2 = 2 * i + 1
            x_pooled[i] = 0.5 * (x[idx1] + x[idx2])
            cluster_map[idx1] = i
            cluster_map[idx2] = i
        if N % 2 == 1:
            # 마지막 간선은 단독 클러스터
            x_pooled[new_N - 1] = x[N - 1]
            cluster_map[N - 1] = new_N - 1
        # cluster_map을 Tensor로 저장
        cluster_map = torch.tensor(cluster_map, dtype=torch.long, device=device)
        self.cluster_map = cluster_map
        
        # 새로운 neighbor_index 계산 (클러스터 그래프의 이웃)
        neighbors_coarse = []
        # 원래 간선의 neighbor_index로부터 클러스터 간 이웃을 유추
        for edge_idx in range(N):
            cluster_idx = cluster_map[edge_idx].item()
            # 원래 간선의 이웃들의 클러스터 인덱스
            for nbr_edge in neighbor_index[edge_idx]:
                nbr_cluster = cluster_map[nbr_edge].item()
                if nbr_cluster != cluster_idx:
                    # 리스트 길이를 cluster_idx에 맞춰 확장
                    if cluster_idx >= len(neighbors_coarse):
                        neighbors_coarse.extend([set() for _ in range(cluster_idx - len(neighbors_coarse) + 1)])
                    neighbors_coarse[cluster_idx].add(nbr_cluster)
        # 집합을 리스트로 변환하고 4개로 패딩
        for i in range(len(neighbors_coarse)):
            nbrs = list(neighbors_coarse[i])
            nbrs = nbrs[:4]  # 최대 4개까지 사용
            while len(nbrs) < 4:
                nbrs.append(i)  # 자기 자신으로 패딩
            neighbors_coarse[i] = nbrs
        # 만약 어떤 클러스터에 대해 neighbor_set이 비어있으면 자기 자신 4개로
        if len(neighbors_coarse) < new_N:
            # 빈 클러스터 neighbor 세트 초기화
            for i in range(len(neighbors_coarse), new_N):
                neighbors_coarse.append([i, i, i, i])
        neighbor_index_coarse = torch.tensor(neighbors_coarse, dtype=torch.long, device=device)
        
        return x_pooled, neighbor_index_coarse

class MeshUnpool(nn.Module):
    """메쉬 업풀링 레이어: 풀링된 간선 특징을 원래 개수로 복원."""
    def forward(self, x_coarse, cluster_map):
        # x_coarse: (M, C) 풀링된 간선 특징, cluster_map: (N,) 각 원래 간선 -> 클러스터 인덱스 매핑
        device = x_coarse.device
        N = cluster_map.shape[0]  # 원래 간선 개수
        C = x_coarse.shape[1]
        # cluster_map을 이용해 각 원래 간선의 특징 할당
        x_reconstructed = torch.zeros((N, C), dtype=x_coarse.dtype, device=device)
        # cluster_map의 각 인덱스를 순회하며 할당
        for orig_edge_idx, cluster_idx in enumerate(cluster_map):
            x_reconstructed[orig_edge_idx] = x_coarse[cluster_idx]
        return x_reconstructed

# 이제 Encoder, Decoder, Classifier, Segmenter를 정의
class MeshEncoder(nn.Module):
    def __init__(self, in_channels=5):
        super(MeshEncoder, self).__init__()
        # 채널 설정: conv1_out=16, conv2_out=32, conv3_out=64 (예시)
        self.conv1 = MeshConv(in_channels, 16)
        self.pool1 = MeshPool()
        self.conv2 = MeshConv(16, 32)
        self.pool2 = MeshPool()
        self.conv3 = MeshConv(32, 64)
    
    def forward(self, x, neighbor_index):
        # conv1 + ReLU
        x1 = F.relu(self.conv1(x, neighbor_index))
        # pool1
        x_pooled1, neighbor_coarse1 = self.pool1(x1, neighbor_index)
        # conv2 + ReLU on pooled1
        x2 = F.relu(self.conv2(x_pooled1, neighbor_coarse1))
        # pool2
        x_pooled2, neighbor_coarse2 = self.pool2(x2, neighbor_coarse1)
        # conv3 + ReLU on pooled2 (encoder 최종)
        x3 = F.relu(self.conv3(x_pooled2, neighbor_coarse2))
        # Encoder 결과와 중간 결과 반환 (스킵 연결 및 Decoder에 필요)
        return {
            'x1': x1,                       # original level features (E edges, 16ch)
            'neighbor0': neighbor_index, 
            'x2': x2,                       # half level features (E/2 edges, 32ch)
            'neighbor1': neighbor_coarse1,
            'x3': x3,                       # quarter level features (E/4 edges, 64ch)
            'neighbor2': neighbor_coarse2,
            'pool1': self.pool1,
            'pool2': self.pool2
        }

class MeshDecoder(nn.Module):
    def __init__(self):
        super(MeshDecoder, self).__init__()
        self.conv2_up = MeshConv(32 + 64, 32)
        self.conv1_up = MeshConv(16 + 32, 16)

    def forward(self, enc_out):
        x2 = enc_out['x2']
        x3 = enc_out['x3']
        neighbor1 = enc_out['neighbor1']
        neighbor0 = enc_out['neighbor0']
        pool1_layer = enc_out['pool1']
        pool2_layer = enc_out['pool2']

        # Unpool quarter → half
        x2_up = MeshUnpool().forward(x3.detach(), pool2_layer.cluster_map)  # (E/2, 64)

        # Concat with skip connection
        x2_cat = torch.cat([x2.detach(), x2_up], dim=1)  # (E/2, 96)
        x2_up_refined = F.relu(self.conv2_up(x2_cat, neighbor1))  # (E/2, 32)

        # Unpool half → original
        x1_up = MeshUnpool().forward(x2_up_refined.detach(), pool1_layer.cluster_map)  # (E, 32)

        x1 = enc_out['x1']
        x1_cat = torch.cat([x1.detach(), x1_up], dim=1)  # (E, 48)

        x1_up_refined = F.relu(self.conv1_up(x1_cat, neighbor0))  # (E, 16)

        return x1_up_refined

class ClassificationHead(nn.Module):
    def __init__(self, in_channels, hidden_dim=100, num_classes=2):
        super(ClassificationHead, self).__init__()
        # 단순한 MLP 분류기: [in_channels] -> [hidden_dim] -> [num_classes]
        self.fc1 = nn.Linear(in_channels, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, num_classes)
    
    def forward(self, x):
        # x: (in_channels) or (1, in_channels) 전역 특징 벡터
        x = F.relu(self.fc1(x))
        out = self.fc2(x)
        return out  # raw logits (num_classes)
    
class SegmentationHead(nn.Module):
    def __init__(self, in_channels, num_classes=2):
        super(SegmentationHead, self).__init__()
        self.linear = nn.Linear(in_channels, num_classes)
    def forward(self, x):
        # x: (E, in_channels) 모든 원래 간선에 대한 복원된 특징
        out = self.linear(x)  # (E, num_classes)
        return out

# 전체 모델 통합
class MedMeshNet(nn.Module):
    def __init__(self):
        super(MedMeshNet, self).__init__()
        self.encoder = MeshEncoder(in_channels=5)
        self.decoder = MeshDecoder()
        # classification은 encoder 최종 출력 채널 (예: 64)을 받아 이진 분류
        self.classifier = ClassificationHead(in_channels=64, hidden_dim=100, num_classes=2)
        # segmentation은 decoder 최종 출력 채널 (예: 16)을 받아 2-class 출력
        self.segmenter = SegmentationHead(in_channels=16, num_classes=2)
    
    def forward(self, edge_features, neighbor_index):
        # 1. Encoder: 특징 추출 및 다운샘플
        enc_out = self.encoder(edge_features, neighbor_index)
        # 2. Classification: Encoder 최종 특징들을 전역 요약하여 클래스 예측
        x_global = torch.max(enc_out['x3'], dim=0)[0]  # 전역 max 풀 (64차원 벡터)
        class_logits = self.classifier(x_global.unsqueeze(0))  # (1,2) 출력
        # 3. Decoder: 업샘플로 원래 해상도 특징 복원
        dec_out = self.decoder(enc_out)  # (E, 16)
        seg_logits = self.segmenter(dec_out)  # (E, 2)
        return class_logits, seg_logits


In [None]:
import time

# 모델 초기화
model = MedMeshNet().to(device)

# 손실 함수 정의
criterion_class = nn.CrossEntropyLoss()
class_weights = torch.tensor([1.0, 5.0], device=device)  # ✅ 결절 가중치 강화
criterion_seg = nn.CrossEntropyLoss(weight=class_weights)

# 옵티마이저 & 스케줄러
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)

# AMP용 스케일러
scaler = torch.cuda.amp.GradScaler()

# 검증용 val_loader 존재 확인
try:
    val_loader
except NameError:
    val_loader = None

# ---------------------------
# AMP 적용 학습 루프 시작
# ---------------------------
num_epochs = 10
total_steps = len(train_loader)
val_steps = len(val_loader)

for epoch in range(1, num_epochs + 1):
    model.train()
    total_loss = 0.0
    correct_train = 0
    total_train = 0
    print(f"\n========== [Epoch {epoch}] ==========")

    start_time = time.time()

    for i, batch in enumerate(train_loader):
        step_start = time.time()

        features = batch['edge_features'].to(device).squeeze(0)
        if features.numel() == 0:
            continue

        neighbor_index = batch['neighbor_index'].to(device).squeeze(0)
        class_label = batch['class_label'].to(device)
        seg_label = batch['edge_labels'].to(device).squeeze(0)

        if features.shape[0] > 25000:
            continue

        optimizer.zero_grad()
        with torch.cuda.amp.autocast():
            class_logits, seg_logits = model(features, neighbor_index)
            loss_class = criterion_class(class_logits, class_label)
            loss_seg = criterion_seg(seg_logits, seg_label)
            loss = loss_class + loss_seg

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        total_loss += loss.item()
        pred_class = torch.argmax(class_logits, dim=1)
        correct_train += (pred_class == class_label).sum().item()
        total_train += class_label.size(0)

        # Keras-style 출력
        avg_loss = total_loss / (i + 1)
        train_acc = correct_train / total_train if total_train > 0 else 0
        elapsed = time.time() - step_start
        print(f"{i+1}/{total_steps} [{(i+1)/total_steps:.0%}] - {elapsed:.1f}s/step - loss: {avg_loss:.4f} - acc: {train_acc:.4f}", end="\r")

    total_elapsed = time.time() - start_time
    print(f"\n✅ Epoch {epoch} 완료 - 평균 손실: {avg_loss:.4f}, 정확도: {train_acc*100:.2f}%, 시간: {total_elapsed:.1f}s")

    scheduler.step()

    # ---------------------------
    # 검증 단계 (no_grad만 사용)
    # ---------------------------
    if val_loader is not None:
        model.eval()
        correct = 0
        total_samples = 0
        inter = 0
        union = 0
        val_loss = 0.0
        val_start = time.time()

        for j, batch in enumerate(val_loader):
            step_start = time.time()

            features = batch['edge_features'].to(device).squeeze(0)
            if features.numel() == 0:
                continue
            neighbor_index = batch['neighbor_index'].to(device).squeeze(0)
            class_label = batch['class_label'].to(device)
            seg_label = batch['edge_labels'].to(device).squeeze(0)

            with torch.no_grad():
                class_logits, seg_logits = model(features, neighbor_index)
                loss_class = criterion_class(class_logits, class_label)
                loss_seg = criterion_seg(seg_logits, seg_label)
                loss = loss_class + loss_seg
                val_loss += loss.item()

                pred_class = torch.argmax(class_logits, dim=1)
                correct += (pred_class == class_label).sum().item()
                total_samples += class_label.size(0)

                pred_seg = torch.argmax(seg_logits, dim=1)
                num_classes = seg_logits.shape[1]
                for cls in range(num_classes):
                    inter += ((pred_seg == cls) & (seg_label == cls)).sum().item()
                    union += ((pred_seg == cls) | (seg_label == cls)).sum().item()

            # 실시간 출력
            avg_vloss = val_loss / (j + 1)
            val_acc = correct / total_samples if total_samples > 0 else 0
            elapsed = time.time() - step_start
            print(f"{j+1}/{val_steps} [{(j+1)/val_steps:.0%}] - {elapsed:.1f}s/step - val_loss: {avg_vloss:.4f} - val_acc: {val_acc:.4f}", end='\r')

        val_elapsed = time.time() - val_start
        val_acc = correct / total_samples if total_samples > 0 else 0
        val_iou = inter / (union + 1e-8) if union > 0 else 0
        print(f"\n🧪 검증 완료 - 평균 손실: {avg_vloss:.4f}, 정확도: {val_acc * 100:.2f}%, IoU: {val_iou:.4f}, 시간: {val_elapsed:.1f}s")



244/244 [100%] - 10.2s/step - loss: 0.7902 - acc: 0.6808
✅ Epoch 1 완료 - 평균 손실: 0.7902, 정확도: 68.08%, 시간: 3024.9s
70/70 [100%] - 8.3s/step - val_loss: 0.6880 - val_acc: 0.7286
🧪 검증 완료 - 평균 손실: 0.6880, 정확도: 72.86%, IoU: 0.9941, 시간: 790.5s

244/244 [100%] - 10.0s/step - loss: 0.6100 - acc: 0.7183
✅ Epoch 2 완료 - 평균 손실: 0.6100, 정확도: 71.83%, 시간: 3007.7s
70/70 [100%] - 8.3s/step - val_loss: 0.6813 - val_acc: 0.7286
🧪 검증 완료 - 평균 손실: 0.6813, 정확도: 72.86%, IoU: 0.9941, 시간: 791.4s

244/244 [100%] - 10.4s/step - loss: 0.5783 - acc: 0.7371
✅ Epoch 3 완료 - 평균 손실: 0.5783, 정확도: 73.71%, 시간: 3014.9s
70/70 [100%] - 8.3s/step - val_loss: 0.6613 - val_acc: 0.7286
🧪 검증 완료 - 평균 손실: 0.6613, 정확도: 72.86%, IoU: 0.9941, 시간: 791.5s

243/244 [100%] - 10.4s/step - loss: 0.5837 - acc: 0.7465
✅ Epoch 4 완료 - 평균 손실: 0.5837, 정확도: 74.65%, 시간: 3028.1s
70/70 [100%] - 8.3s/step - val_loss: 0.7452 - val_acc: 0.7143
🧪 검증 완료 - 평균 손실: 0.7452, 정확도: 71.43%, IoU: 0.9941, 시간: 791.9s

244/244 [100%] - 11.7s/step - loss: 0.5943 - acc: 0

In [None]:
def compute_classification_accuracy(model, data_loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for batch in data_loader:
            features = batch['edge_features'].to(device)
            neighbor_index = batch['neighbor_index'].to(device)
            labels = batch['class_label'].to(device)
            class_logits, _ = model(features, neighbor_index)
            pred = torch.argmax(class_logits, dim=1)  # 예측 클래스
            if pred.item() == labels.item():
                correct += 1
            total += 1
    acc = correct / total if total > 0 else 0
    return acc

def compute_segmentation_iou(model, data_loader):
    model.eval()
    inter = 0
    union = 0
    with torch.no_grad():
        for batch in data_loader:
            features = batch['edge_features'].to(device)
            neighbor_index = batch['neighbor_index'].to(device)
            true_seg = batch['edge_labels'].to(device)
            _, seg_logits = model(features, neighbor_index)
            pred_seg = torch.argmax(seg_logits, dim=1)
            # 결절 클래스(1)에 대한 IoU 계산
            inter += torch.logical_and(pred_seg == 1, true_seg == 1).sum().item()
            union += torch.logical_or(pred_seg == 1, true_seg == 1).sum().item()
    iou = inter / (union + 1e-8) if union > 0 else 0
    return iou

# 예시: 학습 완료 후 테스트 세트에 대한 성능 측정

test_acc = compute_classification_accuracy(model, test_loader)
test_iou = compute_segmentation_iou(model, test_loader)
print(f"테스트 정확도: {test_acc:.4f}, 테스트 IoU: {test_iou:.4f}")


In [None]:
import open3d as o3d
import trimesh
import numpy as np
import torch
import os

# ✅ 모델 준비
model.eval()

# ✅ 파일 이름 지정
target_filename = "A0173_abnormal.obj"

# ✅ 경로 설정
original_mesh_dir = r"C:\Users\konyang\Desktop\MeshCNN_TF\data\dataset\mesh"
cache_dir = r"C:\Users\konyang\Desktop\MeshCNN_TF\data\dataset\cached_mesh\train"
npz_path = os.path.join(cache_dir, target_filename.replace(".obj", ".npz"))

# ✅ 캐시된 특징 불러오기
if not os.path.exists(npz_path):
    raise FileNotFoundError(f"{npz_path} 캐시 파일이 없습니다.")

data = np.load(npz_path)
vertices = data['vertices']
faces = data['faces']
face_normals = data['face_normals']

# ✅ 간선 및 이웃 인덱스 계산
edge_to_faces = {}
for f_idx, face in enumerate(faces):
    for e in [(face[0], face[1]), (face[1], face[2]), (face[2], face[0])]:
        e_sorted = tuple(sorted(e))
        edge_to_faces.setdefault(e_sorted, []).append(f_idx)
edges = list(edge_to_faces.keys())

# 이웃 인덱스 계산
vert_to_edges = {v: [] for v in range(len(vertices))}
for e_idx, (v1, v2) in enumerate(edges):
    vert_to_edges[v1].append(e_idx)
    vert_to_edges[v2].append(e_idx)

neighbors_list = []
for e_idx, (v1, v2) in enumerate(edges):
    neighbor_set = set(vert_to_edges[v1] + vert_to_edges[v2])
    neighbor_set.discard(e_idx)
    neighbors = list(neighbor_set)[:4]
    while len(neighbors) < 4:
        neighbors.append(e_idx)
    neighbors_list.append(neighbors)

neighbor_index = torch.tensor(neighbors_list, dtype=torch.long)

# 간선 특징 계산
edge_features = []
for e_idx, (v1, v2) in enumerate(edges):
    p1, p2 = vertices[v1], vertices[v2]
    edge_length = np.linalg.norm(p1 - p2)
    faces_indices = edge_to_faces[(v1, v2)]
    if len(faces_indices) == 2:
        f1, f2 = faces_indices[0], faces_indices[1]
        n1, n2 = face_normals[f1], face_normals[f2]
        cos_theta = np.dot(n1, n2) / (np.linalg.norm(n1)*np.linalg.norm(n2) + 1e-8)
        dihedral = np.arccos(np.clip(cos_theta, -1.0, 1.0))
    else:
        dihedral = 0.0

    def compute_inner_ratio(f_idx):
        face_v = faces[f_idx]
        opp_v = [v for v in face_v if v not in (v1, v2)][0]
        vec1 = vertices[v1] - vertices[opp_v]
        vec2 = vertices[v2] - vertices[opp_v]
        cos_inner = np.dot(vec1, vec2) / (np.linalg.norm(vec1)*np.linalg.norm(vec2) + 1e-8)
        inner = np.arccos(np.clip(cos_inner, -1.0, 1.0))
        area = np.linalg.norm(np.cross(vec1, vec2)) / 2.0
        height = (2.0 * area) / (edge_length + 1e-8)
        ratio = edge_length / (height + 1e-8)
        return inner, ratio

    inner1, ratio1 = compute_inner_ratio(faces_indices[0]) if faces_indices else (0.0, 0.0)
    inner2, ratio2 = compute_inner_ratio(faces_indices[1]) if len(faces_indices) == 2 else (inner1, ratio1)

    edge_features.append([dihedral, inner1, inner2, ratio1, ratio2])

edge_features = torch.tensor(edge_features, dtype=torch.float32)

# ✅ 모델 예측
edge_features = edge_features.to(device)
neighbor_index = neighbor_index.to(device)
with torch.no_grad():
    class_logits, seg_logits = model(edge_features, neighbor_index)
    pred_class = torch.argmax(class_logits, dim=1).item()
    pred_seg = torch.argmax(seg_logits, dim=1).cpu().numpy()

print("모델 예측:", "폐 결절 있음" if pred_class == 1 else "결절 없음")

# ✅ 결절로 예측된 face 추출
pred_nodule_faces = set()
for e_idx, edge in enumerate(edges):
    if pred_seg[e_idx] == 1:
        for f_idx in edge_to_faces[edge]:
            pred_nodule_faces.add(f_idx)

pred_nodule_faces = list(pred_nodule_faces)
print(f"예측된 결절 영역 면 개수: {len(pred_nodule_faces)}개")

# ✅ 원본 메쉬 로드 (고해상도 obj)
obj_path = os.path.join(original_mesh_dir, target_filename)
tm = trimesh.load(obj_path)
vertices = np.array(tm.vertices)
faces = np.array(tm.faces)

# 하이라이트할 정점
highlighted_vertices = set()
for f_idx in pred_nodule_faces:
    if f_idx < len(faces):
        highlighted_vertices.update(faces[f_idx])

# 색상 할당
colors = np.tile([0.7, 0.7, 0.7], (len(vertices), 1))  # 회색
for v in highlighted_vertices:
    if v < len(colors):
        colors[v] = [1.0, 0.0, 0.0]  # 빨강

# Open3D 메쉬 변환
mesh_o3d = o3d.geometry.TriangleMesh()
mesh_o3d.vertices = o3d.utility.Vector3dVector(vertices)
mesh_o3d.triangles = o3d.utility.Vector3iVector(faces)
mesh_o3d.vertex_colors = o3d.utility.Vector3dVector(colors)
mesh_o3d.compute_vertex_normals()

# ✅ 시각화
o3d.visualization.draw_geometries([mesh_o3d], window_name="결절 예측 시각화")


In [None]:
import torch
import numpy as np
from sklearn.metrics import classification_report, confusion_matrix, jaccard_score, accuracy_score

def evaluate_model(model, dataloader, device='cuda'):
    model.eval()
    
    all_preds_cls = []
    all_labels_cls = []

    all_preds_seg = []
    all_labels_seg = []

    with torch.no_grad():
        for batch in dataloader:
            features = batch['edge_features'].to(device).squeeze(0)
            neighbor_index = batch['neighbor_index'].to(device).squeeze(0)
            class_label = batch['class_label'].to(device)
            seg_label = batch['edge_labels'].to(device).squeeze(0)

            if features.numel() == 0 or features.shape[0] > 25000:
                continue

            class_logits, seg_logits = model(features, neighbor_index)

            # ✅ 분류 결과
            pred_class = torch.argmax(class_logits, dim=1)
            all_preds_cls.extend(pred_class.cpu().numpy())
            all_labels_cls.extend(class_label.cpu().numpy())

            # ✅ 세분화 결과
            pred_seg = torch.argmax(seg_logits, dim=1)
            all_preds_seg.extend(pred_seg.cpu().numpy())
            all_labels_seg.extend(seg_label.cpu().numpy())

    # ▶ Classification 성능
    print("🔎 [Classification Results]")
    print(classification_report(all_labels_cls, all_preds_cls, target_names=["Normal", "Nodule"]))
    print("Confusion Matrix:\n", confusion_matrix(all_labels_cls, all_preds_cls))

    # ▶ Segmentation 성능 (IoU, 정확도)
    print("\n🔎 [Segmentation Results]")
    acc_seg = accuracy_score(all_labels_seg, all_preds_seg)
    iou_seg = jaccard_score(all_labels_seg, all_preds_seg, average='binary')
    print(f"Segmentation Accuracy: {acc_seg * 100:.2f}%")
    print(f"Segmentation IoU: {iou_seg:.4f}")

evaluate_model(model, val_loader, device=device)

하이퍼파라미터 조정, 데이터 증강, 평가 지표 추가 및 최적화 등 추가 작업 필요