# ES-KMeans

This program is an exercise for ES-KMeans
> H. Kamper, K. Livescu, and S. J. Goldwater,
"An embedded segmental K-means model for unsupervised segmentation and clustering of speech,"
in *Proc. ASRU*, 2017.

We have copied and modified some of the code available at
https://github.com/kamperh/eskmeans,
which is released under GNU General Public License version 3.

Shinozaki Lab Tokyo Tech  
http://www.ts.ip.titech.ac.jp/  
2021

## Table of contents
1. Get landmarks
1. Get segmentation list from landmarks
1. MFCC extraction
1. Embedding MFCC into a fixed length vector
1. ES-KMeans
1. Wave file segmentation
1. Visualize segmentation

## Install ES-KMeans

In [None]:
%cd ../tools
!git clone https://github.com/kamperh/eskmeans.git
%cd ..

## Download data

In [None]:
!wget https://tslab2.ip.titech.ac.jp/spolacq/data.zip
!wget https://tslab2.ip.titech.ac.jp/spolacq/exp1.zip
!wget https://tslab2.ip.titech.ac.jp/spolacq/fig.zip
!unzip -q data.zip
!unzip -q exp1.zip
!unzip -q fig.zip

In [None]:
import os
from pathlib import Path
import pickle
import random
import sys

import librosa
import matplotlib.pyplot as plt
import numpy as np
from pydub import AudioSegment
import scipy.signal as signal

sys.path.append("tools/eskmeans")

from tools.eskmeans.eskmeans.kmeans import KMeans
from tools.eskmeans.eskmeans.eskmeans_wordseg import ESKmeans
from utils.eskmeans_api import save

## Data path

In [None]:
class Args:
    def __init__(self):
        # Get landmarks
        self.sylseg = "exp1/seg/sylseg.csv"
        self.landmarks = "exp1/seg/landmarks_syllable_seg.pkl"
        # Get segmentation list from landmarks
        self.seglist = "exp1/seg/seglist.pkl"
        # MFCC extraction
        self.mfcc = "exp1/seg/mfccs.pkl"
        # MFCC downsampling
        self.embed = "exp1/seg/embed_dur_dic.pkl"
        # ES-KMeans
        self.txt_path = "exp1/seg/eskseg_result.txt"
        self.pkl_path = "exp1/seg/eskseg.pkl"
        # Wave file segmentation
        self.wordseg = "exp1/seg/eskseg.pkl"
        self.wavseg_dir = "exp1/seg/segmented_wavs"


args = Args()

## Get landmarks

Landmarks are the numbers of the frames, with a frame interval of 0.01 seconds, at which the syllables are segmented.
We obtain the landmarks from the combined audio using MATLAB.

In [None]:
def get_landmarks():
    with open(args.sylseg) as f:
        lineList = f.readlines()

    landmarks = []
    for line in lineList:
        bound_list = line.split()
        landmarks_per_wav = [int(round(float(bound)*100.0)) for bound in bound_list]
        landmarks_per_wav = landmarks_per_wav[1:]
        landmarks.append(landmarks_per_wav)
    
    # with open(args.landmarks, "wb") as f:
    #     pickle.dump(landmarks, f, -1)
    # print("landmarks saved to: " + args.landmarks)

get_landmarks()

## Get segmentation list from landmarks

The segmentation list is a list of candidates for word segmentation.
Words are segmented by up to four syllables.

In [None]:
N_LANDMARKS_MAX = 4

def get_seglist_from_landmarks():
    with open(args.landmarks, "rb") as f:
        landmarks = pickle.load(f)
    
    seglist = []
    for m in range(len(landmarks)):
        seglist_per_wav = []
        prev_landmark = 0
        for i in range(len(landmarks[m])):
            for j in landmarks[m][i:i + N_LANDMARKS_MAX]:
                seglist_per_wav.append((prev_landmark, j))
            prev_landmark = landmarks[m][i]
        seglist.append(seglist_per_wav)

    # with open(args.seglist, "wb") as f:
    #     pickle.dump(seglist, f, -1)


get_seglist_from_landmarks()

## MFCC extraction

We extract the MFCC from the combined audio.

In [None]:
def extract_mfcc():
    lines = ["data/combined_sounds_8foods_interval_spolacq1.wav"]
    
    mfccs = []
    for line in lines:
        wav_path = line.rstrip()
        x, sr = librosa.load(wav_path, sr=44100)
        mfccs_per_wav = librosa.feature.mfcc(x, sr=sr, hop_length=441, n_mfcc=20).T
        print("for wav file " + line.rstrip() + ", mfcc shape:")
        print(mfccs_per_wav.shape)
        mfccs.append(mfccs_per_wav)

    # with open(args.mfcc, "wb") as handle:
    #     pickle.dump(mfccs, handle)


# Comment out because it takes time.
# extract_mfcc()

## Embedding MFCC into a fixed length vector

To perform K-Means clustering over segmentation list, we embed MFCC of each segment into a fixed-length vector.
In this case, the length is $\mathrm{n\_mfcc}\times10=20\times10=200$.
The `durations` of a segment is used as a weight in the objective function.

In [None]:
def mfcc_downsampling():
    with open(args.mfcc, "rb") as handle:
        mfccs = pickle.load(handle)
    with open(args.seglist, "rb") as handle:
        seglist = pickle.load(handle)

    embed_dur_dic = {}

    embeddings  = []
    durations   = []

    for m in range(len(mfccs)):
        embeddings_per_wav  = []
        durations_per_wav   = []
        for i, j in seglist[m]:
            y = mfccs[m][i:j+1, :].T
            y_new = signal.resample(y, 10, axis=1).flatten("C")
            embeddings_per_wav.append(y_new)
            durations_per_wav.append(j + 1 - i)
        embeddings.append(embeddings_per_wav)
        durations.append(durations_per_wav)

    embed_dur_dic["embeddings"] = embeddings
    embed_dur_dic["durations"]  = durations

    # with open(args.embed, "wb") as f:
    #     pickle.dump(embed_dur_dic, f, -1)


# Comment out because it takes time.
# mfcc_downsampling()

## ES-KMeans

We perform word segmentation using ES-KMeans, which alternates between segmentation and clustering for many iterations.

In [None]:
def eskmeans():
    
    # n_slices == num of landmarks
    with open(args.landmarks, "rb") as handle:
        landmarks = pickle.load(handle)
    n_slices = [len(landmarks_per_wav) for landmarks_per_wav in landmarks]
    
    n_slices_max = 4
    # n_iter = 100
    n_iter = 10
    p_boundary_init = 1.0

    # get embeddings
    with open(args.embed, "rb") as handle:
        embed_dur_dic = pickle.load(handle)
    
    embeddings = embed_dur_dic["embeddings"]
    durations_raw = embed_dur_dic["durations"]

    # number of wav files
    num_wav_files = len(landmarks)

    # get Vector IDs
    vec_ids = []
    durations = []
    for m in range(num_wav_files):
        vec_ids_per_wav = -1*np.ones((n_slices[m]**2 + n_slices[m])//2, dtype=int)
        durations_per_wav = -1*np.ones((n_slices[m]**2 + n_slices[m])//2, dtype=int)
        i_embed = 0
        for cur_start in range(n_slices[m]):
            for cur_end in range(cur_start, min(n_slices[m], cur_start + n_slices_max)):
                cur_end += 1
                t = cur_end
                i = t*(t - 1)//2
                vec_ids_per_wav[i + cur_start] = i_embed
                durations_per_wav[i + cur_start] = durations_raw[m][i_embed]
                i_embed += 1
        vec_ids.append(vec_ids_per_wav)
        durations.append(durations_per_wav)

    # convert into dics
    embedding_mats = {}
    vec_ids_dict = {}
    durations_dict = {}
    landmarks_dict = {}
    for m in range(num_wav_files):
        embedding_mats[str(m)] = embeddings[m]
        vec_ids_dict[str(m)] = vec_ids[m]
        durations_dict[str(m)] = durations[m]
        landmarks_dict[str(m)] = landmarks[m]


    # Initialize model
    K_max = 2
    segmenter = ESKmeans(
        K_max, embedding_mats, vec_ids_dict, durations_dict, landmarks_dict,
        p_boundary_init=p_boundary_init, n_slices_max=n_slices_max
        )

    # Perform inference
    record = segmenter.segment(n_iter=n_iter)

    # save assignment results
    # save(segmenter.acoustic_model, None, args.txt_path, args.pkl_path)


eskmeans()

## Wave file segmentation

We segment the combined audio using the result of ES-KMeans.

In [None]:
val = []

def wav_segmentation():

    wavs = [AudioSegment.from_wav("data/combined_sounds_8foods_interval_spolacq1.wav")]

    with open(args.seglist, "rb") as handle:
        seglist = pickle.load(handle)
    with open(args.wordseg, "rb") as handle:
        result_dict = pickle.load(handle)

    seglist_a = []
    for i,s in enumerate(seglist):
        for ss in s:
            seglist_a.append((i,ss[0],ss[1]))

    for key in result_dict.keys():
        for seg_index in result_dict[key]:
            # start and end time in milliseconds
            wavi = seglist_a[seg_index][0]
            wav = wavs[wavi]
            t_start = seglist_a[seg_index][1] * 10
            t_end   = seglist_a[seg_index][2] * 10
            
            for i in range(t_start//10, t_end//10):
                val.append((wavi, i))

            wav_segment = wav[t_start:t_end]
            # wav_segment.export(args.wavseg_dir + "/" + str(key) + "_" + str(seg_index) + ".wav", format="wav")

    if len(set(val)) != len(val): #there are at least one duplication
        print("Error: There is an overlap between segmented wavs. The segmentation is FAILED.", file=sys.stderr)
        exit(1)


wav_segmentation()

## Visualize segmentation

In [None]:
wav_path = "data/combined_sounds_8foods_interval_spolacq1.wav"
wav, sr = librosa.load(wav_path)

truncate_second = 11
truncated_wav = wav[:sr*truncate_second]

plt.figure(figsize=(16, 4))
plt.plot(np.arange(len(truncated_wav))/sr, truncated_wav)
plt.xlim([0, truncate_second])
plt.xlabel("time [s]")

sylseg = list() # segmentation seconds
wordseg = list() # segmentation seconds

# Load syllable segmentation result
with open(args.sylseg) as f:
    line = f.read()

for landmark in line.strip().split():
    sylseg.append(float(landmark))

# Load word segmentation result
with open(args.seglist, "rb") as handle:
    seglist = pickle.load(handle)
with open(args.wordseg, "rb") as handle:
    result_dict = pickle.load(handle)

seglist_a = []
for i,s in enumerate(seglist):
    for ss in s:
        seglist_a.append((i,ss[0],ss[1]))

for key in result_dict.keys():
    for seg_index in result_dict[key]:
        # start and end time in milliseconds
        t_start = seglist_a[seg_index][1] * 10 / 1000
        t_end   = seglist_a[seg_index][2] * 10 / 1000
        wordseg.append(t_start)
        wordseg.append(t_end)

plt.vlines(sylseg, color="black", ymin=0, ymax=1, label="Syllable segmentation")
plt.vlines(wordseg, color="red", ymin=-1, ymax=0, label="Word segmentation")
plt.legend(fontsize=16)