In [None]:
import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.preprocessing import StandardScaler
from sklearn.manifold import TSNE
from tqdm import tqdm
import timm
import platform

In [None]:
# Check for MPS availability (Apple Silicon)
if hasattr(torch, 'backends') and hasattr(torch.backends, 'mps'):
    mps_available = torch.backends.mps.is_available()
else:
    mps_available = False

# Set random seeds for reproducibility
np.random.seed(42)
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)

In [None]:
# Define paths
SEGMENTED_ROOT = "./dataset/original"
OUTPUT_DIR = "2_Feature_Extraction/BEiT"
CLASSES = ["normal", "oscc"]
feature_extract_file_name = "feature_extract_BEiT"

# Create output directory
os.makedirs(OUTPUT_DIR, exist_ok=True)

# Self-Attention Module

In [None]:
# Define Attention Module
class SelfAttention(nn.Module):
    def __init__(self, in_channels):
        super(SelfAttention, self).__init__()
        self.query = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)
        self.key = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)
        self.value = nn.Conv2d(in_channels, in_channels, kernel_size=1)
        self.gamma = nn.Parameter(torch.zeros(1))

    def forward(self, x):
        batch_size, C, H, W = x.size()

        # Reshape and permute query and key for matrix multiplication
        proj_query = self.query(x).view(batch_size, -1, H * W).permute(0, 2, 1)
        proj_key = self.key(x).view(batch_size, -1, H * W)

        # Calculate attention map
        energy = torch.bmm(proj_query, proj_key)
        attention = F.softmax(energy, dim=2)

        # Apply attention to value
        proj_value = self.value(x).view(batch_size, -1, H * W)
        out = torch.bmm(proj_value, attention.permute(0, 2, 1))
        out = out.view(batch_size, C, H, W)

        # Apply gamma parameter and residual connection
        out = self.gamma * out + x
        return out

In [None]:
class SelfAttentionForSwin(nn.Module):
    def __init__(self, dim):
        super(SelfAttentionForSwin, self).__init__()
        self.dim = dim
        self.query = nn.Linear(dim, dim)
        self.key = nn.Linear(dim, dim)
        self.value = nn.Linear(dim, dim)
        self.softmax = nn.Softmax(dim=-1)
        
    def forward(self, x):
        # Handle different tensor shapes flexibly
        if len(x.shape) == 2:
            # If already pooled: [batch_size, channels]
            batch_size, c = x.shape
            # Add a dummy sequence dimension
            x = x.unsqueeze(1)  # [batch_size, 1, channels]
            
        elif len(x.shape) == 3:
            # Expected format: [batch_size, seq_len, channels]
            batch_size, seq_len, c = x.shape
            
        elif len(x.shape) == 4:
            # If 4D: [batch_size, height, width, channels] 
            batch_size, h, w, c = x.shape
            x = x.view(batch_size, h * w, c)  # Flatten to sequence
            
        else:
            raise ValueError(f"Unexpected tensor shape: {x.shape}")
        
        Q = self.query(x)  # [batch_size, seq_len, dim]
        K = self.key(x)    # [batch_size, seq_len, dim]
        V = self.value(x)  # [batch_size, seq_len, dim]
        
        # Compute attention scores
        attention_scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.dim ** 0.5)
        attention_weights = self.softmax(attention_scores)
        
        # Apply attention
        attended = torch.matmul(attention_weights, V)
        
        return attended + x

# Modified ConvNeXt Model with Attention

In [None]:
# Define Modified ConvNeXt Model with Attention
class ModifiedConvNeXt(nn.Module):
    def __init__(self, output_dim=2048):
        super(ModifiedConvNeXt, self).__init__()
        # Load pretrained ConvNeXt Large (instead of tiny)
        self.base_model = timm.create_model('convnext_large', pretrained=True)

        # Remove the classification head
        self.base_model.head = nn.Identity()

        # Get the feature dimension
        feature_dim = 1536

        # Add attention layer at various stages
        self.attention1 = SelfAttention(feature_dim)

        # New layers for feature extraction
        self.feature_extractor = nn.Sequential(
            nn.Linear(feature_dim, output_dim),
            nn.ReLU(),
            nn.BatchNorm1d(output_dim)
        )

    def forward(self, x):
        # Extract features from ConvNeXt
        x = self.base_model.stem(x)
        x = self.base_model.stages[0](x)
        x = self.base_model.stages[1](x)
        x = self.base_model.stages[2](x)
        x = self.base_model.stages[3](x)

        # Apply attention
        x = self.attention1(x)

        # Global pooling
        x = self.base_model.norm_pre(x)
        x = x.mean([-2, -1])  # Global average pooling

        # Extract features
        features = self.feature_extractor(x)

        return features

# Modified SwinTransformer Model

In [None]:
class ModifiedSwinTransformer(nn.Module):
    def __init__(self, output_dim=2048):
        super(ModifiedSwinTransformer, self).__init__()
        # Load pretrained Swin Transformer Large (224x224 input)
        self.base_model = timm.create_model('swin_large_patch4_window7_224', pretrained=True)
        
        # Remove the classification head
        self.base_model.head = nn.Identity()
        
        # Get the feature dimension (Swin Large has 1536 features)
        feature_dim = 1536
        
        # Add attention layer for the pooled features
        self.attention1 = SelfAttentionForSwin(feature_dim)
        
        # New layers for feature extraction
        self.feature_extractor = nn.Sequential(
            nn.Linear(feature_dim, output_dim),
            nn.ReLU(),
            nn.BatchNorm1d(output_dim)
        )
    
    def forward(self, x):
        # Use the base model's forward_features method to get features before classification
        features = self.base_model.forward_features(x)  # Let's see what shape this returns
        
        # Apply attention to the features
        features = self.attention1(features)
        
        # Ensure we have the right shape for pooling
        if len(features.shape) == 3:
            # [B, seq_len, C] -> [B, C]
            features = features.mean(dim=1)
        elif len(features.shape) == 2:
            # Already [B, C]
            pass
        else:
            # Handle other cases by flattening appropriately
            features = features.view(features.size(0), -1)
        
        # Extract final features
        final_features = self.feature_extractor(features)
        
        return final_features

# Modified BEiT Model

In [None]:
class CustomBEiT(nn.Module):
    def __init__(self, output_dim=1024):
        super(CustomBEiT, self).__init__()
        # Load pretrained BEiT
        self.backbone = timm.create_model('beit_large_patch16_224', pretrained=True)
        self.backbone.head = nn.Identity()  # Remove classification head

        # Get feature dimension (BEiT Large has 1024 features)
        backbone_dim = 1024

        # Custom multi-head attention for feature enhancement
        self.multihead_attention = nn.MultiheadAttention(backbone_dim, num_heads=8, batch_first=True)
        self.norm1 = nn.LayerNorm(backbone_dim)

        # Enhanced feature extraction with residual connections
        self.feature_extractor = nn.Sequential(
            nn.Linear(backbone_dim, 1536),
            nn.GELU(),
            nn.LayerNorm(1536),
            nn.Dropout(0.3),

            nn.Linear(1536, 1280),
            nn.GELU(),
            nn.LayerNorm(1280),
            nn.Dropout(0.2),

            nn.Linear(1280, output_dim),
            nn.GELU(),
            nn.LayerNorm(output_dim)
        )

    def forward(self, x):
        # Extract features from BEiT
        features = self.backbone.forward_features(x)

        # BEiT outputs sequence of patches, take CLS token (first token)
        if len(features.shape) == 3:  # [B, N, D] where N is number of patches
            cls_token = features[:, 0, :]  # Take CLS token
        else:
            cls_token = features.mean(dim=[2, 3])  # Global average pooling if needed

        # Apply multi-head attention for feature enhancement
        attended_features, _ = self.multihead_attention(cls_token.unsqueeze(1), cls_token.unsqueeze(1),
                                                        cls_token.unsqueeze(1))
        attended_features = self.norm1(attended_features.squeeze(1) + cls_token)

        # Extract final features
        features = self.feature_extractor(attended_features)

        return features

In [None]:
# Define Factorized Convolution for Dimensionality Reduction
class FactorizedConv(nn.Module):
    def __init__(self, in_dim, out_dim):
        super(FactorizedConv, self).__init__()

        # Two-step dimensionality reduction
        self.layers = nn.Sequential(
            nn.Linear(in_dim, in_dim // 2),
            nn.ReLU(),
            nn.BatchNorm1d(in_dim // 2),
            nn.Linear(in_dim // 2, out_dim),
            nn.ReLU(),
            nn.BatchNorm1d(out_dim)
        )

    def forward(self, x):
        return self.layers(x)

In [None]:
# Define Dataset class
class OralCancerDataset(Dataset):
    def __init__(self, root_dir, classes, transform=None):
        self.root_dir = root_dir
        self.classes = classes
        self.transform = transform
        self.image_paths = []
        self.labels = []

        # Collect image paths and labels
        for class_idx, class_name in enumerate(classes):
            class_dir = os.path.join(root_dir, class_name)
            if os.path.exists(class_dir):
                for img_name in os.listdir(class_dir):
                    if img_name.endswith(('.png', '.jpg', '.jpeg')):
                        self.image_paths.append(os.path.join(class_dir, img_name))
                        self.labels.append(class_idx)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('RGB')
        label = self.labels[idx]

        if self.transform:
            image = self.transform(image)

        return image, label, img_path

In [None]:
# Define transforms
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [None]:
# Create dataset and dataloader
dataset = OralCancerDataset(SEGMENTED_ROOT, CLASSES, transform=transform)
# Use fewer workers on macOS to avoid potential issues
num_workers = 0 if platform.system() == 'Darwin' else 4
dataloader = DataLoader(dataset, batch_size=16, shuffle=False, num_workers=num_workers)

# Print dataset size
print(f"Total number of images: {len(dataset)}")
for class_idx, class_name in enumerate(CLASSES):
    count = dataset.labels.count(class_idx)
    print(f"Class {class_name}: {count} images")

# Initialize model
if mps_available:
    device = torch.device("mps")
    print("Using MPS (Apple Silicon GPU)")
elif torch.cuda.is_available():
    device = torch.device("cuda")
    print("Using CUDA GPU")
else:
    device = torch.device("cpu")
    print("Using CPU")

print(f"Device: {device}")

In [None]:
model = CustomBEiT(output_dim=1024)
model = model.to(device)
model.eval()

# Print model architecture
print("=" * 50)
print("MODIFIED CONVNEXT MODEL ARCHITECTURE")
print("=" * 50)
print(model)

# Save model architecture to text file
model_arch_path = os.path.join(OUTPUT_DIR, f"{feature_extract_file_name}_architecture.txt")
with open(model_arch_path, 'w') as f:
    f.write("=" * 50 + "\n")
    f.write("MODIFIED CONVNEXT MODEL ARCHITECTURE\n")
    f.write("=" * 50 + "\n")
    f.write(str(model) + "\n\n")

    # Print model parameters and save to file
    f.write("=" * 50 + "\n")
    f.write("MODEL WEIGHTS SUMMARY\n")
    f.write("=" * 50 + "\n")
    total_params = 0
    for name, param in model.named_parameters():
        f.write(f"Layer: {name}\n")
        f.write(f"Shape: {param.shape}\n")
        f.write(f"Parameters: {param.numel()}\n")
        f.write("-" * 30 + "\n")
        total_params += param.numel()

    f.write(f"Total parameters: {total_params:,}\n")
    f.write(f"Total parameters (M): {total_params / 1e6:.2f}M\n")

print(f"Model architecture and weights saved to {model_arch_path}")

In [None]:
# Initialize arrays to store features and labels
all_features = []
all_labels = []
all_filenames = []

# Extract features
with torch.no_grad():
    for images, labels, img_paths in tqdm(dataloader, desc="Extracting features"):
        images = images.to(device)

        # Extract features using the model
        features = model(images)

        # Store features and labels
        all_features.append(features.cpu().numpy())
        all_labels.append(labels.numpy())
        all_filenames.extend([os.path.basename(path) for path in img_paths])

# Concatenate all batches
all_features = np.vstack(all_features)
all_labels = np.concatenate(all_labels)

# Create DataFrame with features
feature_cols = [f'{i + 1}' for i in range(all_features.shape[1])]  # Just use numbers instead of 'feature_X'
df = pd.DataFrame(all_features, columns=feature_cols)
df['label'] = all_labels
df['class'] = [CLASSES[label] for label in all_labels]
df['filename'] = all_filenames

# Save features to CSV
csv_path = os.path.join(OUTPUT_DIR, f"{feature_extract_file_name}_features.csv")
df.to_csv(csv_path, index=False)
print(f"Features saved to {csv_path}")

In [None]:
feature_extractor_path = os.path.join(OUTPUT_DIR, f"{feature_extract_file_name}.pt")
torch.save(model.state_dict(), feature_extractor_path)
print(f"Feature extractor model saved to {feature_extractor_path}")

In [None]:
# Set font to Times New Roman with larger font sizes for all plots
plt.rcParams["font.family"] = "Times New Roman"
plt.rcParams["font.size"] = 18  # Base font size increased
plt.rcParams["axes.titlesize"] = 26  # Larger title font
plt.rcParams["axes.labelsize"] = 22  # Larger axis labels
plt.rcParams["xtick.labelsize"] = 18  # Larger tick labels
plt.rcParams["ytick.labelsize"] = 18  # Larger tick labels
plt.rcParams["legend.fontsize"] = 20  # Larger legend text

In [None]:
# Function to save plots in both PDF and PNG formats at 1000 DPI
def save_plot(fig, base_path, dpi=1000):
    # Save as PNG
    png_path = f"{base_path}.png"
    fig.savefig(png_path, dpi=dpi, bbox_inches='tight')
    
    # Save as PDF
    pdf_path = f"{base_path}.pdf"
    fig.savefig(pdf_path, dpi=dpi, bbox_inches='tight', format='pdf')
    
    print(f"Plot saved to {png_path} and {pdf_path}")

In [None]:
# Calculate and save statistics
stats_path = os.path.join(OUTPUT_DIR, f"{feature_extract_file_name}_statistics.txt")
with open(stats_path, 'w') as f:
    f.write("=== FEATURE EXTRACTION STATISTICS ===\n\n")

    # Basic stats
    f.write(f"Total samples: {len(df)}\n")
    for class_name in CLASSES:
        count = len(df[df['class'] == class_name])
        f.write(f"Class {class_name}: {count} samples\n")

    f.write("\n=== FEATURE STATISTICS ===\n\n")

    # Global feature stats
    feature_stats = df[feature_cols].describe().T
    f.write("Global feature statistics:\n")
    f.write(f"Mean feature value: {feature_stats['mean'].mean():.4f}\n")
    f.write(f"Std of features: {feature_stats['std'].mean():.4f}\n")
    f.write(f"Min feature value: {feature_stats['min'].min():.4f}\n")
    f.write(f"Max feature value: {feature_stats['max'].max():.4f}\n\n")

    # Per-class feature stats
    for class_name in CLASSES:
        class_features = df[df['class'] == class_name][feature_cols]
        f.write(f"Class {class_name} statistics:\n")
        f.write(f"  Mean feature value: {class_features.mean().mean():.4f}\n")
        f.write(f"  Std of features: {class_features.std().mean():.4f}\n")
        f.write(f"  Min feature value: {class_features.min().min():.4f}\n")
        f.write(f"  Max feature value: {class_features.max().max():.4f}\n\n")

print(f"Statistics saved to {stats_path}")

# Box Plot of Feature Values

In [None]:
# Box plots for each class separately with only positive values and distinct colors

print("Generating separate box plots for each class with only positive values...")

selected_features = feature_cols[:20]  # Select first 25 features
scaler = StandardScaler()
df_selected = pd.DataFrame(
    scaler.fit_transform(df[selected_features]),
    columns=selected_features
)
df_selected['class'] = df['class']

# Define distinct colors for each class
class_colors = {
    "normal": "#1f77b4",  # Blue
    "oscc": "#ff7f0e",   # Orange
}

# Create separate box plots for each class
for class_name in CLASSES:
    class_data = df_selected[df_selected['class'] == class_name]

    # Melt the data for visualization
    melted_data = class_data.melt(id_vars=['class'], value_vars=selected_features)

    # Keep only rows where value is greater than 0
    melted_data = melted_data[melted_data['value'] > 0]

    # Check if we have any data left after filtering
    if len(melted_data) == 0:
        print(f"Warning: No positive values found for {class_name} class. Skipping box plot.")
        continue

    plt.figure(figsize=(22, 12))
    sns.set_style("whitegrid")

    # Use the class-specific color
    current_color = class_colors.get(class_name, "gray")  # Default to gray if class not found

    boxplot = sns.boxplot(
        data=melted_data, x='variable', y='value',
        showfliers=False, color=current_color
    )

    plt.title(f'Box Plot of Standardized Features - {class_name} Class (Positive Values Only)',
              fontsize=28, pad=25, fontname="Times New Roman")
    plt.xlabel('Features', fontsize=24, labelpad=20, fontname="Times New Roman")
    plt.ylabel('Standardized Value', fontsize=24, labelpad=20, fontname="Times New Roman")

    plt.xticks(fontsize=18, fontname="Times New Roman")
    plt.yticks(fontsize=18, fontname="Times New Roman")

    # Add a note about removed non-positive values
    original_count = len(class_data) * len(selected_features)
    remaining_count = len(melted_data)
    removed_count = original_count - remaining_count
    removed_percentage = (removed_count / original_count) * 100

    if removed_count > 0:
        plt.figtext(0.5, 0.01,
                   f"Note: {removed_count} non-positive values ({removed_percentage:.1f}%) were removed from the visualization",
                   fontsize=16, fontname="Times New Roman", ha='center')

    plt.tight_layout()

    # Save plot with higher resolution
    boxplot_path = os.path.join(OUTPUT_DIR, f"{feature_extract_file_name}_boxplot_{class_name}")
    save_plot(plt.gcf(), boxplot_path)
    plt.close()

In [None]:
# Also create a combined plot with all classes (different colors)
print("Creating combined box plot with different colors for each class...")

# Combine data for all classes but filter for positive values only
all_melted_data = df_selected.melt(id_vars=['class'], value_vars=selected_features)
all_melted_data = all_melted_data[all_melted_data['value'] > 0]

plt.figure(figsize=(24, 14))
sns.set_style("whitegrid")

# Create box plot with hue for class
boxplot = sns.boxplot(
    data=all_melted_data,
    x='variable',
    y='value',
    hue='class',
    palette=class_colors,
    showfliers=False
)

plt.title('Box Plot of Standardized Features by Class (Positive Values Only)',
          fontsize=28, pad=25, fontname="Times New Roman")
plt.xlabel('Features', fontsize=24, labelpad=20, fontname="Times New Roman")
plt.ylabel('Standardized Value', fontsize=24, labelpad=20, fontname="Times New Roman")

plt.xticks(fontsize=18, fontname="Times New Roman")
plt.yticks(fontsize=18, fontname="Times New Roman")

# Move legend to a better position
plt.legend(title='Class', title_fontsize=22, fontsize=20, bbox_to_anchor=(1.02, 1), loc='upper left')

# Add a note about removed non-positive values
original_count = len(df_selected) * len(selected_features)
remaining_count = len(all_melted_data)
removed_count = original_count - remaining_count
removed_percentage = (removed_count / original_count) * 100

if removed_count > 0:
    plt.figtext(0.5, 0.01,
               f"Note: {removed_count} non-positive values ({removed_percentage:.1f}%) were removed from the visualization",
               fontsize=16, fontname="Times New Roman", ha='center')

plt.tight_layout()

# Save combined plot
combined_boxplot_path = os.path.join(OUTPUT_DIR, f"{feature_extract_file_name}_boxplot_all_classes_positive_only")
save_plot(plt.gcf(), combined_boxplot_path)
plt.close()

In [None]:
plt.rcParams["font.family"] = "Times New Roman"

print("Computing t-SNE projection...")
tsne = TSNE(n_components=2, random_state=42, perplexity=30, max_iter=1000)
features_embedded = tsne.fit_transform(all_features)

plt.figure(figsize=(14, 12), dpi=100)
for i, class_name in enumerate(CLASSES):
    mask = all_labels == i
    plt.scatter(features_embedded[mask, 0], features_embedded[mask, 1],
                label=class_name, alpha=0.7, s=80)

plt.title('t-SNE Visualization of Feature Vectors', fontsize=28, pad=25)
plt.xlabel('t-SNE Component 1', fontsize=24, labelpad=20)
plt.ylabel('t-SNE Component 2', fontsize=24, labelpad=20)
plt.grid(alpha=0.3)

# Border width of the plot and color black
for spine in plt.gca().spines.values():
    spine.set_linewidth(2)
    spine.set_color('black')

# Move legend outside the frame with increased title font size
plt.legend(title='Class', title_fontsize=20, fontsize=20, bbox_to_anchor=(1.02, 1), loc='upper left')
plt.tight_layout()

# Save plot with higher resolution
tsne_path = os.path.join(OUTPUT_DIR, f"{feature_extract_file_name}_tsne")
save_plot(plt.gcf(), tsne_path)
plt.close()

In [None]:
# 4. Feature correlation heatmap
print("Generating correlation heatmap...")
selected_features = feature_cols[:20]
corr_matrix = df[selected_features].corr()

plt.rcParams["font.family"] = "Times New Roman"
plt.figure(figsize=(20, 16))
sns.heatmap(corr_matrix, cmap="coolwarm", center=0, annot=False, square=True)
plt.title('Feature Correlation Heatmap (First 20 Features)', fontsize=28, pad=25)

plt.xticks(fontsize=20)
plt.yticks(fontsize=20)

plt.xlabel('Features', fontsize=26, labelpad=20)
plt.ylabel('Features', fontsize=26, labelpad=20)

plt.tight_layout()
corr_path = os.path.join(OUTPUT_DIR, f"{feature_extract_file_name}_correlation")
save_plot(plt.gcf(), corr_path)
plt.close()

In [None]:

# 4. Feature correlation heatmaps for each class separately
print("Generating correlation heatmaps for each class...")
selected_features = feature_cols[:20]  # First 20 features

# Set font to Times New Roman
plt.rcParams["font.family"] = "Times New Roman"

# Generate separate correlation heatmap for each class
for class_name in CLASSES:
    print(f"Creating correlation heatmap for {class_name} class...")

    # Filter data for this class only
    class_df = df[df['class'] == class_name]

    # Calculate correlation matrix for this class
    class_corr_matrix = class_df[selected_features].corr()

    # Create figure
    plt.figure(figsize=(20, 16))

    # Generate heatmap with consistent color scaling
    sns.heatmap(class_corr_matrix, cmap="coolwarm", center=0, annot=False, square=True, vmin=-1, vmax=1)

    # Add title and labels with larger font sizes
    plt.title(f'Feature Correlation Heatmap - {class_name} Class (First 20 Features)', fontsize=28, pad=25)
    plt.xlabel('Features', fontsize=26, labelpad=20)
    plt.ylabel('Features', fontsize=26, labelpad=20)

    # Increase tick label font sizes
    plt.xticks(fontsize=20)
    plt.yticks(fontsize=20)

    plt.tight_layout()

    # Save plot with class-specific name
    corr_path = os.path.join(OUTPUT_DIR, f"{feature_extract_file_name}_correlation_{class_name}")
    save_plot(plt.gcf(), corr_path)

    plt.close()

# Also create a combined heatmap for comparison
print("Creating combined correlation heatmap (all classes)...")
corr_matrix = df[selected_features].corr()

plt.figure(figsize=(20, 16))
sns.heatmap(corr_matrix, cmap="coolwarm", center=0, annot=False, square=True, vmin=-1, vmax=1)
plt.title('Feature Correlation Heatmap - All Classes (First 20 Features)', fontsize=28, pad=25)
plt.xlabel('Features', fontsize=26, labelpad=20)
plt.ylabel('Features', fontsize=26, labelpad=20)

plt.xticks(fontsize=20)
plt.yticks(fontsize=20)

plt.tight_layout()
corr_path = os.path.join(OUTPUT_DIR, f"{feature_extract_file_name}_correlation_All")
save_plot(plt.gcf(), corr_path)
plt.close()

In [None]:
print("Feature extraction and visualization complete!")