In [None]:
# MedMeshCNN - TensorFlow Implementation
# Jupyter Notebook으로 각 셀별로 실행 가능하도록 구성

# =============================================================================
# 필요한 라이브러리 import
# =============================================================================

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, models
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
import os
import json
from typing import List, Tuple, Dict, Optional
import warnings
warnings.filterwarnings('ignore')

print(f"TensorFlow version: {tf.__version__}")
print(f"GPU Available: {tf.config.list_physical_devices('GPU')}")

TensorFlow version: 2.18.0
GPU Available: [PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]


In [None]:

# =============================================================================
# 데이터 로딩 및 전처리 유틸리티
# =============================================================================

class MeshDataLoader:
    """메쉬 데이터 로딩 및 전처리를 위한 클래스"""

    def __init__(self):
        self.supported_formats = ['.obj', '.ply', '.stl', '.off', '.3mf']
    
    def load_mesh_data(self, file_path: str) -> Dict:
        """개별 메쉬 파일 로딩"""
        file_path = Path(file_path)
        
        if not file_path.exists():
            raise FileNotFoundError(f"메쉬 파일을 찾을 수 없습니다: {file_path}")
        
        if file_path.suffix.lower() not in self.supported_formats:
            raise ValueError(f"지원하지 않는 파일 형식입니다. 지원 형식: {self.supported_formats}")
        
        try:
            # Method 1: trimesh 사용 (권장)
            mesh_data = self._load_with_trimesh(file_path)
            print(f"✅ Trimesh로 로딩 성공: {file_path.name}")
            return mesh_data
            
        except Exception as e:
            print(f"⚠️ Trimesh 로딩 실패: {e}")
            try:
                # Method 2: Open3D 사용 (백업)
                mesh_data = self._load_with_open3d(file_path)
                print(f"✅ Open3D로 로딩 성공: {file_path.name}")
                return mesh_data
                
            except Exception as e2:
                print(f"❌ Open3D 로딩도 실패: {e2}")
                # Method 3: 수동 파싱 (OBJ 파일만)
                if file_path.suffix.lower() == '.obj':
                    mesh_data = self._load_obj_manual(file_path)
                    print(f"✅ 수동 파싱으로 로딩 성공: {file_path.name}")
                    return mesh_data
                else:
                    raise RuntimeError(f"모든 로딩 방법 실패: {file_path}")
    
    def _load_with_trimesh(self, file_path: Path) -> Dict:
        """Trimesh 라이브러리를 사용한 로딩"""
        mesh = trimesh.load(str(file_path))
        
        # 여러 메쉬가 있는 경우 첫 번째 메쉬 사용
        if isinstance(mesh, trimesh.Scene):
            mesh = list(mesh.geometry.values())[0]
        
        vertices = np.array(mesh.vertices, dtype=np.float32)
        faces = np.array(mesh.faces, dtype=np.int32)
        
        # 면 법선 계산
        face_normals = np.array(mesh.face_normals, dtype=np.float32)
        
        # 엣지 특성 계산
        edge_features = self._compute_edge_features(vertices, faces)
        
        mesh_data = {
            'vertices': vertices,  # N x 3
            'faces': faces,        # M x 3
            'features': edge_features,  # M x 5 (면적, 각도 등)
            'normals': face_normals,    # M x 3
            'metadata': {
                'vertex_count': len(vertices),
                'face_count': len(faces),
                'is_watertight': mesh.is_watertight,
                'bounds': mesh.bounds
            }
        }
        
        return mesh_data
    
    def _load_with_open3d(self, file_path: Path) -> Dict:
        """Open3D 라이브러리를 사용한 로딩"""
        mesh = o3d.io.read_triangle_mesh(str(file_path))
        
        if len(mesh.vertices) == 0:
            raise ValueError("빈 메쉬 파일입니다.")
        
        vertices = np.asarray(mesh.vertices, dtype=np.float32)
        faces = np.asarray(mesh.triangles, dtype=np.int32)
        
        # 면 법선 계산
        mesh.compute_triangle_normals()
        face_normals = np.asarray(mesh.triangle_normals, dtype=np.float32)
        
        # 엣지 특성 계산
        edge_features = self._compute_edge_features(vertices, faces)
        
        mesh_data = {
            'vertices': vertices,
            'faces': faces,
            'features': edge_features,
            'normals': face_normals,
            'metadata': {
                'vertex_count': len(vertices),
                'face_count': len(faces),
                'bounds': [vertices.min(axis=0), vertices.max(axis=0)]
            }
        }
        
        return mesh_data
    
    def _load_obj_manual(self, file_path: Path) -> Dict:
        """OBJ 파일 수동 파싱"""
        vertices = []
        faces = []
        
        with open(file_path, 'r') as f:
            for line in f:
                line = line.strip()
                if line.startswith('v '):  # 정점
                    coords = list(map(float, line.split()[1:4]))
                    vertices.append(coords)
                elif line.startswith('f '):  # 면
                    # f v1/vt1/vn1 v2/vt2/vn2 v3/vt3/vn3 형식 처리
                    face_data = line.split()[1:]
                    face_indices = []
                    for vertex_data in face_data:
                        # v/vt/vn 또는 v//vn 또는 v 형식
                        vertex_index = int(vertex_data.split('/')[0]) - 1  # OBJ는 1-indexed
                        face_indices.append(vertex_index)
                    
                    # 삼각형이 아닌 경우 삼각분할
                    if len(face_indices) == 3:
                        faces.append(face_indices)
                    elif len(face_indices) > 3:
                        # 팬 삼각분할
                        for i in range(1, len(face_indices) - 1):
                            faces.append([face_indices[0], face_indices[i], face_indices[i + 1]])
        
        vertices = np.array(vertices, dtype=np.float32)
        faces = np.array(faces, dtype=np.int32)
        
        # 면 법선 계산
        face_normals = self._compute_face_normals(vertices, faces)
        
        # 엣지 특성 계산
        edge_features = self._compute_edge_features(vertices, faces)
        
        mesh_data = {
            'vertices': vertices,
            'faces': faces,
            'features': edge_features,
            'normals': face_normals,
            'metadata': {
                'vertex_count': len(vertices),
                'face_count': len(faces),
                'bounds': [vertices.min(axis=0), vertices.max(axis=0)]
            }
        }
        
        return mesh_data
    
    def _compute_face_normals(self, vertices: np.ndarray, faces: np.ndarray) -> np.ndarray:
        """면 법선 벡터 계산"""
        face_normals = []
        
        for face in faces:
            v0, v1, v2 = vertices[face]
            edge1 = v1 - v0
            edge2 = v2 - v0
            normal = np.cross(edge1, edge2)
            
            # 정규화
            norm = np.linalg.norm(normal)
            if norm > 0:
                normal = normal / norm
            else:
                normal = np.array([0.0, 0.0, 1.0])  # 기본값
            
            face_normals.append(normal)
        
        return np.array(face_normals, dtype=np.float32)
    
    def _compute_edge_features(self, vertices: np.ndarray, faces: np.ndarray) -> np.ndarray:
        """엣지/면 특성 계산"""
        features = []
        
        for face in faces:
            v0, v1, v2 = vertices[face]
            
            # 1. 면적
            edge1 = v1 - v0
            edge2 = v2 - v0
            area = 0.5 * np.linalg.norm(np.cross(edge1, edge2))
            
            # 2. 둘레
            perimeter = (np.linalg.norm(v1 - v0) + 
                        np.linalg.norm(v2 - v1) + 
                        np.linalg.norm(v0 - v2))
            
            # 3. 종횡비 (가장 긴 변 / 가장 짧은 변)
            edge_lengths = [
                np.linalg.norm(v1 - v0),
                np.linalg.norm(v2 - v1),
                np.linalg.norm(v0 - v2)
            ]
            aspect_ratio = max(edge_lengths) / (min(edge_lengths) + 1e-8)
            
            # 4. 각도 (가장 작은 내각)
            angles = []
            for i in range(3):
                v_curr = vertices[face[i]]
                v_prev = vertices[face[(i-1) % 3]]
                v_next = vertices[face[(i+1) % 3]]
                
                vec1 = v_prev - v_curr
                vec2 = v_next - v_curr
                
                cos_angle = np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2) + 1e-8)
                cos_angle = np.clip(cos_angle, -1.0, 1.0)
                angle = np.arccos(cos_angle)
                angles.append(angle)
            
            min_angle = min(angles)
            
            # 5. 중심점까지의 거리 (원점에서)
            centroid = (v0 + v1 + v2) / 3.0
            distance_to_origin = np.linalg.norm(centroid)
            
            features.append([area, perimeter, aspect_ratio, min_angle, distance_to_origin])
        
        return np.array(features, dtype=np.float32)
    
    def validate_mesh_data(self, mesh_data: Dict) -> bool:
        """메쉬 데이터 유효성 검사"""
        try:
            vertices = mesh_data['vertices']
            faces = mesh_data['faces']
            
            # 기본 형태 검사
            if vertices.shape[1] != 3 or faces.shape[1] != 3:
                print("❌ 잘못된 데이터 형태")
                return False
            
            # 인덱스 범위 검사
            if faces.max() >= len(vertices) or faces.min() < 0:
                print("❌ 면 인덱스가 정점 범위를 벗어남")
                return False
            
            # NaN/Inf 검사
            if np.any(np.isnan(vertices)) or np.any(np.isinf(vertices)):
                print("❌ 정점에 NaN 또는 Inf 값 존재")
                return False
            
            print("✅ 메쉬 데이터 유효성 검사 통과")
            return True
            
        except Exception as e:
            print(f"❌ 유효성 검사 중 오류: {e}")
            return False

# 사용 예시
if __name__ == "__main__":
    loader = MeshLoader()
    
    # 메쉬 파일 로딩
    try:
        mesh_data = loader.load_mesh_data("example.obj")
        
        if loader.validate_mesh_data(mesh_data):
            print(f"정점 수: {mesh_data['metadata']['vertex_count']}")
            print(f"면 수: {mesh_data['metadata']['face_count']}")
            print(f"특성 차원: {mesh_data['features'].shape}")
        
    except Exception as e:
        print(f"메쉬 로딩 실패: {e}")


In [None]:
# =============================================================================
# Cell 3: MeshConv Layer 정의
# =============================================================================

class MeshConvLayer(layers.Layer):
    """메쉬 합성곱 레이어"""

    def __init__(self,
                 out_channels: int,
                 kernel_size: int = 5,
                 stride: int = 1,
                 padding: str = 'same',
                 activation: str = 'relu',
                 **kwargs):
        super(MeshConvLayer, self).__init__(**kwargs)

        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.activation = activation

        # 1D 합성곱으로 메쉬의 엣지 특성들을 처리
        self.conv1d = layers.Conv1D(
            filters=out_channels,
            kernel_size=kernel_size,
            strides=stride,
            padding=padding,
            activation=activation,
            name=f'mesh_conv_{out_channels}'
        )

        self.batch_norm = layers.BatchNormalization()

    def call(self, inputs, training=None):
        x = self.conv1d(inputs)
        x = self.batch_norm(x, training=training)
        return x

    def get_config(self):
        config = super().get_config()
        config.update({
            'out_channels': self.out_channels,
            'kernel_size': self.kernel_size,
            'stride': self.stride,
            'padding': self.padding,
            'activation': self.activation
        })
        return config



In [None]:
# =============================================================================
# Cell 4: MeshPooling Layer 정의
# =============================================================================

class MeshPoolingLayer(layers.Layer):
    """메쉬 풀링 레이어 - 메쉬 해상도를 줄임"""

    def __init__(self,
                 pool_size: int = 2,
                 pool_type: str = 'max',
                 **kwargs):
        super(MeshPoolingLayer, self).__init__(**kwargs)

        self.pool_size = pool_size
        self.pool_type = pool_type

        if pool_type == 'max':
            self.pooling = layers.MaxPooling1D(pool_size=pool_size)
        elif pool_type == 'avg':
            self.pooling = layers.AveragePooling1D(pool_size=pool_size)
        else:
            raise ValueError(f"Unsupported pool_type: {pool_type}")

    def call(self, inputs):
        return self.pooling(inputs)

    def get_config(self):
        config = super().get_config()
        config.update({
            'pool_size': self.pool_size,
            'pool_type': self.pool_type
        })
        return config



In [None]:
# =============================================================================
# Cell 5: MedMeshCNN 모델 정의
# =============================================================================

class MedMeshCNN(keras.Model):
    """의료 메쉬 분류를 위한 CNN 모델"""

    def __init__(self,
                 num_classes: int,
                 input_features: int = 5,
                 conv_channels: List[int] = [64, 128, 256, 512],
                 fc_features: List[int] = [1024, 512],
                 dropout_rate: float = 0.5,
                 **kwargs):
        super(MedMeshCNN, self).__init__(**kwargs)

        self.num_classes = num_classes
        self.input_features = input_features
        self.conv_channels = conv_channels
        self.fc_features = fc_features
        self.dropout_rate = dropout_rate

        # 메쉬 합성곱 레이어들
        self.mesh_convs = []
        self.mesh_pools = []

        for i, channels in enumerate(conv_channels):
            self.mesh_convs.append(
                MeshConvLayer(
                    out_channels=channels,
                    kernel_size=5,
                    name=f'meshconv_{i+1}'
                )
            )
            self.mesh_pools.append(
                MeshPoolingLayer(
                    pool_size=2,
                    pool_type='max',
                    name=f'meshpool_{i+1}'
                )
            )

        # 전역 풀링
        self.global_pool = layers.GlobalMaxPooling1D()

        # 완전연결 레이어들
        self.fc_layers = []
        self.dropout_layers = []

        for i, features in enumerate(fc_features):
            self.fc_layers.append(
                layers.Dense(
                    features,
                    activation='relu',
                    name=f'fc_{i+1}'
                )
            )
            self.dropout_layers.append(
                layers.Dropout(dropout_rate, name=f'dropout_{i+1}')
            )

        # 분류 레이어
        self.classifier = layers.Dense(
            num_classes,
            activation='softmax' if num_classes > 2 else 'sigmoid',
            name='classifier'
        )

    def call(self, inputs, training=None):
        x = inputs

        # 메쉬 합성곱 및 풀링
        for conv, pool in zip(self.mesh_convs, self.mesh_pools):
            x = conv(x, training=training)
            x = pool(x)

        # 전역 풀링
        x = self.global_pool(x)

        # 완전연결 레이어들
        for fc, dropout in zip(self.fc_layers, self.dropout_layers):
            x = fc(x)
            x = dropout(x, training=training)

        # 분류
        outputs = self.classifier(x)

        return outputs

    def get_config(self):
        config = super().get_config()
        config.update({
            'num_classes': self.num_classes,
            'input_features': self.input_features,
            'conv_channels': self.conv_channels,
            'fc_features': self.fc_features,
            'dropout_rate': self.dropout_rate
        })
        return config

model = MedMeshCNN(num_classes=10)  # 예시: 10개의 클래스로 분류하는 모델
model.summary()

In [None]:
# =============================================================================
# Cell 6: 모델 빌드 및 컴파일
# =============================================================================

def build_medmeshcnn_model(num_classes: int = 3,
                          input_shape: Tuple = (1800, 5)) -> keras.Model:
    """MedMeshCNN 모델 생성 및 컴파일"""

    # 입력 레이어
    inputs = layers.Input(shape=input_shape, name='mesh_input')

    # MedMeshCNN 모델 인스턴스 생성
    model = MedMeshCNN(
        num_classes=num_classes,
        conv_channels=[64, 128, 256, 512],
        fc_features=[1024, 512],
        dropout_rate=0.5
    )

    # 모델 출력
    outputs = model(inputs)

    # 전체 모델 정의
    full_model = keras.Model(inputs=inputs, outputs=outputs, name='MedMeshCNN')

    # 모델 컴파일
    optimizer = keras.optimizers.Adam(learning_rate=0.001)

    if num_classes > 2:
        loss = 'sparse_categorical_crossentropy'
        metrics = ['accuracy']
    else:
        loss = 'binary_crossentropy'
        metrics = ['accuracy']

    full_model.compile(
        optimizer=optimizer,
        loss=loss,
        metrics=metrics
    )

    return full_model

# 모델 생성 예시
model = build_medmeshcnn_model(num_classes=10, input_shape=(1800, 5))
model.summary()


In [None]:
# =============================================================================
# Cell 7: 훈련 설정 및 콜백
# =============================================================================

def setup_training_callbacks(model_name: str = 'medmeshcnn',
                           patience: int = 10) -> List:
    """훈련용 콜백 설정"""

    callbacks = [
        # 조기 종료
        keras.callbacks.EarlyStopping(
            monitor='val_loss',
            patience=patience,
            restore_best_weights=True,
            verbose=1
        ),

        # 학습률 감소
        keras.callbacks.ReduceLROnPlateau(
            monitor='val_loss',
            factor=0.5,
            patience=5,
            min_lr=1e-7,
            verbose=1
        ),

        # 모델 체크포인트
        keras.callbacks.ModelCheckpoint(
            filepath=f'{model_name}_best.h5',
            monitor='val_accuracy',
            save_best_only=True,
            save_weights_only=False,
            verbose=1
        ),

        # CSV 로거
        keras.callbacks.CSVLogger(
            filename=f'{model_name}_training_log.csv',
            separator=',',
            append=False
        )
    ]

    return callbacks


In [None]:
# =============================================================================
# Cell 8: 데이터 생성 및 전처리 (예시)
# =============================================================================

def generate_sample_data(num_samples: int = 1000,
                        num_classes: int = 3) -> Tuple[np.ndarray, np.ndarray]:
    """샘플 메쉬 데이터 생성 (실제 데이터로 교체 필요)"""

    # 가상의 메쉬 특성 데이터 생성
    X = np.random.randn(num_samples, 1800, 5).astype(np.float32)

    # 랜덤 라벨 생성
    y = np.random.randint(0, num_classes, num_samples)

    print(f"Generated data shape: {X.shape}")
    print(f"Labels shape: {y.shape}")
    print(f"Number of classes: {num_classes}")

    return X, y

# 샘플 데이터 생성
X_sample, y_sample = generate_sample_data(num_samples=1000, num_classes=3)

# 훈련/검증 데이터 분할
X_train, X_val, y_train, y_val = train_test_split(
    X_sample, y_sample,
    test_size=0.2,
    random_state=42,
    stratify=y_sample
)

print(f"Training data: {X_train.shape}, {y_train.shape}")
print(f"Validation data: {X_val.shape}, {y_val.shape}")


Generated data shape: (1000, 1800, 5)
Labels shape: (1000,)
Number of classes: 3
Training data: (800, 1800, 5), (800,)
Validation data: (200, 1800, 5), (200,)


In [None]:
# =============================================================================
# Cell 9: 모델 훈련
# =============================================================================

def train_medmeshcnn(model: keras.Model,
                    X_train: np.ndarray,
                    y_train: np.ndarray,
                    X_val: np.ndarray,
                    y_val: np.ndarray,
                    epochs: int = 100,
                    batch_size: int = 32) -> keras.callbacks.History:
    """MedMeshCNN 모델 훈련"""

    # 콜백 설정
    callbacks = setup_training_callbacks()

    # 훈련 실행
    history = model.fit(
        X_train, y_train,
        batch_size=batch_size,
        epochs=epochs,
        validation_data=(X_val, y_val),
        callbacks=callbacks,
        verbose=1
    )

    return history

# 모델 훈련 실행
print("Starting model training...")
history = train_medmeshcnn(
    model=model,
    X_train=X_train,
    y_train=y_train,
    X_val=X_val,
    y_val=y_val,
    epochs=50,
    batch_size=16
)


Starting model training...
Epoch 1/50
[1m48/50[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 17ms/step - accuracy: 0.2925 - loss: 6.0628
Epoch 1: val_accuracy improved from -inf to 0.33000, saving model to medmeshcnn_best.h5




TypeError: Could not locate class 'MeshConvLayer'. Make sure custom classes are decorated with `@keras.saving.register_keras_serializable()`. Full object config: {'module': None, 'class_name': 'MeshConvLayer', 'config': {'name': 'meshconv_1', 'trainable': True, 'dtype': {'module': 'keras', 'class_name': 'DTypePolicy', 'config': {'name': 'float32'}, 'registered_name': None}, 'out_channels': 64, 'kernel_size': 5, 'stride': 1, 'padding': 'same', 'activation': 'relu'}, 'registered_name': 'MeshConvLayer', 'build_config': {'input_shape': [None, 1800, 5]}}

In [None]:
# =============================================================================
# Cell 10: 훈련 결과 시각화
# =============================================================================

def plot_training_history(history: keras.callbacks.History):
    """훈련 과정 시각화"""

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

    # 손실 함수 그래프
    ax1.plot(history.history['loss'], label='Training Loss')
    ax1.plot(history.history['val_loss'], label='Validation Loss')
    ax1.set_title('Model Loss')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.legend()
    ax1.grid(True)

    # 정확도 그래프
    ax2.plot(history.history['accuracy'], label='Training Accuracy')
    ax2.plot(history.history['val_accuracy'], label='Validation Accuracy')
    ax2.set_title('Model Accuracy')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Accuracy')
    ax2.legend()
    ax2.grid(True)

    plt.tight_layout()
    plt.show()

    # 최종 성능 출력
    final_train_acc = history.history['accuracy'][-1]
    final_val_acc = history.history['val_accuracy'][-1]

    print(f"Final Training Accuracy: {final_train_acc:.4f}")
    print(f"Final Validation Accuracy: {final_val_acc:.4f}")

# 훈련 히스토리 시각화
plot_training_history(history)


In [None]:
# =============================================================================
# Cell 11: 모델 평가 및 예측
# =============================================================================

def evaluate_model(model: keras.Model,
                  X_test: np.ndarray,
                  y_test: np.ndarray) -> Dict:
    """모델 평가"""

    # 모델 평가
    test_loss, test_accuracy = model.evaluate(X_test, y_test, verbose=0)

    # 예측 수행
    y_pred_proba = model.predict(X_test)
    y_pred = np.argmax(y_pred_proba, axis=1)

    print(f"Test Loss: {test_loss:.4f}")
    print(f"Test Accuracy: {test_accuracy:.4f}")

    return {
        'test_loss': test_loss,
        'test_accuracy': test_accuracy,
        'predictions': y_pred,
        'prediction_probabilities': y_pred_proba
    }

# 검증 데이터로 평가
results = evaluate_model(model, X_val, y_val)


In [None]:
# =============================================================================
# Cell 12: 모델 저장 및 로딩
# =============================================================================

def save_model(model: keras.Model, model_path: str = 'medmeshcnn_model'):
    """모델 저장"""

    # 전체 모델 저장 (권장)
    model.save(f'{model_path}.h5')

    # SavedModel 형식으로도 저장
    model.save(f'{model_path}_savedmodel')

    print(f"Model saved to {model_path}")

def load_model(model_path: str) -> keras.Model:
    """모델 로딩"""

    # 커스텀 객체와 함께 모델 로딩
    custom_objects = {
        'MeshConvLayer': MeshConvLayer,
        'MeshPoolingLayer': MeshPoolingLayer,
        'MedMeshCNN': MedMeshCNN
    }

    model = keras.models.load_model(model_path, custom_objects=custom_objects)
    print(f"Model loaded from {model_path}")

    return model

# 모델 저장
save_model(model, 'medmeshcnn_tensorflow')


In [None]:
# =============================================================================
# Cell 13: 추론 및 활용 예시
# =============================================================================

def predict_single_mesh(model: keras.Model,
                       mesh_features: np.ndarray,
                       class_names: List[str] = None) -> Dict:
    """단일 메쉬에 대한 예측"""

    # 배치 차원 추가
    if len(mesh_features.shape) == 2:
        mesh_features = np.expand_dims(mesh_features, axis=0)

    # 예측 수행
    prediction_proba = model.predict(mesh_features)
    predicted_class = np.argmax(prediction_proba, axis=1)[0]
    confidence = np.max(prediction_proba, axis=1)[0]

    result = {
        'predicted_class': predicted_class,
        'confidence': confidence,
        'probabilities': prediction_proba[0]
    }

    if class_names:
        result['predicted_class_name'] = class_names[predicted_class]

    return result

# 예시 추론
sample_mesh = X_val[0]  # 첫 번째 검증 샘플
class_names = ['Class_0', 'Class_1', 'Class_2']  # 실제 클래스명으로 교체

prediction_result = predict_single_mesh(model, sample_mesh, class_names)
print("Prediction Result:")
print(f"Predicted Class: {prediction_result['predicted_class_name']}")
print(f"Confidence: {prediction_result['confidence']:.4f}")
print(f"All Probabilities: {prediction_result['probabilities']}")


In [None]:
# =============================================================================
# Cell 14: 하이퍼파라미터 튜닝 (선택사항)
# =============================================================================

def hyperparameter_search(X_train, y_train, X_val, y_val):
    """간단한 하이퍼파라미터 검색"""

    param_grid = {
        'conv_channels': [
            [32, 64, 128, 256],
            [64, 128, 256, 512],
            [128, 256, 512, 1024]
        ],
        'dropout_rate': [0.3, 0.5, 0.7],
        'learning_rate': [0.001, 0.0005, 0.0001]
    }

    best_accuracy = 0
    best_params = {}

    for conv_channels in param_grid['conv_channels']:
        for dropout_rate in param_grid['dropout_rate']:
            for lr in param_grid['learning_rate']:

                print(f"Testing: conv_channels={conv_channels}, dropout={dropout_rate}, lr={lr}")

                # 모델 생성
                test_model = build_medmeshcnn_model(num_classes=3)

                # 새로운 옵티마이저로 컴파일
                test_model.compile(
                    optimizer=keras.optimizers.Adam(learning_rate=lr),
                    loss='sparse_categorical_crossentropy',
                    metrics=['accuracy']
                )

                # 빠른 훈련 (적은 에포크)
                history = test_model.fit(
                    X_train, y_train,
                    batch_size=16,
                    epochs=10,
                    validation_data=(X_val, y_val),
                    verbose=0
                )

                # 최고 검증 정확도 확인
                best_val_acc = max(history.history['val_accuracy'])

                if best_val_acc > best_accuracy:
                    best_accuracy = best_val_acc
                    best_params = {
                        'conv_channels': conv_channels,
                        'dropout_rate': dropout_rate,
                        'learning_rate': lr
                    }

                print(f"Best validation accuracy: {best_val_acc:.4f}")

    print(f"\nBest parameters: {best_params}")
    print(f"Best accuracy: {best_accuracy:.4f}")

    return best_params

# 하이퍼파라미터 검색 실행 (시간이 오래 걸릴 수 있음)
# best_params = hyperparameter_search(X_train, y_train, X_val, y_val)

print("MedMeshCNN TensorFlow implementation complete!")
print("All cells are ready to run in Jupyter Lab/Notebook environment.")