# Stroke analysis


In [1]:
# Import extras and supress warnings to keep the tutorial clean
import os
import random
import warnings

## Importing compiam to the project

warnings.filterwarnings('ignore')


In [2]:
from compiam.timbre.stroke_classification import MridangamStrokeClassification

msc = MridangamStrokeClassification()

[   INFO   ] MusicExtractorSVM: no classifier models were configured by default


In [3]:
msc.load_mridangam_dataset(data_home="../audio/mir_dataset/")

100%|██████████| 6976/6976 [00:00<00:00, 13396.76it/s]


In [4]:
# Loading tracks for the mirdangam dataset
mridangam_tracks = msc.dataset.load_tracks()

# Getting list of id per split
# NOTE: We use (0.9, 0.1): two splits including 90% and 10% of the whole dataset
split_dict = msc.dataset.get_random_track_splits(
    splits=(0.9, 0.1),
    split_names=("train", "validation")
)

# Get track dictionaries given the created splits
train_split = {x: mridangam_tracks[x] for x in split_dict["train"]}
evaluation_split = {y: mridangam_tracks[y] for y in split_dict["validation"]}

msc.mridangam_tracks = train_split
msc.mridangam_ids = list(train_split.keys())


In [5]:
svm_accuracy = msc.train_model()


100%|██████████| 10/10 [00:34<00:00,  3.46s/it]


SVM model successfully trained with accuracy 91% in the testing set


In [6]:
% % capture
# Get paths from created evaluation split
eval_paths = [evaluation_split[x].audio_path for x in list(evaluation_split.keys())]

# Compute prediction from list of paths
predictions = msc.predict(eval_paths)

In [15]:
import json

In [17]:
# Selecting a random example from the predicted files
predictions_stats = {x: {"correct": 0, "wrong": 0} for x in msc.list_strokes()}

true_count = 0
for file, prediction in predictions.items():
    identifier = os.path.basename(file).split("__")[0]
    if evaluation_split[identifier].stroke_name == prediction:
        predictions_stats[prediction]["correct"] += 1
        true_count += 1
    else:
        predictions_stats[prediction]["wrong"] += 1

for stroke, stat in predictions_stats.items():
    predictions_stats[stroke]["accuracy"] = predictions_stats[stroke]["correct"]*100/(predictions_stats[stroke]["correct"]+ predictions_stats[stroke]["wrong"])

print("total:", len(predictions));
print("correct", true_count)

print(json.dumps(predictions_stats, indent=2))

total: 697
correct 546
{
  "bheem": {
    "correct": 3,
    "wrong": 0,
    "accuracy": 100.0
  },
  "cha": {
    "correct": 10,
    "wrong": 4,
    "accuracy": 71.42857142857143
  },
  "dheem": {
    "correct": 22,
    "wrong": 1,
    "accuracy": 95.65217391304348
  },
  "dhin": {
    "correct": 34,
    "wrong": 3,
    "accuracy": 91.89189189189189
  },
  "num": {
    "correct": 48,
    "wrong": 11,
    "accuracy": 81.35593220338983
  },
  "ta": {
    "correct": 82,
    "wrong": 28,
    "accuracy": 74.54545454545455
  },
  "tha": {
    "correct": 104,
    "wrong": 67,
    "accuracy": 60.8187134502924
  },
  "tham": {
    "correct": 9,
    "wrong": 1,
    "accuracy": 90.0
  },
  "thi": {
    "correct": 191,
    "wrong": 31,
    "accuracy": 86.03603603603604
  },
  "thom": {
    "correct": 43,
    "wrong": 5,
    "accuracy": 89.58333333333333
  }
}
