In [1]:
import numpy as np
import pandas as pd
import random
from collections import Counter

In [2]:
def dataframe_to_array(dataset):
    A = []
    counter = Counter()
    for idx, x in dataset.iterrows():
        fname = x.fname
        labels = x.labels.split(',')
        A.append([fname, labels])
    return A

In [3]:
def get_labels_counter(dataset):
    counter = Counter()
    for fname, labels in dataset:
        counter.update(labels)
    
    return counter

In [4]:
dataset = pd.read_csv('../input/train_curated.csv')
dataset = dataframe_to_array(dataset)

In [5]:
def check_input(k, arr):
    if(arr == None):
        arr = [1.0/k for i in range(k)]
    arr = np.array(arr)
    assert len(arr) == k
    assert arr.sum() == 1.0
    return arr

In [6]:
def stratify_multilabel(dataset, k, dist=None):
    dataset = dataset[:]
    # dist - distribution
    labels_counter = get_labels_counter(dataset)
    dist = check_input(k, dist)
    S = [[] for i in range(k)]
    C2 = []
    for j in range(k):
        temp = {label: int(labels_counter[label] * dist[j]) for label in labels_counter}
        C2.append(temp)
        
    while len(dataset) > 0:
        labels_left = get_labels_counter(dataset)
        l = min(labels_left, key=labels_left.get)
        to_delete = []
        for idx, x in enumerate(dataset):
            labels = x[1]
            if l not in labels:
                continue

            indexes = [C2[i][l] for i in range(k)]
            idx_max = np.argmax(indexes)
            fname = dataset[idx][0]
            labels = dataset[idx][1]
            for label in labels:
                C2[idx_max][label] -= 1
            S[idx_max].append(x)
            to_delete.append(idx)
        
        for e in sorted(to_delete, reverse=True):
            del dataset[e]
    return S

In [7]:
k = 3
S = stratify_multilabel(dataset, k, [0.8, 0.1, 0.1])

for num, s in enumerate(S):
    counter = get_labels_counter(s)
    print(num, ':!!!')
    for x, val in counter.items():
        print(x, ': ', val)
    print()

0 :!!!
Accordion :  38
Gasp :  39
Crowd :  61
Mechanical_fan :  40
Fill_(with_liquid) :  40
Sink_(filling_or_washing) :  61
Water_tap_and_faucet :  61
Bathtub_(filling_or_washing) :  61
Trickle_and_dribble :  43
Drip :  61
Squeak :  61
Dishes_and_pots_and_pans :  61
Gurgling :  61
Stream :  61
Buzz :  45
Cricket :  61
Bark :  61
Chirp_and_tweet :  61
Race_car_and_auto_racing :  45
Accelerating_and_revving_and_vroom :  61
Car_passing_by :  61
Motorcycle :  61
Glockenspiel :  45
Sigh :  46
Frying_(food) :  51
Crackle :  61
Traffic_noise_and_roadway_noise :  61
Sneeze :  51
Cutlery_and_silverware :  61
Raindrop :  61
Toilet_flush :  61
Chink_and_clink :  61
Purr :  53
Meow :  61
Clapping :  61
Hiss :  61
Bicycle_bell :  54
Cheering :  61
Applause :  61
Yell :  61
Bus :  61
Screaming :  61
Child_speech_and_kid_speaking :  61
Harmonica :  61
Tap :  61
Knock :  61
Cupboard_open_or_close :  61
Drawer_open_or_close :  61
Shatter :  61
Bass_drum :  61
Acoustic_guitar :  60
Fart :  61
Church_bel

In [8]:
for s in S:
    random.shuffle(s)
file_names = ['train_curated_small.csv', 'val_curated.csv', 'test_curated.csv']
train_set = S[0]
val_set = S[1]
test_set = S[2]

In [9]:
for idx, file_name in enumerate(file_names):
    with open('../input/' + file_name, 'w+') as f:
        f.write('fname,labels\n')
        for x in S[idx]:
            f.write(x[0] + ',' + ','.join(x[1]) + '\n')