In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!pip install pynrrd

In [None]:
!pip install torch_geometric

In [None]:
import torch
import torch_geometric
from torch_geometric.data import Data
import pandas as pd
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
from scipy.sparse import csr_matrix
from scipy.sparse.csgraph import connected_components
from torch_geometric.data import Data, Dataset
from sklearn.preprocessing import StandardScaler
from torch_geometric.nn import GCNConv, global_mean_pool
import os

In [None]:
clinical_features = pd.read_excel("/content/drive/MyDrive/GNN_/phenotypic_information.xlsx")
label = pd.read_excel("/content/drive/MyDrive/GNN_/Label.xlsx")
clinical_features["label"] = label['Staging(Metastasis)#(Mx -replaced by -1)[M]']
clinical_features.to_excel("/content/drive/MyDrive/GNN_/phenotypic_information.xlsx",index=False)

In [None]:
dataa = pd.read_excel("/content/drive/MyDrive/GNN_/phenotypic_information.xlsx")
dataa.columns

In [None]:
dataa["label"]

In [None]:
np.unique(clinical_features["label"], return_counts=True)

In [None]:
import os
import pandas as pd
import torch
from torch.utils.data import Dataset
from torch_geometric.data import DataLoader, Data
from torch_geometric.nn import GCNConv, global_mean_pool
import torch.nn.functional as F
from sklearn.preprocessing import StandardScaler
import nrrd
from scipy.ndimage import zoom
import glob
class VolumetricGraphDataset(Dataset):
    def __init__(self, root, feature_file, classes=["0", "1","2"], transform=None):
        self.classes = classes
        self.root = root
        self.data_list = []

        self.additional_features = pd.read_excel(feature_file)
        self.name_ids = self.additional_features['Image Data ID']
        self.feature_scaler = StandardScaler()
        feature_columns = [col for col in self.additional_features.columns if col not in ['Image Data ID', 'label']]

        self.additional_features[feature_columns] = self.feature_scaler.fit_transform(
            self.additional_features[feature_columns]
        )

        self.process()

    def volume_to_graph(self, volume_path, additional_features):
        volume = nrrd.read(volume_path)[0]
        target_size = (32, 32, 32)
        volume = resize_volume(volume, target_size)

        # Normalize volume
        volume_min = volume.min()
        volume_max = volume.max()
        if volume_max > volume_min:  # Avoid division by zero
            volume = (volume - volume_min) / (volume_max - volume_min)

        nodes = []
        edges = []
        node_features = []

        for i in range(volume.shape[0]):
            for j in range(volume.shape[1]):
                for k in range(volume.shape[2]):
                    nodes.append((i, j, k))
                    node_feat = [volume[i, j, k]] + additional_features.tolist()
                    node_features.append(node_feat)

        threshold = 0.5
        for idx, (i, j, k) in enumerate(nodes):
            for di, dj, dk in [(-1, 0, 0), (1, 0, 0), (0, -1, 0), (0, 1, 0), (0, 0, -1), (0, 0, 1)]:
                ni, nj, nk = i + di, j + dj, k + dk
                if (0 <= ni < volume.shape[0] and 0 <= nj < volume.shape[1] and 0 <= nk < volume.shape[2]):
                    neighbor_idx = (ni * volume.shape[1] * volume.shape[2] + nj * volume.shape[2] + nk)
                    if abs(volume[i,j,k] - volume[ni,nj,nk]) < threshold:
                        edges.append([idx, neighbor_idx])

        edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()
        x = torch.tensor(node_features, dtype=torch.float)

        return x, edge_index

    def process(self):
        for name_idx, name_id in enumerate(self.name_ids):
            class_dir = os.path.join(self.root, str(name_id))

            # for vol_name in os.listdir(class_dir):
            vol_path = glob.glob(class_dir + "/*")
            print(f'vol_path {vol_path}' )
            vol_path=vol_path[0]
            vol_features = self.additional_features[
                self.additional_features['Image Data ID'] == name_id
            ]

            if len(vol_features) > 0:
                feature_columns = [col for col in vol_features.columns if col not in ['Image Data ID', 'label']]
                additional_features = vol_features[feature_columns].values[0]
                label = vol_features['label'].values[0]
                x, edge_index = self.volume_to_graph(vol_path, additional_features)
                data = Data(x=x,
                            edge_index=edge_index,
                            y=torch.tensor([label]))
                self.data_list.append(data)

    def __len__(self):  # Corrected to __len__
        return len(self.data_list)

    def __getitem__(self, idx):  # Corrected to __getitem__
        return self.data_list[idx]

def resize_volume(volume, target_shape):
    current_shape = volume.shape
    factors = [float(t) / float(s) for t, s in zip(target_shape, current_shape)]
    return zoom(volume, factors, order=1)

import torch.nn as nn
from torch_geometric.nn import GINConv, global_mean_pool

def make_convolution(in_channels, out_channels):
    return GINConv(nn.Sequential(
        nn.Linear(in_channels, out_channels),
        nn.BatchNorm1d(out_channels),
        nn.ReLU(),
        nn.Linear(out_channels, out_channels),
        nn.BatchNorm1d(out_channels),
        nn.ReLU()
    ))

class GINClassification(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_classes):
        super(GINClassification, self).__init__()
        self.conv1 = make_convolution(in_channels, hidden_channels)
        self.conv2 = make_convolution(hidden_channels, hidden_channels)
        self.conv3 = make_convolution(hidden_channels, out_channels)
        self.classifier = nn.Linear(out_channels, num_classes)


    def forward(self, x, edge_index, batch):
        x = self.conv1(x, edge_index)
        x = self.conv2(x, edge_index)
        x = self.conv3(x, edge_index)
        x = global_mean_pool(x, batch=batch)
        return self.classifier(x)

    def extract_embedding(self, x, edge_index, batch):
        x = self.conv1(x, edge_index)
        x = self.conv2(x, edge_index)
        x = self.conv3(x, edge_index)
        x = global_mean_pool(x, batch=batch)
        return x
import os
import torch
from tqdm import tqdm
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.metrics import f1_score, precision_score, recall_score, confusion_matrix
import itertools


def train_model(model, train_loader, test_loader, optimizer, criterion,fold, num_epochs=500, model_save_path='models/'):
    best_model_path = os.path.join(model_save_path, 'best_model.pth')
    last_model_path = os.path.join(model_save_path, 'last_model.pth')
    best_test_accuracy = 0.0
    best_f1 = 0.0

    # Initialize log data
    log_data = {
        'Epoch': [],
        'Train Average Loss': [],
        'Test Average Loss': [],
        'Train Accuracy': [],  # Add this line
        'Test Accuracy': [],
        'Precision': [],
        'Recall': [],
        'F1 Score': [],
        'Confusion Matrix': []
    }

    for epoch in tqdm(range(num_epochs), desc='Training'):
        total_loss = 0
        total_samples = 0
        correct_train = 0  # Initialize correct predictions for training
        total_train = 0  # Initialize total samples for training

        # Training loop
        model.train()
        for data in train_loader:
            x, edge_index, y = data.x, data.edge_index, data.y
            optimizer.zero_grad()
            out = model(x, edge_index, data.batch)
            loss = criterion(out, y)
            loss.backward()
            optimizer.step()
            total_loss += loss.item() * data.num_graphs
            total_samples += data.num_graphs

            _, predicted = torch.max(out, 1)
            correct_train += (predicted == y).sum().item()
            total_train += y.size(0)

        avg_loss = total_loss / total_samples
        train_accuracy = correct_train / total_train  # Calculate train accuracy

        # Test loop
        model.eval()
        test_total_loss = 0
        correct = 0
        total = 0
        true_labels = []
        predicted_labels = []
        with torch.no_grad():
            for data in test_loader:
                x, edge_index, y = data.x, data.edge_index, data.y
                out = model(x, edge_index, data.batch)
                loss = criterion(out, y)
                test_total_loss += loss.item() * data.num_graphs
                _, predicted = torch.max(out, 1)
                total += y.size(0)
                correct += (predicted == y).sum().item()
                true_labels.extend(y.numpy())
                predicted_labels.extend(predicted.numpy())

        test_accuracy = correct / total
        test_avg_loss = test_total_loss / total
        f1 = f1_score(true_labels, predicted_labels, average='weighted')
        precision = precision_score(true_labels, predicted_labels, average='weighted')
        recall = recall_score(true_labels, predicted_labels, average='weighted')
        cm = confusion_matrix(true_labels, predicted_labels)

        tqdm.write(f'Epoch [{epoch + 1}/{num_epochs}], Train Avg Loss: {avg_loss:.5f}, Train Accuracy: {train_accuracy:.5f}, '
                    f'Test Avg Loss: {test_avg_loss:.5f}, Test Accuracy: {test_accuracy:.5f}, '
                    f'Precision: {precision:.5f}, Recall: {recall:.5f}, F1 Score: {f1:.5f}')

        # Log data
        log_data['Epoch'].append(epoch + 1)
        log_data['Train Average Loss'].append(avg_loss)
        log_data['Test Average Loss'].append(test_avg_loss)
        log_data['Train Accuracy'].append(train_accuracy)  # Log train accuracy
        log_data['Test Accuracy'].append(test_accuracy)
        log_data['Precision'].append(precision)
        log_data['Recall'].append(recall)
        log_data['F1 Score'].append(f1)
        log_data['Confusion Matrix'].append(cm.tolist())

        # Save the best model based on test accuracy and F1-score
        if test_accuracy > best_test_accuracy or (test_accuracy == best_test_accuracy and best_f1 < f1):
            tqdm.write("$$$ best model is updated according to accuracy! $$$")
            best_test_accuracy = test_accuracy
            best_f1 = f1
            os.makedirs(model_save_path, exist_ok=True)
            torch.save(model.state_dict(), best_model_path)

    # Save the last model after training
    os.makedirs(model_save_path, exist_ok=True)
    torch.save(model.state_dict(), last_model_path)

    # Load the best model for evaluation
    model.load_state_dict(torch.load(best_model_path))

    # Convert log_data to DataFrame
    log_df = pd.DataFrame(log_data)
    log_df.to_csv(os.path.join(model_save_path, f'train_log_{fold}.csv'), index=False)

    # Plotting Loss and Accuracy
    plt.figure(figsize=(12, 5))

    # Plot training and test loss
    plt.subplot(1, 2, 1)
    plt.plot(log_data['Epoch'], log_data['Train Average Loss'], label='Train Loss', color='blue')
    plt.plot(log_data['Epoch'], log_data['Test Average Loss'], label='Test Loss', color='red')
    plt.title('Training and Test Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()

    # Plot training and test accuracy
    plt.subplot(1, 2, 2)
    plt.plot(log_data['Epoch'], log_data['Train Accuracy'], label='Train Accuracy', color='blue')
    plt.plot(log_data['Epoch'], log_data['Test Accuracy'], label='Test Accuracy', color='green')
    plt.title('Train and Test Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()

    plt.tight_layout()
    plt.show()

    # Overall metrics
    overall_precision = precision_score(true_labels, predicted_labels, average='weighted')
    overall_recall = recall_score(true_labels, predicted_labels, average='weighted')
    overall_f1 = f1_score(true_labels, predicted_labels, average='weighted')

    print(f'Overall Test Accuracy: {test_accuracy:.5f}, Precision: {overall_precision:.5f}, Recall: {overall_recall:.5f}, F1 Score: {overall_f1:.5f}')

    # Plotting the confusion matrix
    plt.figure(figsize=(8, 6))
    plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
    plt.title('Confusion Matrix')
    plt.colorbar()
    tick_marks = range(len(set(true_labels)))
    plt.xticks(tick_marks, tick_marks)
    plt.yticks(tick_marks, tick_marks)

    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, format(cm[i, j], 'd'),
                 horizontalalignment="center",
                 color="white" if cm[i, j] > thresh else "black")

    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    plt.tight_layout()
    plt.show()



def test_model(model, loader, device):
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for data in loader:
            data.to(device)  # Move data to device

            output = model(data)
            pred = output.max(dim=1)[1]

            correct += pred.eq(data.y).sum().item()
            total += data.y.size(0)

    return correct / total



In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

data_dir = "/content/drive/MyDrive/breast_MRI_DUKE_image/breast MRI cropped/IHS_fuse/fused_data"
feature_file = "/content/drive/MyDrive/GNN_/phenotypic_information.xlsx"

dataset = VolumetricGraphDataset(data_dir, feature_file)
print( dataset)
num_node_features = dataset[0].x.size(1)

In [None]:
from sklearn.metrics import classification_report, accuracy_score
import matplotlib.pyplot as plt
import numpy as np
from imblearn.over_sampling import SMOTE
from torch_geometric.data import DataLoader

labels = [data.y.item() for data in dataset]
unique_labels, counts = np.unique(labels, return_counts=True)
print("Class distribution before oversampling:")
for lbl, count in zip(unique_labels, counts):
    print(f"Class {lbl}: {count} samples")

In [None]:
oversampler = SMOTE()
node_features = np.vstack([data.x.numpy() for data in dataset])
node_labels = np.array(labels)


In [None]:
from collections import Counter
graph_features = []
graph_labels = []

for data in dataset:
    graph_features.append(data.x.mean(dim=0).numpy())
    graph_labels.append(data.y.item())

graph_features = np.array(graph_features)
graph_labels = np.array(graph_labels)
class_counts = Counter(graph_labels)

valid_indices = [i for i, label in enumerate(graph_labels) if class_counts[label] > 3]
graph_features = graph_features[valid_indices]
graph_labels = graph_labels[valid_indices]

oversampled_features, oversampled_labels = oversampler.fit_resample(graph_features, graph_labels)

In [None]:
import numpy as np

unique_labels = np.unique(graph_labels)

label_mapping = {old_label: new_label for new_label, old_label in enumerate(unique_labels)}

graph_labels = np.array([label_mapping[label] for label in graph_labels])

In [None]:
print("Class distribution after remapping:")
unique_labels, counts = np.unique(graph_labels, return_counts=True)
for lbl, count in zip(unique_labels, counts):
    print(f"Class {lbl}: {count} samples")

In [None]:
from imblearn.over_sampling import SMOTE

oversampler = SMOTE(k_neighbors=1)
oversampled_features, oversampled_labels = oversampler.fit_resample(graph_features, graph_labels)

In [None]:
print("Class distribution after oversampling:")
unique_labels, counts = np.unique(oversampled_labels, return_counts=True)
for lbl, count in zip(unique_labels, counts):
    print(f"Class {lbl}: {count} samples")

In [None]:
oversampled_data_list = []
for i in range(len(oversampled_labels)):
    additional_features = torch.tensor(oversampled_features[i], dtype=torch.float).unsqueeze(0)
    label = torch.tensor([oversampled_labels[i]], dtype=torch.long)

    x = additional_features.repeat(10, 1)
    edge_index = torch.tensor([[i, j] for i in range(10) for j in range(10)], dtype=torch.long).t().contiguous()

    data = Data(x=x, edge_index=edge_index, y=label)
    oversampled_data_list.append(data)


In [None]:
from tqdm import tqdm
from sklearn.metrics import f1_score
from sklearn.model_selection import KFold
from sklearn.metrics import classification_report
from torch_geometric.data import DataLoader
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
num_node_features = 16
num_classes = 3
num_epochs = 150
k_folds = 5




kf = KFold(n_splits=k_folds, shuffle=True, random_state=42)
fold_train_losses = []
fold_val_losses = []
fold_train_accuracies = []
fold_val_accuracies = []


for fold, (train_idx, val_idx) in enumerate(kf.split(oversampled_data_list)):
    print(f"Fold {fold + 1}/5")
    train_data = [oversampled_data_list[i] for i in train_idx]
    val_data = [oversampled_data_list[i] for i in val_idx]

    train_loader = DataLoader(train_data, batch_size=16, shuffle=True)
    test_loader = DataLoader(val_data, shuffle=False)
    model = GINClassification(in_channels=num_node_features, hidden_channels=1000, out_channels=100, num_classes=num_classes)

    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    criterion = nn.CrossEntropyLoss()
    train_model(model, train_loader, test_loader, optimizer, criterion,fold=fold, num_epochs=num_epochs, model_save_path="/content/drive/MyDrive/GNN_/Staging(Metastasis)#(Mx -replaced by -1)[M]/fusion/model")


