In [1]:
import numpy as np
import pyabf
import os

In [4]:
from pathlib import Path
import pandas as pd
import matplotlib.pyplot as plt
import pybaselines
from sklearn.cluster import DBSCAN
from scipy import stats

In [16]:
def jbcd(folder,
         file,
         sweep_n:int, 
         start_time, 
         end_time,
        threshold):
    
        ABF = pyabf.ABF(Path(folder).joinpath(file))
        start_ix = np.where(np.isclose(ABF.sweepX, start_time))[0][0]
        end_ix = np.where(np.isclose(ABF.sweepX, end_time))[0][0]
        
        def signal_foo(sweep_n):
            ABF.setSweep(sweep_n)
            I_late = ABF.sweepY[start_ix:end_ix+1]

            return I_late


        jbcd_late = pybaselines.morphological.jbcd(
                signal_foo(sweep_n), beta_mult=1.05, gamma_mult=0.95)[1]['signal']
       
        jbcd_late = jbcd_late - np.median(jbcd_late)
 
        t = ABF.sweepX[start_ix:end_ix+1] - ABF.sweepX[start_ix]
        
        plt.ioff()
        fig = plt.figure(figsize=(180, 4))
        ax1 = plt.subplot(111)
        ax1.set_title(f'{file} Late period. Sweep number: {sweep_n}')
        ax1.plot(t, jbcd_late, color = 'blue')
        ax1.set_ylim(-1, 10)
        ax1.set_xlabel('time, s')
        ax1.set_ylabel('pA')
        ax1.plot(t, 0*t, '--', color='green', label='baseline')
        ax1.plot(t, 0*t+threshold, '--', color='red', label=f'threshold={threshold}')
        ax1.legend(loc=1)
        ax1.grid()
        plt.savefig(Path(folder).joinpath('Fig').joinpath(f'{file.split(".")[0]} sweep {sweep_n+1}'))
        plt.close()

        return jbcd_late
    
    
def re_write_ABF(folder,
                 file, 
                 start_time, 
                 end_time,
                 threshold = 1.5, 
                 excluded:list=[]):

    ABF = pyabf.ABF(Path(folder).joinpath(file))
    freq = (ABF.sweepX[1]-ABF.sweepX[0])**-1
    sweep_list = [_ for _ in ABF.sweepList if _ not in excluded]
    I_list = []
    for sweep_n in sweep_list:
        signal = jbcd(folder, file, sweep_n, start_time, end_time, threshold)
        I_list.append(signal)

    I_array = np.row_stack(I_list)
    wr = pyabf.abfWriter
    wr.writeABF1(sweepData=I_array, filename= file.split(".")[0]+'new.abf', sampleRateHz=freq)
    
    
def get_events(atf_file_name, save:bool=True):
    data = pd.DataFrame()
    with open(atf_file_name, mode='r') as events:
        line = 0
        for record in events:
            if line == 2:
                cols = record.split('\t')[:11]
                cols = [s.strip('"') for s in cols]

            if line > 2:
                row = record.split('\t')[:11]
                processed_row = []

                for item in row:
                    try:
                        item = float(item)
                    except:
                        processed_row.append(item)
                    else:
                        processed_row.append(float(item))

                data = pd.concat([data, pd.DataFrame(data=processed_row).T], axis=0, ignore_index=True)

            line += 1
    data = data.rename(columns={old_col:new_col for old_col, new_col in zip(data.columns, cols)})

    if save:
        data.to_csv('events.csv', header=data.columns, index=False)
         
        
def late_activity(atf_file, plot=True):
        
    get_events(f'{atf_file}.atf')
    df = pd.read_csv('events.csv')
    sweep_list = np.array(pyabf.ABF(Path(folder).joinpath(file)).sweepList)+1
    dcts=[]
    dct = {'shut_start':[], 'shut_end':[]}
    for x in sweep_list:
        for ix, row in df.iterrows():
            if row['Trace']==x and row['Level']==0:
                dct['shut_start'].append(row['Event Start Time (ms)'])
                dct['shut_end'].append(row['Event End Time (ms)'])
        dcts.append(dct)
        dct = {'shut_start':[], 'shut_end':[]}
        
    total_number_of_late_openings=0
    for ix, row in df.iterrows():
        if row['Level']>0:
            total_number_of_late_openings += 1
    
    def plot_shut_periods(ix):
        
        colors=np.random.rand(len(dcts[ix]['shut_start']))
        plt.scatter(dcts[ix]['shut_start'], dcts[ix]['shut_end'],c=colors, s=600, alpha=0.4, marker='o', linewidths=3)
        #plt.yscale('log')
        #plt.xscale('log')
        plt.xlabel('shut period start, ms', fontsize=14)
        plt.ylabel('shut period end, ms', fontsize=14)
        plt.xlim(min(dcts[ix]['shut_start']) - min(dcts[ix]['shut_start'])*0.4,
                max(dcts[ix]['shut_start']) + max(dcts[ix]['shut_start'])*0.4)
        plt.ylim(min(dcts[ix]['shut_end']) - min(dcts[ix]['shut_end'])*0.4,
                max(dcts[ix]['shut_end']) + max(dcts[ix]['shut_end'])*0.4)
        plt.title(f'Shut times in sweep {ix+1}', fontsize=14)
        
    
    def cluster_analysis(ix):
        
        X = np.array([dcts[ix]['shut_start'], dcts[ix]['shut_end']]).T
        db = DBSCAN(eps=2, min_samples=2).fit(X)
        
        core_samples_mask = np.zeros_like(db.labels_, dtype=bool)
        core_samples_mask[db.core_sample_indices_] = True
        labels = db.labels_

        # Number of clusters in labels, ignoring noise if present.
        n_clusters_ = len(set(labels)) - (1 if -1 in labels else 0)
        n_noise_ = list(labels).count(-1)
        
        if n_clusters_ > 0:
            # Black removed and is used for noise instead.
            unique_labels = set(labels)
            colors = [plt.cm.Spectral(each) for each in np.linspace(0, 1, len(unique_labels))]
            plt.ioff()
            for k, col in zip(unique_labels, colors):
                if k == -1:
                    # Black used for noise.
                    col = [0, 0, 0, 1]

                class_member_mask = labels == k

                xy = X[class_member_mask & core_samples_mask]
                plt.plot(
                    xy[:, 0],
                    xy[:, 1],
                    "o",
                    markerfacecolor=tuple(col),
                    markeredgecolor="k",
                    markersize=14,
                )

                xy = X[class_member_mask & ~core_samples_mask]
                plt.plot(
                    xy[:, 0],
                    xy[:, 1],
                    "o",
                    markerfacecolor=tuple(col),
                    markeredgecolor="k",
                    markersize=6,
                )
                
            #plt.yscale('log')
            #plt.xscale('log')

            plt.title(f'Shut times clusters in sweep {ix+1}', fontsize=14)
            plt.savefig(Path(folder).joinpath('Fig').joinpath(f'DBSCAN Shut Periods in {file.split(".")[0]}, sweep {ix+1}.tif'))
            plt.close()
  
        
        cluster_labels = labels[labels>=0]
        cluster_labels, cluster_sizes = np.unique(cluster_labels, return_counts=True)
        number_of_clusters = cluster_labels.size
        openins_in_clusters = cluster_sizes+1
        return number_of_clusters, [_ for _ in openins_in_clusters]
    
    labels = []
    total_number_of_clusters = 0
    total_openings_in_clusters = []
    for ix, d in enumerate(dcts):
        if len(d['shut_start'])>1:
            number_of_clusters, openins_in_clusters = cluster_analysis(ix)
            total_number_of_clusters += number_of_clusters
            if number_of_clusters>0:
                total_openings_in_clusters.extend(openins_in_clusters)
            
            
        
    if plot:
        for ix, d in enumerate(dcts):
            if len(d['shut_start'])>0:
                plt.ioff()
                fig = plt.figure(figsize=(18, 4))
                ax1 = plt.subplot(111)
                plot_shut_periods(ix)
                plt.savefig(Path(folder).joinpath('Fig').joinpath(f'Shut Periods in {file.split(".")[0]} sweep {ix+1}.tif'))
                plt.close()
                
                
    return {'total number of clusters': total_number_of_clusters, 
            'numbers of openings in clusters': total_openings_in_clusters,
           'total_number_of_late_openings': total_number_of_late_openings,
           'total number of 1 s sweeps': sweep_list.size,
           'bursts // channels // seconds': total_number_of_clusters/n_channels/sweep_list.size,
           'total_number_of_late_openings// channels // seconds': total_number_of_late_openings/n_channels/sweep_list.size}