In [None]:
# imports
import matplotlib.pyplot as plt
import pandas as pd
import glob
import os
import csv
import re
import more_itertools as mit
from tqdm import tqdm

In [None]:
def delete_keys(d):
    """delete events that have a start or stop missing."""
    del_e = []
    for event in d:
        del_k = [[event, k] for k in d[event] if not 'stop' in d[event][k]]
        if not d[event]:
            del_e.append(event)
    for k in del_k: 
        del d[k[0]][k[1]]
    for e in del_e: 
        del d[e]

def extract_all_images(): 
    """loop through all records and save all images by UC per record."""
    record_nums = [x[-8:-4]for x in glob.glob(f'../../data/annotations/csv/annotation_*.csv')]

    for rec in tqdm(record_nums):
        record = Record(rec)
        record.saveImages()
        del record

def clear_all_images():
    """loop through all records and clear the image directory for each record."""
    record_nums = [x[-8:-4]for x in glob.glob(f'../../data/annotations/csv/annotation_*.csv')]

    for rec in tqdm(record_nums):
        record = Record(rec)
        record.clearImages()
        del record
        
def get_class_counts():
    count_str = ''
    for label_type in ['ACC', 'BRADY', 'DEC_EARLY', 'DEC_LATE', 'DEC_PROLONG', 'DEC_VAR', 'NONE', 'TACHY']:
        img_count = len(list(glob.glob(f"../images/{label_type}/record_*.png", recursive=True)))
        count_str += f'{label_type}: {img_count}\n'
    print(count_str)

In [None]:
"""Accessing and interacting with Record files"""

class Record():

    def __init__(self, record_name: str) -> None:
        self.record_name = record_name
 
        self.ann = { 'UC': {},
                     'DEC_EARLY': {},
                     'DEC_VAR': {},
                     'DEC_LATE': {},
                     'DEC_PROLONG': {},
                     'ACC': {},
                     'TACHY': {},
                     'BRADY': {}
                     }

        self.label_dict = { 'NONE': 0,
                     'DEC_EARLY': 1,
                     'DEC_VAR': 2,
                     'DEC_LATE': 3,
                     'DEC_PROLONG': 4,
                     'ACC': 5,
                     'TACHY': 6,
                     'BRADY': 7
        }

        self.label_map_from_int = { '0': 'NONE',
                     '1':'DEC_EARLY',
                     '2':'DEC_VAR',
                     '3':'DEC_LATE',
                     '4':'DEC_PROLONG',
                     '5':'ACC',
                     '6':'TACHY',
                     '7':'BRADY'
        }

        self.decel_map = {'E':'DEC_EARLY',
              'V':'DEC_VAR',
              'L':'DEC_LATE',
              'P':'DEC_PROLONG'}

        self._signalDf = pd.read_csv(f'../../data/database/signals/{record_name}.csv', na_values=['0.0'])

        # Call to get annotations
        self.__getannotations(self.record_name)
  

    def __getannotations(self, record_name) -> None:
        """Gets the annotations from the annotation csv file
        and populates the appropriate annotation list"""

        with open(f'../../data/annotations/csv/annotation_{record_name}.csv', newline='',
        encoding='UTF-8') as csvfile:
            annreader = csv.reader(csvfile, delimiter=',')
            for i, row in enumerate(annreader):
                                
                joined_row = ''.join(row)
                
                if 'UC' in joined_row:
                    for ann in re.findall(r'[\(\)]UC\d+', joined_row):
                        if ann[0] == '(':
                            self.ann['UC'][ann[1:]] = {'start': i} 
                        if ann[0] == ')' and ann[1:] in self.ann['UC'].keys():
                            self.ann['UC'][ann[1:]]['stop'] = i       

                if 'DEC' in joined_row:
                    for ann in re.findall(r'[\(\)]DEC\w+', joined_row):
                        if ann[0] == '(':
                            self.ann[self.decel_map[ann[-1]]][ann[1:]] = {'start': i} 
                        if ann[0] == ')' and ann[1:] in self.ann[self.decel_map[ann[-1]]].keys():
                            self.ann[self.decel_map[ann[-1]]][ann[1:]]['stop'] = i

                if 'ACC' in joined_row:
                    for ann in re.findall(r'[\(\)]ACC\d+', joined_row):
                        if ann[0] == '(':
                            self.ann['ACC'][ann[1:]] = {'start': i} 
                        if ann[0] == ')' and ann[1:] in self.ann['ACC'].keys():
                            self.ann['ACC'][ann[1:]]['stop'] = i  

                if 'TC' in joined_row:
                    for ann in re.findall(r'[\(\)]TC\d+', joined_row):
                        if ann[0] == '(':
                            self.ann['TACHY'][ann[1:]] = {'start': i}   
                        if ann[0] == ')' and ann[1:] in self.ann['TACHY'].keys():
                            self.ann['TACHY'][ann[1:]]['stop'] = i  

                if 'BC' in joined_row:
                    for ann in re.findall(r'[\(\)]BC\d+', joined_row):
                        if ann[0] == '(':
                            self.ann['BRADY'][ann[1:]] = {'start': i}  
                        if ann[0] == ')' and ann[1:] in self.ann['BRADY'].keys():
                            self.ann['BRADY'][ann[1:]]['stop'] = i
                
        # Call to add labels
        if self.ann['UC']:
            self.add_labels(delta=0)
    
            
    def add_labels(self, delta=0):
        """adds labels for each UC event"""
        for uc_key in self.ann['UC']:
            if self.event_child_exists('UC', uc_key):
                uc_start, uc_stop = (self.ann['UC'][uc_key]['start'],self.ann['UC'][uc_key]['stop'])
                output_events = list(self.ann.keys())
                output_events.remove('UC')
                for output_event in output_events:
                    for event_key in self.ann[output_event]:
                        if self.event_child_exists(output_event, event_key):
                            estart, estop = (self.ann[output_event][event_key]['start'],self.ann[output_event][event_key]['stop'])
                            if ( (estart >= (uc_start - delta) and estart < (uc_stop + delta)) or (estop >= (uc_start - delta) and estop < (uc_stop + delta)) ):
                                label = self.label_dict[output_event]
                                self.ann['UC'][uc_key]['label'] = label
            if not 'label' in self.ann['UC'][uc_key].keys():
                    self.ann['UC'][uc_key]['label'] = 0


    def plotUC(self, ucNum: int):
        """plot UC/FHR pair by UC number"""
        plotNum = 'UC'+str(ucNum)
        if plotNum in self.uc.keys():
            start, end = self.uc[plotNum]
            self.createPlot(start, end, plotNum)
            plt.show()        


    def createPlot(self, start: int, end: int, plotID: str):
        """ Helper Func that creates a plot for the UC contraction and FHR specified by ucNum"""
        
        x = self._signalDf['seconds'][start:end].to_numpy()
        y_uc = self._signalDf['UC'][start:end].interpolate(method='linear').to_numpy()
        y_fhr = self._signalDf['FHR'][start:end].interpolate(method='linear').to_numpy()
        
        # FHR subplot
        fig, axs = plt.subplots(nrows=2, ncols=1, sharex=True)
        axs[0].set_title(f'Record {self.record_name} {plotID}')
        axs[0].plot(x, y_fhr, '#1f77b4')
        axs[0].set_xlim(start+2, end+2)
        axs[0].set_ylabel('FHR')
        axs[0].set_ylim(0, 220)
        # Uterine Contraction subplot
        axs[1].plot(x, y_uc, '#ff7f0e')
        axs[1].set_xlim(start+2, end+2)
        axs[1].set_ylim(0, 140)
        axs[1].set_ylabel('Uterine Contraction')
        fig.set_size_inches(8, 4)


    def _savePlot(self, start, end, plotID, label):
        """Saves a plot of the UC contraction and FHR specified by ucNum"""

        self.createPlot(start, end, plotID)

        fig = plt.gcf()
        fig.set_size_inches(18, 8)
        fig.savefig(f"../images/{label}/record_{self.record_name}_{plotID}_{label}.png", bbox_inches='tight')
        plt.close(fig)


    def saveImages(self):
        """Saves the plotted images in dir defined in _savePlot"""
        if 'UC' in self.ann:
            missingdata = self.findMissing()

            for key in self.ann['UC']:
                if int(key[2:]) not in missingdata and self.event_child_exists('UC', key):
                    start, end = (self.ann['UC'][key]['start'], self.ann['UC'][key]['stop']) # TO DO: CREATE get_start_stop and log errors
                    self._savePlot(start, end, key, self.label_map_from_int[str(self.ann['UC'][key]['label'])])
        
            print(f"Images for Record {self.record_name} have been saved.")
        else:
            print(f"There are no UCs for Record {self.record_name}.")

                
    def clearImages(self):
        """Deletes all of the images for this record"""
        imgs = glob.glob(f"../images/*/record_{self.record_name}*.png", recursive=True)

        for img in imgs:
            try:
                os.remove(img)
            except OSError as e:
                print("Error: %s : %s" % (img, e.strerror))
                

    def findMissing(self):
        """creates a list of UC events that have 15 seconds or more of missing data"""
        NA_list = []

        # gets all null index values within contraction times
        for uc in self.ann['UC']:
            if self.event_child_exists("UC", uc):
                ann_start = self.ann['UC'][uc]['start']
                ann_end = self.ann['UC'][uc]['stop']

                rec_df = self._signalDf['FHR'][ann_start:ann_end]
                NA_list.append(list(rec_df[rec_df.isna()].index))

        # finds consecutive index of null values
        null_list = []
        for i in range(len(NA_list)): 

            temp_list = NA_list[i]
            null_groups = [list(group) for group in mit.consecutive_groups(temp_list)] #gets consecutive numbers
            null_list.append(null_groups)

        # searches for over 15 seconds of consecutive loss
        uc_list = []
        for i1 in range((len(null_list))):
            for i2 in range(len(null_list[i1])):
                if len(null_list[i1][i2]) > 60:
                    new_val = i1+1 # contraction with missing data
                    uc_list.append(new_val)

        return uc_list


    def event_child_exists(self, event, child):
        """Checks for valid ann dict keys"""
        try:
            _ = (self.ann[event][child]['start'],self.ann[event][child]['stop'])
            return True
        except KeyError as ke:
            return False


In [None]:
clear_all_images()


In [None]:
extract_all_images()

In [None]:
get_class_counts()