In [None]:
import re
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from upsetplot import UpSet, from_memberships, plot
import itertools

# Add Path to HOMER functions:
# HOMER has to be installed in your system first
os.environ["PATH"] += "/bioinformatics/anaconda3_052020/condabin:/bioinformatics/A.C.Rsuite:/bioinformatics/bowtie2:/bioinformatics/BSseeker2:/bioinformatics/adapterremoval:/bioinformatics/idr/bin:/bioinformatics/glassutils/scripts:/bioinformatics/STAR/bin/Linux_x86_64_static:/bioinformatics/scripts:/bioinformatics/sratoolkit/bin:/bioinformatics/FastQC:/bioinformatics/bedtools/bin:/bioinformatics/samtools/bin:/bioinformatics/anaconda3_052020/bin:/bioinformatics/homer/bin:/usr/local/sbin:/usr/sbin:/usr/bin:/usr/local/bin:/usr/local/lib:/bioinformatics/anaconda3_052020/condabin:/bioinformatics/A.C.Rsuite:/bioinformatics/bowtie2:/bioinformatics/BSseeker2:/bioinformatics/adapterremoval:/bioinformatics/idr/bin:/bioinformatics/glassutils/scripts:/bioinformatics/STAR/bin/Linux_x86_64_static:/bioinformatics/scripts:/bioinformatics/sratoolkit/bin:/bioinformatics/FastQC:/bioinformatics/bedtools/bin:/bioinformatics/samtools/bin:/bioinformatics/anaconda3_052020/bin:/bioinformatics/homer/bin:/usr/local/sbin:/usr/sbin:/usr/bin:/usr/local/bin:/usr/local/lib:"
#os.environ["PATH"] += "/bioinformatics/homer/bin/"

#import networkx as nx
from collections import defaultdict
import pandas as pd
import networkx as nx
import matplotlib.pyplot as plt
 
def read_motif_headers(motif_file):
    
    """Read headers from a motif file to extract motif names from header lines."""
    # Regex pattern to capture the motif name right after '>' and before the first tab

    if is_denovo_motif_file(motif_file):
        print('DENOVO HOMER')
        return read_motif_headers_homer(motif_file)
    else:
        print('Not DEVONVO HOMER')
        return read_motif_headers_jaspar(motif_file)   


def is_denovo_motif_file(filename):
    """
    Reads a file and checks if any header lines starting with '>' contain "BestGuess:".

    :param filename: str, the path to the file to be read
    :return: bool, True if any header line contains "BestGuess:", otherwise False
    """
    try:
        with open(filename, 'r') as file:
            for line in file:
                if line.startswith('>') and "BestGuess:" in line:
                    return True
        return False
    except FileNotFoundError:
        print(f"Error: The file '{filename}' does not exist.")
        return False
    except Exception as e:
        print(f"An error occurred: {e}")
        return False

def read_motif_headers_jaspar(motif_file):
    
    """Read headers from a motif file to extract motif names from header lines."""
    # Regex pattern to capture the motif name right after '>' and before the first tab
    pattern = r'^>(\w+)\t'
    motif_names = []
    
    with open(motif_file, 'r') as file:
        for line in file:
            if line.startswith('>'):  # Check if the line is a header
                match = re.match(pattern, line)
                if match:
                    motif_names.append(match.group(1))  # Add the motif name to the list
    
    print(motif_names)
    return motif_names



def read_motif_headers_homer(motif_file):
    """Read headers from a motif file to extract motif names."""
    pattern = r'^>(\w+?)\t'
    with open(motif_file, 'r') as file:
        headers = [line.strip() for line in file if line.startswith('>')]
 
    pattern = r"BestGuess:([^/]+)"
    names = []

    for item in headers:
        names.extend( re.findall(pattern, item) )
 
    return names 


def get_motif_count(peak_file, genome, motif_file, force=False):
    """Get motif count using HOMER annotatePeaks."""
    
    
    base_file = os.path.splitext(os.path.basename(peak_file))[0]
    motif_base_file = os.path.splitext(os.path.basename(motif_file))[0]
    
    
    base_path = os.path.dirname(os.path.abspath(peak_file))
    
    motif_count_file = os.path.join(base_path, "motifCount_counts_"+base_file+'_'+motif_base_file+'.txt')
    print('Motif_Count_File=',motif_count_file,'Force=',force)
    command = f'annotatePeaks.pl {peak_file} {genome} -cpu 12 -noann -nogene -m {motif_file} -nmotifs > {motif_count_file}'
    
    if (not(os.path.exists(motif_count_file))):
        print('FILE DOES NOT EXIST: ',motif_count_file)
    
    
    
    if (not(os.path.exists(motif_count_file)) or force==True):
        print(f'{command}')
        print('Generating '+motif_count_file)
        os.system(command)
    else :
        print(f'** Loading ** ' + motif_count_file)

        
    return pd.read_table(motif_count_file, sep="\t"),  motif_count_file



def transform_combinations_to_matrix(series, peak_file):
   
    motifs = set()
    for combo in series.index:
        motifs.update(motif.strip() for motif in combo.split(','))
    

    motif_list = sorted(motifs)  # Sorting to maintain a consistent order

    
    motif_list = list(set(motif_list))

    peak_data = pd.read_csv(peak_file, sep='\t')

    # Get the total number of peaks
    total_sum = len(peak_data)
    
    # Initialize a DataFrame to store the binary representation of combinations
    data = {motif: [] for motif in motif_list}
    data['Count'] = []  # Additional column for counts
    data['Norm_Count'] = []  # Additional column for counts normalized
    data['Motif Subset'] = []  # Additional column to store the set of motifs as a string
    data[f'code-{peak_file}'] = []  # Additional column for binary code
    
    # Populate the DataFrame
    for combo, count in series.items():
        
        current_row = {motif: 0 for motif in motif_list}
        current_row['Count'] = count
        current_row['Norm_Count'] = count / total_sum
        current_row['Motif Subset'] = ','.join(sorted(combo.split(',')))  # Store the motif names as a sorted string
        binary_code = 0

      
        for motif in motif_list:
            #print('motif_list=',motif_list)
            #print('combo=',combo)
            #print( (motif.strip() for motif in combo.split(',')))
            
            binary_code <<= 1  # Shift left for each motif position
            if motif in [motif.strip() for motif in combo.split(',')]: #      combo.split(','):
                #print(motif,' is in current group - setting flag=1')
                current_row[motif] = 1
                binary_code |= 1  # Set the last bit to 1 if the motif is present

        current_row[f'code-{peak_file}'] = binary_code
        for key, value in current_row.items():
            data[key].append(value)

    # Create DataFrame from data dictionary
    result_df = pd.DataFrame(data)
    #display(result_df)  # Display the DataFrame for verification
    return result_df

def create_motif_network(data0, min_node_size=1000, max_node_size=5000):
    """
    Create a network graph based on motif co-occurrences.

    Args:
        data (pd.DataFrame): DataFrame with columns representing motifs (M1, M2, ...) and a 'Count' column indicating
                             the frequency of each motif combination.
        min_node_size (int, optional): Minimum size of the node bubbles. Default is 100.
        max_node_size (int, optional): Maximum size of the node bubbles. Default is 1000.

    Returns:
        None
    """
    data= data0.copy()
    # Remove rows where 'Count' is 0
    data = data[data['Count'] >  0] #min_subset_count]
    
   
    data = data[data.iloc[:, :-1].sum(axis=1) > min_motif_set_count]
    
  
    # Normalize the 'Count' column

    max_count = data['Count'].max()
    data['normalized_count'] = (data['Count'] / max_count) 
    
    #print(data.columns)

    # Create a NetworkX graph
    G = nx.Graph()

    # Add nodes and edges based on co-occurrence
    for _, row in data.iterrows():
        #print(row)
        motifs = [col for col in data.columns if row[col] == 1 and col not in ['Count', 'normalized_count']]
        #print(motifs)
        for i, motif1 in enumerate(motifs):
           # print('m1',motif1)
            for motif2 in motifs[i+1:]:
                #print('m2',motif2)
                if G.has_edge(motif1, motif2):
                    G[motif1][motif2]['weight'] +=  ( row['Count'] +0  )
                else:
                    G.add_edge(motif1, motif2, weight= ( 0+row['Count'])) 

                    

                    
                    
    # Calculate node sizes based on their total normalized co-occurrence count
    node_weights = defaultdict(int)
    for _, row in data.iterrows():
        for motif in [col for col in data.columns if row[col] == 1 and col not in ['Count', 'normalized_count']]:
            node_weights[motif] +=  (1*0+ row['normalized_count'])

    # Scale node sizes between min_node_size and max_node_size
    min_weight = min(node_weights.values())
    max_weight = max(node_weights.values())
    
    node_size = [
        min_node_size + (node_weights[motif] - min_weight) / (max_weight - min_weight) * (max_node_size - min_node_size)
        for motif in G.nodes()
    ]

 
    fixed_edge_width=1
    
            # Calculate edge widths
    if fixed_edge_width is not None:
        edge_width = [fixed_edge_width for _ in G.edges()]
    else:
        edge_width = [G[u][v]['weight'] for u, v in G.edges()]                
                    

    # Create labels with motif names and their total occurrences
    labels = {motif: f"{motif}: \n {int(node_weights[motif] * max_count)}" for motif in G.nodes()}

    # Draw the network
    #pos = nx.spring_layout(G)
    # Draw the network
    pos = nx.spring_layout(G, k=10/np.sqrt(G.order()),scale=2, iterations=2000,seed=1234)  # Adjust k and iterations for more even spread

        # Enforce maximum edge length
    def enforce_max_edge_length(pos, max_length):
        for u, v in G.edges():
            dx = pos[v][0] - pos[u][0]
            dy = pos[v][1] - pos[u][1]
            distance = np.sqrt(dx**2 + dy**2)
            
            if distance > max_length:
                factor = max_length / (distance)
                mid_x = (pos[u][0] + pos[v][0]) / 2
                mid_y = (pos[u][1] + pos[v][1]) / 2
                pos[u] = (mid_x - factor * (mid_x - pos[u][0]), mid_y - factor * (mid_y - pos[u][1]))
                pos[v] = (mid_x + factor * (mid_x - pos[v][0]), mid_y + factor * (mid_y - pos[v][1]))

    nx.draw_networkx_nodes(G, pos, node_size=node_size, node_color='skyblue',edgecolors='red', linewidths=1.5)
    nx.draw_networkx_edges(G, pos, width=edge_width, alpha=0.5)
    nx.draw_networkx_labels(G, pos, labels=labels, font_size=8)

    plt.title('Motif Co-occurrence Network')
    
    
def all_possible_combinations33(df, min_set_count=0):

    # Get the list of motifs from the DataFrame columns
    motifs = df.columns.tolist()
    
    # Create a dictionary to store the count of each combination, including "None Present"
    combination_counts = {}
    
    # Check for rows where none of the motifs are present
    none_present_count = df.eq(0).all(axis=1).sum()
    combination_counts['0_Set'] = none_present_count  # Store count under the key 'None'   

    # Generate combinations 
    for r in range(1, len(motifs) + 1):
        for comb in itertools.combinations(motifs, r):
            # Convert combination to a list to use as DataFrame selector
            comb_list = list(comb)
            
            
            if (len(comb_list) >= min_set_count):
                #print(len(comb_list), 'COMB LIST :' , comb_list)
                
                # Create a mask that selects only rows where all motifs in the combo are 1
                selected_motifs = df[comb_list].all(axis=1)  # True where all selected motifs are present
                non_selected_motifs = df.drop(comb_list, axis=1, errors='ignore').eq(0).all(axis=1)  # True where all non-selected motifs are absent

                # Combine masks to find rows that exactly match the combination
                exact_match_mask = selected_motifs & non_selected_motifs

                # Calculate the sum of rows that exactly match the combination
                count = exact_match_mask.sum()

                # Convert tuple to a comma-separated string for easier reading and use in plots
                combination_key = ','.join(comb)
                combination_counts[combination_key] = count
    
    return pd.Series(combination_counts)




def calculate_all_motif_co_occurrences(df, max_distance):
    category_counts = defaultdict(int)
    
    # Collect the original order of motifs
    original_motifs = df.columns[1:].tolist()  # Exclude the first column which is PeakID
    motif_columns = {motif: idx for idx, motif in enumerate(original_motifs)}
    
    # Second pass to build the binary matrix
    binary_matrix = []
    peak_names = []
    
    for index, row in df.iterrows():
        co_occurrences = find_co_occurring_motifs(row[1:], max_distance)  # Exclude PeakID column
        
        if co_occurrences:
            base_peak_id = row[0].rsplit('-', 1)[0]  # Remove the last part after '-'
            start_index = int(row[0].rsplit('-', 1)[1])  # Get the starting index (the number after '-')
            
            for k, occurrence in enumerate(co_occurrences):
                binary_row = [0] * len(original_motifs)
                for motif in occurrence:
                    if motif in motif_columns:  # Ensure only motifs in original columns are considered
                        binary_row[motif_columns[motif]] = 1
                category_counts[frozenset(occurrence)] += 1
                
                peak_id = f"{base_peak_id}-{start_index + k}"  # Increment the index for each co-occurrence
                binary_matrix.append(binary_row)
                peak_names.append(peak_id)
    
    # Create DataFrame from the binary matrix
    binary_df = pd.DataFrame(binary_matrix, columns=original_motifs)
    
    # Add peak names as an additional column if needed for reference
    binary_df.insert(0, 'PeakID', peak_names)
    
    return binary_df, category_counts


import upsetplot  
import ipdb


def peak_motif_sets(
    peak_file,
    genome,
    motif_list_file,
    output_file=None,  # Optional parameter for output image base name
    Motif_recount=False,
    seperate_duplicates=False,
    min_motif_set_count=0,
    min_subset_count=0
):
    """
    Process motif analysis and generate an UpSet plot and normalized bar graphs as image files.

    Parameters:
        peak_file (str): Path to the peak file.
        genome (str): Genome reference.
        motif_list_file (str): Path to the motif list file.
        output_directory (str): Directory to save output files.
        output_file (str, optional): Base path for saving the plots. The function appends
                                     '_upset.png' and '_bar.png' to save distinct plots.
                                     If None, plots are not saved. Defaults to None.
        Motif_recount (bool, optional): Whether to force motif recount. Defaults to False.
        seperate_duplicates (bool, optional): Whether to separate duplicates. Defaults to False.
        min_motif_set_count (int, optional): Minimum number of motifs in a set. Defaults to 0.
        min_subset_count (int, optional): Minimum subset size for plotting. Defaults to 0.

    Returns:
        tuple: A tuple containing the sorted data and the binary table.
    """
    # Ensure the output directory exists
    # os.makedirs(output_directory, exist_ok=True)

    # Read motif headers
    headers = read_motif_headers(motif_list_file)

    # Get motif count
    motif_count_table, motif_count_file = get_motif_count(
        peak_file, genome, motif_list_file,  force=Motif_recount
    )

    # Format the motif count table
    formatted_table = motif_count_table.iloc[:, -len(headers):].T
    formatted_table.columns = motif_count_table.iloc[:, 0]
    formatted_table.columns.name = 'PeakID'
    formatted_table.index = headers

    # Handle duplicate motifs if specified
    if seperate_duplicates:
        formatted_table = formatted_table.T
        for item in subset_motifs:
            condition = formatted_table[item] > 2
            formatted_table.loc[condition, item] = 0
            new_col = f'{item}_3+'
            formatted_table[new_col] = 1 * condition
        formatted_table = formatted_table.T
        headers = formatted_table.index

    # Create a binary table
    binary_table = (formatted_table >= 0.5).astype(int)
    df = binary_table.T
    df_sorted = df.sort_values(by='PeakID')

    # Rename columns if headers match
    if len(headers) == len(df.columns):
        df.columns = headers
    else:
        print("Error: The number of headers does not match the number of columns in the DataFrame.")

    # Generate UpSet data
    df_upset = df.groupby(list(df.columns)).size().reset_index(name='count')
    df_upset['num_motifs'] = df_upset.drop('count', axis=1).sum(axis=1)
    filtered_df = df_upset[df_upset['num_motifs'] > min_motif_set_count]
    filtered_df.set_index(df.columns.tolist(), inplace=True)
    upset_data = filtered_df['count']

    # Initialize saving mechanism based on output_file extension
    save_as_images = False
    image_extensions = ['.png', '.jpg', '.jpeg', '.tiff', '.bmp', '.gif']

    if output_file:
        _, ext = os.path.splitext(output_file)
        ext = ext.lower()
        if ext in image_extensions:
            save_as_images = True
            # Prepare base name for image files
            image_base = os.path.splitext(output_file)[0]
            # Ensure output directory exists
            os.makedirs(os.path.dirname(output_file), exist_ok=True)
        else:
            print(f"Unsupported output file extension: {ext}. Plots will not be saved.")

    # Generate UpSet plot
    fig_upset = plt.figure(figsize=(20, 20), constrained_layout=True)
    plot(
        upset_data,
        fig=fig_upset,
        orientation="horizontal",
        show_counts=True,
        min_subset_size=min_subset_count,
        sort_categories_by=None
    )

    title_txt = (
        f'Set intersections: [Membership > {min_motif_set_count}] '
        f'[Counts > {min_subset_count}]\n'
        f'PEAKS: [{peak_file}]\n'
        f'MOTIFS: [{motif_list_file}]\n'
    )
    fig_upset.suptitle(title_txt, fontsize=16)

    # Save UpSet plot as image
    if save_as_images:
        upset_image_path = f"{image_base}_upset.png"
        fig_upset.savefig(upset_image_path, bbox_inches="tight")
    #plt.close(fig_upset)  # Close the figure to free memory

    # Prepare membership data
    sets = {name: set(df[df[col] > 0].index) for name, col in zip(headers, df.columns)}
    membership_df = pd.DataFrame(
        {name: df.index.isin(indices) for name, indices in sets.items()}
    )

    all_combinations = list(
        itertools.chain.from_iterable(
            itertools.combinations(headers, r) for r in range(1, len(headers) + 1)
        )
    )
    ordered_combinations = sorted(all_combinations, key=lambda x: (len(x), x))

    membership_matrix = {
        ','.join(comb): df[list(comb)].all(axis=1) for comb in ordered_combinations
    }
    membership_df = pd.DataFrame(membership_matrix)
    memberships = membership_df.apply(lambda row: list(membership_df.columns[row]), axis=1)
    intersections = memberships.value_counts()

    upset_data1 = all_possible_combinations33(df, min_set_count=min_motif_set_count)

    # Process UpSet data
    upset_data1.index = pd.MultiIndex.from_tuples(
        [
            (len(x.split(',')), tuple(sorted(x.split(','))))
            for x in upset_data1.index
        ],
        names=['Cardinality', 'Motifs']
    )
    upset_data1 = upset_data1[
        upset_data1.index.get_level_values('Cardinality') >= min_motif_set_count
    ]
    upset_data1 = upset_data1.sort_index(ascending=False)
    upset_data1.index = [', '.join(motifs) for _, motifs in upset_data1.index]

    sorted_data = upset_data1.sort_values(ascending=True).tail(50)
    total_sum = len(df)
    binary_matrix = transform_combinations_to_matrix(sorted_data, peak_file)
    sorted_binary_matrix = binary_matrix.sort_values(by='Count', ascending=False)
    sorted_data = (sorted_data / total_sum * 10000).astype(int) / 100

    # Generate normalized bar graphs
    fig_bar = plt.figure(figsize=(20, 30), constrained_layout=True)
    ax = sorted_data.plot(kind='barh', color='green', ax=plt.gca())
    plt.title(
        f'PEAKS [{total_sum}] : {peak_file}\n'
        f'MOTIFS: [{motif_list_file}]',
        fontsize=16
    )
    ax.set_ylabel(
        'Top 50 co-occurring motif groups with # of motifs >= '
        f'{min_motif_set_count}'
    )
    ax.set_xlabel('Co-occurrences (as % ratio of total peaks)')

    for i, (value, index) in enumerate(zip(sorted_data, sorted_data.index)):
        ax.text(
            value,
            i,
            f' {value}, {round(total_sum * value / 100 + 0.5)}',
            va='center',
            ha='left',
            color='black'
        )

    # Save bar graph as image
    if save_as_images:
        bar_image_path = f"{image_base}_bar.png"
        fig_bar.savefig(bar_image_path, bbox_inches="tight")
    #plt.close(fig_bar)  # Close the figure to free memory

    # Prepare return values
    return_table = binary_table.T
    return_data = sorted_data

    return return_data, return_table


In [1]:
import itertools
import os
import re
from collections import defaultdict

import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import pandas as pd
from upsetplot import UpSet, from_memberships, plot

# Configure environment variables
os.environ["PATH"] += (
    "/bioinformatics/homer/bin:"
    "/usr/local/sbin:/usr/sbin:/usr/bin:/usr/local/bin:/usr/local/lib:"
)

def read_motif_headers(motif_file):
    """Read headers from a motif file to extract motif names from header lines."""
    if is_denovo_motif_file(motif_file):
        print('DENOVO HOMER')
        return read_motif_headers_homer(motif_file)
    else:
        print('Not DENOVO HOMER')
        return read_motif_headers_jaspar(motif_file)   

def is_denovo_motif_file(filename):
    """
    Reads a file and checks if any header lines starting with '>' contain "BestGuess:".

    :param filename: str, the path to the file to be read
    :return: bool, True if any header line contains "BestGuess:", otherwise False
    """
    try:
        with open(filename, 'r') as file:
            for line in file:
                if line.startswith('>') and "BestGuess:" in line:
                    return True
        return False
    except FileNotFoundError:
        print(f"Error: The file '{filename}' does not exist.")
        return False
    except Exception as e:
        print(f"An error occurred: {e}")
        return False

def read_motif_headers_jaspar(motif_file):
    """Read headers from a JASPAR motif file to extract motif names."""
    pattern = r'^>(\w+)\t'
    motif_names = []
    
    with open(motif_file, 'r') as file:
        for line in file:
            if line.startswith('>'):  # Check if the line is a header
                match = re.match(pattern, line)
                if match:
                    motif_names.append(match.group(1))  # Add the motif name to the list
    
    print(motif_names)
    return motif_names

def read_motif_headers_homer(motif_file):
    """Read headers from a HOMER motif file to extract motif names."""
    pattern = r"BestGuess:([^/]+)"
    names = []
    
    with open(motif_file, 'r') as file:
        headers = [line.strip() for line in file if line.startswith('>')]
    
    for item in headers:
        names.extend(re.findall(pattern, item))
    
    return names 

def get_motif_count(peak_file, genome, motif_file, force=False):
    """Get motif count using HOMER annotatePeaks."""
    base_file = os.path.splitext(os.path.basename(peak_file))[0]
    motif_base_file = os.path.splitext(os.path.basename(motif_file))[0]
    
    base_path = os.path.dirname(os.path.abspath(peak_file))
    motif_count_file = os.path.join(base_path, f"motifCount_counts_{base_file}_{motif_base_file}.txt")
    
    print('Motif_Count_File=', motif_count_file, 'Force=', force)
    command = (
        f'annotatePeaks.pl {peak_file} {genome} -cpu 12 -noann -nogene '
        f'-m {motif_file} -nmotifs > {motif_count_file}'
    )
    
    if not os.path.exists(motif_count_file):
        print('FILE DOES NOT EXIST:', motif_count_file)
    
    if not os.path.exists(motif_count_file) or force:
        print(command)
        print('Generating ' + motif_count_file)
        os.system(command)
    else:
        print('** Loading ** ' + motif_count_file)
    
    return pd.read_table(motif_count_file, sep="\t"), motif_count_file

def transform_combinations_to_matrix(series, peak_file):
    """Transform motif combinations into a binary matrix."""
    motifs = set()
    for combo in series.index:
        motifs.update(motif.strip() for motif in combo.split(','))
    
    motif_list = sorted(motifs)  # Sorting to maintain a consistent order
    motif_list = list(set(motif_list))
    
    peak_data = pd.read_csv(peak_file, sep='\t')
    total_sum = len(peak_data)
    
    # Initialize a DataFrame to store the binary representation of combinations
    data = {motif: [] for motif in motif_list}
    data['Count'] = []
    data['Norm_Count'] = []
    data['Motif Subset'] = []
    data[f'code-{peak_file}'] = []
    
    # Populate the DataFrame
    for combo, count in series.items():
        current_row = {motif: 0 for motif in motif_list}
        current_row['Count'] = count
        current_row['Norm_Count'] = count / total_sum
        current_row['Motif Subset'] = ','.join(sorted(combo.split(',')))
        binary_code = 0

        for motif in motif_list:
            binary_code <<= 1  # Shift left for each motif position
            if motif in [motif.strip() for motif in combo.split(',')]:
                current_row[motif] = 1
                binary_code |= 1  # Set the last bit to 1 if the motif is present

        current_row[f'code-{peak_file}'] = binary_code
        for key, value in current_row.items():
            data[key].append(value)

    # Create DataFrame from data dictionary
    result_df = pd.DataFrame(data)
    return result_df

def create_motif_network(data0, min_node_size=1000, max_node_size=5000):
    """
    Create a network graph based on motif co-occurrences.

    Args:
        data0 (pd.DataFrame): DataFrame with columns representing motifs and a 'Count' column.
        min_node_size (int, optional): Minimum size of the node bubbles. Defaults to 1000.
        max_node_size (int, optional): Maximum size of the node bubbles. Defaults to 5000.

    Returns:
        None
    """
    data = data0.copy()
    data = data[data['Count'] > 0]
    
    # Assuming 'min_motif_set_count' is defined elsewhere or passed as a parameter
    data = data[data.iloc[:, :-1].sum(axis=1) > min_motif_set_count]
    
    max_count = data['Count'].max()
    data['normalized_count'] = data['Count'] / max_count
    
    # Create a NetworkX graph
    G = nx.Graph()
    
    # Add nodes and edges based on co-occurrence
    for _, row in data.iterrows():
        motifs = [col for col in data.columns if row[col] == 1 and col not in ['Count', 'normalized_count']]
        for i, motif1 in enumerate(motifs):
            for motif2 in motifs[i+1:]:
                if G.has_edge(motif1, motif2):
                    G[motif1][motif2]['weight'] += row['Count']
                else:
                    G.add_edge(motif1, motif2, weight=row['Count']) 
    
    # Calculate node sizes based on their total normalized co-occurrence count
    node_weights = defaultdict(int)
    for _, row in data.iterrows():
        for motif in [col for col in data.columns if row[col] == 1 and col not in ['Count', 'normalized_count']]:
            node_weights[motif] += row['normalized_count']
    
    # Scale node sizes between min_node_size and max_node_size
    min_weight = min(node_weights.values())
    max_weight = max(node_weights.values())
    
    node_size = [
        min_node_size + (node_weights[motif] - min_weight) / (max_weight - min_weight) * (max_node_size - min_node_size)
        for motif in G.nodes()
    ]
    
    fixed_edge_width = 1
    
    # Calculate edge widths
    if fixed_edge_width is not None:
        edge_width = [fixed_edge_width for _ in G.edges()]
    else:
        edge_width = [G[u][v]['weight'] for u, v in G.edges()]                
    
    # Create labels with motif names and their total occurrences
    labels = {motif: f"{motif}: \n {int(node_weights[motif] * max_count)}" for motif in G.nodes()}
    
    # Draw the network
    pos = nx.spring_layout(G, k=10/np.sqrt(G.order()), scale=2, iterations=2000, seed=1234)
    
    nx.draw_networkx_nodes(G, pos, node_size=node_size, node_color='skyblue', edgecolors='red', linewidths=1.5)
    nx.draw_networkx_edges(G, pos, width=edge_width, alpha=0.5)
    nx.draw_networkx_labels(G, pos, labels=labels, font_size=8)
    
    plt.title('Motif Co-occurrence Network')
    plt.show()

def all_possible_combinations33(df, min_set_count=0):
    """Generate all possible motif combinations with their counts."""
    motifs = df.columns.tolist()
    combination_counts = {}
    
    # Count rows where none of the motifs are present
    none_present_count = df.eq(0).all(axis=1).sum()
    combination_counts['0_Set'] = none_present_count
    
    # Generate combinations 
    for r in range(1, len(motifs) + 1):
        for comb in itertools.combinations(motifs, r):
            comb_list = list(comb)
            
            if len(comb_list) >= min_set_count:
                selected_motifs = df[comb_list].all(axis=1)
                non_selected_motifs = df.drop(comb_list, axis=1, errors='ignore').eq(0).all(axis=1)
                exact_match_mask = selected_motifs & non_selected_motifs
                count = exact_match_mask.sum()
                combination_key = ','.join(comb)
                combination_counts[combination_key] = count
    
    return pd.Series(combination_counts)

def calculate_all_motif_co_occurrences(df, max_distance):
    """Calculate all motif co-occurrences within a specified distance."""
    category_counts = defaultdict(int)
    
    # Collect the original order of motifs
    original_motifs = df.columns[1:].tolist()  # Exclude the first column which is PeakID
    motif_columns = {motif: idx for idx, motif in enumerate(original_motifs)}
    
    # Second pass to build the binary matrix
    binary_matrix = []
    peak_names = []
    
    for index, row in df.iterrows():
        co_occurrences = find_co_occurring_motifs(row[1:], max_distance)  # Exclude PeakID column
        
        if co_occurrences:
            base_peak_id = row[0].rsplit('-', 1)[0]  # Remove the last part after '-'
            start_index = int(row[0].rsplit('-', 1)[1])  # Get the starting index (the number after '-')
            
            for k, occurrence in enumerate(co_occurrences):
                binary_row = [0] * len(original_motifs)
                for motif in occurrence:
                    if motif in motif_columns:
                        binary_row[motif_columns[motif]] = 1
                category_counts[frozenset(occurrence)] += 1
                
                peak_id = f"{base_peak_id}-{start_index + k}"
                binary_matrix.append(binary_row)
                peak_names.append(peak_id)
    
    # Create DataFrame from the binary matrix
    binary_df = pd.DataFrame(binary_matrix, columns=original_motifs)
    binary_df.insert(0, 'PeakID', peak_names)
    
    return binary_df, category_counts

def peak_motif_sets(
    peak_file,
    genome,
    motif_list_file,
    output_file=None,
    Motif_recount=False,
    seperate_duplicates=False,
    min_motif_set_count=0,
    min_subset_count=0
):
    """
    Process motif analysis and generate an UpSet plot and normalized bar graphs as image files.

    Parameters:
        peak_file (str): Path to the peak file.
        genome (str): Genome reference.
        motif_list_file (str): Path to the motif list file.
        output_file (str, optional): Base path for saving the plots. Defaults to None.
        Motif_recount (bool, optional): Whether to force motif recount. Defaults to False.
        seperate_duplicates (bool, optional): Whether to separate duplicates. Defaults to False.
        min_motif_set_count (int, optional): Minimum number of motifs in a set. Defaults to 0.
        min_subset_count (int, optional): Minimum subset size for plotting. Defaults to 0.

    Returns:
        tuple: A tuple containing the sorted data and the binary table.
    """
    # Read motif headers
    headers = read_motif_headers(motif_list_file)

    # Get motif count
    motif_count_table, motif_count_file = get_motif_count(
        peak_file, genome, motif_list_file, force=Motif_recount
    )

    # Format the motif count table
    formatted_table = motif_count_table.iloc[:, -len(headers):].T
    formatted_table.columns = motif_count_table.iloc[:, 0]
    formatted_table.columns.name = 'PeakID'
    formatted_table.index = headers

    # Handle duplicate motifs if specified
    if seperate_duplicates:
        formatted_table = formatted_table.T
        for item in subset_motifs:
            condition = formatted_table[item] > 2
            formatted_table.loc[condition, item] = 0
            new_col = f'{item}_3+'
            formatted_table[new_col] = 1 * condition
        formatted_table = formatted_table.T
        headers = formatted_table.index

    # Create a binary table
    binary_table = (formatted_table >= 0.5).astype(int)
    df = binary_table.T
    df_sorted = df.sort_values(by='PeakID')

    # Rename columns if headers match
    if len(headers) == len(df.columns):
        df.columns = headers
    else:
        print("Error: The number of headers does not match the number of columns in the DataFrame.")

    # Generate UpSet data
    df_upset = df.groupby(list(df.columns)).size().reset_index(name='count')
    df_upset['num_motifs'] = df_upset.drop('count', axis=1).sum(axis=1)
    filtered_df = df_upset[df_upset['num_motifs'] > min_motif_set_count]
    filtered_df.set_index(df.columns.tolist(), inplace=True)
    upset_data = filtered_df['count']

    # Initialize saving mechanism based on output_file extension
    save_as_images = False
    image_extensions = ['.png', '.jpg', '.jpeg', '.tiff', '.bmp', '.gif']

    if output_file:
        _, ext = os.path.splitext(output_file)
        ext = ext.lower()
        if ext in image_extensions:
            save_as_images = True
            # Prepare base name for image files
            image_base = os.path.splitext(output_file)[0]
            # Ensure output directory exists
            os.makedirs(os.path.dirname(output_file), exist_ok=True)
        else:
            print(f"Unsupported output file extension: {ext}. Plots will not be saved.")

    # Generate UpSet plot
    fig_upset = plt.figure(figsize=(20, 20), constrained_layout=True)
    plot(
        upset_data,
        fig=fig_upset,
        orientation="horizontal",
        show_counts=True,
        min_subset_size=min_subset_count,
        sort_categories_by=None
    )

    title_txt = (
        f'Set intersections: [Membership > {min_motif_set_count}] '
        f'[Counts > {min_subset_count}]\n'
        f'PEAKS: [{peak_file}]\n'
        f'MOTIFS: [{motif_list_file}]\n'
    )
    fig_upset.suptitle(title_txt, fontsize=16)

    # Save UpSet plot as image
    if save_as_images:
        upset_image_path = f"{image_base}_upset.png"
        fig_upset.savefig(upset_image_path, bbox_inches="tight")
    plt.close(fig_upset)  # Close the figure to free memory

    # Prepare membership data
    sets = {name: set(df[df[col] > 0].index) for name, col in zip(headers, df.columns)}
    membership_df = pd.DataFrame(
        {name: df.index.isin(indices) for name, indices in sets.items()}
    )

    all_combinations = list(
        itertools.chain.from_iterable(
            itertools.combinations(headers, r) for r in range(1, len(headers) + 1)
        )
    )
    ordered_combinations = sorted(all_combinations, key=lambda x: (len(x), x))

    membership_matrix = {
        ','.join(comb): df[list(comb)].all(axis=1) for comb in ordered_combinations
    }
    membership_df = pd.DataFrame(membership_matrix)
    memberships = membership_df.apply(lambda row: list(membership_df.columns[row]), axis=1)
    intersections = memberships.value_counts()

    upset_data1 = all_possible_combinations33(df, min_set_count=min_motif_set_count)

    # Process UpSet data
    upset_data1.index = pd.MultiIndex.from_tuples(
        [
            (len(x.split(',')), tuple(sorted(x.split(','))))
            for x in upset_data1.index
        ],
        names=['Cardinality', 'Motifs']
    )
    upset_data1 = upset_data1[
        upset_data1.index.get_level_values('Cardinality') >= min_motif_set_count
    ]
    upset_data1 = upset_data1.sort_index(ascending=False)
    upset_data1.index = [', '.join(motifs) for _, motifs in upset_data1.index]

    sorted_data = upset_data1.sort_values(ascending=True).tail(50)
    total_sum = len(df)
    binary_matrix = transform_combinations_to_matrix(sorted_data, peak_file)
    sorted_binary_matrix = binary_matrix.sort_values(by='Count', ascending=False)
    sorted_data = (sorted_data / total_sum * 10000).astype(int) / 100

    # Generate normalized bar graphs
    fig_bar = plt.figure(figsize=(20, 30), constrained_layout=True)
    ax = sorted_data.plot(kind='barh', color='green', ax=plt.gca())
    plt.title(
        f'PEAKS [{total_sum}] : {peak_file}\n'
        f'MOTIFS: [{motif_list_file}]',
        fontsize=16
    )
    ax.set_ylabel(
        'Top 50 co-occurring motif groups with # of motifs >= '
        f'{min_motif_set_count}'
    )
    ax.set_xlabel('Co-occurrences (as % ratio of total peaks)')

    for i, (value, index) in enumerate(zip(sorted_data, sorted_data.index)):
        ax.text(
            value,
            i,
            f' {value}, {round(total_sum * value / 100 + 0.5)}',
            va='center',
            ha='left',
            color='black'
        )

    # Save bar graph as image
    if save_as_images:
        bar_image_path = f"{image_base}_bar.png"
        fig_bar.savefig(bar_image_path, bbox_inches="tight")
    plt.close(fig_bar)  # Close the figure to free memory

    # Prepare return values
    return_table = binary_table.T
    return_data = sorted_data

    return return_data, return_table


In [2]:
from upsetplot import plot
import matplotlib.pyplot as plt
import os
import pandas as pd
from matplotlib.backends.backend_pdf import PdfPages
from PyPDF2 import PdfReader, PdfWriter

# -------------------------
# Configuration and Setup
# -------------------------

# Base directory path for data and outputs
base_dir_in ='/home/psaisan/MCAT/'
base_dir_out='/home/psaisan/CARTMAN/'



# Genome reference
GENOME = 'mm10'

# Motif list file path
motif_list_file = os.path.join(base_dir_in, 'KCH_VS_KCN_w400_L70.motifs')

# Peak files
peak_file1 = os.path.join(base_dir_in, 'KCN_H3K_FC2_1000_w200.txt')
peak_file2 = os.path.join(base_dir_in, 'KCH_H3K_FC2_1000_w200.txt')

# Output directories
output_directory1 = os.path.join(base_dir_out, 'Example_Peak1/')
output_directory2 = os.path.join(base_dir_out, 'Example_Peak2/')

# CSV output paths for Peak 1 and 2
csv_out_file1 = os.path.join(base_dir_out, 'Peak1_table.txt')
csv_out_file2 = os.path.join(base_dir_out, 'Peak2_table.txt')

# Image output base paths
fig_output_path1 = os.path.join(base_dir_out, 'Peak1.png')  # Function appends _upset.png and _bar.png
fig_output_path2 = os.path.join(base_dir_out, 'Peak2.png')  # Function appends _upset.png and _bar.png


# Parameters with default values
motif_recount = False
seperate_duplicates = False
min_motif_set_count = 0
min_subset_count = 0

# -------------------------
# Processing Peak File 1
# -------------------------

# Call the peak_motif_sets function with appropriate parameters
dp1, bdf1 = peak_motif_sets(
    peak_file=peak_file1,
    genome=GENOME,
    motif_list_file=motif_list_file,
    Motif_recount=motif_recount,
    seperate_duplicates=seperate_duplicates,
    min_motif_set_count=min_motif_set_count,
    min_subset_count=min_subset_count,
    output_file=fig_output_path1
)

 

# Save the binary table to a CSV file for Peak 1
print('Save results in :',csv_out_file1)
bdf1.to_csv(csv_out_file1, sep='\t')

# -------------------------
# Processing Peak File 2
# -------------------------

# Call the peak_motif_sets function with appropriate parameters
dp2, bdf2 = peak_motif_sets(
    peak_file=peak_file2,
    genome=GENOME,
    motif_list_file=motif_list_file,
    Motif_recount=motif_recount,
    seperate_duplicates=seperate_duplicates,
    min_motif_set_count=min_motif_set_count,
    min_subset_count=min_subset_count,
    output_file=fig_output_path2
)


# Save the binary table to a CSV file for Peak 2
print('Save results in :',csv_out_file2)
bdf2.to_csv(csv_out_file2, sep='\t')


DENOVO HOMER
Motif_Count_File= /home/psaisan/MCAT/motifCount_counts_KCN_H3K_FC2_1000_w200_KCH_VS_KCN_w400_L70.txt Force= False
** Loading ** /home/psaisan/MCAT/motifCount_counts_KCN_H3K_FC2_1000_w200_KCH_VS_KCN_w400_L70.txt


  fig_upset.savefig(upset_image_path, bbox_inches="tight")


Save results in : /home/psaisan/CARTMAN/Peak1_table.txt
DENOVO HOMER
Motif_Count_File= /home/psaisan/MCAT/motifCount_counts_KCH_H3K_FC2_1000_w200_KCH_VS_KCN_w400_L70.txt Force= False
** Loading ** /home/psaisan/MCAT/motifCount_counts_KCH_H3K_FC2_1000_w200_KCH_VS_KCN_w400_L70.txt


  fig_upset.savefig(upset_image_path, bbox_inches="tight")


Save results in : /home/psaisan/CARTMAN/Peak2_table.txt


In [None]:
# Inspect results: motif combinations table
print('===')
print('Motif co-occurance counts (normalized %) for '+peak_file1)
print('===')
print(dp1)

print('===')
print('Outuput File : ',csv_out_file1)
print('===')
display(pd.read_csv(csv_out_file1,delimiter='\t'))
