# Training a Swin transformer model to predict layer annotations

import h5py
import numpy as np
import tensorflow as tf
from transformers import TFSwinModel
import datetime
from sklearn.model_selection import train_test_split
print(tf.__version__)

In [2]:
import h5py
import numpy as np
import tensorflow as tf
from transformers import TFSwinModel
import datetime
from sklearn.model_selection import train_test_split
print(tf.__version__)

2025-07-16 16:42:17.196672: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1752698537.209644   19734 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1752698537.214017   19734 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-07-16 16:42:17.227286: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: SSE4.1 SSE4.2 AVX AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
  from .autonotebook import tqdm as notebook_tqdm


2.18.1


In [3]:
with h5py.File('/home/suraj/Git/SCR-Progression/Duke_Control_processed.h5', 'r') as f:
    images = f['images'][:]  # shape: (N, 224, 224)
    layer_maps = f['layer_maps'][:]  # shape: (N, 224, 2) or (N, 224, 3)

In [4]:
# add another dimension to images for compatibility
if images.ndim == 3:
    images = np.expand_dims(images, axis=-1)

# We only want ILM and BM (first and last columns) for training
layer_maps = layer_maps[:, :, [0, 2]]  # if shape is (N, 224, 3)

In [5]:
#Testing with only 1000 samples
images = images[:1000]
layer_maps = layer_maps[:1000]

In [6]:
X_train, X_test, y_train, y_test = train_test_split(
    images, layer_maps, test_size=0.2, random_state=42
)

In [None]:
train_dataset = tf.data.Dataset.from_tensor_slices((X_train, y_train)).batch(16).shuffle(100)
test_dataset = tf.data.Dataset.from_tensor_slices((X_test, y_test)).batch(16)

In [None]:
# Swin Transformer approach for layer annotation prediction
print("Setting up Swin Transformer approach...")

def extract_swin_features(images, base_model, batch_size=8):
    """
    Extract features from images using Swin Transformer
    """
    features = []
    
    for i in range(0, len(images), batch_size):
        batch = images[i:i+batch_size]
        
        # Convert grayscale to RGB (repeat channels)
        if batch.shape[-1] == 1:
            batch_rgb = np.repeat(batch, 3, axis=-1)
        else:
            batch_rgb = batch
        
        # Extract features using Swin
        outputs = base_model({"pixel_values": batch_rgb})
        batch_features = outputs.last_hidden_state.numpy()
        features.extend(batch_features)
        
        if i % 100 == 0:
            print(f"Processed {i}/{len(images)} images")
    
    return np.array(features)

# Load Swin-Tiny base model for feature extraction
swin_model = TFSwinModel.from_pretrained("microsoft/swin-tiny-patch4-window7-224")
print("Swin model loaded successfully!")

In [None]:
# Extract features from your training and test data
print("Extracting features from training data...")
X_train_features = extract_swin_features(X_train, swin_model, batch_size=8)

print("Extracting features from test data...")
X_test_features = extract_swin_features(X_test, swin_model, batch_size=8)

print(f"Training features shape: {X_train_features.shape}")
print(f"Test features shape: {X_test_features.shape}")
print("Feature extraction completed!")

In [None]:
# Create regression head for Swin features
def create_swin_regression_head(feature_shape):
    """
    Create a regression head that takes Swin features and predicts layer coordinates
    """
    input_layer = tf.keras.layers.Input(shape=feature_shape)
    
    # Global average pooling to reduce sequence dimension
    x = tf.keras.layers.GlobalAveragePooling1D()(input_layer)
    
    # Dense layers for regression
    x = tf.keras.layers.Dense(1024, activation='relu')(x)
    x = tf.keras.layers.Dropout(0.3)(x)
    x = tf.keras.layers.Dense(512, activation='relu')(x)
    x = tf.keras.layers.Dropout(0.2)(x)
    x = tf.keras.layers.Dense(256, activation='relu')(x)
    
    # Output layer: predict 224 points for 2 layers (ILM and BM)
    x = tf.keras.layers.Dense(224 * 2, activation='linear')(x)
    output = tf.keras.layers.Reshape((224, 2))(x)
    
    return tf.keras.Model(inputs=input_layer, outputs=output)

# Create and compile Swin regression model
swin_regression_model = create_swin_regression_head(X_train_features.shape[1:])
swin_regression_model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
    loss=tf.keras.losses.MeanSquaredError(),
    metrics=[tf.keras.metrics.MeanAbsoluteError()]
)

print("Swin regression model created successfully!")
print(f"Input feature shape: {X_train_features.shape[1:]}")
print(f"Output shape: (224, 2)")
swin_regression_model.summary()

In [None]:
# Create datasets for Swin features
swin_train_dataset = tf.data.Dataset.from_tensor_slices((X_train_features, y_train)).batch(16).shuffle(100)
swin_test_dataset = tf.data.Dataset.from_tensor_slices((X_test_features, y_test)).batch(16)

# TensorBoard callback for Swin model
swin_log_dir = "logs/swin_fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
swin_tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=swin_log_dir, histogram_freq=1)

print("Starting Swin model training...")
# Train Swin regression model
swin_history = swin_regression_model.fit(
    swin_train_dataset,
    validation_data=swin_test_dataset, 
    epochs=15,  # Slightly more epochs since we have rich features
    callbacks=[swin_tensorboard_callback]
)

In [None]:
# Evaluate Swin model
swin_test_loss = swin_regression_model.evaluate(swin_test_dataset)
print(f"Swin Test MSE: {swin_test_loss[0]:.6f}")
print(f"Swin Test MAE: {swin_test_loss[1]:.6f}")

# Compare with CNN model performance
print("\n=== Model Comparison ===")
print(f"CNN Test MSE: {test_loss[0]:.6f}")
print(f"CNN Test MAE: {test_loss[1]:.6f}")
print(f"Swin Test MSE: {swin_test_loss[0]:.6f}")  
print(f"Swin Test MAE: {swin_test_loss[1]:.6f}")

# Save Swin model
swin_regression_model.save('Swin_regression_model.h5')
print("Swin model saved as 'Swin_regression_model.h5'")

In [None]:
import matplotlib.pyplot as plt

def compare_model_predictions(image_idx=0):
    """
    Compare predictions from CNN and Swin models on a test image
    """
    # Get a test image and ground truth
    test_image = X_test[image_idx:image_idx+1]
    test_features = X_test_features[image_idx:image_idx+1] 
    true_layers = y_test[image_idx]
    
    # Get predictions from both models
    cnn_pred = model.predict(test_image)[0]
    swin_pred = swin_regression_model.predict(test_features)[0]
    
    # Create visualization
    fig, axes = plt.subplots(1, 3, figsize=(18, 6))
    
    # Show original image with ground truth
    axes[0].imshow(test_image[0, :, :, 0], cmap='gray')
    axes[0].plot(range(224), true_layers[:, 0], 'g-', linewidth=2, label='True ILM')
    axes[0].plot(range(224), true_layers[:, 1], 'b-', linewidth=2, label='True BM')
    axes[0].set_title('Ground Truth')
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)
    
    # Show CNN predictions
    axes[1].imshow(test_image[0, :, :, 0], cmap='gray')
    axes[1].plot(range(224), cnn_pred[:, 0], 'r--', linewidth=2, label='CNN ILM')
    axes[1].plot(range(224), cnn_pred[:, 1], 'm--', linewidth=2, label='CNN BM')
    axes[1].plot(range(224), true_layers[:, 0], 'g-', linewidth=1, alpha=0.7, label='True ILM')
    axes[1].plot(range(224), true_layers[:, 1], 'b-', linewidth=1, alpha=0.7, label='True BM')
    axes[1].set_title('CNN Predictions')
    axes[1].legend()
    axes[1].grid(True, alpha=0.3)
    
    # Show Swin predictions
    axes[2].imshow(test_image[0, :, :, 0], cmap='gray')
    axes[2].plot(range(224), swin_pred[:, 0], 'r--', linewidth=2, label='Swin ILM')
    axes[2].plot(range(224), swin_pred[:, 1], 'm--', linewidth=2, label='Swin BM')
    axes[2].plot(range(224), true_layers[:, 0], 'g-', linewidth=1, alpha=0.7, label='True ILM')
    axes[2].plot(range(224), true_layers[:, 1], 'b-', linewidth=1, alpha=0.7, label='True BM')
    axes[2].set_title('Swin Predictions')
    axes[2].legend()
    axes[2].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    # Calculate and print errors
    cnn_mae_ilm = np.mean(np.abs(cnn_pred[:, 0] - true_layers[:, 0]))
    cnn_mae_bm = np.mean(np.abs(cnn_pred[:, 1] - true_layers[:, 1]))
    swin_mae_ilm = np.mean(np.abs(swin_pred[:, 0] - true_layers[:, 0]))
    swin_mae_bm = np.mean(np.abs(swin_pred[:, 1] - true_layers[:, 1]))
    
    print(f"\nSample {image_idx} Error Analysis:")
    print(f"CNN  - ILM MAE: {cnn_mae_ilm:.4f}, BM MAE: {cnn_mae_bm:.4f}")
    print(f"Swin - ILM MAE: {swin_mae_ilm:.4f}, BM MAE: {swin_mae_bm:.4f}")

# Compare predictions on a few test samples
print("Comparing model predictions...")
compare_model_predictions(0)
compare_model_predictions(5)
compare_model_predictions(10)