# Milestone 1

This milestone focuses on understanding the dataset and establishing a baseline performance through **exploratory data analysis (EDA)** and simple **heuristic-based methods** using `librosa`.

---

## Suggested Readings
- [Hugging Face Audio Course](https://huggingface.co/learn/audio-course/en/chapter0/introduction)
- [Librosa Documentation](https://librosa.org/doc/main/core.html#audio-loading)

---

## Instructions
Use this notebook to answer **all Milestone-1 questions**.

---

## Resources
- Notebook Link:  
  https://colab.research.google.com/drive/1m6UczhxQIke_raWSqukSWuiKbIVt7MMb?usp=sharing  

- Competition Link:  
  https://www.kaggle.com/competitions/jan-2026-dl-gen-ai-project/


In [11]:
import os
import glob
import numpy as np
import pandas as pd
from tqdm import tqdm
import librosa
import librosa.display
import matplotlib.pyplot as plt
import random
import torch

import warnings
warnings.filterwarnings("ignore")

In [12]:
#----------------------------- DON'T CHANGE THIS --------------------------
DATA_SEED = 67
TRAINING_SEED = 1234
SR = 22050
DURATION = 5.0
N_FFT = 2048
HOP_LENGTH = 512
N_MELS = 128
TOP_DB=20
TARGET_SNR_DB = 10

random.seed(DATA_SEED)
np.random.seed(DATA_SEED)
torch.manual_seed(DATA_SEED)
torch.cuda.manual_seed(DATA_SEED)

In [13]:
# CONFIGURATION
DATA_ROOT = "/kaggle/input/jan-2026-dl-gen-ai-project/messy_mashup/genres_stems/"
GENRES = ['blues', 'classical', 'country', 'disco', 'hiphop', 'jazz', 'metal', 'pop', 'reggae', 'rock']
STEMS = ['bass.wav', 'drums.wav', 'other.wav', 'vocals.wav']
STEM_KEYS = ['drums', 'vocals', 'bass', 'other']
GENRE_TO_TEST = 'rock'
SONG_INDEX = 0

In [14]:
def build_dataset(root_dir, val_split=0.17, seed=42):
    # Initialize empty dictionaries
    train_dataset = {g: {s.replace('.wav', ''): [] for s in STEMS} for g in GENRES}
    val_dataset   = {g: {s.replace('.wav', ''): [] for s in STEMS} for g in GENRES}

    rng = random.Random(seed)

    # ------------------- write your code here -------------------------------
    
    # Iterate through Genres
    for genre in tqdm(GENRES, desc="Processing Genres"):
        genre_path = os.path.join(root_dir, genre)
        
        # Check: if genre folder exists
        if not os.path.exists(genre_path):
            print(f"Warning: Genre folder {genre} not found!")
            continue
        
        # Get all song folders
        song_folders = sorted([f for f in os.listdir(genre_path) if os.path.isdir(os.path.join(genre_path, f))])
        
        # Filter valid songs
        valid_songs = []
        for song_folder in song_folders:
            song_path = os.path.join(genre_path, song_folder)
            
            # CHECK : Completeness (Does it have all stems?)
            stems_present = [os.path.join(song_path, stem) for stem in STEMS]
            if not all(os.path.exists(stem_path) for stem_path in stems_present):
                continue  # Skip incomplete songs
            
            # CHECK : Corruption (Is any file too small? (less than 4kb))
            # size checks
            corrupted = False
            for stem_path in stems_present:
                file_size = os.path.getsize(stem_path)
                if file_size < 4 * 1024:  # 4KB = 4 * 1024 bytes
                    corrupted = True
                    break
            
            if not corrupted:
                valid_songs.append(song_folder)
        
        # Stratified Shuffle Split
        rng.shuffle(valid_songs)
        split_idx = int(len(valid_songs) * (1 - val_split))
        train_songs = valid_songs[:split_idx]
        val_songs = valid_songs[split_idx:]
        
        # Helper function to populate dict
        def add_to_dict(target_dict, song_list):
            for song_folder in song_list:
                song_path = os.path.join(genre_path, song_folder)
                for stem in STEMS:
                    stem_name = stem.replace('.wav', '')
                    stem_path = os.path.join(song_path, stem)
                    target_dict[genre][stem_name].append(stem_path)
        
        add_to_dict(train_dataset, train_songs)
        add_to_dict(val_dataset, val_songs)
     #-------------------------------------------------------------------------

    return train_dataset, val_dataset

tr, val = build_dataset(DATA_ROOT)

Processing Genres: 100%|██████████| 10/10 [00:04<00:00,  2.08it/s]


In [15]:
# ---- Q1-Q3 Analysis ----
MB = 1024 * 1024
threshold_small = 5.0491 * MB   # < 5.0491 MB
threshold_large = 5.0493 * MB   # > 5.0493 MB

corrupted_count = 0
small_files     = 0
large_files     = 0

for genre in GENRES:
    genre_path = os.path.join(DATA_ROOT, genre)
    if not os.path.exists(genre_path):
        continue
    for song_folder in os.listdir(genre_path):
        song_path = os.path.join(genre_path, song_folder)
        if not os.path.isdir(song_path):
            continue
        for stem in STEMS:
            stem_path = os.path.join(song_path, stem)
            if os.path.exists(stem_path):
                file_size = os.path.getsize(stem_path)
                if file_size < 4 * 1024:
                    corrupted_count += 1
                if file_size < threshold_small:
                    small_files += 1
                if file_size > threshold_large:
                    large_files += 1

print(f"Corrupted songs (< 4KB)  : {corrupted_count}")
print(f"Total songs < 5.0491MB   : {small_files}")
print(f"Total songs > 5.0493MB   : {large_files}")
print()

# Q1
q1 = corrupted_count + small_files
print(f"[Q1] Corrupted + Small   = {corrupted_count} + {small_files} = {q1}")

# Q2
q2 = abs(large_files - small_files)
print(f"[Q2] |Large - Small|     = |{large_files} - {small_files}| = {q2}")

# Q3
train_reggae_drums  = len(tr['reggae']['drums'])
val_country_vocals  = len(val['country']['vocals'])
q3 = abs(train_reggae_drums - val_country_vocals)
print(f"[Q3] |Train reggae drums - Val country vocals| = |{train_reggae_drums} - {val_country_vocals}| = {q3}")

Corrupted songs (< 4KB)  : 0
Total songs < 5.0491MB   : 1256
Total songs > 5.0493MB   : 184

[Q1] Corrupted + Small   = 0 + 1256 = 1256
[Q2] |Large - Small|     = |184 - 1256| = 1072
[Q3] |Train reggae drums - Val country vocals| = |83 - 17| = 66


In [16]:
def find_long_silences(dataset_dict, sr=SR, threshold_sec=DURATION, top_db=TOP_DB):
    """
    Input:
        dataset_dict: The dictionary structure {genre: {stem: [paths...]}}
    Output:
        df: Pandas DataFrame containing details of all files with silence >= 5s
    """
    records = []
    # ------------------- write your code here -------------------------------

    total_files = sum(len(paths) for genre_data in dataset_dict.values()
                      for paths in genre_data.values())
    print(f"Analyzing {total_files} files for long silences...")

    pbar = tqdm(total=total_files, desc="Processing files")

    for genre, stems_dict in dataset_dict.items():
        for stem_name, file_paths in stems_dict.items():
            for file_path in file_paths:
                try:
                    # Load Audio
                    y, _ = librosa.load(file_path, sr=sr, duration=None)
                    total_duration = len(y) / sr

                    # Find Non-Silent Intervals
                    non_silent_intervals = librosa.effects.split(y, top_db=top_db)

                    max_silence  = 0.0
                    silence_type = []

                    if len(non_silent_intervals) == 0:
                        # CASE A: Fully silent
                        max_silence  = total_duration
                        silence_type = ["start", "middle", "end"]
                    else:
                        # CASE B: START silence
                        start_silence = non_silent_intervals[0][0] / sr
                        if start_silence > max_silence:
                            max_silence  = start_silence
                            silence_type = ["start"]

                        # CASE C: END silence
                        end_silence = (len(y) - non_silent_intervals[-1][1]) / sr
                        if end_silence > max_silence:
                            max_silence  = end_silence
                            silence_type = ["end"]
                        elif abs(end_silence - max_silence) < 0.01:
                            if "end" not in silence_type:
                                silence_type.append("end")

                        # CASE D: MIDDLE silence
                        for i in range(len(non_silent_intervals) - 1):
                            gap_start    = non_silent_intervals[i][1]
                            gap_end      = non_silent_intervals[i + 1][0]
                            gap_duration = (gap_end - gap_start) / sr

                            if gap_duration > max_silence:
                                max_silence  = gap_duration
                                silence_type = ["middle"]
                            elif abs(gap_duration - max_silence) < 0.01:
                                if "middle" not in silence_type:
                                    silence_type.append("middle")

                    # Store result if silence meets threshold
                    if max_silence >= threshold_sec:
                        records.append({
                            "Genre":            genre,
                            "Stem":             stem_name,
                            "Duration":         round(total_duration, 2),
                            "Max_Silence_Sec":  round(max_silence, 2),
                            "Silence_Location": ", ".join(silence_type),
                            "File_Path":        file_path
                        })

                except Exception as e:
                    print(f"Error processing {file_path}: {e}")

                pbar.update(1)

    pbar.close()
    #-------------------------------------------------------------------------
    df = pd.DataFrame(records)
    return df


# --- EXECUTION ---
df_silence = find_long_silences(tr, threshold_sec=DURATION, top_db=TOP_DB)

# --- RESULTS ANALYSIS ---
# ------------------- write your code here -------------------------------
print(f"\nTotal files with silence >= {DURATION}s: {len(df_silence)}")
print(df_silence.head())
#-------------------------------------------------------------------------

# Hint: Create a pivot Table: Count by Genre vs Stem
if len(df_silence) > 0:
    pivot_table = df_silence.pivot_table(index='Genre', columns='Stem', aggfunc='size', fill_value=0)
    print("\nPivot Table (Genre vs Stem):")
    print(pivot_table)

# ---- Q4-Q9 Analysis ----
print()
q4 = len(df_silence)
print(f"[Q4] Total sound files with silence >= 5s          : {q4}")

q5 = len(df_silence[df_silence['Stem'] == 'vocals'])
print(f"[Q5] Total Vocals tracks with silence >= 5s        : {q5}")

vocals_df = df_silence[df_silence['Stem'] == 'vocals']
q6 = round(vocals_df['Max_Silence_Sec'].mean(), 2) if len(vocals_df) > 0 else 0.0
print(f"[Q6] Average Silence Length in Vocals (secs)       : {q6}")

jazz_drums = df_silence[(df_silence['Genre'] == 'jazz') & (df_silence['Stem'] == 'drums')]
q7 = len(jazz_drums)
print(f"[Q7] Jazz drums tracks with silence >= 5s          : {q7}")

q8 = len(jazz_drums[jazz_drums['Silence_Location'] == 'middle'])
print(f"[Q8] Jazz drums (silence >= 5s, location=middle)   : {q8}")

q9 = len(jazz_drums[jazz_drums['Max_Silence_Sec'] >= 10.0])
print(f"[Q9] Jazz drums (silence >= 5s, Max_Silence >= 10) : {q9}")

Analyzing 3320 files for long silences...


Processing files: 100%|██████████| 3320/3320 [06:49<00:00,  8.10it/s]


Total files with silence >= 5.0s: 680
   Genre  Stem  Duration  Max_Silence_Sec Silence_Location  \
0  blues  bass     30.01             7.08           middle   
1  blues  bass     30.01             8.68           middle   
2  blues  bass     30.01             8.38            start   
3  blues  bass     30.01             5.87           middle   
4  blues  bass     30.01            21.80            start   

                                           File_Path  
0  /kaggle/input/jan-2026-dl-gen-ai-project/messy...  
1  /kaggle/input/jan-2026-dl-gen-ai-project/messy...  
2  /kaggle/input/jan-2026-dl-gen-ai-project/messy...  
3  /kaggle/input/jan-2026-dl-gen-ai-project/messy...  
4  /kaggle/input/jan-2026-dl-gen-ai-project/messy...  

Pivot Table (Genre vs Stem):
Stem       bass  drums  other  vocals
Genre                                
blues        17     22      5      43
classical    68     57      5      69
country      16     16      2      16
disco         8      2      3      18





In [17]:
stems_audio = []
try:
    for key in STEM_KEYS:
        # ------------------- write your code here -------------------------------
        # Load audio (Duration 5.0s for speed/consistency)
        stem_path = tr[GENRE_TO_TEST][key][SONG_INDEX]
        y, sr_loaded = librosa.load(stem_path, sr=SR, duration=DURATION)
        stems_audio.append(y)
        print(f"Loaded {key}: {stem_path}")
        print(f"  Shape: {y.shape}, Sample rate: {sr_loaded}")
        #-------------------------------------------------------------------------

    print("\nAudio loaded successfully.")
except NameError:
    print("ERROR: 'tr' dictionary not found. Please run build_dataset() first.")
except IndexError:
    print(f"ERROR: Song index {SONG_INDEX} out of range for genre {GENRE_TO_TEST}.")
except Exception as e:
    print(f"ERROR: {e}")

Loaded drums: /kaggle/input/jan-2026-dl-gen-ai-project/messy_mashup/genres_stems/rock/rock.00092/drums.wav
  Shape: (110250,), Sample rate: 22050
Loaded vocals: /kaggle/input/jan-2026-dl-gen-ai-project/messy_mashup/genres_stems/rock/rock.00092/vocals.wav
  Shape: (110250,), Sample rate: 22050
Loaded bass: /kaggle/input/jan-2026-dl-gen-ai-project/messy_mashup/genres_stems/rock/rock.00092/bass.wav
  Shape: (110250,), Sample rate: 22050
Loaded other: /kaggle/input/jan-2026-dl-gen-ai-project/messy_mashup/genres_stems/rock/rock.00092/other.wav
  Shape: (110250,), Sample rate: 22050

Audio loaded successfully.


In [18]:
# ------------------- write your code here -------------------------------
# Stack them into a numpy array (Shape: 4 x Samples)
stems_stack = np.array(stems_audio)
print(f"Stems stack shape: {stems_stack.shape}")

# Mix the stems by summing them element-wise
mix_raw = np.sum(stems_stack, axis=0)
print(f"Mix raw shape: {mix_raw.shape}")
print(f"Mix raw length: {len(mix_raw)}")

# Calculate RMS Amplitude MANUALLY
rms_val = np.sqrt(np.mean(mix_raw**2))
print(f"RMS Amplitude: {rms_val:.4f}")

# Peak Normalization
max_val = np.max(np.abs(mix_raw))
print(f"Max value (before normalization): {max_val:.4f}")

if max_val > 0:
    mix_norm = mix_raw / max_val
else:
    mix_norm = mix_raw

print(f"Max value (after normalization): {np.max(np.abs(mix_norm)):.4f}")

# VALIDATION
assert np.isclose(np.max(np.abs(mix_norm)), 1.0), "Normalization failed."
print("\nNormalization validation: PASSED")
#------------------------------------------------------------------------

# ---- Q10-Q12 Analysis ----
print(f"\n[Q10] Length of mix sample          : {len(mix_raw)}")
print(f"[Q11] RMS Amplitude of mix sample   : {round(rms_val, 2)}")
print(f"[Q12] Max value of normalized sample: {round(float(np.max(np.abs(mix_norm))), 2)}")

Stems stack shape: (4, 110250)
Mix raw shape: (110250,)
Mix raw length: 110250
RMS Amplitude: 0.1021
Max value (before normalization): 0.5894
Max value (after normalization): 1.0000

Normalization validation: PASSED

[Q10] Length of mix sample          : 110250
[Q11] RMS Amplitude of mix sample   : 0.10000000149011612
[Q12] Max value of normalized sample: 1.0
