In [21]:
import pandas as pd
import numpy as np
import nibabel as nib
from nilearn import input_data, datasets
import networkx as nx
from sklearn.model_selection import train_test_split
from torch_geometric.data import Data, DataLoader
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool
from fastdtw import fastdtw
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, classification_report
from torch.nn import BatchNorm1d, Dropout
import os
from nilearn import plotting
from nilearn import image
import matplotlib.pyplot as plt
import warnings
import pprint
# Ignore all warnings (not recommended unless you know what you are doing)
warnings.filterwarnings("ignore")
from tqdm import tqdm
            


train_data_eda = "test_dwp_train_eda/"
os.makedirs(train_data_eda, exist_ok=True)

test_result_dir = "test_dwp_test_result/"
os.makedirs(test_result_dir, exist_ok=True)

# Load the CSV file
csv_file = pd.read_csv(r"/Users/vinoth/PycharmProjects/paper_implementation/Dataset/source/mri_images/ABIDE_pcp/Phenotypic_V1_0b_preprocessed1.csv")
csv_file['DX_GROUP'].replace({1: 0, 2: 1}, inplace=True)
train_df, test_df = train_test_split(csv_file, test_size=0.1, random_state=42)

harvard_oxford_atlas = [
    "cort-maxprob-thr25-2mm",
    "cort-maxprob-thr50-2mm",
    "cort-prob-2mm",
    "cort-maxprob-thr0-2mm"
    
]


'''harvard_oxford_atlas = [
    "cort-maxprob-thr0-1mm",
    "cort-maxprob-thr0-2mm",
    "cort-maxprob-thr25-1mm",
    "cort-maxprob-thr25-2mm",
    "cort-maxprob-thr50-1mm",
    "cort-maxprob-thr50-2mm",
    "cort-prob-1mm",
    "cort-prob-2mm",
    "cortl-maxprob-thr0-1mm",
    "cortl-maxprob-thr0-2mm",
    "cortl-maxprob-thr25-1mm",
    "cortl-maxprob-thr25-2mm",
    "cortl-maxprob-thr50-1mm",
    "cortl-maxprob-thr50-2mm",
    "cortl-prob-1mm",
    "cortl-prob-2mm",
    "sub-maxprob-thr0-1mm",
    "sub-maxprob-thr0-2mm",
    "sub-maxprob-thr25-1mm",
    "sub-maxprob-thr25-2mm",
    "sub-maxprob-thr50-1mm",
    "sub-maxprob-thr50-2mm",
    "sub-prob-1mm",
    "sub-prob-2mm"
]'''


results = {}
atlas_threshold = None

for data in tqdm(harvard_oxford_atlas):
    atlas_threshold = data
    print("----Threshold----")
    print(atlas_threshold)
    results[atlas_threshold] = {}
    atlas = datasets.fetch_atlas_harvard_oxford(data)
    masker = input_data.NiftiLabelsMasker(labels_img=atlas.maps, standardize=True)
    mri_dir = r"/Users/vinoth/PycharmProjects/paper_implementation/Dataset/source/mri_images/ABIDE_pcp/cpac/nofilt_noglobal/"

    # Placeholder for Graph Neural Network Data
    graph_data_list = []

    # Data Preprocessing
    for idx, row in tqdm(enumerate(train_df.itertuples()), total=len(train_df)):
        
        if idx == 2:
            break
        # Combine the parent and nested folder paths
        '''file_dir = os.path.join(train_data_eda, row.FILE_ID)
        os.makedirs(file_dir, exist_ok=True)'''
        mri_filename = os.path.join(mri_dir, row.FILE_ID + "_func_preproc.nii.gz")
        
        try:
            mri_img = nib.load(mri_filename)
            
            #mri_img_dir = os.path.join(file_dir, 'mri_image')
            #os.makedirs(mri_img_dir, exist_ok=True)
            
            # Select the first time point
            first_volume = mri_img.slicer[:,:,:,0]
            
            image_shape = mri_img.shape

            # The total number of volumes in the 4D dimension is the size of the fourth dimension
            total_volumes = image_shape[3]

            print("Total number of volumes in the 4D image for file " + row.FILE_ID + " : ", total_volumes)

            '''# Plot the image
            plotting.plot_img(first_volume, cmap='gray')  # grayscale often works well for MRIs
            filename = os.path.join(mri_img_dir, row.FILE_ID+'_img.png')
            plt.savefig(filename)
            plt.close()  # Close the plot to avoid overlaps

            # Plot the EPI
            plotting.plot_epi(first_volume, display_mode='z', cut_coords=5, cmap='viridis')  # viridis is a perceptually uniform colormap
            filename = os.path.join(mri_img_dir, row.FILE_ID+'_epi_img.png')
            plt.savefig(filename)
            plt.close()

            # Plot the anatomy
            plotting.plot_anat(first_volume, cmap='gray')  # grayscale again for anatomical images
            filename = os.path.join(mri_img_dir, row.FILE_ID+'_anat_img.png')
            plt.savefig(filename)
            plt.close()

            # Plot the statistical map
            plotting.plot_stat_map(first_volume, bg_img=None, threshold=3.0, cmap='cold_hot')  # cold_hot is often used for stat maps
            filename = os.path.join(mri_img_dir, row.FILE_ID+'_stat_map_img.png')
            plt.savefig(filename)
            plt.close()

            # Plot the probabilistic atlas
            plotting.plot_prob_atlas(mri_img, bg_img=None, colorbar=True)  # default colormap should work for probabilistic atlas
            filename = os.path.join(mri_img_dir, row.FILE_ID+'_atlas_map_img.png')
            plt.savefig(filename)
            plt.close()'''
            
            '''from concurrent.futures import ProcessPoolExecutor

            def calculate_distance(i, j, time_series):
                distance, _ = fastdtw(time_series[i, :], time_series[j, :])
                return i, j, distance
            
            def wrapper(args):
                return calculate_distance(*args)

            time_series = masker.fit_transform(mri_img)
            n_regions, n_time_points = time_series.shape
            print("****************************")
            print(row.FILE_ID, " --> ", "n_regions --> ", n_regions, "n_time_points --> ", n_time_points)
            print("****************************")
            distance_matrix = np.zeros((n_regions, n_regions))

            # Creating a list of arguments to pass to the function
            args = [(i, j, time_series) for i in range(n_regions) for j in range(i + 1, n_regions)]

            # Using ProcessPoolExecutor to execute the function in parallel
            with ProcessPoolExecutor() as executor:
                results = list(executor.map(wrapper, args))

            # Filling the distance_matrix with the results
            for i, j, distance in results:
                distance_matrix[i, j] = distance_matrix[j, i] = distance

            distance_matrix = distance_matrix / distance_matrix.max()
            similarity_matrix = 1 - distance_matrix
            threshold = 0.5
            similarity_matrix[similarity_matrix < threshold] = 0
            
            print(similarity_matrix)
            
            G = nx.from_numpy_matrix(similarity_matrix)'''

            
            
            from concurrent.futures import ThreadPoolExecutor

            def calculate_distance(i, j, time_series):
                distance, _ = fastdtw(time_series[i, :], time_series[j, :])
                return i, j, distance

            def wrapper(args):
                return calculate_distance(*args)

            
            time_series = masker.fit_transform(mri_img)
            n_regions, n_time_points = time_series.shape
            print("****************************")
            print(row.FILE_ID, " --> ", "n_regions --> ", n_regions, "n_time_points --> ", n_time_points)
            print("****************************")
            distance_matrix = np.zeros((n_regions, n_regions))

            # Creating a list of arguments to pass to the function
            args = [(i, j, time_series) for i in range(n_regions) for j in range(i + 1, n_regions)]

            # Using ThreadPoolExecutor to execute the function in parallel
            with ThreadPoolExecutor() as executor:
                train_thread_results = list(executor.map(wrapper, args))

            # Filling the distance_matrix with the train_thread_results
            for i, j, distance in train_thread_results:
                distance_matrix[i, j] = distance_matrix[j, i] = distance

            distance_matrix = distance_matrix / distance_matrix.max()
            similarity_matrix = 1 - distance_matrix
            threshold = 0.3
            similarity_matrix[similarity_matrix < threshold] = 0

            #print(similarity_matrix)

            G = nx.from_numpy_matrix(similarity_matrix)

                

# Here you can call the function with the appropriate row
# graph = process_time_series(row)



            #if idx == 0:  # Only for the first iteration
            # Plot the time series for the regions
            '''plt.figure(figsize=(35, 15))
            for i in range(min(n_regions, time_series.shape[0])):
                plt.plot(time_series[i, :], label=f'Region {i + 1}')
            plt.xlabel('Time point')
            plt.ylabel('Blood Oxygen Level(BOLD) - Normalized signal')
            plt.title('Time series of the regions')
            plt.legend(loc='center left', bbox_to_anchor=(1, 0.5))
            plt.tight_layout()'''
        
            # Save the plot to the existing folder
            '''time_series_filename = row.FILE_ID+'_'+'time_series_plot.png'
            
            plt.savefig(os.path.join(file_dir, time_series_filename))
            
            plt.figure(figsize=(10, 10))
            sns.heatmap(similarity_matrix, annot=False, cmap='turbo')
            plt.title('Similarity Matrix')
            similarity_matrix_adj_img_filename = row.FILE_ID + '_similarity_matrix.png'
            plt.savefig(os.path.join(file_dir, similarity_matrix_adj_img_filename))
            plt.close() # Close the plot
            similarity_matrix_npy_filename = row.FILE_ID + '_similarity_matrix.npy'
            similarity_matrix_npy_path = os.path.join(file_dir, similarity_matrix_npy_filename)
            np.save(similarity_matrix_npy_path, similarity_matrix)'''

            '''# Visualize the graph
            plt.figure(figsize=(45, 25))
            pos = nx.spring_layout(G)'''

            # Extract the edge weights from the graph
            weights = [G[u][v].get('weight', 1) for u, v in G.edges()]

            # Normalize the weights to fit your desired range of thickness
            normalized_weights = [5 * weight / max(weights) for weight in weights]

            # Draw the edges with the thickness determined by the normalized weights
            '''nx.draw_networkx_edges(G, pos, width=normalized_weights)

            # Draw the nodes and labels
            nx.draw_networkx_nodes(G, pos)
            nx.draw_networkx_labels(G, pos)'''
            
            # Define the path and filename where you want to save the plot
            '''graph_plot_filename = row.FILE_ID + '_graph_plot.png'
            graph_plot_path = os.path.join(file_dir, graph_plot_filename)

            # Save the plot to the specified path
            plt.savefig(graph_plot_path)'''

            edge_index = torch.tensor(list(G.edges), dtype=torch.long)
            x = torch.tensor(time_series, dtype=torch.float)
            y = torch.tensor([row.DX_GROUP], dtype=torch.float)
            data = Data(x=x, edge_index=edge_index.t().contiguous(), y=y)
            graph_data_list.append(data)
        except FileNotFoundError:
            pass
        
    print("Graph Data List -----> ")
    #print(graph_data_list)

    # Neural Network Model with Regularization, Batch Normalization, and Dropout
    class Net(torch.nn.Module):
        def __init__(self, num_node_features, num_classes):
            super(Net, self).__init__()
            self.conv1 = GCNConv(num_node_features, 16)
            self.bn1 = BatchNorm1d(16)
            self.conv2 = GCNConv(16, 32)
            self.bn2 = BatchNorm1d(32)
            self.fc = torch.nn.Linear(32, num_classes)
            self.dropout = Dropout(0.5)

        def forward(self, data):
            x, edge_index, batch = data.x, data.edge_index, data.batch
            x = self.conv1(x, edge_index)
            x = self.bn1(x)
            x = F.relu(x)
            x = self.dropout(x)
            x = self.conv2(x, edge_index)
            x = self.bn2(x)
            x = global_mean_pool(x, batch)
            x = self.fc(x)
            return F.log_softmax(x, dim=1)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    num_features = graph_data_list[0].num_node_features
    num_classes = 2
    model = Net(num_features, num_classes).to(device)
    loader = DataLoader(graph_data_list, batch_size=32, shuffle=True)
    
    # Hyperparameter Tuning (Example: Adjusting Learning Rate)
    learning_rates = [0.01, 0.001, 0.0001]
    l_rate = None
    for lr in learning_rates:
        l_rate = str(lr)
        results[atlas_threshold][l_rate] = {}
        print("Learning Rate --> ", lr)
        optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=5e-4) # L2 Regularization
        for epoch in tqdm(range(100)):
            total_loss = 0
            model.train()
            for data in loader:
                data = data.to(device)
                optimizer.zero_grad()
                out = model(data)
                loss = F.nll_loss(out, data.y.long())
                loss.backward()
                optimizer.step()
                total_loss += loss.item()
            print(f'Epoch: {epoch+1}, Loss: {total_loss/len(loader)}')
        results[atlas_threshold][l_rate]['loss'] = total_loss/len(loader)

        # Placeholder for time series data
        time_series_list = []
        successful_indices = []

        # Testing Data Preprocessing
        # for idx, row in enumerate(test_df.itertuples()):
        for idx, row in tqdm(enumerate(test_df.itertuples()), total=len(test_df)):
            mri_filename = os.path.join(mri_dir, row.FILE_ID + "_func_preproc.nii.gz")
            try:
                mri_img = nib.load(mri_filename)
                time_series = masker.fit_transform(mri_img)
                time_series_list.append(time_series)
                successful_indices.append(idx)
            except FileNotFoundError:
                pass

        
        # Placeholder for Graph Neural Network Data for testing
        graph_data_test_list = []

        #for idx, successful_idx in tqdm(enumerate(successful_indices, total=len(successful_indices))):
        '''for idx, successful_idx in tqdm(enumerate(successful_indices), total=len(successful_indices)):
            row = test_df.iloc[successful_idx]
            time_series = time_series_list[idx]
            n_regions = time_series.shape[0]
            distance_matrix = np.zeros((n_regions, n_regions))
            for i in range(n_regions):
                for j in range(i + 1, n_regions):
                    distance, _ = fastdtw(time_series[i, :], time_series[j, :])
                    distance_matrix[i, j] = distance_matrix[j, i] = distance
            distance_matrix = distance_matrix / distance_matrix.max()
            similarity_matrix = 1 - distance_matrix
            threshold = 0.5
            similarity_matrix[similarity_matrix < threshold] = 0
            G = nx.from_numpy_matrix(similarity_matrix)
            edge_index = torch.tensor(list(G.edges), dtype=torch.long)
            x = torch.tensor(time_series, dtype=torch.float)
            y = torch.tensor([row.DX_GROUP], dtype=torch.float)
            data = Data(x=x, edge_index=edge_inde'.t().contiguous(), y=y)
            graph_data_test_list.append(data)'''

        
        
        from concurrent.futures import ThreadPoolExecutor

        def calculate_distance(i, j, time_series):
            distance, _ = fastdtw(time_series[i, :], time_series[j, :])
            return i, j, distance

        def process_row(successful_idx, time_series):
            row = test_df.iloc[successful_idx]
            n_regions = time_series.shape[0]
            distance_matrix = np.zeros((n_regions, n_regions))

            # Creating a list of arguments to pass to the function
            args = [(i, j, time_series) for i in range(n_regions) for j in range(i + 1, n_regions)]

            # Using ThreadPoolExecutor to execute the function in parallel
            with ThreadPoolExecutor() as executor:
                thread_results = list(executor.map(calculate_distance, args))

            # Filling the distance_matrix with the thread_results
            for i, j, distance in thread_results:
                distance_matrix[i, j] = distance_matrix[j, i] = distance

            distance_matrix = distance_matrix / distance_matrix.max()
            similarity_matrix = 1 - distance_matrix
            threshold = 0.3
            similarity_matrix[similarity_matrix < threshold] = 0
            G = nx.from_numpy_matrix(similarity_matrix)
            edge_index = torch.tensor(list(G.edges), dtype=torch.long)
            x = torch.tensor(time_series, dtype=torch.float)
            y = torch.tensor([row.DX_GROUP], dtype=torch.float)
            data = Data(x=x, edge_index=edge_index.t().contiguous(), y=y)
            return data

        graph_data_test_list = []
        with ThreadPoolExecutor() as executor:
            args = [(successful_idx, time_series_list[idx]) for idx, successful_idx in enumerate(successful_indices)]
            graph_data_test_list = list(tqdm(executor.map(process_row, args), total=len(successful_indices)))


        
        
            



        
        
        
        # Create a data loader for testing data
        test_loader = DataLoader(graph_data_test_list, batch_size=32, shuffle=False)

        # Testing
        model.eval()
        correct = 0
        all_preds = []
        all_labels = []

        for data in tqdm(test_loader):
            data = data.to(device)
            with torch.no_grad():
                output = model(data)
                _, pred = output.max(dim=1)
            all_preds.append(pred.cpu().numpy())
            all_labels.append(data.y.cpu().numpy())
            correct += int((pred == data.y.long()).sum())

        accuracy = correct / len(test_loader.dataset)

        print(f'Test Accuracy: {accuracy:.4f}')
        
        results[atlas_threshold][l_rate]['accuracy'] = accuracy

        # Flatten the list of predictions and labels
        all_preds = np.concatenate(all_preds)
        all_labels = np.concatenate(all_labels)
        
        # Specify the parent folder
        parent_folder = test_result_dir

        # Specify the nested folder names
        nested_folder1 = atlas_threshold
        nested_folder2 = 'learning_rate_'+l_rate
        
        # Combine the parent and nested folder pa'hs
        validation_result_dir = os.path.join(parent_folder, nested_folder1, nested_folder2)
        
        # Create the nested folders, including any necessary parent directories
        os.makedirs(validation_result_dir, exist_ok=True)

        # Confusion Matrix
        cm = confusion_matrix(all_labels, all_preds)
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]  # Normalize

        plt.figure(figsize=(10, 7))
        sns.heatmap(cm, annot=True, cmap='Blues', fmt=".2%")
        plt.xlabel('Predicted')
        plt.ylabel('Actual')
        plt.title('Confusion Matrix (Normalized)')
        
        # Save the image inside the nested folder
        plt.savefig(os.path.join(validation_result_dir, 'confusion_matrix.png'))
        
        # Print actual vs predicted
        actual_vs_predicted = pd.DataFrame({'Actual': all_labels, 'Predicted': all_preds})
        print(actual_vs_predicted)

        # Classification report
        report = classification_report(all_labels, all_preds, target_names=['Non-Autistic', 'Autistic'], output_dict=True)
        #print(classification_report(all_labels, all_preds, target_names=['Non-Autistic', 'Autistic']))
        report_text = classification_report(all_labels, all_preds, target_names=['Non-Autistic', 'Autistic'])

        plt.figure(figsize=(10, 7))
        plt.text(0.01, 0.05, report_text, {'fontsize': 12}, fontproperties='monospace') # Adjust text size and position accordingly
        plt.axis('off')
        plt.tight_layout()
        plt.savefig(os.path.join(validation_result_dir, 'classification_report.png'))
        
        with open(os.path.join(validation_result_dir, 'classification_report.txt'), 'w') as file:
            file.write(report_text)
  
        # Individual class metrics
        # Access individual values
        non_autistic_precision = report['Non-Autistic']['precision']
        autistic_precision = report['Autistic']['precision']

        non_autistic_recall = report['Non-Autistic']['recall']
        autistic_recall = report['Autistic']['recall']

        non_autistic_f1_score = report['Non-Autistic']['f1-score']
        autistic_f1_score = report['Autistic']['f1-score']

        non_autistic_support = report['Non-Autistic']['support']
        autistic_support = report['Autistic']['support']

        '''# Aggregated metrics
        accuracy = report['accuracy'''

        macro_avg_precision = report['macro avg']['precision']
        weighted_avg_precision = report['weighted avg']['precision']

        macro_avg_recall = report['macro avg']['recall']
        weighted_avg_recall = report['weighted avg']['recall']

        macro_avg_f1_score = report['macro avg']['f1-score']
        weighted_avg_f1_score = report['weighted avg']['f1-score']

        macro_avg_support = report['macro avg']['support']
        weighted_avg_support = report['weighted avg']['support']
        
        results[atlas_threshold][l_rate]['non_autistic_precision'] = non_autistic_precision
        results[atlas_threshold][l_rate]['autistic_precision'] = autistic_precision
        results[atlas_threshold][l_rate]['non_autistic_recall'] = non_autistic_recall
        results[atlas_threshold][l_rate]['autistic_recall'] = autistic_recall
        results[atlas_threshold][l_rate]['non_autistic_f1_score'] = non_autistic_f1_score
        results[atlas_threshold][l_rate]['autistic_f1_score'] = autistic_f1_score
        results[atlas_threshold][l_rate]['non_autistic_support'] = non_autistic_support
        results[atlas_threshold][l_rate]['autistic_support'] = autistic_support
        results[atlas_threshold][l_rate]['macro_avg_precision'] = macro_avg_precision
        results[atlas_threshold][l_rate]['weighted_avg_precision'] = weighted_avg_precision
        results[atlas_threshold][l_rate]['macro_avg_recall'] = macro_avg_recall
        results[atlas_threshold][l_rate]['weighted_avg_recall'] = weighted_avg_recall
        results[atlas_threshold][l_rate]['macro_avg_f1_score'] = macro_avg_f1_score
        results[atlas_threshold][l_rate]['weighted_avg_f1_score'] = weighted_avg_f1_score
        results[atlas_threshold][l_rate]['macro_avg_support'] = macro_avg_support
        results[atlas_threshold][l_rate]['weighted_avg_support'] = weighted_avg_support
        
        atlas = None
        l_rate = None
        
print("----Final Result----")
pprint.pprint(results)

sorted_data = [(key, subkey, values['accuracy']) for key, subdata in results.items() for subkey, values in subdata.items()]
sorted_data.sort(key=lambda x: x[2], reverse=True)

print("----Sorted Accuracy----")
for key, subkey, accuracy in sorted_data:
    print(f"Key: {key}, Subkey: {subkey}, Accuracy: {accuracy}")
    
# Create a list of tuples containing key, subkey, and corresponding details
sorted_data = [(key, subkey, values) for key, subdata in results.items() for subkey, values in subdata.items()]

# Sort the list based on the accuracy
sorted_data.sort(key=lambda x: x[2]['accuracy'], reverse=True)

# Create a new dictionary with the sorted order
sorted_data_dict = {f"{key}-{subkey}": values for key, subkey, values in sorted_data}
print("----Sorted Dict----")
pprint.pprint(sorted_data_dict)



  0%|          | 0/4 [00:00<?, ?it/s]

----Threshold----
cort-maxprob-thr25-2mm



  0%|          | 0/1000 [00:00<?, ?it/s][A

Total number of volumes in the 4D image for file USM_0050446 :  236
****************************
USM_0050446  -->  n_regions -->  236 n_time_points -->  48
****************************



  0%|          | 2/1000 [00:20<2:54:00, 10.46s/it][A


Graph Data List -----> 
Learning Rate -->  0.01



  0%|          | 0/100 [00:00<?, ?it/s][A
  1%|          | 1/100 [00:00<00:19,  4.96it/s][A
  4%|▍         | 4/100 [00:00<00:06, 14.59it/s][A

Epoch: 1, Loss: 0.6519943475723267
Epoch: 2, Loss: 0.6239675283432007
Epoch: 3, Loss: 0.5927113890647888
Epoch: 4, Loss: 0.5578501224517822
Epoch: 5, Loss: 0.519758403301239



  7%|▋         | 7/100 [00:00<00:04, 19.30it/s][A
 10%|█         | 10/100 [00:00<00:04, 20.79it/s][A

Epoch: 6, Loss: 0.4789934456348419
Epoch: 7, Loss: 0.43625032901763916
Epoch: 8, Loss: 0.3923397660255432
Epoch: 9, Loss: 0.3481598198413849
Epoch: 10, Loss: 0.3046555519104004
Epoch: 11, Loss: 0.2627648711204529



 14%|█▍        | 14/100 [00:00<00:03, 23.74it/s][A
 17%|█▋        | 17/100 [00:00<00:03, 25.04it/s][A

Epoch: 12, Loss: 0.22335362434387207
Epoch: 13, Loss: 0.18714982271194458
Epoch: 14, Loss: 0.1546863317489624
Epoch: 15, Loss: 0.12626507878303528
Epoch: 16, Loss: 0.10194984823465347
Epoch: 17, Loss: 0.08158980309963226



 20%|██        | 20/100 [00:00<00:03, 24.94it/s][A
 25%|██▌       | 25/100 [00:01<00:02, 30.20it/s][A

Epoch: 18, Loss: 0.06486798077821732
Epoch: 19, Loss: 0.0513605959713459
Epoch: 20, Loss: 0.04059751331806183
Epoch: 21, Loss: 0.032110415399074554
Epoch: 22, Loss: 0.025467440485954285
Epoch: 23, Loss: 0.020291481167078018
Epoch: 24, Loss: 0.016266483813524246
Epoch: 25, Loss: 0.01313568465411663
Epoch: 26, Loss: 0.010695278644561768



 31%|███       | 31/100 [00:01<00:01, 36.77it/s][A
 37%|███▋      | 37/100 [00:01<00:01, 40.67it/s][A

Epoch: 27, Loss: 0.00878635048866272
Epoch: 28, Loss: 0.00728604756295681
Epoch: 29, Loss: 0.00610013073310256
Epoch: 30, Loss: 0.0051570250652730465
Epoch: 31, Loss: 0.004401875659823418
Epoch: 32, Loss: 0.0037930700927972794
Epoch: 33, Loss: 0.0032987960148602724
Epoch: 34, Loss: 0.002894618781283498
Epoch: 35, Loss: 0.002561979927122593
Epoch: 36, Loss: 0.002286202972754836
Epoch: 37, Loss: 0.002056271303445101



 43%|████▎     | 43/100 [00:01<00:01, 45.29it/s][A
 48%|████▊     | 48/100 [00:01<00:01, 44.73it/s]

Epoch: 38, Loss: 0.0018630543490871787
Epoch: 39, Loss: 0.0016999093350023031
Epoch: 40, Loss: 0.0015613758005201817
Epoch: 41, Loss: 0.0014429405564442277
Epoch: 42, Loss: 0.001341396477073431
Epoch: 43, Loss: 0.0012536532012745738
Epoch: 44, Loss: 0.001177571015432477
Epoch: 45, Loss: 0.001111366436816752
Epoch: 46, Loss: 0.001053493469953537
Epoch: 47, Loss: 0.0010027624666690826
Epoch: 48, Loss: 0.0009579836623743176


[A
 54%|█████▍    | 54/100 [00:01<00:00, 47.35it/s][A

Epoch: 49, Loss: 0.0009184433147311211
Epoch: 50, Loss: 0.0008834273321554065
Epoch: 51, Loss: 0.0008522216230630875
Epoch: 52, Loss: 0.0008244690834544599
Epoch: 53, Loss: 0.0007995745982043445
Epoch: 54, Loss: 0.0007773000397719443
Epoch: 55, Loss: 0.0007572882459498942
Epoch: 56, Loss: 0.0007391819381155074
Epoch: 57, Loss: 0.0007228621980175376
Epoch: 58, Loss: 0.0007082099909894168
Epoch: 59, Loss: 0.0006948678637854755



 60%|██████    | 60/100 [00:01<00:00, 48.38it/s][A
 66%|██████▌   | 66/100 [00:01<00:00, 51.10it/s][A

Epoch: 60, Loss: 0.0006827168981544673
Epoch: 61, Loss: 0.0006716379430145025
Epoch: 62, Loss: 0.0006615119054913521
Epoch: 63, Loss: 0.0006523388437926769
Epoch: 64, Loss: 0.0006438804557546973
Epoch: 65, Loss: 0.000636255950666964
Epoch: 66, Loss: 0.0006291079334914684
Epoch: 67, Loss: 0.0006225554971024394
Epoch: 68, Loss: 0.0006165986997075379
Epoch: 69, Loss: 0.000610999355558306
Epoch: 70, Loss: 0.0006058764411136508
Epoch: 71, Loss: 0.0006011109799146652



 72%|███████▏  | 72/100 [00:01<00:00, 51.57it/s][A
 78%|███████▊  | 78/100 [00:02<00:00, 49.19it/s][A

Epoch: 72, Loss: 0.0005965837044641376
Epoch: 73, Loss: 0.0005925330333411694
Epoch: 74, Loss: 0.0005886013968847692
Epoch: 75, Loss: 0.0005850272136740386
Epoch: 76, Loss: 0.0005816913326270878
Epoch: 77, Loss: 0.000578474544454366
Epoch: 78, Loss: 0.0005754960584454238
Epoch: 79, Loss: 0.0005727558163926005
Epoch: 80, Loss: 0.0005700155161321163
Epoch: 81, Loss: 0.0005675135762430727



 83%|████████▎ | 83/100 [00:02<00:00, 48.75it/s][A
 88%|████████▊ | 88/100 [00:02<00:00, 48.13it/s][A

Epoch: 82, Loss: 0.0005651307292282581
Epoch: 83, Loss: 0.0005628670332953334
Epoch: 84, Loss: 0.0005607224884442985
Epoch: 85, Loss: 0.0005586970364674926
Epoch: 86, Loss: 0.0005567907355725765
Epoch: 87, Loss: 0.0005548844928853214
Epoch: 88, Loss: 0.0005530973430722952
Epoch: 89, Loss: 0.000551310193259269
Epoch: 90, Loss: 0.0005496421363204718
Epoch: 91, Loss: 0.0005479741375893354



 93%|█████████▎| 93/100 [00:02<00:00, 47.45it/s][A
100%|██████████| 100/100 [00:02<00:00, 39.07it/s][A


Epoch: 92, Loss: 0.000546425289940089
Epoch: 93, Loss: 0.0005449955351650715
Epoch: 94, Loss: 0.000543446687515825
Epoch: 95, Loss: 0.0005420169327408075
Epoch: 96, Loss: 0.0005407063290476799
Epoch: 97, Loss: 0.0005392765742726624
Epoch: 98, Loss: 0.0005379660287871957
Epoch: 99, Loss: 0.000536655425094068
Epoch: 100, Loss: 0.0005353448214009404



  0%|          | 0/112 [00:00<?, ?it/s][A
  1%|          | 1/112 [00:03<06:15,  3.38s/it][A
  2%|▏         | 2/112 [00:06<06:08,  3.35s/it][A
  3%|▎         | 3/112 [00:09<05:26,  3.00s/it][A
  4%|▎         | 4/112 [00:11<04:51,  2.70s/it][A
  4%|▍         | 5/112 [00:13<04:22,  2.46s/it][A
  5%|▌         | 6/112 [00:15<03:54,  2.21s/it][A
  6%|▋         | 7/112 [00:16<03:32,  2.03s/it][A
  8%|▊         | 9/112 [00:20<03:23,  1.98s/it][A
  9%|▉         | 10/112 [00:22<03:05,  1.82s/it][A
 10%|▉         | 11/112 [00:23<02:50,  1.69s/it][A
 11%|█         | 12/112 [00:26<03:20,  2.01s/it][A
 12%|█▏        | 13/112 [00:30<04:07,  2.50s/it][A
 12%|█▎        | 14/112 [00:32<04:13,  2.59s/it][A
 13%|█▎        | 15/112 [00:34<03:54,  2.42s/it][A
 14%|█▍        | 16/112 [00:37<03:45,  2.35s/it][A
 15%|█▌        | 17/112 [00:40<04:06,  2.60s/it][A
 16%|█▌        | 18/112 [00:41<03:37,  2.31s/it][A
 17%|█▋        | 19/112 [00:43<03:17,  2.12s/it][A
 18%|█▊        | 20/112 [00:

TypeError: process_row() missing 1 required positional argument: 'time_series'

In [None]:
import pandas as pd
import numpy as np
import nibabel as nib
from nilearn import input_data, datasets
import networkx as nx
from sklearn.model_selection import train_test_split
from torch_geometric.data import Data, DataLoader
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool
from fastdtw import fastdtw
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, classification_report
from torch.nn import BatchNorm1d, Dropout
import os
from nilearn import plotting
from nilearn import image
import matplotlib.pyplot as plt
import warnings
import pprint
# Ignore all warnings (not recommended unless you know what you are doing)
warnings.filterwarnings("ignore")
from tqdm import tqdm
            


'''train_data_eda = "test_dwp_train_eda/"
os.makedirs(train_data_eda, exist_ok=True'''

test_result_dir = "test_dwp_test_result/"
os.makedirs(test_result_dir, exist_ok=True)

# Load the CSV file
csv_file = pd.read_csv(r"/Users/vinoth/PycharmProjects/paper_implementation/Dataset/source/mri_images/ABIDE_pcp/Phenotypic_V1_0b_preprocessed1.csv")
csv_file['DX_GROUP'].replace({1: 0, 2: 1}, inplace=True)
train_df, test_df = train_test_split(csv_file, test_size=0.1, random_state=42)

harvard_oxford_atlas = [
    "cort-maxprob-thr25-2mm"
    #"cort-maxprob-thr50-2mm",
    #"cort-prob-2mm",
    #"cort-maxprob-thr0-2mm"
    
]


'''harvard_oxford_atlas = [
    "cort-maxprob-thr0-1mm",
    "cort-maxprob-thr0-2mm",
    "cort-maxprob-thr25-1mm",
    "cort-maxprob-thr25-2mm",
    "cort-maxprob-thr50-1mm",
    "cort-maxprob-thr50-2mm",
    "cort-prob-1mm",
    "cort-prob-2mm",
    "cortl-maxprob-thr0-1mm",
    "cortl-maxprob-thr0-2mm",
    "cortl-maxprob-thr25-1mm",
    "cortl-maxprob-thr25-2mm",
    "cortl-maxprob-thr50-1mm",
    "cortl-maxprob-thr50-2mm",
    "cortl-prob-1mm",
    "cortl-prob-2mm",
    "sub-maxprob-thr0-1mm",
    "sub-maxprob-thr0-2mm",
    "sub-maxprob-thr25-1mm",
    "sub-maxprob-thr25-2mm",
    "sub-maxprob-thr50-1mm",
    "sub-maxprob-thr50-2mm",
    "sub-prob-1mm",
    "sub-prob-2mm"
]'''


results = {}
atlas_threshold = None

for data in tqdm(harvard_oxford_atlas):
    atlas_threshold = data
    print("----Threshold----")
    print(atlas_threshold)
    results[atlas_threshold] = {}
    atlas = datasets.fetch_atlas_harvard_oxford(data)
    masker = input_data.NiftiLabelsMasker(labels_img=atlas.maps, standardize=True)
    mri_dir = r"/Users/vinoth/PycharmProjects/paper_implementation/Dataset/source/mri_images/ABIDE_pcp/cpac/nofilt_noglobal/"

    # Placeholder for Graph Neural Network Data
    graph_data_list = []

    # Data Preprocessing
    for idx, row in tqdm(enumerate(train_df.itertuples()), total=len(train_df)):
        # Combine the parent and nested folder paths
        #file_dir = os.path.join(train_data_eda, row.FILE_ID)
        #os.makedirs(file_dir, exist_ok=True)
        mri_filename = os.path.join(mri_dir, row.FILE_ID + "_func_preproc.nii.gz")
        
        try:
            mri_img = nib.load(mri_filename)
            
            #mri_img_dir = os.path.join(file_dir, 'mri_image')
            #os.makedirs(mri_img_dir, exist_ok=True)
            
            # Select the first time point
            #first_volume = mri_img.slicer[:,:,:,0]
            
            #image_shape = mri_img.shape

            # The total number of volumes in the 4D dimension is the size of the fourth dimension
            #total_volumes = image_shape[3]

            #print("Total number of volumes in the 4D image for file " + row.FILE_ID + " : ", total_volumes)

            '''# Plot the image
            plotting.plot_img(first_volume, cmap='gray')  # grayscale often works well for MRIs
            filename = os.path.join(mri_img_dir, row.FILE_ID+'_img.png')
            plt.savefig(filename)
            plt.close()  # Close the plot to avoid overlaps

            # Plot the EPI
            plotting.plot_epi(first_volume, display_mode='z', cut_coords=5, cmap='viridis')  # viridis is a perceptually uniform colormap
            filename = os.path.join(mri_img_dir, row.FILE_ID+'_epi_img.png')
            plt.savefig(filename)
            plt.close()

            # Plot the anatomy
            plotting.plot_anat(first_volume, cmap='gray')  # grayscale again for anatomical images
            filename = os.path.join(mri_img_dir, row.FILE_ID+'_anat_img.png')
            plt.savefig(filename)
            plt.close()

            # Plot the statistical map
            plotting.plot_stat_map(first_volume, bg_img=None, threshold=3.0, cmap='cold_hot')  # cold_hot is often used for stat maps
            filename = os.path.join(mri_img_dir, row.FILE_ID+'_stat_map_img.png')
            plt.savefig(filename)
            plt.close()

            # Plot the probabilistic atlas
            plotting.plot_prob_atlas(mri_img, bg_img=None, colorbar=True)  # default colormap should work for probabilistic atlas
            filename = os.path.join(mri_img_dir, row.FILE_ID+'_atlas_map_img.png')
            plt.savefig(filename)
            plt.close()'''
            
            time_series = masker.fit_transform(mri_img)
            n_regions, n_time_points = time_series.shape
            '''print("****************************")
            print(row.FILE_ID, " --> ", "n_regions --> ", n_regions, "n_time_points --> ", n_time_points)
            print("****************************")'''
            distance_matrix = np.zeros((n_regions, n_regions))
            for i in range(n_regions):
                for j in range(i + 1, n_regions):
                    distance, _ = fastdtw(time_series[i, :], time_series[j, :])
                    distance_matrix[i, j] = distance_matrix[j, i] = distance
            distance_matrix = distance_matrix / distance_matrix.max()
            similarity_matrix = 1 - distance_matrix
            threshold = 0.3
            similarity_matrix[similarity_matrix < threshold] = 0
            G = nx.from_numpy_matrix(similarity_matrix)


            #if idx == 0:  # Only for the first iteration
            # Plot the time series for the regions
            '''plt.figure(figsize=(35, 15))
            for i in range(min(n_regions, time_series.shape[0])):
                plt.plot(time_series[i, :], label=f'Region {i + 1}')
            plt.xlabel('Time point')
            plt.ylabel('Blood Oxygen Level(BOLD) - Normalized signal')
            plt.title('Time series of the regions')
            plt.legend(loc='center left', bbox_to_anchor=(1, 0.5))
            plt.tight_layout()'''
        
            # Save the plot to the existing folder
            '''time_series_filename = row.FILE_ID+'_'+'time_series_plot.png'
            
            plt.savefig(os.path.join(file_dir, time_series_filename))
            
            plt.figure(figsize=(10, 10))
            sns.heatmap(similarity_matrix, annot=False, cmap='turbo')
            plt.title('Similarity Matrix')
            similarity_matrix_adj_img_filename = row.FILE_ID + '_similarity_matrix.png'
            plt.savefig(os.path.join(file_dir, similarity_matrix_adj_img_filename))
            plt.close() # Close the plot
            similarity_matrix_npy_filename = row.FILE_ID + '_similarity_matrix.npy'
            similarity_matrix_npy_path = os.path.join(file_dir, similarity_matrix_npy_filename)
            np.save(similarity_matrix_npy_path, similarity_matrix)'''

            '''# Visualize the graph
            plt.figure(figsize=(45, 25))
            pos = nx.spring_layout(G)'''

            # Extract the edge weights from the graph
            weights = [G[u][v].get('weight', 1) for u, v in G.edges()]

            # Normalize the weights to fit your desired range of thickness
            normalized_weights = [5 * weight / max(weights) for weight in weights]

            # Draw the edges with the thickness determined by the normalized weights
            '''nx.draw_networkx_edges(G, pos, width=normalized_weights)

            # Draw the nodes and labels
            nx.draw_networkx_nodes(G, pos)
            nx.draw_networkx_labels(G, pos)'''
            
            # Define the path and filename where you want to save the plot
            '''graph_plot_filename = row.FILE_ID + '_graph_plot.png'
            graph_plot_path = os.path.join(file_dir, graph_plot_filename)

            # Save the plot to the specified path
            plt.savefig(graph_plot_path)'''

            edge_index = torch.tensor(list(G.edges), dtype=torch.long)
            x = torch.tensor(time_series, dtype=torch.float)
            y = torch.tensor([row.DX_GROUP], dtype=torch.float)
            data = Data(x=x, edge_index=edge_index.t().contiguous(), y=y)
            graph_data_list.append(data)
        except FileNotFoundError:
            pass
        
    print("Graph Data List -----> ")
    print(graph_data_list)

    # Neural Network Model with Regularization, Batch Normalization, and Dropout
    class Net(torch.nn.Module):
        def __init__(self, num_node_features, num_classes):
            super(Net, self).__init__()
            self.conv1 = GCNConv(num_node_features, 16)
            self.bn1 = BatchNorm1d(16)
            self.conv2 = GCNConv(16, 32)
            self.bn2 = BatchNorm1d(32)
            self.fc = torch.nn.Linear(32, num_classes)
            self.dropout = Dropout(0.5)

        def forward(self, data):
            x, edge_index, batch = data.x, data.edge_index, data.batch
            x = self.conv1(x, edge_index)
            x = self.bn1(x)
            x = F.relu(x)
            x = self.dropout(x)
            x = self.conv2(x, edge_index)
            x = self.bn2(x)
            x = global_mean_pool(x, batch)
            x = self.fc(x)
            return F.log_softmax(x, dim=1)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    num_features = graph_data_list[0].num_node_features
    num_classes = 2
    model = Net(num_features, num_classes).to(device)
    loader = DataLoader(graph_data_list, batch_size=32, shuffle=True)
    
    # Hyperparameter Tuning (Example: Adjusting Learning Rate)
    learning_rates = [0.01, 0.001, 0.0001]
    l_rate = None
    for lr in learning_rates:
        l_rate = str(lr)
        results[atlas_threshold][l_rate] = {}
        print("Learning Rate --> ", lr)
        optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=5e-4) # L2 Regularization
        for epoch in tqdm(range(100)):
            total_loss = 0
            model.train()
            for data in loader:
                data = data.to(device)
                optimizer.zero_grad()
                out = model(data)
                loss = F.nll_loss(out, data.y.long())
                loss.backward()
                optimizer.step()
                total_loss += loss.item()
            print(f'Epoch: {epoch+1}, Loss: {total_loss/len(loader)}')
        results[atlas_threshold][l_rate]['loss'] = total_loss/len(loader)

        # Placeholder for time series data
        time_series_list = []
        successful_indices = []

        # Testing Data Preprocessing
        # for idx, row in enumerate(test_df.itertuples()):
        for idx, row in tqdm(enumerate(test_df.itertuples()), total=len(test_df)):
            mri_filename = os.path.join(mri_dir, row.FILE_ID + "_func_preproc.nii.gz")
            try:
                mri_img = nib.load(mri_filename)
                time_series = masker.fit_transform(mri_img)
                time_series_list.append(time_series)
                successful_indices.append(idx)
            except FileNotFoundError:
                pass

        
        # Placeholder for Graph Neural Network Data for testing
        graph_data_test_list = []

        #for idx, successful_idx in tqdm(enumerate(successful_indices, total=len(successful_indices))):
        for idx, successful_idx in tqdm(enumerate(successful_indices), total=len(successful_indices)):
            row = test_df.iloc[successful_idx]
            time_series = time_series_list[idx]
            n_regions = time_series.shape[0]
            distance_matrix = np.zeros((n_regions, n_regions))
            for i in range(n_regions):
                for j in range(i + 1, n_regions):
                    distance, _ = fastdtw(time_series[i, :], time_series[j, :])
                    distance_matrix[i, j] = distance_matrix[j, i] = distance
            distance_matrix = distance_matrix / distance_matrix.max()
            similarity_matrix = 1 - distance_matrix
            threshold = 0.3
            similarity_matrix[similarity_matrix < threshold] = 0
            G = nx.from_numpy_matrix(similarity_matrix)
            edge_index = torch.tensor(list(G.edges), dtype=torch.long)
            x = torch.tensor(time_series, dtype=torch.float)
            y = torch.tensor([row.DX_GROUP], dtype=torch.float)
            data = Data(x=x, edge_index=edge_index.t().contiguous(), y=y)
            graph_data_test_list.append(data)

        # Create a data loader for testing data
        test_loader = DataLoader(graph_data_test_list, batch_size=32, shuffle=False)

        # Testing
        model.eval()
        correct = 0
        all_preds = []
        all_labels = []

        for data in tqdm(test_loader):
            data = data.to(device)
            with torch.no_grad():
                output = model(data)
                _, pred = output.max(dim=1)
            all_preds.append(pred.cpu().numpy())
            all_labels.append(data.y.cpu().numpy())
            correct += int((pred == data.y.long()).sum())

        accuracy = correct / len(test_loader.dataset)

        print(f'Test Accuracy: {accuracy:.4f}')
        
        results[atlas_threshold][l_rate]['accuracy'] = accuracy

        # Flatten the list of predictions and labels
        all_preds = np.concatenate(all_preds)
        all_labels = np.concatenate(all_labels)
        
        # Specify the parent folder
        parent_folder = test_result_dir

        # Specify the nested folder names
        nested_folder1 = atlas_threshold
        nested_folder2 = 'learning_rate_'+l_rate
        
        # Combine the parent and nested folder pa'hs
        validation_result_dir = os.path.join(parent_folder, nested_folder1, nested_folder2)
        
        # Create the nested folders, including any necessary parent directories
        os.makedirs(validation_result_dir, exist_ok=True)

        # Confusion Matrix
        cm = confusion_matrix(all_labels, all_preds)
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]  # Normalize

        plt.figure(figsize=(10, 7))
        sns.heatmap(cm, annot=True, cmap='Blues', fmt=".2%")
        plt.xlabel('Predicted')
        plt.ylabel('Actual')
        plt.title('Confusion Matrix (Normalized)')
        
        # Save the image inside the nested folder
        plt.savefig(os.path.join(validation_result_dir, 'confusion_matrix.png'))
        
        # Print actual vs predicted
        actual_vs_predicted = pd.DataFrame({'Actual': all_labels, 'Predicted': all_preds})
        print(actual_vs_predicted)

        # Classification report
        report = classification_report(all_labels, all_preds, target_names=['Non-Autistic', 'Autistic'], output_dict=True)
        #print(classification_report(all_labels, all_preds, target_names=['Non-Autistic', 'Autistic']))
        report_text = classification_report(all_labels, all_preds, target_names=['Non-Autistic', 'Autistic'])

        plt.figure(figsize=(10, 7))
        plt.text(0.01, 0.05, report_text, {'fontsize': 12}, fontproperties='monospace') # Adjust text size and position accordingly
        plt.axis('off')
        plt.tight_layout()
        plt.savefig(os.path.join(validation_result_dir, 'classification_report.png'))
        
        with open(os.path.join(validation_result_dir, 'classification_report.txt'), 'w') as file:
            file.write(report_text)
  
        # Individual class metrics
        # Access individual values
        non_autistic_precision = report['Non-Autistic']['precision']
        autistic_precision = report['Autistic']['precision']

        non_autistic_recall = report['Non-Autistic']['recall']
        autistic_recall = report['Autistic']['recall']

        non_autistic_f1_score = report['Non-Autistic']['f1-score']
        autistic_f1_score = report['Autistic']['f1-score']

        non_autistic_support = report['Non-Autistic']['support']
        autistic_support = report['Autistic']['support']

        '''# Aggregated metrics
        accuracy = report['accuracy'''

        macro_avg_precision = report['macro avg']['precision']
        weighted_avg_precision = report['weighted avg']['precision']

        macro_avg_recall = report['macro avg']['recall']
        weighted_avg_recall = report['weighted avg']['recall']

        macro_avg_f1_score = report['macro avg']['f1-score']
        weighted_avg_f1_score = report['weighted avg']['f1-score']

        macro_avg_support = report['macro avg']['support']
        weighted_avg_support = report['weighted avg']['support']
        
        results[atlas_threshold][l_rate]['non_autistic_precision'] = non_autistic_precision
        results[atlas_threshold][l_rate]['autistic_precision'] = autistic_precision
        results[atlas_threshold][l_rate]['non_autistic_recall'] = non_autistic_recall
        results[atlas_threshold][l_rate]['autistic_recall'] = autistic_recall
        results[atlas_threshold][l_rate]['non_autistic_f1_score'] = non_autistic_f1_score
        results[atlas_threshold][l_rate]['autistic_f1_score'] = autistic_f1_score
        results[atlas_threshold][l_rate]['non_autistic_support'] = non_autistic_support
        results[atlas_threshold][l_rate]['autistic_support'] = autistic_support
        results[atlas_threshold][l_rate]['macro_avg_precision'] = macro_avg_precision
        results[atlas_threshold][l_rate]['weighted_avg_precision'] = weighted_avg_precision
        results[atlas_threshold][l_rate]['macro_avg_recall'] = macro_avg_recall
        results[atlas_threshold][l_rate]['weighted_avg_recall'] = weighted_avg_recall
        results[atlas_threshold][l_rate]['macro_avg_f1_score'] = macro_avg_f1_score
        results[atlas_threshold][l_rate]['weighted_avg_f1_score'] = weighted_avg_f1_score
        results[atlas_threshold][l_rate]['macro_avg_support'] = macro_avg_support
        results[atlas_threshold][l_rate]['weighted_avg_support'] = weighted_avg_support
        
        atlas = None
        l_rate = None
        
print("----Final Result----")
pprint.pprint(results)

sorted_data = [(key, subkey, values['accuracy']) for key, subdata in results.items() for subkey, values in subdata.items()]
sorted_data.sort(key=lambda x: x[2], reverse=True)

print("----Sorted Accuracy----")
for key, subkey, accuracy in sorted_data:
    print(f"Key: {key}, Subkey: {subkey}, Accuracy: {accuracy}")
    
# Create a list of tuples containing key, subkey, and corresponding details
sorted_data = [(key, subkey, values) for key, subdata in results.items() for subkey, values in subdata.items()]

# Sort the list based on the accuracy
sorted_data.sort(key=lambda x: x[2]['accuracy'], reverse=True)

# Create a new dictionary with the sorted order
sorted_data_dict = {f"{key}-{subkey}": values for key, subkey, values in sorted_data}
print("----Sorted Dict----")
pprint.pprint(sorted_data_dict)

