In [1]:
import sys
import random
import matplotlib
matplotlib.use('Qt5Agg')
import numpy as np
import matplotlib.pyplot as plt
import serial
import time
from PyQt5 import QtCore, QtWidgets
from PyQt5.QtWidgets import QApplication, QWidget, QPushButton, QInputDialog, QMessageBox, QFileDialog
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg, NavigationToolbar2QT as NavigationToolbar
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
from matplotlib.figure import Figure
from PyQt5.QtCore import QObject, QThread, pyqtSignal
from pathlib import Path
from pyID_GUI import pyID_arduinos

class MplCanvas(FigureCanvas):

    def __init__(self, parent=None, width=70, height=10, dpi=100):
        fig = Figure(figsize=(width, height), dpi=dpi)
        self.ax1 = fig.add_subplot(2,4,1)
        self.ax2 = fig.add_subplot(2,4,2)
        self.ax3 = fig.add_subplot(2,4,3)
        self.ax4 = fig.add_subplot(2,4,4)
        self.ax5 = fig.add_subplot(2,4,5)
        self.ax6 = fig.add_subplot(2,4,6)
        self.ax7 = fig.add_subplot(2,4,7)
        self.ax8 = fig.add_subplot(2,4,8)
        fig.subplots_adjust(left=.1, bottom=.1, right=.9, top=.9, wspace=.3, hspace=.2)
        super(MplCanvas, self).__init__(fig)


class MainWindow(QtWidgets.QMainWindow):

    def __init__(self, *args, **kwargs):
        super(MainWindow, self).__init__(*args, **kwargs)
        
        arduino_script_path = Path().cwd()
        
        arduino1_port = 'COM4'
        arduino2_port = 'COM5'
        
        pyID_arduinos.arduino1_PID_control_uploader(arduino_script_path, arduino1_port)
        pyID_arduinos.arduino2_PID_control_uploader(arduino_script_path, arduino2_port)

        self.canvas = MplCanvas(self, width=20, height=10, dpi=100)
        self.setCentralWidget(self.canvas)
        
        toolbar = NavigationToolbar(self.canvas, self)
        
        self.b0 = QtWidgets.QPushButton(self)
        self.b0.setText("test all valves")
        self.b0.adjustSize()

        
        self.b1 = QtWidgets.QPushButton(self)
        self.b1.setText("min oil")
        
        self.b2 = QtWidgets.QPushButton(self)
        self.b2.setText("valve 1")
        
        self.b3 = QtWidgets.QPushButton(self)
        self.b3.setText("valve 2")
        
        self.b4 = QtWidgets.QPushButton(self)
        self.b4.setText("valve 3")
        
        self.b5 = QtWidgets.QPushButton(self)
        self.b5.setText("valve 4")
        
        self.b6 = QtWidgets.QPushButton(self)
        self.b6.setText("valve 5")
        
        self.b7 = QtWidgets.QPushButton(self)
        self.b7.setText("valve 6")
        
        self.b8 = QtWidgets.QPushButton(self)
        self.b8.setText("valve 7")
        
        self.b9 = QtWidgets.QPushButton(self)
        self.b9.setText("save PID data")
                #titles = ['min oil', 'etbu', '2-hxn', 'isac', 'hex', 'et tig', 'etac','ace']


        layout = QtWidgets.QGridLayout()
        layout.addWidget(toolbar,0,0,1,1, alignment=QtCore.Qt.AlignRight)
        layout.addWidget(self.canvas, 0,2,7,1)
        layout.addWidget(self.b0, 1,0,4,2)
        layout.addWidget(self.b1, 2,0,1,1) #(Widget, row, column, rowspan, columnspan)
        layout.addWidget(self.b2, 2,1,1,1)
        layout.addWidget(self.b3, 3,0,1,1)
        layout.addWidget(self.b4, 3,1,1,1)
        layout.addWidget(self.b5, 4,0,1,1) 
        layout.addWidget(self.b6, 4,1,1,1)
        layout.addWidget(self.b7, 5,0,1,1)
        layout.addWidget(self.b8, 5,1,1,1)
        layout.addWidget(self.b9, 3,0,4,2)
        
        self.b0.clicked.connect(self.b0_clicked)
        self.b1.clicked.connect(self.b1_clicked)
        self.b2.clicked.connect(self.b2_clicked)
        self.b3.clicked.connect(self.b3_clicked)
        self.b4.clicked.connect(self.b4_clicked)
        self.b5.clicked.connect(self.b5_clicked)
        self.b6.clicked.connect(self.b6_clicked)
        self.b7.clicked.connect(self.b7_clicked)
        self.b8.clicked.connect(self.b8_clicked)
        self.b9.clicked.connect(self.b9_clicked)

        # Create a placeholder widget to hold our toolbar and canvas.
        widget = QtWidgets.QWidget()
        widget.setLayout(layout)
        self.setCentralWidget(widget)

        #self.update_plot()
        self.setup_plot()

        self.show()
        
        self.odors = ['min_oil', 'valve 1', 'valve 2', 'valve 3', 'valve 4', 'valve 5', 'valve 6', 'valve 7']
        
        self.PID_odordata = {}
        self.PID_timepoints = {}
            
    def b0_clicked(self):
        dialog = QInputDialog()
        text, ok = dialog.getText(self, 'starting experiment...', 'Enter number of presentations for each odor')
        if ok:
            self.n_odor_trials = int(text)
        self.odor_valve = 0
        self.test_all_valves()
    
    def b1_clicked(self):
        self.odor_valve = 0
        self.collect_PID_data()
        
    def b2_clicked(self):
        self.odor_valve = 1
        self.collect_PID_data()
        
    def b3_clicked(self):
        self.odor_valve = 2
        self.collect_PID_data()
    
    def b4_clicked(self):
        self.odor_valve = 3
        self.collect_PID_data()
    
    def b5_clicked(self):
        self.odor_valve = 4
        self.collect_PID_data()

    def b6_clicked(self):
        self.odor_valve = 5
        self.collect_PID_data()
        
    def b7_clicked(self):
        self.odor_valve = 6
        self.collect_PID_data()
        
    def b8_clicked(self):
        self.odor_valve = 7
        self.collect_PID_data()
        
    def b9_clicked(self):
        self.save_PID_data()

        
    def button_disabler(self):
        buttons = [self.b0,self.b1,self.b2,self.b3,self.b4,self.b5,self.b6,self.b7,self.b8,self.b9]
        for button in buttons:
            button.setEnabled(False)
            
    def button_enabler(self):
        buttons = [self.b0,self.b1,self.b2,self.b3,self.b4,self.b5,self.b6,self.b7,self.b8,self.b9]
        for button in buttons:
            button.setEnabled(True)
        
    def collect_PID_data(self):
        # Step 2: Create a QThread object
        self.thread = QThread()
        # Step 3: Create a worker object
        self.button_push_collector = Button_push_collector(self.odor_valve)
        # Step 4: Move worker to the thread
        self.button_push_collector.moveToThread(self.thread)
        # Step 5: Connect signals and slots
        self.thread.started.connect(self.button_push_collector.collect_PID_data)
        self.button_push_collector.finished.connect(self.thread.quit)
        self.button_push_collector.finished.connect(self.button_push_collector.deleteLater)
        self.thread.finished.connect(self.thread.deleteLater)
        self.button_push_collector.signalExample.connect(self.update_single_valves)
        self.button_push_collector.finished.connect(self.update_plot)
        self.button_push_collector.finished.connect(self.button_enabler)
        # Step 6: Start the thread
        self.thread.start()
        self.button_disabler()
        
    def test_all_valves(self):
        # Step 2: Create a QThread object
        self.thread = QThread()
        # Step 3: Create a worker object
        self.all_valve_test = All_valve_test(self.n_odor_trials)
        # Step 4: Move worker to the thread
        self.all_valve_test.moveToThread(self.thread)
        # Step 5: Connect signals and slots
        self.thread.started.connect(self.all_valve_test.test_all_valves)
        self.all_valve_test.finished.connect(self.thread.quit)
        self.all_valve_test.finished.connect(self.all_valve_test.deleteLater)
        self.thread.finished.connect(self.thread.deleteLater)
        self.all_valve_test.signalExample.connect(self.update_all_valves)
        self.all_valve_test.update.connect(self.update_plot)
        self.all_valve_test.finished.connect(self.save_PID_data)
        # Step 6: Start the thread
        self.thread.start()
        
    def record_PID_data(self):
        if self.block == 0:
            print(self.odor_valve)
            print(self.odors[self.odor_valve])
            self.PID_odordata[self.odors[self.odor_valve]] = [self.ydata]
            self.PID_timepoints[self.odors[self.odor_valve]] = [self.xdata]
        else:
            self.PID_odordata[self.odors[self.odor_valve]].append(self.ydata)
            self.PID_timepoints[self.odors[self.odor_valve]].append(self.xdata)

            
    def save_PID_data(self): 
        file_name = QFileDialog.getSaveFileName(self, 'Save File')
        self.PID_data = {'odor_data':self.PID_odordata, 'timepoints': self.PID_timepoints}
        np.save(file_name[0], [self.PID_data])
        
    def update_single_valves(self, x_vals, y_vals):
        self.xdata = x_vals
        self.ydata = y_vals
        
    def update_all_valves(self, x_vals, y_vals, odor_valve, block): 
        self.xdata = x_vals
        self.ydata = y_vals
        self.odor_valve = odor_valve-1 #from olfactometer coordinates back to python 
        self.block = block
        self.record_PID_data()

    def setup_plot(self):
        self.axes = [self.canvas.ax1, self.canvas.ax2, self.canvas.ax3, self.canvas.ax4, self.canvas.ax5, self.canvas.ax6, self.canvas.ax7, self.canvas.ax8]
        self.titles = ['min oil', 'valve 1', 'valve 2', 'valve 3', 'valve 4', 'valve 5', 'valve 6','valve 7']
        for n, ax in enumerate(self.axes):
            ax.cla()  # Clear the canvas.
            ax.set_xlim(0,10)
            ax.set_ylim(0,300)
            ax.set_aspect(15/300)
            ax.set_title(self.titles[n])
            ax.set_xlabel('Time (s)')
            if (n==0) | (n==4):
                ax.set_ylabel('Output voltage (mV)')
        self.updatenum = 0
        self.canvas.draw()

    def update_plot(self):
        # Drop off the first y element, append a new one.
        ax = self.axes[self.odor_valve]
        #ax.cla()  # Clear the canvas.
        ax.plot(self.xdata, self.ydata, color = 'k', linewidth = 1)
        ax.set_xlim(0,10)
        ax.set_ylim(0,300)
        ax.set_aspect(15/300)
        ax.set_title(self.titles[self.odor_valve])
        # Trigger the canvas to update and redraw.
        self.canvas.draw()
        
# A worker class for testing all of the valves 
class All_valve_test(QObject):
    finished = pyqtSignal()
    update = pyqtSignal()
    signalExample = pyqtSignal(list, list, int, int)
    
    def __init__(self, trials_per_odor, parent=None):
        QThread.__init__(self, parent)
        self.trials_per_odor = trials_per_odor
        self.n_odors = 8 
    
    def test_all_valves(self): 
        blocked_odor_trials = np.empty((self.trials_per_odor,self.n_odors), dtype = int)
        for block in range(self.trials_per_odor):
            valves = np.arange(self.n_odors)+1
            random.shuffle(valves)
            blocked_odor_trials[block,:] = valves
        
        for block in range(self.trials_per_odor):
            for trial in range(self.n_odors):
                odor_valve = blocked_odor_trials[block, trial]
                ser1 = serial.Serial('COM4', 9600)
                ser2 = serial.Serial('COM5', 9600)
                time.sleep(2) # wait two seconds to establish the serial connection 
                ser1.write(bytes(str(odor_valve),   'utf-8'))
                start = time.perf_counter()
                elapsed_time = 0
                time_list = []
                val_list = []
                while elapsed_time<10:
                    try:
                        in_val = int(ser2.readline().decode("utf-8"))
                        val_list.append(in_val)
                        elapsed_time = time.perf_counter()-start
                        time_list.append(elapsed_time)
                        time.sleep(.005)
                    except:
                        pass
                ser1.flush()
                ser2.flush()
                ser1.close()
                ser2.close()
                self.signalExample.emit(time_list, val_list, odor_valve, block)
                self.update.emit()
        self.finished.emit()

    
# A worker class for testing each valve via button press 
class Button_push_collector(QObject):
    finished = pyqtSignal()
    signalExample = pyqtSignal(list, list)
    
    def __init__(self, odor_valve, parent=None):
        QThread.__init__(self, parent)
        self.odor_valve = odor_valve + 1
    
    def collect_PID_data(self): 
        ser1 = serial.Serial('COM4', 9600)
        ser2 = serial.Serial('COM5', 9600)
        time.sleep(2) # wait two seconds to establish the serial connection 
        ser1.write(bytes(str(self.odor_valve),   'utf-8'))
        start = time.perf_counter()
        elapsed_time = 0
        time_list = []
        val_list = []
        # collect 10 seconds of PID data 
        while elapsed_time<10:
            try:
                in_val = int(ser2.readline().decode("utf-8"))
                val_list.append(in_val)
                elapsed_time = time.perf_counter()-start
                time_list.append(elapsed_time)
                time.sleep(.005)
            except:
                pass
        ser1.flush()
        ser2.flush()
        ser1.close()
        ser2.close()
        self.signalExample.emit(time_list, val_list)
        self.finished.emit()



app = QtWidgets.QApplication(sys.argv)
w = MainWindow()
app.exec_()

Log Sketch uses 3296 bytes (1%) of program storage space. Maximum is 253952 bytes.
Global variables use 190 bytes (2%) of dynamic memory, leaving 8002 bytes for local variables. Maximum is 8192 bytes.

[92mUsed platform[0m [92mVersion[0m [90mPath[0m                                                                      
[93marduino:avr[0m   1.8.6   [90mC:\Users\rmb55\AppData\Local\Arduino15\packages\arduino\hardware\avr\1.8.6[0m

Sketch stats {'flash': 3296, 'flash_max': 253952, 'flash_percent': 0.012978830645161291, 'ram': 190, 'ram_max': 8192, 'ram_percent': 0.023193359375, 'compile_time': 1.120340599998599}
New upload port: COM4 (serial)

Log Sketch uses 2356 bytes (7%) of program storage space. Maximum is 32256 bytes.
Global variables use 190 bytes (9%) of dynamic memory, leaving 1858 bytes for local variables. Maximum is 2048 bytes.

[92mUsed platform[0m [92mVersion[0m [90mPath[0m                                                                      
[93marduino:avr

0