In [1]:
import numpy as np
import matplotlib.pyplot as plt
import os, glob, json 

from sklearn.metrics import recall_score, precision_score, classification_report, confusion_matrix, plot_confusion_matrix
from sklearn.metrics import accuracy_score

from datetime import datetime

In [2]:
class BaseModel():
    def train(cls):
        raise NotImplementedError
        
    def grid_search():
        raise NotImplementedError

    def read_dataset(dataset_type, dataset_name, split_sex:bool = False):        
        dataset = {
            'OriginalData': {
                'X': [],
                'y': []
            },
            'Augmented': {
                'X': [],
                'y': []
            }
        }
        
        path = f'../datasets/json/{dataset_type}/{dataset_name}'
        
        files = []

        for file in os.listdir(path):
            if file.endswith(".json"):     
                if split_sex:
                    if 'male' in file:
                        file_path = f"{path}/{file}"
                        files.append(file_path)
                else:
                    if 'male' not in file:
                        file_path = f"{path}/{file}"
                        files.append(file_path)
        
        for file_path in files:
            if 'augmented' in os.path.basename(file_path):
                with open(file_path) as file:
                    data = json.load(file)
                    dataset['Augmented']['X'] = np.array(data['features'])
                    dataset['Augmented']['y'] = np.array(data['emotions'])
            else:
                with open(file_path) as file:
                    data = json.load(file)
                    dataset['OriginalData']['X'] = np.array(data['features'])
                    dataset['OriginalData']['y'] = np.array(data['emotions'])
                
        return dataset
    
    def model_accuracy(model, X_train, X_test, y_train, y_test):
        #write contents to files
        path = 'evaluationoutput/'
        if not os.path.exists(path):
            os.makedirs(path)       
        
        dt = datetime.today().strftime('%d-%m-%Y')
        with open(path+dt+'x.txt', 'a') as fp:
            # Train accuracy
            
            y_pred = model.predict(X_train)
            accuracy = accuracy_score(y_train, y_pred)
            print(f"Train accuracy is: {accuracy}")
            fp.write(f"\nTrain accuracy is: {accuracy}")

            # Test accuracy
            y_pred = model.predict(X_test)
            accuracy = accuracy_score(y_test, y_pred)
            print(f"Test accuracy is: {accuracy}")
            fp.write(f"\nTest accuracy is: {accuracy}")

            # Recall & Precision score
            print(f"\nRecall: {recall_score(y_test, y_pred, average=None)}")
            print(f"Precision: {precision_score(y_test, y_pred, average=None)}")   
            print("\nClassification Report:")
            print(classification_report(y_test, y_pred))
            
            fp.write(f"\nRecall: {recall_score(y_test, y_pred, average=None)}")
            fp.write(f"\nPrecision: {precision_score(y_test, y_pred, average=None)}")   
            fp.write("\nClassification Report:")
            fp.write(classification_report(y_test, y_pred))

            # Confusion matrix
            plot_confusion_matrix(model, X_test, y_test, normalize='true')  
            
            fig1 = plt.gcf()
            plt.show()
            plt.draw
            fig1.savefig(path + dt)


In [8]:
BaseModel.read_dataset(dataset_type='test', dataset_name='ravdess', split_sex=False)

{'OriginalData': {'X': array([[-5.44306213e+02,  3.86382599e+01, -1.08062201e+01, ...,
           8.08588928e-04,  2.61889014e-04,  2.27400069e-05],
         [-3.96396637e+02,  2.70301723e+01, -2.65522423e+01, ...,
           1.28889014e-03,  5.22386050e-04,  4.16709554e-05],
         [-4.27742279e+02,  2.99424496e+01, -2.29725513e+01, ...,
           8.08671990e-04,  3.63425730e-04,  2.90307489e-05],
         ...,
         [-4.39426758e+02,  5.28229065e+01, -1.31232004e+01, ...,
           4.92935127e-04,  1.56381386e-04,  1.95508146e-05],
         [-5.95779724e+02,  4.45684929e+01, -1.27382908e+01, ...,
           1.03444772e-05,  3.39071380e-06,  2.49413091e-07],
         [-5.55434692e+02,  6.54705048e+01, -1.08793592e+01, ...,
           6.17211117e-05,  2.40966838e-05,  2.10481926e-06]]),
  'y': array(['angry', 'fearful', 'fearful', 'disgust', 'neutral', 'fearful',
         'sad', 'disgust', 'angry', 'disgust', 'fearful', 'sad', 'happy',
         'neutral', 'fearful', 'fearful', '