In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import matplotlib as mpl
from scipy.io import loadmat
import os
import glob

def process_and_save_plots(mat_file_path, csv_file_path):
    # Load data from the MAT file
    mat_data = loadmat(mat_file_path)
    mouse_movement = mat_data['mouseMovement']
    mouse_df = pd.DataFrame(mouse_movement, columns=['trial', 'sample', 'dx', 'dy'])
    mouse_df['x'] = mouse_df.groupby('trial')['dx'].cumsum()
    mouse_df['y'] = mouse_df.groupby('trial')['dy'].cumsum()
    mouse_df = mouse_df.reset_index(drop=True)

    # Load data from the CSV file
    control_df = pd.read_csv(csv_file_path)
    control_df_reset = control_df.reset_index()
    mouse_df_reset = mouse_df.reset_index()

    # Merge the dataframes
    merged_df = pd.merge(mouse_df_reset, control_df_reset, left_on='trial', right_on='index')

    # Create subplots
    num_trials = mouse_df['trial'].nunique()
    num_rows = int(np.ceil(np.sqrt(num_trials)))
    num_cols = num_rows

    fig, axs = plt.subplots(num_rows, num_cols, figsize=(50, 50))
    fig.subplots_adjust(wspace=0.4, hspace=0.4)
    for ax in axs.flat:
        ax.set_aspect('equal')
        ax.set_box_aspect(1)

    # Plot the data for each trial
    for i, (name, group) in enumerate(merged_df.groupby('trial')):
        row = i // num_cols
        col = i % num_cols
        x = group['x'].values
        y = group['y'].values

        colors = np.linspace(0, 1, len(x))
        ax = axs[row, col]
        ax.plot(x, y, '-k', alpha=0.2)
        ax.scatter(x, y, c=colors, cmap='turbo')
        ax.plot(x[0], y[0], 'Dr', label='start', markersize=8)
        ax.axis('equal')
        norm = mpl.colors.Normalize(vmin=0, vmax=len(x))
        cbar = fig.colorbar(mpl.cm.ScalarMappable(cmap='turbo', norm=norm), ax=ax)
        cbar.set_label('Time step')
        ax.set_xlabel('x')
        ax.set_ylabel('y')
        ax.legend()
        ax.set_xlim([group['x'].min(), group['x'].max()])
        ax.set_ylim([group['y'].min(), group['y'].max()])
        control_level = group['actual control'].iloc[0]
        ax.set_title(f'Trial = {name}, Control = {control_level}')

    # Remove empty subplots
    for i in range(num_trials, num_rows * num_cols):
        fig.delaxes(axs.flatten()[i])

    plt.tight_layout()
    
     # プロットを保存
    img_file = f'{os.path.splitext(os.path.basename(mat_file_path))[0]}.png'
    img_path = os.path.join('/Users/tstakuma/Desktop/rikkyo_action_analysis/PRJ', img_file)
    plt.savefig(img_path)

mat_files = sorted(glob.glob('/Users/tstakuma/Desktop/rikkyo_action_analysis/all datas/Kio University 20230518_行動データ/*.mat'))
csv_files = sorted(glob.glob('/Users/tstakuma/Desktop/rikkyo_action_analysis/all datas/Kio University 20230518_行動データ/*.csv'))

# 出力フォルダを指定してMATファイルとCSVファイルを処理
folder_path = '/Users/tstakuma/Desktop/rikkyo_action_analysis/all datas/Kio University 20230518_行動データ'



# この部分でMATファイルとCSVファイルを処理し、プロットを作成し、ファイルに保存
for i, (mat_file, csv_file) in enumerate(zip(mat_files, csv_files)):
    # process_and_save_plots関数を呼び出してプロットを作成し、ファイルに保存
    process_and_save_plots(mat_file, csv_file)