In [None]:
from matplotlib import pyplot as plt
import seaborn as sns
import os 
import yaml 
import glob 
import pandas as pd
from burst import *
import openpyxl

class LabelAnalysis():

    def __init__(self, video_name):
        self.video_name = video_name 
        self.__path = r'C:\Users\jyc13\Documents\Python Scripts\openCV_eyemovement\yolov5\runs\detect\{0}\labels'.format(video_name)
        self.__object_classes = yaml.safe_load(open('object_classes.yaml', 'r'))
        self.read_video_info()
        self.check_video_name()

        self.__video_w = int(self.this_info['FrameWidth']) # 가로 resolution
        self.__video_h = int(self.this_info['FrameHeight']) # 세로 resolution
        self.__video_fps = float(self.this_info['FPS']) # frame per second
        self.__start_frame = int(self.this_info['StartFrame'])
        self.__end_frame = int(self.this_info['EndFrame'])
        # need to update
        # self.__pixelpermm = 10.2 # pixel/mm 
    
    @property
    def get_start_frame(self):
        return self.__start_frame
    
    @property 
    def get_end_frame(self):
        return self.__end_frame

    def read_video_info(self):
        file_path = "D:\Research\SC\DATA\Eye_Movement\VideoInfo.csv"
        self.video_info = pd.read_csv(file_path)
    
    def check_video_name(self):
        _this_info = self.video_info[self.video_info['Frequency'] == 20]
        _this_video_names = _this_info['VideoName']
        if (_this_video_names.str.contains(self.video_name.lower(), case=False).any()) == True:
            self.this_info = self.video_info.loc[self.video_info['VideoName'].str.lower() == self.video_name.lower(), :]
            display(self.this_info)
        else: 
            raise Exception(f"There is no file named {self.video_name}.")

    def read_labels(self):
        # YOLO style: (num_class center_x center_y w, h \n) 모든 값은 resoluion 기준으로 비율
        
        # read all txt files in the __path
        files = glob.glob(f'{self.__path}/*.txt')
        data = []
        for file in files:
            n_frame = int(file.split('_')[-1].replace('.txt', ''))
            if (n_frame >= self.__start_frame) & (n_frame <= self.__end_frame):
                with open(file, 'rt') as f:
                    lines = f.readlines()
                    for line in lines:
                        splited = line.split(' ')
                        num_class = int(splited[0])
                        class_name = str(self.__object_classes[num_class])
                        # pixel 단위로 변환
                        center_x = float(splited[1]) * self.__video_w 
                        center_y = float(splited[2]) * self.__video_h 
                        w = float(splited[3]) * self.__video_w 
                        h = float(splited[4]) * self.__video_h 
                        size = w * h # size는 width * height 직사각형이라고 가정

                        frame_data = [n_frame, num_class, class_name, center_x, center_y, w, h, size]
                        data.append(frame_data)
            else:
                pass 

        # format Dataframe
        df = pd.DataFrame(data)
        df.rename(columns = {0:'n_frame'}, inplace = True)
        df.rename(columns = {1:'num_class'}, inplace = True)
        df.rename(columns = {2:'class_name'}, inplace = True)
        df.rename(columns = {3:'center_x'}, inplace = True)
        df.rename(columns = {4:'center_y'}, inplace = True)
        df.rename(columns = {5:'width'}, inplace = True)
        df.rename(columns = {6:'height'}, inplace = True)
        df.rename(columns = {7:'size'}, inplace = True)

        df.sort_values(by=['n_frame', 'num_class'], inplace=True, ignore_index=True)

        # interpolation
        # df['center_x_intp'] = df.groupby(by=['class_name'])['center_x'].apply(lambda group: group.interpolate(method='n_frame'))

        # diff
        df['pupil-eye_x'] = df.groupby(by=['n_frame'])['center_x'].diff()
        df['pupil-eye_y'] = df.groupby(by=['n_frame'])['center_y'].diff()

        # Time code
        df['time_code'] = df['n_frame'] / self.__video_fps # return time code

        return df 
    
    def get_onoff(self):
        mouse = Burst(video_name=self.video_name)
        df_onoff = mouse.get_stim_on_frames()

        return df_onoff

    def merge_df(self):
        df_base = self.read_labels()
        df_onoff = self.get_onoff()
        df_merge = df_base.merge(df_onoff, on=['n_frame'])

        return df_merge

    def get_result_df(self):
        df_res = self.merge_df()

        return df_res
    
    def get_stim_range_list(self):
        df_res = self.merge_df()

        range_list = []
        prev_val = False

        for inx, val in df_res['stim'].iteritems():
            if prev_val != val:
                if val:
                    start = df_res.loc[inx, 'n_frame']
                else:
                    range_list.append((start, df_res.loc[inx, 'n_frame']))

            prev_val = val
            
        return range_list

    def get_peth(self):
        res_df = self.get_result_df()
        range_list = self.get_stim_range_list()
        
        x_span = 180
        peth = []
        n_cycle = 0

        for (start, end) in range_list:
            if (start >= x_span) & (end <= max(res_df['n_frame']) - x_span):
                tmp = res_df.copy()
                tmp['peth_frame'] = tmp['n_frame'] - start
                tmp = tmp.loc[(tmp['peth_frame'] >= -x_span) & (tmp['peth_frame'] <= x_span)]
                tmp.loc[:, 'n_cycle'] = n_cycle
                n_cycle += 1
                peth.append(tmp)

        df_peth = pd.concat(peth)

        return df_peth

    def draw_peth(self, x='peth_frame', y='pupil-eye_x', hue='stim'):
        df_peth = self.get_peth()
        if y == 'pupil-eye_x' or 'pupil_eye_y':
            df_plot = df_peth.dropna()
        else:
            df_plot = df_peth

        fig_peth = sns.lineplot(x='peth_frame', y=y, hue='stim', data=df_plot) # std span 보기 위해
        plt.show()

        return fig_peth
    
    def save_peth_fig(self, save_path=r'./result_fig'):
        fig = self.draw_peth().figure
        save_name = self.video_name + '.png'
        fig.savefig(os.path.join(save_path, save_name))
    
    def peth_to_xlsx(self, save_path=r'./result_xlsx'):
        df_peth = self.get_peth()
        save_name = self.video_name + '.xlsx'
        df_peth.to_excel(os.path.join(save_path, save_name))


In [None]:
detect_path = r'C:\Users\jyc13\Documents\Python Scripts\openCV_eyemovement\yolov5\runs\detect'
filenames = [x for x in os.listdir(detect_path) if '20hz' in x]
noproblems_path = r'C:\Users\jyc13\Documents\Python Scripts\openCV_eyemovement\analysis\result_xlsx'
noproblems = [x.replace('.xlsx', '') for x in os.listdir(noproblems_path)]
problems = list(set(filenames)-set(noproblems))
print(len(filenames), len(noproblems), len(problems))

In [None]:
# save xlsx and png files
for _name in problems:
    try:
        video1_name = _name
        mouse1 = LabelAnalysis(video_name=video1_name)
        res_df = mouse1.get_result_df()
        display(res_df)
        peth_df = mouse1.get_peth()
        display(peth_df)
        mouse1.peth_to_xlsx()
        mouse1.save_peth_fig()
    except: 
        print(f'error: {video1_name}')

In [None]:
# x-axis
sns.lineplot(x='n_frame', y='center_x', hue='class_name', data=res_df)

for (start, end) in range_list:
    plt.axvspan(start, end, facecolor='blue', alpha=0.5)

plt.xlim(0, 12000)
plt.show()

In [None]:
# y-axis
sns.lineplot(x='n_frame', y='center_y', hue='class_name', data=res_df)
for (start, end) in range_list:
    plt.axvspan(start, end, facecolor='blue', alpha=0.5)
plt.show()

In [None]:
# diff_x
sns.lineplot(x='n_frame', y='pupil-eye_x', hue='class_name', data=res_df, linewidth=0.2)
for (start, end) in range_list:
    plt.axvspan(start, end, facecolor='blue', alpha=0.5)

plt.xlim(6000, 12000)
plt.show()

In [None]:
# diff_y
sns.lineplot(x='n_frame', y='pupil-eye_y', hue='class_name', data=res_df)
for (start, end) in range_list:
    plt.axvspan(start, end, facecolor='blue', alpha=0.5)

plt.xlim(0, 3000)
plt.show()