(stroke-classification)=
# Stroke classification
The task of percussion stroke classification has been, historically, the principal target on the timbre-side of computational analysis of Indian Art Music. As seen in the [instrumentation presentation](instrumentation), the musical arrangement in Indian Art Music is quite well-defined, while there is an important scarcity of good-quality data (specially monphonic recordings) for many of the Carnatic and Hindustani specific instruments. These factors combined with the importance of the different stroke types in the main percussion instruments, have given research importance on mridangam {cite}`anantapadmanabhan_mridangam_2013, mridangam_stroke` and tabla stroke classification {cite}`4way_tabla`.

We will go through examples of these tasks in this walkthrough.


In [1]:
# Import extras and supress warnings to keep the tutorial clean
import os
import random
import warnings
from pprint import pprint

## Importing compiam to the project
import compiam

warnings.filterwarnings('ignore')

## Mridangam stroke classification

In [2]:
from compiam.timbre.stroke_classification import MridangamStrokeClassification

msc = MridangamStrokeClassification()  # Let's use msc for simplicity


[   INFO   ] MusicExtractorSVM: no classifier models were configured by default


Let's start by loading the mridangam stroke dataset. Since ``MridangamStrokeClassification``is based on the Mridangam Stroke Dataset, `compiam` includes a specific function to load the dataset and integrate it to the pipeline.

In [3]:
msc.load_mridangam_dataset(data_home="../audio/mir_dataset/")

100%|██████████| 6976/6976 [00:00<00:00, 10035.84it/s]


```{note}
This function does not return a dataloader. Instead, the dataloader lives within the tool class. We will see how this works in the following steps of this walkthrough. You may check out the [MridangamStrokeClassification documentation](https://mtg.github.io/compIAM/source/timbre.html#compiam.timbre.stroke_classification.mridangam_stroke_classification.MridangamStrokeClassification) to learn how we do take advantage of the dataloader in this tool class.
```

In [4]:
# Print list of available mirdangam strokes in the dataset
msc.list_strokes()

['bheem', 'cha', 'dheem', 'dhin', 'num', 'ta', 'tha', 'tham', 'thi', 'thom']

Let's train and evaluate a very basic model to perform classification of mridangam strokes. We first use a util function in the `mirdata` Dataset class to separate the Mridangam Stroke Dataset in paticular splits. We will use [``get_random_track_splits``](https://mirdata.readthedocs.io/en/latest/source/mirdata.html#mirdata.core.Dataset.get_random_track_splits), since this dataset does not have pre-determined splits, and we will create these randomly.

In [4]:
# Loading tracks for the mirdangam dataset
mridangam_tracks = msc.dataset.load_tracks()

# Getting list of id per split
# NOTE: We use (0.9, 0.1): two splits including 90% and 10% of the whole dataset
split_dict = msc.dataset.get_random_track_splits(
    splits=(0.9, 0.1),
    split_names=("train", "validation")
)

# Get track dictionaries given the created splits
train_split = {x: mridangam_tracks[x] for x in split_dict["train"]}
evaluation_split = {y: mridangam_tracks[y] for y in split_dict["validation"]}

# Let's get random track from the created evaluation split
random.choice(list(evaluation_split.items()))

('225104',
 Track(
   audio_path="../audio/mir_dataset/mridangam_stroke_1.5/B/225104__akshaylaya__thi-b-323.wav",
   stroke_name="thi",
   tonic="B",
   track_id="225104",
   audio: The track's audio
 
         Returns,
 ))

Our class will assume that the entire dataset is used for the training process. We need to update the dataset in the class with the training split.

In [5]:
msc.mridangam_tracks = train_split
msc.mridangam_ids = list(train_split.keys())

**Let's now train the model!** We will train Support Vector Machine (SVM) model using `scikit learn`. The mridangam stroke classification tool in `compiam` uses the [MusicExtraction in Essentia](https://essentia.upf.edu/streaming_extractor_music.html) to compute low-level features from the stroke recordings and feed the model.

```{note}
You can also train a different model and compare the performance. We offer other options (see [the documentation of the tool](https://mtg.github.io/compIAM/source/timbre.html#mridangam-stroke-classification)), but feel free to open a Pull Request in `compiam` to add more models to the available options.
```


In [6]:
svm_accuracy = msc.train_model(model_type="svm")

100%|██████████| 10/10 [00:44<00:00,  4.46s/it]


SVM model successfully trained with accuracy 91% in the testing set


**The model has been trained!** We have also got the testing accuracy returned in case we want to store it, re-train the model again using different settings, and compare. 

Now we can predict the stroke on a particular list of instances. First, we need to get the list of paths for the `mirdata` dataset split we generated a few steps earlier.

In [7]:
# Get paths from created evaluation split
eval_paths = [evaluation_split[x].audio_path for x in list(evaluation_split.keys())]

# Compute prediction from list of paths
# prediction = msc.predict(eval_paths)

In [8]:
prediction_1 = msc.predict(eval_paths[30])
print(prediction_1)

   mfcc0.mean  mfcc0.dev  mfcc1.mean  mfcc1.dev  mfcc2.mean  mfcc2.dev  \
0 -737.903198  61.411499   151.83432   7.143669  -34.474564    4.08334   

   mfcc3.mean  mfcc3.dev  mfcc4.mean  mfcc4.dev  ...  \
0    9.920733   2.599565    0.318367    0.43185  ...   

   lowLevel.spectral_strongpeak.mean  lowLevel.spectral_strongpeak.dev  \
0                           0.088152                          0.061141   

   lowLevel.zerocrossingrate.mean  lowLevel.zerocrossingrate.dev  \
0                        0.017578                            0.0   

   tonal.tuning_frequency.mean  tonal.tuning_frequency.dev  \
0                   439.491974                         0.0   

   lowLevel.barkbands.mean  lowLevel.barkbands.dev  lowLevel.scvalleys.mean  \
0                 0.000207                0.000568                -7.096732   

   lowLevel.scvalleys.dev  
0                1.428743  

[1 rows x 90 columns]
{'../audio/mir_dataset/mridangam_stroke_1.5/D#/229769__akshaylaya__thi-dsh-206.wav': 'thi

In [8]:
# Visualise and evaluate some predictions from the model output
pprint(random.choice(list(prediction.items())))
# pprint(random.choice(list(prediction.items())))
# pprint(random.choice(list(prediction.items())))

('../audio/mir_dataset/mridangam_stroke_1.5/D#/228925__akshaylaya__dhin-dsh-126.wav',
 'dhin')


In the file paths of this validation files we can already see the actual stroke that is present in the recording, so we can evaluate how good our model classified the mridangam strokes. Otherwise, we can also get the actual tonic using the `mirdata` loader and a particular track ID. 

In [13]:
msc.dataset.choice_track()

Track(
  audio_path="../../../audio/mir_dataset/mridangam_stroke_1.5/C/226097__akshaylaya__thi-c-026.wav",
  stroke_name="thi",
  tonic="C",
  track_id="226097",
  audio: The track's audio

        Returns,
)

We note that the ID has been directly taken from the file name of the stroke recordings. Let's use that to compare, for a random prediction, the predicted and ground-truth stroke annotations. 

In [8]:
# Selecting a random example from the predicted files
predicted_file, predicted_stroke = random.choice(list(prediction.items()))

print("total:", len(prediction));

true_count = 0
for k in prediction:
    if prediction[k] == evaluation_split[os.path.basename(k).split("__")[0]].stroke_name:
        true_count += 1

print("true:", true_count)

# Getting the ID from filepath
identifier = os.path.basename(predicted_file).split("__")[0]

# Comparing target and estimation
# if evaluation_split[identifier].stroke_name == predicted_stroke:
#     print("Nice! Predicted stroke in {}\n coincides with ground-truth {}"\
#         .format(
#             os.path.basename(predicted_file),
#             evaluation_split[identifier].stroke_name
#         )
#     )
# else:
#     print("Missed! Predicted stroke in {}\n does NOT coincide with ground-truth {}"\
#         .format(
#             os.path.basename(predicted_file),
#             evaluation_split[identifier].stroke_name
#         )
#     )

NameError: name 'prediction' is not defined

In [7]:
def absoluteFilePaths(directory):
    for dirpath, _, filenames in os.walk(directory):
        for f in filenames:
            yield os.path.abspath(os.path.join(dirpath, f))


samples = list(absoluteFilePaths("../../../Mridangam_Sir/june6_210"))

print(samples)


['/home/vaichu/projects/IITM/Mridangam_Sir/june6_210/bheem1.wav', '/home/vaichu/projects/IITM/Mridangam_Sir/june6_210/tha2.wav', '/home/vaichu/projects/IITM/Mridangam_Sir/june6_210/tha_thi_thom_num_2.wav', '/home/vaichu/projects/IITM/Mridangam_Sir/june6_210/num2.wav', '/home/vaichu/projects/IITM/Mridangam_Sir/june6_210/tha_thi_thom_num_3.wav', '/home/vaichu/projects/IITM/Mridangam_Sir/june6_210/chappu3.wav', '/home/vaichu/projects/IITM/Mridangam_Sir/june6_210/thi4.wav', '/home/vaichu/projects/IITM/Mridangam_Sir/june6_210/bheem1-2.wav', '/home/vaichu/projects/IITM/Mridangam_Sir/june6_210/tha_thi_thom_num_2_2.wav', '/home/vaichu/projects/IITM/Mridangam_Sir/june6_210/num1.wav', '/home/vaichu/projects/IITM/Mridangam_Sir/june6_210/ki3.wav', '/home/vaichu/projects/IITM/Mridangam_Sir/june6_210/ki1.wav', '/home/vaichu/projects/IITM/Mridangam_Sir/june6_210/bheem2.wav', '/home/vaichu/projects/IITM/Mridangam_Sir/june6_210/dheem1.wav', '/home/vaichu/projects/IITM/Mridangam_Sir/june6_210/bheem3.wav

In [8]:
from compiam.timbre.stroke_classification.mridangam_stroke_classification.stroke_features import (
    features_for_pred,
)
import importlib
import json

In [9]:
importlib.reload(compiam.timbre.stroke_classification.mridangam_stroke_classification.stroke_features)
pre = msc.predict(samples)
print(json.dumps({x.split("/")[-1]: v for (x, v) in pre.items()}, indent=4, sort_keys=True))

{
    "tha2.wav": "thi"
}
