# Vesuvius Challenge - Pure PyTorch Implementation
完全なPyTorch実装（TensorFlow依存なし）

In [None]:
import os
import glob
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast, GradScaler
import matplotlib.pyplot as plt
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

In [None]:
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')
if torch.cuda.is_available():
    print(f'GPU: {torch.cuda.get_device_name(0)}')
    print(f'Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB')

## Configuration

In [None]:
class Config:
    # Data
    input_shape = (64, 64, 64)  # (D, H, W) - adjust based on GPU memory
    num_classes = 3
    
    # Training
    batch_size = 2
    epochs = 100
    learning_rate = 3e-4
    weight_decay = 1e-5
    
    # Paths
    data_path = "./data"  # Change to your data path
    checkpoint_path = "best_model_pytorch.pth"
    
    # Augmentation
    use_augmentation = True
    
    # Mixed precision
    use_amp = True
    
    # Sliding window inference
    sw_batch_size = 4
    overlap = 0.5

config = Config()

## Data Augmentation

In [None]:
class Augmentation3D:
    """3D augmentation for volumetric data"""
    
    @staticmethod
    def random_flip(image, label, axis):
        if np.random.rand() > 0.5:
            image = np.flip(image, axis=axis).copy()
            label = np.flip(label, axis=axis).copy()
        return image, label
    
    @staticmethod
    def random_rotate90(image, label):
        if np.random.rand() > 0.5:
            k = np.random.randint(1, 4)
            axes = (1, 2)  # Rotate in H-W plane
            image = np.rot90(image, k, axes).copy()
            label = np.rot90(label, k, axes).copy()
        return image, label
    
    @staticmethod
    def random_intensity_shift(image, shift_range=0.1):
        if np.random.rand() > 0.5:
            shift = np.random.uniform(-shift_range, shift_range)
            image = image + shift
            image = np.clip(image, 0, 1)
        return image
    
    @staticmethod
    def random_crop(image, label, crop_size):
        """Random crop to specified size"""
        d, h, w = image.shape
        cd, ch, cw = crop_size
        
        # Random crop positions
        d_start = np.random.randint(0, max(1, d - cd + 1))
        h_start = np.random.randint(0, max(1, h - ch + 1))
        w_start = np.random.randint(0, max(1, w - cw + 1))
        
        image_crop = image[d_start:d_start+cd, h_start:h_start+ch, w_start:w_start+cw]
        label_crop = label[d_start:d_start+cd, h_start:h_start+ch, w_start:w_start+cw]
        
        return image_crop, label_crop
    
    @staticmethod
    def apply(image, label, crop_size):
        """Apply augmentation pipeline"""
        # Random crop
        image, label = Augmentation3D.random_crop(image, label, crop_size)
        
        # Geometric augmentations
        image, label = Augmentation3D.random_flip(image, label, axis=0)
        image, label = Augmentation3D.random_flip(image, label, axis=1)
        image, label = Augmentation3D.random_flip(image, label, axis=2)
        image, label = Augmentation3D.random_rotate90(image, label)
        
        # Intensity augmentation
        image = Augmentation3D.random_intensity_shift(image)
        
        return image, label

## Dataset

In [None]:
class VesuviusDataset(Dataset):
    """PyTorch Dataset for Vesuvius Challenge"""
    
    def __init__(self, data_files, config, is_train=True):
        self.data_files = data_files
        self.config = config
        self.is_train = is_train
        
    def __len__(self):
        # Assuming each file contains multiple samples
        return len(self.data_files) * 6  # Adjust based on your data
    
    def __getitem__(self, idx):
        # This is a placeholder - replace with actual data loading
        # For demonstration, generating random data
        
        # In real implementation, load from your data files:
        # file_idx = idx // 6
        # sample_idx = idx % 6
        # image, label = self.load_sample(self.data_files[file_idx], sample_idx)
        
        # Generate dummy data (replace with actual loading)
        if self.is_train:
            # Larger volume for cropping
            image = np.random.randn(128, 128, 128).astype(np.float32)
            label = np.random.randint(0, self.config.num_classes, (128, 128, 128))
        else:
            # Fixed size for validation
            image = np.random.randn(*self.config.input_shape).astype(np.float32)
            label = np.random.randint(0, self.config.num_classes, self.config.input_shape)
        
        # Normalize image
        image = (image - image.mean()) / (image.std() + 1e-8)
        
        # Apply augmentation if training
        if self.is_train and self.config.use_augmentation:
            image, label = Augmentation3D.apply(image, label, self.config.input_shape)
        
        # Ensure correct shape
        if image.shape != self.config.input_shape:
            # Center crop or pad if needed
            image = self._resize_volume(image, self.config.input_shape)
            label = self._resize_volume(label, self.config.input_shape)
        
        # Convert to torch tensors
        image = torch.from_numpy(image).unsqueeze(0)  # Add channel dimension
        label = torch.from_numpy(label).long()
        
        return image, label
    
    def _resize_volume(self, volume, target_shape):
        """Resize volume to target shape by cropping or padding"""
        current_shape = volume.shape
        
        # Calculate padding or cropping for each dimension
        pad_or_crop = []
        for cur, tar in zip(current_shape, target_shape):
            diff = tar - cur
            if diff > 0:
                # Padding needed
                pad_before = diff // 2
                pad_after = diff - pad_before
                pad_or_crop.append((pad_before, pad_after))
            else:
                # Cropping needed
                crop_before = (-diff) // 2
                crop_after = cur - (-diff) - crop_before
                pad_or_crop.append((crop_before, crop_after))
        
        # Apply padding or cropping
        if any(p[0] > 0 or p[1] > 0 for p in pad_or_crop):
            # Padding
            volume = np.pad(volume, pad_or_crop, mode='constant', constant_values=0)
        
        # Cropping
        slices = []
        for i, (crop_start, crop_end) in enumerate(pad_or_crop):
            if crop_start < 0 or crop_end < 0:
                slices.append(slice(abs(min(crop_start, 0)), 
                                   volume.shape[i] - abs(min(crop_end, 0))))
            else:
                slices.append(slice(None))
        
        return volume[tuple(slices)]

## Model Architecture

In [None]:
class ConvBlock3D(nn.Module):
    """Basic 3D convolutional block"""
    def __init__(self, in_channels, out_channels, kernel_size=3, padding=1):
        super().__init__()
        self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size, padding=padding)
        self.bn1 = nn.BatchNorm3d(out_channels)
        self.conv2 = nn.Conv3d(out_channels, out_channels, kernel_size, padding=padding)
        self.bn2 = nn.BatchNorm3d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        
    def forward(self, x):
        x = self.relu(self.bn1(self.conv1(x)))
        x = self.relu(self.bn2(self.conv2(x)))
        return x

class UNet3D(nn.Module):
    """3D U-Net for volumetric segmentation"""
    def __init__(self, in_channels=1, num_classes=3, features=[32, 64, 128, 256]):
        super().__init__()
        self.pool = nn.MaxPool3d(kernel_size=2, stride=2)
        self.up = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=False)
        
        # Encoder
        self.encoder1 = ConvBlock3D(in_channels, features[0])
        self.encoder2 = ConvBlock3D(features[0], features[1])
        self.encoder3 = ConvBlock3D(features[1], features[2])
        self.encoder4 = ConvBlock3D(features[2], features[3])
        
        # Bottleneck
        self.bottleneck = ConvBlock3D(features[3], features[3] * 2)
        
        # Decoder
        self.decoder4 = ConvBlock3D(features[3] * 2 + features[3], features[3])
        self.decoder3 = ConvBlock3D(features[3] + features[2], features[2])
        self.decoder2 = ConvBlock3D(features[2] + features[1], features[1])
        self.decoder1 = ConvBlock3D(features[1] + features[0], features[0])
        
        # Final layer
        self.final = nn.Conv3d(features[0], num_classes, kernel_size=1)
        
    def forward(self, x):
        # Encoder
        e1 = self.encoder1(x)
        p1 = self.pool(e1)
        
        e2 = self.encoder2(p1)
        p2 = self.pool(e2)
        
        e3 = self.encoder3(p2)
        p3 = self.pool(e3)
        
        e4 = self.encoder4(p3)
        p4 = self.pool(e4)
        
        # Bottleneck
        b = self.bottleneck(p4)
        
        # Decoder
        d4 = self.up(b)
        d4 = torch.cat([d4, e4], dim=1)
        d4 = self.decoder4(d4)
        
        d3 = self.up(d4)
        d3 = torch.cat([d3, e3], dim=1)
        d3 = self.decoder3(d3)
        
        d2 = self.up(d3)
        d2 = torch.cat([d2, e2], dim=1)
        d2 = self.decoder2(d2)
        
        d1 = self.up(d2)
        d1 = torch.cat([d1, e1], dim=1)
        d1 = self.decoder1(d1)
        
        return self.final(d1)

## Loss Functions

In [None]:
class DiceLoss(nn.Module):
    """Dice loss for segmentation"""
    def __init__(self, num_classes=3, smooth=1e-5, ignore_index=None):
        super().__init__()
        self.num_classes = num_classes
        self.smooth = smooth
        self.ignore_index = ignore_index
        
    def forward(self, pred, target):
        pred = torch.softmax(pred, dim=1)
        
        # Create one-hot encoding
        target_one_hot = F.one_hot(target, self.num_classes)
        target_one_hot = target_one_hot.permute(0, 4, 1, 2, 3).float()
        
        # Calculate dice for each class
        dice_scores = []
        for i in range(self.num_classes):
            if i == self.ignore_index:
                continue
                
            pred_i = pred[:, i]
            target_i = target_one_hot[:, i]
            
            intersection = (pred_i * target_i).sum(dim=(1, 2, 3))
            union = pred_i.sum(dim=(1, 2, 3)) + target_i.sum(dim=(1, 2, 3))
            
            dice = (2. * intersection + self.smooth) / (union + self.smooth)
            dice_scores.append(dice)
        
        dice_scores = torch.stack(dice_scores, dim=1)
        return 1 - dice_scores.mean()

class CombinedLoss(nn.Module):
    """Combined Dice + CrossEntropy loss"""
    def __init__(self, num_classes=3, dice_weight=0.5, ce_weight=0.5, ignore_index=2):
        super().__init__()
        self.dice_loss = DiceLoss(num_classes, ignore_index=ignore_index)
        self.ce_loss = nn.CrossEntropyLoss(ignore_index=ignore_index)
        self.dice_weight = dice_weight
        self.ce_weight = ce_weight
        
    def forward(self, pred, target):
        dice = self.dice_loss(pred, target)
        ce = self.ce_loss(pred, target)
        return self.dice_weight * dice + self.ce_weight * ce

## Metrics

In [None]:
class DiceMetric:
    """Calculate Dice score for evaluation"""
    def __init__(self, num_classes=3, ignore_index=2):
        self.num_classes = num_classes
        self.ignore_index = ignore_index
        self.reset()
        
    def reset(self):
        self.intersection = np.zeros(self.num_classes)
        self.union = np.zeros(self.num_classes)
        
    def update(self, pred, target):
        pred = torch.argmax(pred, dim=1).cpu().numpy()
        target = target.cpu().numpy()
        
        for i in range(self.num_classes):
            if i == self.ignore_index:
                continue
                
            pred_i = (pred == i).astype(np.float32)
            target_i = (target == i).astype(np.float32)
            
            self.intersection[i] += np.sum(pred_i * target_i)
            self.union[i] += np.sum(pred_i) + np.sum(target_i)
    
    def compute(self):
        dice_scores = []
        for i in range(self.num_classes):
            if i == self.ignore_index:
                continue
            if self.union[i] == 0:
                dice_scores.append(1.0)
            else:
                dice = (2 * self.intersection[i]) / self.union[i]
                dice_scores.append(dice)
        return np.mean(dice_scores)

## Sliding Window Inference

In [None]:
class SlidingWindowInference:
    """Sliding window inference for large volumes"""
    
    def __init__(self, roi_size, sw_batch_size=4, overlap=0.5, mode='gaussian'):
        self.roi_size = roi_size
        self.sw_batch_size = sw_batch_size
        self.overlap = overlap
        self.mode = mode
        
    def __call__(self, model, image):
        """Run sliding window inference"""
        device = next(model.parameters()).device
        
        # Get image dimensions
        batch_size, channels, d, h, w = image.shape
        
        # Calculate stride
        stride = [int(r * (1 - self.overlap)) for r in self.roi_size]
        
        # Calculate number of windows
        num_windows_d = max(1, int(np.ceil((d - self.roi_size[0]) / stride[0] + 1)))
        num_windows_h = max(1, int(np.ceil((h - self.roi_size[1]) / stride[1] + 1)))
        num_windows_w = max(1, int(np.ceil((w - self.roi_size[2]) / stride[2] + 1)))
        
        # Initialize output
        num_classes = config.num_classes
        output = torch.zeros(batch_size, num_classes, d, h, w).to(device)
        count_map = torch.zeros(batch_size, 1, d, h, w).to(device)
        
        # Generate gaussian importance map if needed
        if self.mode == 'gaussian':
            importance_map = self._get_gaussian_map(self.roi_size).to(device)
        else:
            importance_map = torch.ones(1, 1, *self.roi_size).to(device)
        
        # Sliding window
        for d_idx in range(num_windows_d):
            for h_idx in range(num_windows_h):
                for w_idx in range(num_windows_w):
                    # Calculate window position
                    d_start = min(d_idx * stride[0], d - self.roi_size[0])
                    h_start = min(h_idx * stride[1], h - self.roi_size[1])
                    w_start = min(w_idx * stride[2], w - self.roi_size[2])
                    
                    d_end = d_start + self.roi_size[0]
                    h_end = h_start + self.roi_size[1]
                    w_end = w_start + self.roi_size[2]
                    
                    # Extract window
                    window = image[:, :, d_start:d_end, h_start:h_end, w_start:w_end]
                    
                    # Run inference
                    with torch.no_grad():
                        window_output = model(window)
                        window_output = torch.softmax(window_output, dim=1)
                    
                    # Add to output with importance weighting
                    output[:, :, d_start:d_end, h_start:h_end, w_start:w_end] += \
                        window_output * importance_map
                    count_map[:, :, d_start:d_end, h_start:h_end, w_start:w_end] += \
                        importance_map
        
        # Normalize by count map
        output = output / count_map.clamp(min=1e-5)
        
        return output
    
    def _get_gaussian_map(self, size):
        """Generate gaussian importance map"""
        gaussians = []
        for s in size:
            gaussian_1d = torch.exp(-4 * torch.linspace(-1, 1, s) ** 2)
            gaussians.append(gaussian_1d)
        
        gaussian_3d = gaussians[0].unsqueeze(-1).unsqueeze(-1) * \
                      gaussians[1].unsqueeze(0).unsqueeze(-1) * \
                      gaussians[2].unsqueeze(0).unsqueeze(0)
        
        return gaussian_3d.unsqueeze(0).unsqueeze(0)

## Training Functions

In [None]:
def train_epoch(model, loader, criterion, optimizer, device, scaler=None):
    """Train for one epoch"""
    model.train()
    total_loss = 0
    
    progress_bar = tqdm(loader, desc='Training')
    for batch_idx, (images, labels) in enumerate(progress_bar):
        images = images.to(device)
        labels = labels.to(device)
        
        optimizer.zero_grad()
        
        if scaler is not None:
            # Mixed precision training
            with autocast():
                outputs = model(images)
                loss = criterion(outputs, labels)
            
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            # Normal training
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
        
        total_loss += loss.item()
        progress_bar.set_postfix({'loss': loss.item()})
    
    return total_loss / len(loader)

def validate_epoch(model, loader, criterion, metric, device):
    """Validate for one epoch"""
    model.eval()
    total_loss = 0
    metric.reset()
    
    with torch.no_grad():
        progress_bar = tqdm(loader, desc='Validation')
        for images, labels in progress_bar:
            images = images.to(device)
            labels = labels.to(device)
            
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            total_loss += loss.item()
            metric.update(outputs, labels)
            
            progress_bar.set_postfix({'loss': loss.item()})
    
    avg_loss = total_loss / len(loader)
    dice_score = metric.compute()
    
    return avg_loss, dice_score

## Main Training Loop

In [None]:
def train():
    """Main training function"""
    
    # Create dummy data files (replace with actual data)
    train_files = [f"train_{i}" for i in range(10)]  # Replace with actual file paths
    val_files = [f"val_{i}" for i in range(2)]  # Replace with actual file paths
    
    # Create datasets
    train_dataset = VesuviusDataset(train_files, config, is_train=True)
    val_dataset = VesuviusDataset(val_files, config, is_train=False)
    
    # Create data loaders
    train_loader = DataLoader(
        train_dataset, 
        batch_size=config.batch_size, 
        shuffle=True,
        num_workers=4,
        pin_memory=True
    )
    val_loader = DataLoader(
        val_dataset, 
        batch_size=config.batch_size, 
        shuffle=False,
        num_workers=4,
        pin_memory=True
    )
    
    # Create model
    model = UNet3D(in_channels=1, num_classes=config.num_classes)
    model = model.to(device)
    
    # Count parameters
    total_params = sum(p.numel() for p in model.parameters())
    print(f"Total parameters: {total_params/1e6:.2f}M")
    
    # Create loss and optimizer
    criterion = CombinedLoss(num_classes=config.num_classes)
    optimizer = optim.AdamW(model.parameters(), 
                           lr=config.learning_rate, 
                           weight_decay=config.weight_decay)
    
    # Learning rate scheduler
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=config.epochs)
    
    # Mixed precision scaler
    scaler = GradScaler() if config.use_amp else None
    
    # Metric
    metric = DiceMetric(num_classes=config.num_classes)
    
    # Training history
    history = {
        'train_loss': [],
        'val_loss': [],
        'val_dice': []
    }
    
    best_dice = 0
    
    # Training loop
    for epoch in range(config.epochs):
        print(f"\nEpoch {epoch+1}/{config.epochs}")
        print("-" * 50)
        
        # Train
        train_loss = train_epoch(model, train_loader, criterion, optimizer, device, scaler)
        
        # Validate
        val_loss, val_dice = validate_epoch(model, val_loader, criterion, metric, device)
        
        # Update scheduler
        scheduler.step()
        
        # Save history
        history['train_loss'].append(train_loss)
        history['val_loss'].append(val_loss)
        history['val_dice'].append(val_dice)
        
        # Print metrics
        print(f"Train Loss: {train_loss:.4f}")
        print(f"Val Loss: {val_loss:.4f}")
        print(f"Val Dice: {val_dice:.4f}")
        print(f"LR: {scheduler.get_last_lr()[0]:.6f}")
        
        # Save best model
        if val_dice > best_dice:
            best_dice = val_dice
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'best_dice': best_dice,
            }, config.checkpoint_path)
            print(f"Saved best model with Dice: {best_dice:.4f}")
    
    return model, history

In [None]:
# Run training
model, history = train()

## Visualization

In [None]:
# Plot training history
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# Loss plot
axes[0].plot(history['train_loss'], label='Train Loss')
axes[0].plot(history['val_loss'], label='Val Loss')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Training and Validation Loss')
axes[0].legend()
axes[0].grid(True)

# Dice plot
axes[1].plot(history['val_dice'], label='Val Dice', color='green')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Dice Score')
axes[1].set_title('Validation Dice Score')
axes[1].legend()
axes[1].grid(True)

plt.tight_layout()
plt.show()

## Inference Example

In [None]:
def inference_with_sliding_window(model, image_volume):
    """Run inference with sliding window on a large volume"""
    model.eval()
    
    # Create sliding window inference
    swi = SlidingWindowInference(
        roi_size=config.input_shape,
        sw_batch_size=config.sw_batch_size,
        overlap=config.overlap
    )
    
    # Prepare input
    if isinstance(image_volume, np.ndarray):
        image_volume = torch.from_numpy(image_volume).float()
    
    # Add batch and channel dimensions if needed
    if image_volume.ndim == 3:
        image_volume = image_volume.unsqueeze(0).unsqueeze(0)
    elif image_volume.ndim == 4:
        image_volume = image_volume.unsqueeze(0)
    
    image_volume = image_volume.to(device)
    
    # Run inference
    with torch.no_grad():
        output = swi(model, image_volume)
    
    # Get predictions
    predictions = torch.argmax(output, dim=1)
    
    return predictions.cpu().numpy()

In [None]:
# Example inference on a test volume
test_volume = np.random.randn(128, 128, 128)  # Replace with actual test data
predictions = inference_with_sliding_window(model, test_volume)
print(f"Predictions shape: {predictions.shape}")
print(f"Unique classes: {np.unique(predictions)}")

In [None]:
def visualize_slice(volume, prediction, slice_idx=None):
    """Visualize a slice from the volume and prediction"""
    if slice_idx is None:
        slice_idx = volume.shape[0] // 2
    
    fig, axes = plt.subplots(1, 2, figsize=(10, 5))
    
    axes[0].imshow(volume[slice_idx], cmap='gray')
    axes[0].set_title(f'Input (Slice {slice_idx})')
    axes[0].axis('off')
    
    axes[1].imshow(prediction[0, slice_idx], cmap='jet', vmin=0, vmax=2)
    axes[1].set_title(f'Prediction (Slice {slice_idx})')
    axes[1].axis('off')
    
    plt.tight_layout()
    plt.show()

# Visualize result
visualize_slice(test_volume, predictions)

## Save Final Model

In [None]:
# Save final model
torch.save(model.state_dict(), 'final_model_pytorch.pth')
print("Model saved to final_model_pytorch.pth")