In [0]:
class BaseModel():
    def __init__(self):
        self.loss = None
        self.optimizer_step = None

        # 'ph' stands for 'placeholder'
        # input tensor of shape [batch, in_height, in_width, in_channels]
        self.input_X_ph = tf.placeholder('float32', [None, 1, WINDOW_SIZE, 1], name="input_X") 
        self.input_y_ph = tf.placeholder('float32', [None, MIDI_PITCH_AMOUNT], name="input_y")
        self.pred_y_ph = tf.placeholder('float32', [None, MIDI_PITCH_AMOUNT], name="pred_y")

        # train overall
        self.train_loss_history = []
        self.train_recall_history = []
        self.train_precision_history = []
        self.train_f1_score_history = []


        # train isol
        self.train_isol_recall_history = []
        self.train_isol_precision_history = []
        self.train_isol_f1_score_history = []

        #train rand
        self.train_rand_recall_history = []
        self.train_rand_precision_history = []
        self.train_rand_f1_score_history = []

        #train ucho
        self.train_ucho_recall_history = []
        self.train_ucho_precision_history = []
        self.train_ucho_f1_score_history = []

        #train mus
        self.train_mus_recall_history = []
        self.train_mus_precision_history = []
        self.train_mus_f1_score_history = []

        # test overall
        self.test_loss_history = []
        self.test_recall_history = []
        self.test_precision_history = []
        self.test_f1_score_history = []

        # test isol
        self.test_isol_recall_history = []
        self.test_isol_precision_history = []
        self.test_isol_f1_score_history = []

        #test rand
        self.test_rand_recall_history = []
        self.test_rand_precision_history = []
        self.test_rand_f1_score_history = []

        #test ucho
        self.test_ucho_recall_history = []
        self.test_ucho_precision_history = []
        self.test_ucho_f1_score_history = []

        #test mus
        self.test_mus_recall_history = []
        self.test_mus_precision_history = []
        self.test_mus_f1_score_history = []

        self.history_metrics = {}
        self.history_parameters = []

        self.current_epoch_number = 0
        self.overall_train_time = 0.0

        local_tz = pytz.timezone('Europe/Moscow')
        now =  datetime.datetime.now(local_tz)

        self.postfixe_history_filename = now.strftime("_%Y-%m-%d_%H:%M")

        # ====================================================================

        self.learning_rate = tf.placeholder(tf.float32, shape=None, name='learning_rate')
        self.mom = tf.placeholder(tf.float32, shape=None, name='learning_rate')

        self.start_freq = 50
        self.end_freq = 6000

        self.wscale = 0.0001
                  
    def create_filters(self):
        for i in np.arange(self.lvl1_filters_amount):
            current_freq = self.frequencies[i]
            args = 2 * np.pi * current_freq * self.discrete_time / SAMPLERATE
            self.wsin[0, :, 0, i] = np.sin(args)
            self.wcos[0, :, 0, i] = np.cos(args)
            
    def train(self, X, y_true, sess, lr, mom):
        X =  np.reshape(X, (X.shape[0], 1, WINDOW_SIZE, 1))

        loss, y_pred, _ = sess.run([self.loss, self.pred_y_ph, self.optimizer_step],
                                                           feed_dict={self.input_X_ph: X,
                                                                      self.input_y_ph: y_true,
                                                                      self.learning_rate: lr,
                                                                      self.mom: mom})
        
        y_pred_formatted = np.round(y_pred.flatten())
        y_true_formatted = y_true.flatten()
        precision = precision_score(y_true_formatted, y_pred_formatted)
        f1 = f1_score(y_true_formatted, y_pred_formatted)
        recall = recall_score(y_true_formatted, y_pred_formatted)
        
        return loss, precision, recall, f1
      
    def test(self, X, y_true, sess, threshhold=0.5):
        X =  np.reshape(X, (X.shape[0], 1, WINDOW_SIZE, 1))
        
        loss, y_pred = sess.run([self.loss, self.pred_y_ph], feed_dict={self.input_X_ph: X,
                                                                self.input_y_ph: y_true})
        
        y_pred[y_pred >= threshhold] = 1.0
        y_pred[y_pred < threshhold] = 0.0
        y_pred_formatted = y_pred.flatten()
        y_true_formatted = y_true.flatten()
        
        precision = precision_score(y_true_formatted, y_pred_formatted)
        f1 = f1_score(y_true_formatted, y_pred_formatted)
        recall = recall_score(y_true_formatted, y_pred_formatted)
        
        return loss, precision, recall, f1
    
    def get_batches(self, X, y, batch_size=64):
        n_samples = X.shape[0]

        for start in range(0, n_samples, batch_size):
            end = min(start + batch_size, n_samples)
            yield X[start:end], y[start:end]
    
    def get_epoch_train_parameters(self, X, y, sess):
        loss_arr = []
        precision_arr = []
        recall_arr = []
        f1_score_arr = []

        for x_batch, y_batch in self.get_batches(X, y, batch_size):

            loss, precision, recall, f1_score = self.train(x_batch, y_batch, sess, lr, mom)
            loss_arr.append(loss)
            precision_arr.append(precision)
            recall_arr.append(recall)
            f1_score_arr.append(f1_score)

        return np.mean(loss_arr), np.mean(precision_arr),  np.mean(recall_arr), np.mean(f1_score_arr)
    
    
    def get_epoch_test_parameters(self, X, y, sess, threshhold=0.5):
    #     loss, precision, recall, f1_score = model.test(X[:, :], y[:, :], sess)
    #     return loss, precision, recall, f1_score
        loss_arr = []
        precision_arr = []
        recall_arr = []
        f1_score_arr = []
        print("threshhold is: {}".format(threshhold))

        for x_batch, y_batch in self.get_batches(X, y, batch_size):

            loss, precision, recall, f1_score = self.test(x_batch, y_batch, sess, threshhold)
            loss_arr.append(loss)
            precision_arr.append(precision)
            recall_arr.append(recall)
            f1_score_arr.append(f1_score)

        return np.mean(loss_arr), np.mean(precision_arr),  np.mean(recall_arr), np.mean(f1_score_arr)
    
    def display_history(self):

        test_color = '#198F00'
        test_isol_color = '#01FFD0'
        test_rand_color = '#01FA48'
        test_ucho_color = '#009EFF'
        test_mus_color = '#3601E0'

        train_color = '#B33900'
        train_isol_color = '#F001A8'
        train_rand_color = '#FFA200'
        train_ucho_color = '#FFE600'
        train_mus_color = '#FF0601'

        plt.style.use('seaborn-paper')
    #     plt.style.use('dark_background')

        plt.rc('xtick', labelsize=10)    # fontsize of the tick labels
        plt.rc('ytick', labelsize=10) 
        plt.figure(figsize=(26, 48))


        # first row
        plt.subplot(6, 4, 1)
        plt.title("Train and test loss", fontsize=14)
        plt.xlabel("#epoch", fontsize=12)
        plt.ylabel("loss", fontsize=12)
        plt.yscale("log")
        plt.plot(self.train_loss_history, 'b', color=train_color, label="train")
        plt.plot(self.test_loss_history, 'b', color=test_color, label="test")
        plt.legend(fontsize=12)

        plt.subplot(6, 4, 2)
        plt.title("Train and test precision", fontsize=14)
        plt.xlabel("#epoch", fontsize=12)
        plt.ylabel("precision", fontsize=12)
        plt.plot(self.train_precision_history, 'b', color=train_color, label="train")
        plt.plot(self.test_precision_history, 'b', color=test_color, label="test")
        plt.legend(fontsize=12)


        plt.subplot(6, 4, 3)
        plt.title("Train and test recall", fontsize=14)
        plt.xlabel("#epoch", fontsize=12)
        plt.ylabel("recall", fontsize=12)
        plt.plot(self.train_recall_history, 'b', color=train_color, label="train")
        plt.plot(self.test_recall_history, 'b', color=test_color, label="test")
        plt.legend(fontsize=12)

        plt.subplot(6, 4, 4)
        plt.title("Train and test f1-score", fontsize=14)
        plt.xlabel("#epoch", fontsize=12)
        plt.ylabel("f1-score", fontsize=12)
        plt.plot(self.train_f1_score_history, 'b', color=train_color, label="train")
        plt.plot(self.test_f1_score_history, 'b', color=test_color, label="test")
        plt.legend(fontsize=12)


        #second row
        plt.subplot(6, 4, 6)
        plt.title("precision on different parts of the dataset", fontsize=14)
        plt.xlabel("#epoch", fontsize=12)
        plt.ylabel("precision", fontsize=12)

        plt.plot(self.train_precision_history, 'b', color=train_color, label="train_overall")
        plt.plot(self.train_isol_precision_history, 'b', color=train_isol_color, label="train_isol")
        plt.plot(self.train_rand_precision_history, 'b', color=train_rand_color, label="train_rand")
        plt.plot(self.train_ucho_precision_history, 'b', color=train_ucho_color, label="train_ucho")
        plt.plot(self.train_mus_precision_history, 'b', color=train_mus_color, label="train_mus")

        plt.plot(self.test_precision_history, 'b', color=test_color, label="test_overall")
        plt.plot(self.test_isol_precision_history, 'b', color=test_isol_color, label="test_isol")
        plt.plot(self.test_rand_precision_history, 'b', color=test_rand_color, label="test_rand")
        plt.plot(self.test_ucho_precision_history, 'b', color=test_ucho_color, label="test_ucho")
        plt.plot(self.test_mus_precision_history, 'b', color=test_mus_color, label="test_mus")
        plt.legend(fontsize=12)


        plt.subplot(6, 4, 7)
        plt.title("recall on different parts of the dataset", fontsize=14)
        plt.xlabel("#epoch", fontsize=12)
        plt.ylabel("recall", fontsize=12)

        plt.plot(self.train_recall_history, 'b', color=train_color, label="train_overall")
        plt.plot(self.train_isol_recall_history, 'b', color=train_isol_color, label="train_isol")
        plt.plot(self.train_rand_recall_history, 'b', color=train_rand_color, label="train_rand")
        plt.plot(self.train_ucho_recall_history, 'b', color=train_ucho_color, label="train_ucho")
        plt.plot(self.train_mus_recall_history, 'b', color=train_mus_color, label="train_mus")

        plt.plot(self.test_recall_history, 'b', color=test_color, label="test_overall")
        plt.plot(self.test_isol_recall_history, 'b', color=test_isol_color, label="test_isol")
        plt.plot(self.test_rand_recall_history, 'b', color=test_rand_color, label="test_rand")
        plt.plot(self.test_ucho_recall_history, 'b', color=test_ucho_color, label="test_ucho")
        plt.plot(self.test_mus_recall_history, 'b', color=test_mus_color, label="test_mus")
        plt.legend(fontsize=12)


        plt.subplot(6, 4, 8)
        plt.title("f1-score on different parts of the dataset", fontsize=14)
        plt.xlabel("#epoch", fontsize=12)
        plt.ylabel("f1-score", fontsize=12)

        plt.plot(self.train_f1_score_history, 'b', color=train_color, label="train_overall")
        plt.plot(self.train_isol_f1_score_history, 'b', color=train_isol_color, label="train_isol")
        plt.plot(self.train_rand_f1_score_history, 'b', color=train_rand_color, label="train_rand")
        plt.plot(self.train_ucho_f1_score_history, 'b', color=train_ucho_color, label="train_ucho")
        plt.plot(self.train_mus_f1_score_history, 'b', color=train_mus_color, label="train_mus")

        plt.plot(self.test_f1_score_history, 'b', color=test_color, label="test_overall")
        plt.plot(self.test_isol_f1_score_history, 'b', color=test_isol_color, label="test_isol")
        plt.plot(self.test_rand_f1_score_history, 'b', color=test_rand_color, label="test_rand")
        plt.plot(self.test_ucho_f1_score_history, 'b', color=test_ucho_color, label="test_ucho")
        plt.plot(self.test_mus_f1_score_history, 'b', color=test_mus_color, label="test_mus")
        plt.legend(fontsize=12)

        # third row
        plt.subplot(6, 4, 10)
        plt.title("test and train precision: ISOL", fontsize=14)
        plt.xlabel("#epoch", fontsize=12)
        plt.ylabel("precision", fontsize=12)
        plt.plot(self.train_isol_precision_history, 'b', color=train_isol_color, label="train_isol")
        plt.plot(self.test_isol_precision_history, 'b', color=test_isol_color, label="test_isol")
        plt.legend(fontsize=12)

        plt.subplot(6, 4, 11)
        plt.title("test and train recall: ISOL", fontsize=14)
        plt.xlabel("#epoch", fontsize=12)
        plt.ylabel("recall", fontsize=12)
        plt.plot(self.train_isol_recall_history, 'b', color=train_isol_color, label="train_isol")
        plt.plot(self.test_isol_recall_history, 'b', color=test_isol_color, label="test_isol")
        plt.legend(fontsize=12)

        plt.subplot(6, 4, 12)
        plt.title("test and train f1-score: ISOL", fontsize=14)
        plt.xlabel("#epoch", fontsize=12)
        plt.ylabel("f1-score", fontsize=12)
        plt.plot(self.train_isol_f1_score_history, 'b', color=train_isol_color, label="train_isol")
        plt.plot(self.test_isol_f1_score_history, 'b', color=test_isol_color, label="test_isol")
        plt.legend(fontsize=12)


        #fourth row
        plt.subplot(6, 4, 14)
        plt.title("test and train precision: RAND", fontsize=14)
        plt.xlabel("#epoch", fontsize=12)
        plt.ylabel("precision", fontsize=12)
        plt.plot(self.train_rand_precision_history, 'b', color=train_rand_color, label="train_rand")
        plt.plot(self.test_rand_precision_history, 'b', color=test_rand_color, label="test_rand")
        plt.legend(fontsize=12)

        plt.subplot(6, 4, 15)
        plt.title("test and train recall: RAND", fontsize=14)
        plt.xlabel("#epoch", fontsize=12)
        plt.ylabel("recall", fontsize=12)
        plt.plot(self.train_rand_recall_history, 'b', color=train_rand_color, label="train_rand")
        plt.plot(self.test_rand_recall_history, 'b', color=test_rand_color, label="test_rand")
        plt.legend(fontsize=12)

        plt.subplot(6, 4, 16)
        plt.title("test and train f1-score: RAND", fontsize=14)
        plt.xlabel("#epoch", fontsize=12)
        plt.ylabel("f1-score", fontsize=12)
        plt.plot(self.train_rand_f1_score_history, 'b', color=train_rand_color, label="train_rand")
        plt.plot(self.test_rand_f1_score_history, 'b', color=test_rand_color, label="test_rand")
        plt.legend(fontsize=12)


        #fifth row
        plt.subplot(6, 4, 18)
        plt.title("test and train precision: UCHO", fontsize=14)
        plt.xlabel("#epoch", fontsize=12)
        plt.ylabel("precision", fontsize=12)
        plt.plot(self.train_ucho_precision_history, 'b', color=train_ucho_color, label="train_ucho")
        plt.plot(self.test_ucho_precision_history, 'b', color=test_ucho_color, label="test_ucho")
        plt.legend(fontsize=12)

        plt.subplot(6, 4, 19)
        plt.title("test and train recall: UCHO", fontsize=14)
        plt.xlabel("#epoch", fontsize=12)
        plt.ylabel("recall", fontsize=12)
        plt.plot(self.train_ucho_recall_history, 'b', color=train_ucho_color, label="train_ucho")
        plt.plot(self.test_ucho_recall_history, 'b', color=test_ucho_color, label="test_ucho")
        plt.legend(fontsize=12)

        plt.subplot(6, 4, 20)
        plt.title("test and train f1-score: UCHO", fontsize=14)
        plt.xlabel("#epoch", fontsize=12)
        plt.ylabel("f1-score", fontsize=12)
        plt.plot(self.train_ucho_f1_score_history, 'b', color=train_ucho_color, label="train_ucho")
        plt.plot(self.test_ucho_f1_score_history, 'b', color=test_ucho_color, label="test_ucho")
        plt.legend(fontsize=12)

        #sixth row
        plt.subplot(6, 4, 22)
        plt.title("test and train precision: MUS", fontsize=14)
        plt.xlabel("#epoch", fontsize=12)
        plt.ylabel("precision", fontsize=12)
        plt.plot(self.train_mus_precision_history, 'b', color=train_mus_color, label="train_mus")
        plt.plot(self.test_mus_precision_history, 'b', color=test_mus_color, label="test_mus")
        plt.legend(fontsize=12)

        plt.subplot(6, 4, 23)
        plt.title("test and train recall: MUS", fontsize=14)
        plt.xlabel("#epoch", fontsize=12)
        plt.ylabel("recall", fontsize=12)
        plt.plot(self.train_mus_recall_history, 'b', color=train_mus_color, label="train_mus")
        plt.plot(self.test_mus_recall_history, 'b', color=test_mus_color, label="test_mus")
        plt.legend(fontsize=12)

        plt.subplot(6, 4, 24)
        plt.title("test and train f1-score: MUS", fontsize=14)
        plt.xlabel("#epoch", fontsize=12)
        plt.ylabel("f1-score", fontsize=12)
        plt.plot(self.train_mus_f1_score_history, 'b', color=train_mus_color, label="train_mus")
        plt.plot(self.test_mus_f1_score_history, 'b', color=test_mus_color, label="test_mus")
        plt.legend(fontsize=12)



        plt.subplots_adjust(wspace=0.3)
        plt.savefig('/content/gdrive/My Drive/multipitch_estimation/tmp/model' + self.postfixe_history_filename + '.png', bbox_inches='tight')

        plt.show()

    # НЕ ПРИНИМАЕТ НА ВХОД 1 ВЕКТОР!!! Если очень нужно, то нужно решейпнуть этот вектор в (1, len(вектор))
    def predict(self, X, sess, threshold=0.5):
      
        X_reshaped = np.reshape(X, (X.shape[0], 1, WINDOW_SIZE, 1))
        y_pred = sess.run(self.pred_y_ph, feed_dict={self.input_X_ph: X_reshaped})
        y_pred[y_pred >= threshold] = 1.0
        y_pred[y_pred < threshold] = 0.0
        return y_pred
     
    # n_epoch в данном случае -- это на протяжении скольки эпох указанные параметры были актуальны
    def save_history(self, sess, train_time, n_epoch, lr, batch_size): 
        
        self.current_epoch_number += n_epoch
        self.overall_train_time += train_time
        
        self.history_metrics.update({"train_loss_history": self.train_loss_history,
                                     "train_recall_history": self.train_recall_history,
                                     "train_precision_history": self.train_precision_history,
                                     "train_f1_score_history": self.train_f1_score_history,
                                     "test_loss_history": self.test_loss_history,
                                     "test_recall_history": self.test_recall_history,
                                     "test_precision_history": self.test_precision_history,
                                     "test_f1_score_history": self.test_f1_score_history,
                                     "train_isol_precision_history": self.train_isol_precision_history,
                                     "test_isol_precision_history": self.test_isol_precision_history,
                                     "train_isol_recall_history": self.train_isol_recall_history,
                                     "test_isol_recall_history": self.test_isol_recall_history,
                                     "train_isol_f1_score_history": self.train_isol_f1_score_history,
                                     "test_isol_f1_score_history": self.test_isol_f1_score_history,
                                     "train_rand_precision_history": self.train_rand_precision_history,
                                     "test_rand_precision_history": self.test_rand_precision_history,
                                     "train_rand_recall_history": self.train_rand_recall_history,
                                     "test_rand_recall_history": self.test_rand_recall_history,
                                     "train_rand_f1_score_history": self.train_rand_f1_score_history,
                                     "test_rand_f1_score_history": self.test_rand_f1_score_history,                                     
                                     "train_ucho_precision_history": self.train_ucho_precision_history,
                                     "test_ucho_precision_history": self.test_ucho_precision_history,
                                     "train_ucho_recall_history": self.train_ucho_recall_history,
                                     "test_ucho_recall_history": self.test_ucho_recall_history,
                                     "train_ucho_f1_score_history": self.train_ucho_f1_score_history,
                                     "test_ucho_f1_score_history": self.test_ucho_f1_score_history,
                                     "train_mus_precision_history": self.train_mus_precision_history,
                                     "test_mus_precision_history": self.test_mus_precision_history,
                                     "train_mus_recall_history": self.train_mus_recall_history,
                                     "test_mus_recall_history": self.test_mus_recall_history,
                                     "train_mus_f1_score_history": self.train_mus_f1_score_history,
                                     "test_mus_f1_score_history": self.test_mus_f1_score_history,
                                     "last_epoch_number": self.current_epoch_number,
                                     "overall_train_time": self.overall_train_time,
                                     "postfixe_history_filename": self.postfixe_history_filename})
        
        np.save("experiment_metrics_history" + self.postfixe_history_filename, self.history_metrics)

        
        saver = tf.train.Saver()
        save_path = saver.save(sess, "/content/gdrive/My Drive/multipitch_estimation/tmp/model" + self.postfixe_history_filename + ".ckpt")
        print("Model saved in path: %s" % save_path)
        
        print("The train history have been saved! Current epoch number: {}".format(self.current_epoch_number))
        
    def restore_model(self, restore_name, path_variables, path_history, sess):
        # присвоить всем переменным их initial_value
        sess.run(tf.initialize_all_variables())
        sess.run(tf.local_variables_initializer())
        
        # Add ops to save and restore all the variables.
        saver = tf.train.Saver()

        # Later, launch the model, use the saver to restore variables from disk, and
        # do some work with the model.
        # Restore variables from disk.
        saver.restore(sess, path_variables)
        #path_variables ="/content/gdrive/My Drive/multipitch_estimation/tmp/model_2019-05-10_23:25.ckpt"
        #path_history = 'experiment_metrics_history_2019-05-16_10:53.npy'
        
        
        history = np.load(path_history, allow_pickle=True).item()
        
        self.train_loss_history = history["train_loss_history"]
        self.train_recall_history = history["train_recall_history"]
        self.train_precision_history = history["train_precision_history"]
        self.train_f1_score_history = history["train_f1_score_history"]
        self.test_loss_history = history["test_loss_history"]
        self.test_recall_history = history["test_recall_history"]
        self.test_precision_history = history["test_precision_history"]
        self.test_f1_score_history = history["test_f1_score_history"]
        self.train_isol_precision_history = history["train_isol_precision_history"]
        self.test_isol_precision_history = history["test_isol_precision_history"]
        self.train_isol_recall_history = history["train_isol_recall_history"]
        self.test_isol_recall_history = history["test_isol_recall_history"]
        self.train_isol_f1_score_history = history["train_isol_f1_score_history"]
        self.test_isol_f1_score_history = history["test_isol_f1_score_history"]
        self.train_rand_precision_history = history["train_rand_precision_history"]
        self.test_rand_precision_history = history["test_rand_precision_history"]
        self.train_rand_recall_history = history["train_rand_recall_history"]
        self.test_rand_recall_history = history["test_rand_recall_history"]
        self.train_rand_f1_score_history = history["train_rand_f1_score_history"]
        self.test_rand_f1_score_history = history["test_rand_f1_score_history"]
        self.train_ucho_precision_history = history["train_ucho_precision_history"]
        self.test_ucho_precision_history = history["test_ucho_precision_history"]
        self.train_ucho_recall_history = history["train_ucho_recall_history"]
        self.test_ucho_recall_history = history["test_ucho_recall_history"]
        self.train_ucho_f1_score_history = history["train_ucho_f1_score_history"]
        self.test_ucho_f1_score_history = history["test_ucho_f1_score_history"]
        self.train_mus_precision_history = history["train_mus_precision_history"]
        self.test_mus_precision_history = history["test_mus_precision_history"]
        self.train_mus_recall_history = history["train_mus_recall_history"]
        self.test_mus_recall_history = history["test_mus_recall_history"]
        self.train_mus_f1_score_history = history["train_mus_f1_score_history"]
        self.test_mus_f1_score_history = history["test_mus_f1_score_history"]
        self.current_epoch_number = history["last_epoch_number"]
        self.overall_train_time = history["overall_train_time"]
        
        if restore_name:
            self.postfixe_history_filename = history["postfixe_history_filename"]
        
        print("Model restored.")