In [None]:
# CELL 1: Install dependencies and setup
!pip install transformers torch torchaudio sentencepiece librosa soundfile -q

import torch
import torchaudio
import numpy as np
import json
import os
from pathlib import Path
from google.colab import drive

# Mount Google Drive
drive.mount('/content/drive')

# Check GPU
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
if device == "cuda":
    print(f"GPU: {torch.cuda.get_device_name(0)}")

Mounted at /content/drive
Using device: cuda
GPU: Tesla T4


In [None]:
# CELL 2: Configuration
# =============================================================
# EDIT THIS PATH to point to your audio file in Google Drive
# =============================================================
AUDIO_PATH = "/content/drive/MyDrive/satere_project/acts12.mp3"

# Output directory
OUTPUT_DIR = "/content/drive/MyDrive/satere_project/phase2_output"
os.makedirs(OUTPUT_DIR, exist_ok=True)

# Model settings
MODEL_NAME = "facebook/wav2vec2-large-xlsr-53"
SAMPLE_RATE = 16000  # XLSR-53 expects 16kHz

# BPE settings
VOCAB_SIZE = 500     # Number of motifs to discover
MIN_FREQUENCY = 5    # Minimum times a pattern must appear

print(f"Audio file: {AUDIO_PATH}")
print(f"Output dir: {OUTPUT_DIR}")
print(f"File exists: {os.path.exists(AUDIO_PATH)}")

Audio file: /content/drive/MyDrive/satere_project/acts12.mp3
Output dir: /content/drive/MyDrive/satere_project/phase2_output
File exists: True


In [None]:
# CELL 3: Load and preprocess audio
import librosa
import soundfile as sf

print("Loading audio file...")

# Load audio and resample to 16kHz
audio, sr = librosa.load(AUDIO_PATH, sr=SAMPLE_RATE, mono=True)

duration_seconds = len(audio) / SAMPLE_RATE
duration_minutes = duration_seconds / 60

print(f"Audio loaded successfully!")
print(f"  Duration: {duration_minutes:.1f} minutes ({duration_seconds:.0f} seconds)")
print(f"  Sample rate: {sr} Hz")
print(f"  Samples: {len(audio):,}")

# Save preprocessed audio
preprocessed_path = os.path.join(OUTPUT_DIR, "audio_16khz.wav")
sf.write(preprocessed_path, audio, SAMPLE_RATE)
print(f"  Saved preprocessed audio to: {preprocessed_path}")
```

Press **Shift+Enter**.

You should see something like:
```
Audio loaded successfully!
  Duration: 9.3 minutes (557 seconds)
  Sample rate: 16000 Hz

SyntaxError: invalid syntax (ipython-input-1005178797.py, line 22)

In [None]:
# CELL 3: Load and preprocess audio
import librosa
import soundfile as sf

print("Loading audio file...")

audio, sr = librosa.load(AUDIO_PATH, sr=SAMPLE_RATE, mono=True)

duration_seconds = len(audio) / SAMPLE_RATE
duration_minutes = duration_seconds / 60

print(f"Audio loaded successfully!")
print(f"  Duration: {duration_minutes:.1f} minutes ({duration_seconds:.0f} seconds)")
print(f"  Sample rate: {sr} Hz")
print(f"  Samples: {len(audio):,}")

preprocessed_path = os.path.join(OUTPUT_DIR, "audio_16khz.wav")
sf.write(preprocessed_path, audio, SAMPLE_RATE)
print(f"  Saved preprocessed audio to: {preprocessed_path}")

Loading audio file...
Audio loaded successfully!
  Duration: 9.3 minutes (557 seconds)
  Sample rate: 16000 Hz
  Samples: 8,912,000
  Saved preprocessed audio to: /content/drive/MyDrive/satere_project/phase2_output/audio_16khz.wav


In [None]:
# CELL 4: Load XLSR-53 model
from transformers import Wav2Vec2Model, Wav2Vec2FeatureExtractor

print("Loading XLSR-53 model (this may take a minute)...")

feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(MODEL_NAME)
model = Wav2Vec2Model.from_pretrained(MODEL_NAME)
model = model.to(device)
model.eval()

print("Model loaded successfully!")
print(f"  Model: {MODEL_NAME}")
print(f"  Device: {device}")

Loading XLSR-53 model (this may take a minute)...


pytorch_model.bin:   0%|          | 0.00/1.27G [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.27G [00:00<?, ?B/s]

Model loaded successfully!
  Model: facebook/wav2vec2-large-xlsr-53
  Device: cuda


In [None]:
# CELL 5: Extract acoustic features
from tqdm import tqdm

def extract_features_chunked(audio, model, feature_extractor, chunk_duration=30, overlap=1):
    chunk_samples = int(chunk_duration * SAMPLE_RATE)
    overlap_samples = int(overlap * SAMPLE_RATE)
    step = chunk_samples - overlap_samples

    all_features = []
    n_chunks = max(1, int(np.ceil((len(audio) - overlap_samples) / step)))

    print(f"Processing {n_chunks} chunks...")

    for i in tqdm(range(n_chunks)):
        start = i * step
        end = min(start + chunk_samples, len(audio))
        chunk = audio[start:end]

        inputs = feature_extractor(chunk, sampling_rate=SAMPLE_RATE, return_tensors="pt")
        input_values = inputs.input_values.to(device)

        with torch.no_grad():
            outputs = model(input_values)
            hidden_states = outputs.last_hidden_state

        if i > 0 and overlap_samples > 0:
            frames_to_skip = int(overlap_samples / SAMPLE_RATE * 50)
            hidden_states = hidden_states[:, frames_to_skip:, :]

        all_features.append(hidden_states.cpu().numpy())

    features = np.concatenate(all_features, axis=1)
    return features[0]

print("Extracting acoustic features...")
features = extract_features_chunked(audio, model, feature_extractor)

print(f"\nFeature extraction complete!")
print(f"  Feature shape: {features.shape}")
print(f"  Time frames: {features.shape[0]}")
print(f"  Feature dimension: {features.shape[1]}")

Extracting acoustic features...
Processing 20 chunks...


100%|██████████| 20/20 [00:08<00:00,  2.24it/s]


Feature extraction complete!
  Feature shape: (27830, 1024)
  Time frames: 27830
  Feature dimension: 1024





In [None]:
# CELL 6: Quantize to discrete units using K-means
from sklearn.cluster import MiniBatchKMeans

N_CLUSTERS = 100  # Number of discrete acoustic units

print(f"Clustering features into {N_CLUSTERS} discrete units...")

kmeans = MiniBatchKMeans(
    n_clusters=N_CLUSTERS,
    batch_size=1000,
    random_state=42,
    n_init=3
)

acoustic_units = kmeans.fit_predict(features)

print(f"\nQuantization complete!")
print(f"  Total frames: {len(acoustic_units):,}")
print(f"  Unique units: {len(np.unique(acoustic_units))}")

# Save acoustic units
units_path = os.path.join(OUTPUT_DIR, "acoustic_units.json")
with open(units_path, 'w') as f:
    json.dump({
        "units": acoustic_units.tolist(),
        "n_clusters": N_CLUSTERS,
        "total_frames": len(acoustic_units)
    }, f)

# Convert to pseudo-text
unit_string = " ".join([f"U{u:02d}" for u in acoustic_units])
print(f"\nPseudo-text sample: {unit_string[:200]}...")

pseudotext_path = os.path.join(OUTPUT_DIR, "pseudotext.txt")
with open(pseudotext_path, 'w') as f:
    f.write(unit_string)
print(f"Saved pseudo-text to: {pseudotext_path}")

Clustering features into 100 discrete units...

Quantization complete!
  Total frames: 27,830
  Unique units: 100

Pseudo-text sample: U35 U12 U55 U55 U25 U65 U65 U25 U25 U55 U55 U55 U55 U55 U55 U55 U55 U55 U55 U55 U25 U65 U65 U55 U65 U81 U82 U82 U53 U53 U83 U79 U79 U79 U44 U44 U44 U51 U66 U66 U66 U31 U66 U63 U38 U38 U59 U59 U61 U61 ...
Saved pseudo-text to: /content/drive/MyDrive/satere_project/phase2_output/pseudotext.txt


In [None]:
# CELL 7: Discover motifs using SentencePiece BPE
import sentencepiece as spm

print("Training BPE model to discover acoustic motifs...")

# Split into chunks (BPE needs multiple "sentences")
units_list = unit_string.split(" ")
chunk_size = 500
chunks = [" ".join(units_list[i:i+chunk_size]) for i in range(0, len(units_list), chunk_size)]

train_path = os.path.join(OUTPUT_DIR, "bpe_train.txt")
with open(train_path, 'w') as f:
    for chunk in chunks:
        f.write(chunk + "\n")

model_prefix = os.path.join(OUTPUT_DIR, "motifs")

spm.SentencePieceTrainer.train(
    input=train_path,
    model_prefix=model_prefix,
    vocab_size=116,
    model_type='bpe',
    character_coverage=1.0,
    max_sentence_length=10000,
    num_threads=4
)

sp = spm.SentencePieceProcessor()
sp.load(f"{model_prefix}.model")

print(f"BPE training complete!")
print(f"  Vocabulary size: {sp.get_piece_size()}")

Training BPE model to discover acoustic motifs...
BPE training complete!
  Vocabulary size: 116


In [None]:
# CELL 8: Analyze discovered motifs
from collections import Counter

encoded = sp.encode_as_pieces(unit_string)
motif_counts = Counter(encoded)

print("=" * 60)
print("DISCOVERED ACOUSTIC MOTIFS")
print("=" * 60)

# Get multi-unit motifs (the interesting ones)
multi_unit_motifs = [(m, c) for m, c in motif_counts.most_common()
                      if len(m.replace("▁", "").split("U")) > 2]

print(f"\nMulti-unit motifs (potential morphemes): {len(multi_unit_motifs)}")

print("\nTop 30 multi-unit motifs by frequency:")
print("-" * 50)
for i, (motif, count) in enumerate(multi_unit_motifs[:30]):
    n_units = len(motif.replace("▁", "").split("U")) - 1
    duration_ms = n_units * 20
    print(f"  {i+1:2d}. {motif:30s} count={count:4d}  ~{duration_ms}ms")

DISCOVERED ACOUSTIC MOTIFS

Multi-unit motifs (potential morphemes): 0

Top 30 multi-unit motifs by frequency:
--------------------------------------------------


In [None]:
# CELL 8b: Analyze all motifs (adjusted)
from collections import Counter

encoded = sp.encode_as_pieces(unit_string)
motif_counts = Counter(encoded)

print("=" * 60)
print("DISCOVERED ACOUSTIC MOTIFS")
print("=" * 60)

print(f"\nTotal unique motifs: {len(motif_counts)}")
print(f"Total tokens in sequence: {len(encoded)}")

print("\nTop 30 motifs by frequency:")
print("-" * 50)
for i, (motif, count) in enumerate(motif_counts.most_common(30)):
    # Count units in this motif
    units_in_motif = motif.replace("▁", "").strip().split("U")
    units_in_motif = [u for u in units_in_motif if u]  # Remove empty strings
    n_units = len(units_in_motif)
    duration_ms = n_units * 20
    print(f"  {i+1:2d}. {motif:25s} units={n_units:2d}  count={count:4d}  ~{duration_ms}ms")

# Store all motifs for later use
all_motifs = [(m, c) for m, c in motif_counts.most_common()]
print(f"\nTotal motifs to analyze: {len(all_motifs)}")

DISCOVERED ACOUSTIC MOTIFS

Total unique motifs: 101
Total tokens in sequence: 55660

Top 30 motifs by frequency:
--------------------------------------------------
   1. ▁U                        units= 0  count=27830  ~0ms
   2. 79                        units= 1  count= 764  ~20ms
   3. 58                        units= 1  count= 701  ~20ms
   4. 33                        units= 1  count= 669  ~20ms
   5. 68                        units= 1  count= 567  ~20ms
   6. 55                        units= 1  count= 527  ~20ms
   7. 32                        units= 1  count= 491  ~20ms
   8. 35                        units= 1  count= 475  ~20ms
   9. 12                        units= 1  count= 463  ~20ms
  10. 59                        units= 1  count= 463  ~20ms
  11. 74                        units= 1  count= 463  ~20ms
  12. 41                        units= 1  count= 459  ~20ms
  13. 81                        units= 1  count= 450  ~20ms
  14. 97                        units= 1  count= 449  ~

In [None]:
# CELL 8c: Find patterns using n-grams instead of BPE
from collections import Counter

# Work directly with the acoustic units sequence
units_list = [f"U{u:02d}" for u in acoustic_units]

# Find bigrams (2-unit patterns)
bigrams = [f"{units_list[i]}_{units_list[i+1]}" for i in range(len(units_list)-1)]
bigram_counts = Counter(bigrams)

# Find trigrams (3-unit patterns)
trigrams = [f"{units_list[i]}_{units_list[i+1]}_{units_list[i+2]}" for i in range(len(units_list)-2)]
trigram_counts = Counter(trigrams)

# Find 4-grams
fourgrams = [f"{units_list[i]}_{units_list[i+1]}_{units_list[i+2]}_{units_list[i+3]}" for i in range(len(units_list)-3)]
fourgram_counts = Counter(fourgrams)

print("=" * 60)
print("DISCOVERED ACOUSTIC MOTIFS (N-gram Analysis)")
print("=" * 60)

print("\nTop 20 Bigrams (2-unit patterns, ~40ms):")
print("-" * 50)
for i, (pattern, count) in enumerate(bigram_counts.most_common(20)):
    print(f"  {i+1:2d}. {pattern:20s} count={count:4d}")

print("\nTop 20 Trigrams (3-unit patterns, ~60ms):")
print("-" * 50)
for i, (pattern, count) in enumerate(trigram_counts.most_common(20)):
    print(f"  {i+1:2d}. {pattern:20s} count={count:4d}")

print("\nTop 20 Four-grams (4-unit patterns, ~80ms):")
print("-" * 50)
for i, (pattern, count) in enumerate(fourgram_counts.most_common(20)):
    print(f"  {i+1:2d}. {pattern:25s} count={count:4d}")

# Store for next step
motif_data = {
    'bigrams': bigram_counts.most_common(50),
    'trigrams': trigram_counts.most_common(50),
    'fourgrams': fourgram_counts.most_common(50)
}
print(f"\nStored top 50 patterns of each type for audio extraction.")

DISCOVERED ACOUSTIC MOTIFS (N-gram Analysis)

Top 20 Bigrams (2-unit patterns, ~40ms):
--------------------------------------------------
   1. U58_U58              count= 574
   2. U79_U79              count= 459
   3. U68_U68              count= 345
   4. U12_U12              count= 330
   5. U32_U32              count= 330
   6. U55_U55              count= 307
   7. U59_U59              count= 277
   8. U04_U04              count= 240
   9. U33_U33              count= 239
  10. U64_U64              count= 228
  11. U41_U41              count= 224
  12. U96_U96              count= 217
  13. U82_U82              count= 202
  14. U75_U75              count= 187
  15. U25_U25              count= 185
  16. U53_U53              count= 184
  17. U87_U87              count= 183
  18. U11_U11              count= 178
  19. U90_U90              count= 178
  20. U83_U83              count= 176

Top 20 Trigrams (3-unit patterns, ~60ms):
--------------------------------------------------
   1. U5

In [None]:
# CELL 9: Find heterogeneous patterns (more linguistically interesting)
from collections import Counter

units_list = [f"U{u:02d}" for u in acoustic_units]

# Find bigrams where units are DIFFERENT
hetero_bigrams = [f"{units_list[i]}_{units_list[i+1]}"
                  for i in range(len(units_list)-1)
                  if units_list[i] != units_list[i+1]]
hetero_bigram_counts = Counter(hetero_bigrams)

# Find trigrams with at least 2 different units
hetero_trigrams = []
for i in range(len(units_list)-2):
    tri = (units_list[i], units_list[i+1], units_list[i+2])
    if len(set(tri)) >= 2:  # At least 2 different units
        hetero_trigrams.append(f"{tri[0]}_{tri[1]}_{tri[2]}")
hetero_trigram_counts = Counter(hetero_trigrams)

# Find 4-grams with at least 2 different units
hetero_fourgrams = []
for i in range(len(units_list)-3):
    fg = (units_list[i], units_list[i+1], units_list[i+2], units_list[i+3])
    if len(set(fg)) >= 2:
        hetero_fourgrams.append(f"{fg[0]}_{fg[1]}_{fg[2]}_{fg[3]}")
hetero_fourgram_counts = Counter(hetero_fourgrams)

print("=" * 60)
print("HETEROGENEOUS MOTIFS (Different Units = More Interesting)")
print("=" * 60)

print("\nTop 20 Transition Bigrams:")
print("-" * 50)
for i, (pattern, count) in enumerate(hetero_bigram_counts.most_common(20)):
    print(f"  {i+1:2d}. {pattern:20s} count={count:4d}")

print("\nTop 20 Heterogeneous Trigrams:")
print("-" * 50)
for i, (pattern, count) in enumerate(hetero_trigram_counts.most_common(20)):
    print(f"  {i+1:2d}. {pattern:20s} count={count:4d}")

print("\nTop 20 Heterogeneous Four-grams:")
print("-" * 50)
for i, (pattern, count) in enumerate(hetero_fourgram_counts.most_common(20)):
    print(f"  {i+1:2d}. {pattern:25s} count={count:4d}")

# Store the interesting motifs
interesting_motifs = {
    'transitions': hetero_bigram_counts.most_common(30),
    'trigrams': hetero_trigram_counts.most_common(30),
    'fourgrams': hetero_fourgram_counts.most_common(30)
}

HETEROGENEOUS MOTIFS (Different Units = More Interesting)

Top 20 Transition Bigrams:
--------------------------------------------------
   1. U62_U33              count= 146
   2. U48_U43              count= 142
   3. U33_U48              count= 128
   4. U35_U97              count= 124
   5. U33_U81              count=  94
   6. U55_U25              count=  89
   7. U97_U36              count=  88
   8. U13_U33              count=  86
   9. U32_U47              count=  82
  10. U25_U55              count=  79
  11. U98_U35              count=  79
  12. U48_U02              count=  79
  13. U36_U74              count=  79
  14. U91_U92              count=  78
  15. U36_U33              count=  78
  16. U00_U60              count=  74
  17. U34_U40              count=  73
  18. U68_U84              count=  73
  19. U85_U13              count=  72
  20. U81_U43              count=  72

Top 20 Heterogeneous Trigrams:
--------------------------------------------------
   1. U62_U33_U33   

In [None]:
# CELL 10: Extract audio samples for motifs
import soundfile as sf

motif_audio_dir = os.path.join(OUTPUT_DIR, "motif_samples")
os.makedirs(motif_audio_dir, exist_ok=True)

def find_pattern_positions(units_list, pattern):
    """Find where a pattern occurs in the unit sequence."""
    pattern_units = pattern.split("_")
    positions = []
    for i in range(len(units_list) - len(pattern_units) + 1):
        if units_list[i:i+len(pattern_units)] == pattern_units:
            positions.append(i)
    return positions

def frame_to_time(frame_idx, total_frames, total_duration):
    """Convert frame index to time in seconds."""
    return (frame_idx / total_frames) * total_duration

# Collect top motifs from each category
motifs_to_extract = []

# Top 15 transition bigrams
for pattern, count in hetero_bigram_counts.most_common(15):
    motifs_to_extract.append(('bigram', pattern, count))

# Top 15 heterogeneous trigrams
for pattern, count in hetero_trigram_counts.most_common(15):
    motifs_to_extract.append(('trigram', pattern, count))

# Top 10 heterogeneous fourgrams
for pattern, count in hetero_fourgram_counts.most_common(10):
    motifs_to_extract.append(('fourgram', pattern, count))

print(f"Extracting audio for {len(motifs_to_extract)} motifs...")
print("-" * 50)

avita_data = []
units_list = [f"U{u:02d}" for u in acoustic_units]

for idx, (mtype, pattern, count) in enumerate(motifs_to_extract):
    positions = find_pattern_positions(units_list, pattern)

    # Extract up to 3 samples per motif
    samples_extracted = 0
    for pos in positions[:3]:
        # Convert frame position to time
        start_time = frame_to_time(pos, len(units_list), duration_seconds)

        # Duration based on pattern length (each unit ~20ms) plus padding
        n_units = len(pattern.split("_"))
        duration = (n_units * 0.02) + 0.15  # Add 150ms padding

        # Add context before
        start_time = max(0, start_time - 0.075)

        # Extract audio
        start_sample = int(start_time * SAMPLE_RATE)
        end_sample = int((start_time + duration) * SAMPLE_RATE)
        end_sample = min(end_sample, len(audio))

        sample_audio = audio[start_sample:end_sample]

        # Save audio file
        filename = f"motif_{idx:03d}_{mtype}_{samples_extracted}.wav"
        filepath = os.path.join(motif_audio_dir, filename)
        sf.write(filepath, sample_audio, SAMPLE_RATE)

        avita_data.append({
            "id": f"motif_{idx:03d}_{samples_extracted}",
            "motif_id": f"motif_{idx:03d}",
            "type": mtype,
            "pattern": pattern,
            "frequency": count,
            "audio_file": f"motif_samples/{filename}",
            "start_time": round(start_time, 3),
            "duration": round(duration, 3),
            "tags": {}
        })
        samples_extracted += 1

    if idx < 10:
        print(f"  {idx+1}. {pattern:25s} ({mtype:8s}) - {samples_extracted} samples")

print(f"\nTotal audio samples extracted: {len(avita_data)}")
print(f"Saved to: {motif_audio_dir}")

Extracting audio for 40 motifs...
--------------------------------------------------
  1. U62_U33                   (bigram  ) - 3 samples
  2. U48_U43                   (bigram  ) - 3 samples
  3. U33_U48                   (bigram  ) - 3 samples
  4. U35_U97                   (bigram  ) - 3 samples
  5. U33_U81                   (bigram  ) - 3 samples
  6. U55_U25                   (bigram  ) - 3 samples
  7. U97_U36                   (bigram  ) - 3 samples
  8. U13_U33                   (bigram  ) - 3 samples
  9. U32_U47                   (bigram  ) - 3 samples
  10. U25_U55                   (bigram  ) - 3 samples

Total audio samples extracted: 120
Saved to: /content/drive/MyDrive/satere_project/phase2_output/motif_samples


In [None]:
# CELL 11: Generate AViTA data file

# Save as JavaScript for AViTA
avita_js_path = os.path.join(OUTPUT_DIR, "avita_motifs.js")
with open(avita_js_path, 'w') as f:
    f.write("// AViTA Motif Data - Sateré-Mawé Acts 12\n")
    f.write(f"// Total motifs: {len(motifs_to_extract)}\n")
    f.write(f"// Total samples: {len(avita_data)}\n\n")
    f.write(f"const MOTIF_SAMPLES = {json.dumps(avita_data, indent=2)};\n")

# Save as JSON
avita_json_path = os.path.join(OUTPUT_DIR, "avita_motifs.json")
with open(avita_json_path, 'w') as f:
    json.dump(avita_data, f, indent=2)

print("=" * 60)
print("PHASE 2 COMPLETE!")
print("=" * 60)

print(f"""
FILES CREATED IN GOOGLE DRIVE:
{OUTPUT_DIR}/
├── audio_16khz.wav          (preprocessed audio)
├── acoustic_units.json      (raw unit sequence)
├── pseudotext.txt           (units as text)
├── avita_motifs.js          (for AViTA app)
├── avita_motifs.json        (same data as JSON)
└── motif_samples/           ({len(avita_data)} audio files)

SUMMARY:
- Input: {duration_minutes:.1f} minutes of Sateré-Mawé audio
- Extracted: {len(acoustic_units):,} acoustic frames
- Clustered into: {N_CLUSTERS} discrete units
- Discovered: {len(motifs_to_extract)} recurring motifs
- Created: {len(avita_data)} audio samples for tagging

NEXT STEPS:
1. Download the 'phase2_output' folder from Google Drive
2. Copy avita_motifs.js and motif_samples/ to your AViTA folder
3. Open AViTA and tag each motif with its grammatical meaning
""")

PHASE 2 COMPLETE!

FILES CREATED IN GOOGLE DRIVE:
/content/drive/MyDrive/satere_project/phase2_output/
├── audio_16khz.wav          (preprocessed audio)
├── acoustic_units.json      (raw unit sequence)
├── pseudotext.txt           (units as text)
├── avita_motifs.js          (for AViTA app)
├── avita_motifs.json        (same data as JSON)
└── motif_samples/           (120 audio files)

SUMMARY:
- Input: 9.3 minutes of Sateré-Mawé audio
- Extracted: 27,830 acoustic frames
- Clustered into: 100 discrete units
- Discovered: 40 recurring motifs
- Created: 120 audio samples for tagging

NEXT STEPS:
1. Download the 'phase2_output' folder from Google Drive
2. Copy avita_motifs.js and motif_samples/ to your AViTA folder
3. Open AViTA and tag each motif with its grammatical meaning



In [None]:
# Create avita_motifs.js file (JavaScript format)
avita_js_path = os.path.join(OUTPUT_DIR, "avita_motifs.js")
with open(avita_js_path, 'w') as f:
    f.write("const MOTIF_SAMPLES = ")
    f.write(json.dumps(avita_data, indent=2))
    f.write(";")

print(f"Created: {avita_js_path}")

Created: /content/drive/MyDrive/satere_project/phase2_output/avita_motifs.js


In [None]:
# Check if avita_data has content
print(f"avita_data has {len(avita_data)} items")

# Recreate the JS file with the data
avita_js_path = os.path.join(OUTPUT_DIR, "avita_motifs.js")
with open(avita_js_path, 'w') as f:
    f.write("const MOTIF_SAMPLES = ")
    f.write(json.dumps(avita_data, indent=2))
    f.write(";")

print(f"Created {avita_js_path} with {len(avita_data)} samples")

avita_data has 120 items
Created /content/drive/MyDrive/satere_project/phase2_output/avita_motifs.js with 120 samples


In [None]:
# Print the data so we can embed it
import json
print(json.dumps(avita_data))

[{"id": "motif_000_0", "motif_id": "motif_000", "type": "bigram", "pattern": "U62_U33", "frequency": 146, "audio_file": "motif_samples/motif_000_bigram_0.wav", "start_time": 1.026, "duration": 0.19, "tags": {}}, {"id": "motif_000_1", "motif_id": "motif_000", "type": "bigram", "pattern": "U62_U33", "frequency": 146, "audio_file": "motif_samples/motif_000_bigram_1.wav", "start_time": 2.527, "duration": 0.19, "tags": {}}, {"id": "motif_000_2", "motif_id": "motif_000", "type": "bigram", "pattern": "U62_U33", "frequency": 146, "audio_file": "motif_samples/motif_000_bigram_2.wav", "start_time": 3.207, "duration": 0.19, "tags": {}}, {"id": "motif_001_0", "motif_id": "motif_001", "type": "bigram", "pattern": "U48_U43", "frequency": 142, "audio_file": "motif_samples/motif_001_bigram_0.wav", "start_time": 6.169, "duration": 0.19, "tags": {}}, {"id": "motif_001_1", "motif_id": "motif_001", "type": "bigram", "pattern": "U48_U43", "frequency": 142, "audio_file": "motif_samples/motif_001_bigram_1.wa

In [None]:

# -----------------------------------------------------------------------------
# Step 3: Analyze discovered motifs
# -----------------------------------------------------------------------------
print("\nStep 3: Analyzing discovered motifs...")

# Encode the full corpus
all_encoded = []
for line in lines:
    encoded = sp.encode_as_pieces(line.strip())
    all_encoded.extend(encoded)

motif_counts = Counter(all_encoded)

print(f"  Total unique motifs: {len(motif_counts)}")
print(f"  Total tokens: {len(all_encoded):,}")

# Categorize motifs
single_u# =============================================================================
# SOUNDSCRIPT GENERATION: Full NT Corpus (45 hours)
# =============================================================================

import os
import json
from collections import Counter
import sentencepiece as spm

PROJECT_ROOT = "/content/drive/MyDrive/satere_project"
UNITS_DIR = f"{PROJECT_ROOT}/satere_units"
OUTPUT_DIR = f"{PROJECT_ROOT}/soundscript_full"

os.makedirs(OUTPUT_DIR, exist_ok=True)

# -----------------------------------------------------------------------------
# Step 1: Combine all unit sequences from the full corpus
# -----------------------------------------------------------------------------
print("Step 1: Loading all unit sequences...")

all_units_file = f"{UNITS_DIR}/all_units_for_bpe.txt"

# Check file exists and count lines
with open(all_units_file, 'r') as f:
    lines = f.readlines()

print(f"  Found {len(lines)} files in corpus")

# Count total units
total_units = sum(len(line.strip().split()) for line in lines)
print(f"  Total acoustic units: {total_units:,}")

# -----------------------------------------------------------------------------
# Step 2: Train BPE with larger vocabulary for full corpus
# -----------------------------------------------------------------------------
print("\nStep 2: Training BPE model (this takes 10-30 minutes)...")

VOCAB_SIZE = 500  # Larger vocabulary for 45 hours of audio

spm.SentencePieceTrainer.train(
    input=all_units_file,
    model_prefix=f"{OUTPUT_DIR}/soundscript_satere",
    vocab_size=VOCAB_SIZE,
    model_type='bpe',
    character_coverage=1.0,
    num_threads=4,
    max_sentence_length=500000,  # Allow longer sequences
    train_extremely_large_corpus=True
)

print("  BPE model trained!")

# Load the trained model
sp = spm.SentencePieceProcessor()
sp.load(f"{OUTPUT_DIR}/soundscript_satere.model")
print(f"  Vocabulary size: {sp.get_piece_size()}")
nit = []
multi_unit = []
special = []

for motif, count in motif_counts.items():
    clean = motif.replace("▁", "").strip()
    if not clean or clean in ["U", ""]:
        special.append((motif, count))
    elif "_" in clean or len(clean) > 2:
        multi_unit.append((motif, count))
    else:
        single_unit.append((motif, count))

print(f"\n  Single-unit motifs: {len(single_unit)}")
print(f"  Multi-unit motifs: {len(multi_unit)}")
print(f"  Special tokens: {len(special)}")

# -----------------------------------------------------------------------------
# Step 4: Display top motifs
# -----------------------------------------------------------------------------
print("\n" + "=" * 70)
print("TOP 50 MOTIFS BY FREQUENCY")
print("=" * 70)

for i, (motif, count) in enumerate(motif_counts.most_common(50)):
    print(f"  {i+1:3d}. {motif:30s} count={count:,}")

# -----------------------------------------------------------------------------
# Step 5: Save motif inventory
# -----------------------------------------------------------------------------
print("\nStep 5: Saving SoundScript inventory...")

inventory = {
    "language": "Sateré-Mawé",
    "corpus_files": len(lines),
    "total_units": total_units,
    "vocab_size": sp.get_piece_size(),
    "motifs": [
        {"id": i, "symbol": motif, "count": count}
        for i, (motif, count) in enumerate(motif_counts.most_common())
    ]
}

with open(f"{OUTPUT_DIR}/soundscript_inventory.json", "w") as f:
    json.dump(inventory, f, indent=2)

print(f"  Saved to: {OUTPUT_DIR}/soundscript_inventory.json")

print("\n" + "=" * 70)
print("✓ SOUNDSCRIPT GENERATION COMPLETE")
print("=" * 70)
print(f"  Model: {OUTPUT_DIR}/soundscript_satere.model")
print(f"  Vocab: {OUTPUT_DIR}/soundscript_satere.vocab")
print(f"  Inventory: {OUTPUT_DIR}/soundscript_inventory.json")

Step 1: Loading all unit sequences...


FileNotFoundError: [Errno 2] No such file or directory: '/content/drive/MyDrive/satere_project/satere_units/all_units_for_bpe.txt'

In [None]:
# Explore what files you have
import os

PROJECT_ROOT = "/content/drive/MyDrive/satere_project"

print("Contents of satere_project:")
print("-" * 50)

for item in os.listdir(PROJECT_ROOT):
    item_path = os.path.join(PROJECT_ROOT, item)
    if os.path.isdir(item_path):
        print(f"📁 {item}/")
        # Show what's inside each folder
        try:
            contents = os.listdir(item_path)[:10]  # First 10 items
            for sub in contents:
                print(f"    {sub}")
            if len(os.listdir(item_path)) > 10:
                print(f"    ... and {len(os.listdir(item_path)) - 10} more files")
        except:
            pass
    else:
        print(f"📄 {item}")

Contents of satere_project:
--------------------------------------------------
📁 soundscript_full/


In [None]:
# Deep search for your tokenized files
import os

PROJECT_ROOT = "/content/drive/MyDrive/satere_project"

print("Full directory tree:")
print("=" * 60)

for root, dirs, files in os.walk(PROJECT_ROOT):
    level = root.replace(PROJECT_ROOT, '').count(os.sep)
    indent = '  ' * level
    print(f"{indent}📁 {os.path.basename(root)}/")

    # Show files
    sub_indent = '  ' * (level + 1)
    for file in files[:15]:
        print(f"{sub_indent}📄 {file}")
    if len(files) > 15:
        print(f"{sub_indent}... and {len(files) - 15} more files")

# Also check for specific files
print("\n" + "=" * 60)
print("Searching for key files...")
print("=" * 60)

key_patterns = ['units', 'bpe', 'corpus', 'json', 'model']
found = []

for root, dirs, files in os.walk(PROJECT_ROOT):
    for file in files:
        for pattern in key_patterns:
            if pattern in file.lower():
                found.append(os.path.join(root, file))

print(f"\nFound {len(found)} relevant files:")
for f in found[:30]:
    print(f"  {f}")

Full directory tree:
📁 satere_project/
  📁 soundscript_full/

Searching for key files...

Found 0 relevant files:


In [None]:
# Check the actual folders in your Drive
import os

PROJECT_ROOT = "/content/drive/MyDrive/satere_project"

folders_to_check = [
    "satere_units",
    "satere_motifs",
    "raw_audio",
    "phase2_output",
    "converted_audio"
]

print("Checking your folders:")
print("=" * 60)

for folder in folders_to_check:
    path = os.path.join(PROJECT_ROOT, folder)
    if os.path.exists(path):
        files = os.listdir(path)
        print(f"\n✓ {folder}/")
        print(f"  Total items: {len(files)}")

        # Show first few files
        for f in files[:5]:
            filepath = os.path.join(path, f)
            if os.path.isfile(filepath):
                size = os.path.getsize(filepath) / 1024  # KB
                print(f"    {f} ({size:.1f} KB)")
            else:
                print(f"    📁 {f}/")

        if len(files) > 5:
            print(f"    ... and {len(files) - 5} more")
    else:
        print(f"\n✗ {folder}/ - not found")

# Specifically check for the BPE training file
bpe_file = f"{PROJECT_ROOT}/satere_units/all_units_for_bpe.txt"
if os.path.exists(bpe_file):
    size_mb = os.path.getsize(bpe_file) / (1024 * 1024)
    print(f"\n✓ Found BPE training file: {size_mb:.1f} MB")
else:
    print(f"\n✗ BPE training file not found")

Checking your folders:

✗ satere_units/ - not found

✗ satere_motifs/ - not found

✗ raw_audio/ - not found

✗ phase2_output/ - not found

✗ converted_audio/ - not found

✗ BPE training file not found


In [None]:
# Remount Google Drive
from google.colab import drive

# Force remount
drive.flush_and_unmount()
print("Unmounted. Now remounting...")

drive.mount('/content/drive', force_remount=True)
print("Mounted!")

# Now check again
import os
PROJECT_ROOT = "/content/drive/MyDrive/satere_project"

print("\nContents of satere_project:")
print("-" * 50)
for item in os.listdir(PROJECT_ROOT):
    print(f"  {item}")

Drive not mounted, so nothing to flush and unmount.
Unmounted. Now remounting...


ValueError: Mountpoint must not already contain files

In [None]:
# Clean up and remount
import os

# Remove the stale mount point
!rm -rf /content/drive
!mkdir -p /content/drive

# Now mount fresh
from google.colab import drive
drive.mount('/content/drive')

print("\nChecking satere_project:")
PROJECT_ROOT = "/content/drive/MyDrive/satere_project"
for item in os.listdir(PROJECT_ROOT):
    print(f"  {item}")

Mounted at /content/drive

Checking satere_project:
  raw_audio
  converted_audio
  satere_units
  satere_motifs
  acts12.mp3
  phase2_output


In [None]:
# Check the key files
import os

PROJECT_ROOT = "/content/drive/MyDrive/satere_project"

# Check satere_units folder
units_path = f"{PROJECT_ROOT}/satere_units"
files = os.listdir(units_path)
print(f"satere_units/ has {len(files)} files")

# Check for the BPE training file
bpe_file = f"{units_path}/all_units_for_bpe.txt"
if os.path.exists(bpe_file):
    size_mb = os.path.getsize(bpe_file) / (1024 * 1024)
    with open(bpe_file, 'r') as f:
        lines = f.readlines()
    print(f"\n✓ BPE training file found!")
    print(f"  Size: {size_mb:.1f} MB")
    print(f"  Files (lines): {len(lines)}")

    # Count total units
    total_units = sum(len(line.strip().split()) for line in lines)
    print(f"  Total acoustic units: {total_units:,}")
else:
    print("\n✗ BPE training file not found")

# Also check what's in satere_motifs (from yesterday's small test)
motifs_path = f"{PROJECT_ROOT}/satere_motifs"
if os.path.exists(motifs_path):
    print(f"\nsatere_motifs/ contents:")
    for f in os.listdir(motifs_path):
        print(f"  {f}")

satere_units/ has 232 files

✓ BPE training file found!
  Size: 22.4 MB
  Files (lines): 229
  Total acoustic units: 8,078,315

satere_motifs/ contents:
  satere_bpe.model
  satere_bpe.vocab
  motif_index.json
  motif_statistics.json


In [None]:
# =============================================================================
# SOUNDSCRIPT GENERATION: Full 45-hour corpus
# =============================================================================

import os
import json
from collections import Counter
import sentencepiece as spm

PROJECT_ROOT = "/content/drive/MyDrive/satere_project"
UNITS_DIR = f"{PROJECT_ROOT}/satere_units"
OUTPUT_DIR = f"{PROJECT_ROOT}/soundscript_full"

os.makedirs(OUTPUT_DIR, exist_ok=True)

# The BPE training file
bpe_file = f"{UNITS_DIR}/all_units_for_bpe.txt"

print("=" * 60)
print("SOUNDSCRIPT GENERATION - Full Sateré-Mawé NT")
print("=" * 60)
print(f"Training data: 8,078,315 acoustic units from 229 files")

# Train BPE with larger vocabulary for the full corpus
print("\nTraining BPE model (10-30 minutes)...")

VOCAB_SIZE = 500  # Larger vocabulary for full corpus

spm.SentencePieceTrainer.train(
    input=bpe_file,
    model_prefix=f"{OUTPUT_DIR}/soundscript_satere",
    vocab_size=VOCAB_SIZE,
    model_type='bpe',
    character_coverage=1.0,
    num_threads=4,
    max_sentence_length=500000,
    train_extremely_large_corpus=True
)

print("✓ BPE model trained!")

# Load the model
sp = spm.SentencePieceProcessor()
sp.load(f"{OUTPUT_DIR}/soundscript_satere.model")
print(f"  Vocabulary size: {sp.get_piece_size()}")

# Analyze the motifs
print("\nAnalyzing discovered motifs...")

with open(bpe_file, 'r') as f:
    lines = f.readlines()

all_encoded = []
for line in lines:
    encoded = sp.encode_as_pieces(line.strip())
    all_encoded.extend(encoded)

motif_counts = Counter(all_encoded)

print(f"  Unique motifs: {len(motif_counts)}")
print(f"  Total tokens: {len(all_encoded):,}")

# Show top 40 motifs
print("\n" + "=" * 60)
print("TOP 40 SOUNDSCRIPT MOTIFS")
print("=" * 60)

for i, (motif, count) in enumerate(motif_counts.most_common(40)):
    print(f"  {i+1:3d}. {motif:35s} count={count:,}")

# Save inventory
inventory = {
    "language": "Sateré-Mawé",
    "corpus_files": 229,
    "total_units": 8078315,
    "vocab_size": sp.get_piece_size(),
    "motifs": [
        {"id": i, "symbol": motif, "count": count}
        for i, (motif, count) in enumerate(motif_counts.most_common())
    ]
}

with open(f"{OUTPUT_DIR}/soundscript_inventory.json", "w") as f:
    json.dump(inventory, f, indent=2)

print("\n" + "=" * 60)
print("✓ SOUNDSCRIPT COMPLETE")
print("=" * 60)
print(f"  Model: {OUTPUT_DIR}/soundscript_satere.model")
print(f"  Inventory: {OUTPUT_DIR}/soundscript_inventory.json")

SOUNDSCRIPT GENERATION - Full Sateré-Mawé NT
Training data: 8,078,315 acoustic units from 229 files

Training BPE model (10-30 minutes)...


RuntimeError: Internal: src/trainer_interface.cc(664) [(trainer_spec_.vocab_size()) == (model_proto->pieces_size())] Vocabulary size too high (500). Please set it to a value <= 204.

In [None]:
# =============================================================================
# SOUNDSCRIPT GENERATION: Full 45-hour corpus
# =============================================================================

import os
import json
from collections import Counter
import sentencepiece as spm

PROJECT_ROOT = "/content/drive/MyDrive/satere_project"
UNITS_DIR = f"{PROJECT_ROOT}/satere_units"
OUTPUT_DIR = f"{PROJECT_ROOT}/soundscript_full"

os.makedirs(OUTPUT_DIR, exist_ok=True)

bpe_file = f"{UNITS_DIR}/all_units_for_bpe.txt"

print("=" * 60)
print("SOUNDSCRIPT GENERATION - Full Sateré-Mawé NT")
print("=" * 60)

# Use maximum allowed vocabulary
VOCAB_SIZE = 200

print(f"Training BPE with vocab_size={VOCAB_SIZE}...")

spm.SentencePieceTrainer.train(
    input=bpe_file,
    model_prefix=f"{OUTPUT_DIR}/soundscript_satere",
    vocab_size=VOCAB_SIZE,
    model_type='bpe',
    character_coverage=1.0,
    num_threads=4,
    max_sentence_length=500000,
    train_extremely_large_corpus=True
)

print("✓ BPE model trained!")

# Load the model
sp = spm.SentencePieceProcessor()
sp.load(f"{OUTPUT_DIR}/soundscript_satere.model")
print(f"  Vocabulary size: {sp.get_piece_size()}")

# Analyze the motifs
print("\nAnalyzing discovered motifs...")

with open(bpe_file, 'r') as f:
    lines = f.readlines()

all_encoded = []
for line in lines:
    encoded = sp.encode_as_pieces(line.strip())
    all_encoded.extend(encoded)

motif_counts = Counter(all_encoded)

print(f"  Unique motifs: {len(motif_counts)}")
print(f"  Total tokens: {len(all_encoded):,}")

# Show top 40 motifs
print("\n" + "=" * 60)
print("TOP 40 SOUNDSCRIPT MOTIFS")
print("=" * 60)

for i, (motif, count) in enumerate(motif_counts.most_common(40)):
    print(f"  {i+1:3d}. {motif:35s} count={count:,}")

# Save inventory
inventory = {
    "language": "Sateré-Mawé",
    "corpus_files": 229,
    "total_units": 8078315,
    "vocab_size": sp.get_piece_size(),
    "motifs": [
        {"id": i, "symbol": motif, "count": count}
        for i, (motif, count) in enumerate(motif_counts.most_common())
    ]
}

with open(f"{OUTPUT_DIR}/soundscript_inventory.json", "w") as f:
    json.dump(inventory, f, indent=2)

print("\n" + "=" * 60)
print("✓ SOUNDSCRIPT COMPLETE")
print("=" * 60)
print(f"  Model: {OUTPUT_DIR}/soundscript_satere.model")
print(f"  Inventory: {OUTPUT_DIR}/soundscript_inventory.json")

SOUNDSCRIPT GENERATION - Full Sateré-Mawé NT
Training BPE with vocab_size=200...
✓ BPE model trained!
  Vocabulary size: 200

Analyzing discovered motifs...
  Unique motifs: 100
  Total tokens: 8,078,315

TOP 40 SOUNDSCRIPT MOTIFS
    1. ▁29                                 count=266,472
    2. ▁82                                 count=217,421
    3. ▁26                                 count=179,759
    4. ▁81                                 count=161,271
    5. ▁85                                 count=158,889
    6. ▁55                                 count=153,712
    7. ▁18                                 count=148,241
    8. ▁62                                 count=138,527
    9. ▁44                                 count=131,139
   10. ▁23                                 count=130,346
   11. ▁19                                 count=125,666
   12. ▁16                                 count=122,186
   13. ▁1                                  count=119,696
   14. ▁78                  

In [None]:
# =============================================================================
# EXTRACT AUDIO SAMPLES FOR EACH MOTIF
# =============================================================================

import os
import json
import numpy as np
import scipy.io.wavfile as wav
from tqdm import tqdm
import random

PROJECT_ROOT = "/content/drive/MyDrive/satere_project"
UNITS_DIR = f"{PROJECT_ROOT}/satere_units"
CONVERTED_DIR = f"{PROJECT_ROOT}/converted_audio"
OUTPUT_DIR = f"{PROJECT_ROOT}/soundscript_full"
SAMPLES_DIR = f"{OUTPUT_DIR}/motif_samples"

os.makedirs(SAMPLES_DIR, exist_ok=True)

print("=" * 60)
print("EXTRACTING AUDIO SAMPLES FOR MOTIFS")
print("=" * 60)

# Load one of the unit files to understand the format
unit_files = [f for f in os.listdir(UNITS_DIR) if f.endswith('.units.txt')]
print(f"Found {len(unit_files)} unit files")

# Get list of converted audio files
wav_files = [f for f in os.listdir(CONVERTED_DIR) if f.endswith('.wav')]
print(f"Found {len(wav_files)} WAV files")

# For each of the 100 motifs, find occurrences and extract samples
SAMPLES_PER_MOTIF = 3
motif_samples = {}

# The motifs are just the unit numbers (0-99)
motifs_to_sample = list(range(100))

print(f"\nExtracting {SAMPLES_PER_MOTIF} samples for each of {len(motifs_to_sample)} motifs...")

for motif_id in tqdm(motifs_to_sample, desc="Processing motifs"):
    motif_samples[motif_id] = []

    # Find files that contain this unit
    occurrences = []

    for unit_file in unit_files[:50]:  # Sample from first 50 files for speed
        unit_path = os.path.join(UNITS_DIR, unit_file)
        with open(unit_path, 'r') as f:
            units = f.read().strip().split()

        # Find positions of this motif
        for pos, unit in enumerate(units):
            if unit == str(motif_id):
                # Calculate timestamp (20ms per frame)
                start_time = pos * 0.02
                end_time = start_time + 0.1  # 100ms window

                wav_name = unit_file.replace('.units.txt', '.wav')
                if wav_name in wav_files:
                    occurrences.append({
                        "file": wav_name,
                        "start": start_time,
                        "end": end_time,
                        "position": pos
                    })

    # Sample up to SAMPLES_PER_MOTIF occurrences
    if len(occurrences) >= SAMPLES_PER_MOTIF:
        samples = random.sample(occurrences, SAMPLES_PER_MOTIF)
    elif len(occurrences) > 0:
        samples = occurrences
    else:
        continue

    # Extract audio for each sample
    for sample_idx, sample in enumerate(samples):
        wav_path = os.path.join(CONVERTED_DIR, sample["file"])

        try:
            sr, audio = wav.read(wav_path)
            start_frame = int(sample["start"] * sr)
            end_frame = int(sample["end"] * sr)

            # Add buffer
            start_frame = max(0, start_frame - int(0.05 * sr))
            end_frame = min(len(audio), end_frame + int(0.05 * sr))

            segment = audio[start_frame:end_frame]

            if len(segment) > 0:
                out_file = f"{SAMPLES_DIR}/motif_{motif_id:02d}_sample_{sample_idx}.wav"
                wav.write(out_file, sr, segment)
                motif_samples[motif_id].append(os.path.basename(out_file))
        except Exception as e:
            continue

# Save metadata
samples_meta = {
    "language": "Sateré-Mawé",
    "total_motifs": len(motifs_to_sample),
    "samples_per_motif": SAMPLES_PER_MOTIF,
    "motifs": [
        {"id": mid, "samples": samples}
        for mid, samples in motif_samples.items()
        if samples
    ]
}

with open(f"{OUTPUT_DIR}/motif_samples_meta.json", "w") as f:
    json.dump(samples_meta, f, indent=2)

# Count successful extractions
total_samples = sum(len(s) for s in motif_samples.values())
motifs_with_samples = len([m for m in motif_samples.values() if m])

print(f"\n✓ EXTRACTION COMPLETE")
print(f"  Motifs with samples: {motifs_with_samples}")
print(f"  Total audio samples: {total_samples}")
print(f"  Saved to: {SAMPLES_DIR}")

EXTRACTING AUDIO SAMPLES FOR MOTIFS
Found 229 unit files
Found 260 WAV files

Extracting 3 samples for each of 100 motifs...


Processing motifs: 100%|██████████| 100/100 [02:22<00:00,  1.43s/it]


✓ EXTRACTION COMPLETE
  Motifs with samples: 100
  Total audio samples: 300
  Saved to: /content/drive/MyDrive/satere_project/soundscript_full/motif_samples





In [None]:
# =============================================================================
# CREATE SOUNDSCRIPT MAPPER HTML FOR YOUR MOTIFS
# =============================================================================

import json
import os

PROJECT_ROOT = "/content/drive/MyDrive/satere_project"
OUTPUT_DIR = f"{PROJECT_ROOT}/soundscript_full"

# Load the motif metadata
with open(f"{OUTPUT_DIR}/motif_samples_meta.json", "r") as f:
    meta = json.load(f)

# Load the inventory for counts
with open(f"{OUTPUT_DIR}/soundscript_inventory.json", "r") as f:
    inventory = json.load(f)

# Build motif data with counts
motif_counts = {m["symbol"].replace("▁", ""): m["count"] for m in inventory["motifs"]}

# Create JavaScript data
js_data = []
for motif in meta["motifs"]:
    motif_id = motif["id"]
    count = motif_counts.get(str(motif_id), 0)
    js_data.append({
        "id": motif_id,
        "symbol": f"M{motif_id:02d}",
        "count": count,
        "samples": motif["samples"]
    })

# Sort by count (most frequent first)
js_data.sort(key=lambda x: x["count"], reverse=True)

print(f"Prepared data for {len(js_data)} motifs")
print(f"Total samples: {sum(len(m['samples']) for m in js_data)}")

# Save as JavaScript file for the mapper
js_content = f"const MOTIF_DATA = {json.dumps(js_data, indent=2)};"

with open(f"{OUTPUT_DIR}/soundscript_data.js", "w") as f:
    f.write(js_content)

print(f"\n✓ Saved: {OUTPUT_DIR}/soundscript_data.js")

# Show top 10 most frequent
print("\nTop 10 most frequent motifs:")
for m in js_data[:10]:
    print(f"  {m['symbol']}: {m['count']:,} occurrences")

Prepared data for 100 motifs
Total samples: 300

✓ Saved: /content/drive/MyDrive/satere_project/soundscript_full/soundscript_data.js

Top 10 most frequent motifs:
  M29: 266,472 occurrences
  M82: 217,421 occurrences
  M26: 179,759 occurrences
  M81: 161,271 occurrences
  M85: 158,889 occurrences
  M55: 153,712 occurrences
  M18: 148,241 occurrences
  M62: 138,527 occurrences
  M44: 131,139 occurrences
  M23: 130,346 occurrences


In [None]:
# =============================================================================
# CREATE DOWNLOADABLE ZIP FOR GITHUB DEPLOYMENT
# =============================================================================

import shutil
import os

PROJECT_ROOT = "/content/drive/MyDrive/satere_project"
OUTPUT_DIR = f"{PROJECT_ROOT}/soundscript_full"
DEPLOY_DIR = f"{PROJECT_ROOT}/soundscript_deploy"

# Create deployment folder
os.makedirs(DEPLOY_DIR, exist_ok=True)

# Copy motif samples
samples_src = f"{OUTPUT_DIR}/motif_samples"
samples_dst = f"{DEPLOY_DIR}/motif_samples"

if os.path.exists(samples_dst):
    shutil.rmtree(samples_dst)
shutil.copytree(samples_src, samples_dst)

# Copy data file
shutil.copy(f"{OUTPUT_DIR}/soundscript_data.js", DEPLOY_DIR)
shutil.copy(f"{OUTPUT_DIR}/soundscript_inventory.json", DEPLOY_DIR)

# Count files
sample_count = len(os.listdir(samples_dst))

print(f"✓ Deployment folder ready: {DEPLOY_DIR}")
print(f"  - soundscript_data.js")
print(f"  - soundscript_inventory.json")
print(f"  - motif_samples/ ({sample_count} files)")

# Create zip
zip_path = f"{PROJECT_ROOT}/soundscript_deploy.zip"
shutil.make_archive(zip_path.replace('.zip', ''), 'zip', DEPLOY_DIR)

print(f"\n✓ ZIP created: {zip_path}")
print(f"\nDownload this zip from Google Drive, then:")
print("1. Unzip it")
print("2. Create a new GitHub repository called 'soundscript-satere'")
print("3. Upload all files to the repository")
print("4. Enable GitHub Pages in Settings")

✓ Deployment folder ready: /content/drive/MyDrive/satere_project/soundscript_deploy
  - soundscript_data.js
  - soundscript_inventory.json
  - motif_samples/ (300 files)

✓ ZIP created: /content/drive/MyDrive/satere_project/soundscript_deploy.zip

Download this zip from Google Drive, then:
1. Unzip it
2. Create a new GitHub repository called 'soundscript-satere'
3. Upload all files to the repository
4. Enable GitHub Pages in Settings


In [None]:
# =============================================================================
# CREATE THE SOUNDSCRIPT MAPPER HTML
# =============================================================================

html_content = '''<!DOCTYPE html>
<html lang="en">
<head>
    <meta charset="UTF-8">
    <meta name="viewport" content="width=device-width, initial-scale=1.0">
    <title>SoundScript Mapper | Sateré-Mawé</title>
    <style>
        * { margin: 0; padding: 0; box-sizing: border-box; }
        body {
            font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, sans-serif;
            background: #0a0a0b;
            color: #f5f5f7;
            min-height: 100vh;
            line-height: 1.5;
        }
        .container { max-width: 1200px; margin: 0 auto; padding: 2rem; }
        header { margin-bottom: 2rem; padding-bottom: 1.5rem; border-bottom: 1px solid #2a2a2e; }
        h1 { font-size: 1.75rem; margin-bottom: 0.25rem; }
        .subtitle { color: #a1a1a6; }

        .stats { display: flex; gap: 2rem; margin-top: 1rem; flex-wrap: wrap; }
        .stat { display: flex; align-items: center; gap: 0.5rem; }
        .stat-dot { width: 10px; height: 10px; border-radius: 50%; }
        .stat-dot.morpheme { background: #34c759; }
        .stat-dot.syllable { background: #5e5ce6; }
        .stat-dot.noise { background: #ff6b6b; }
        .stat-dot.unclear { background: #ff9f0a; }
        .stat-dot.pending { background: #6e6e73; }
        .stat-count { font-weight: 600; font-family: monospace; }

        .progress-bar { width: 100%; height: 6px; background: #1a1a1e; border-radius: 3px; margin-top: 1rem; }
        .progress-fill { height: 100%; background: linear-gradient(90deg, #34c759, #0a84ff); border-radius: 3px; transition: width 0.3s; }

        .main-grid { display: grid; grid-template-columns: 280px 1fr; gap: 1.5rem; }
        @media (max-width: 800px) { .main-grid { grid-template-columns: 1fr; } }

        .motif-list { background: #141416; border-radius: 12px; border: 1px solid #2a2a2e; max-height: 70vh; overflow: hidden; display: flex; flex-direction: column; }
        .motif-list-header { padding: 1rem; border-bottom: 1px solid #2a2a2e; font-weight: 600; font-size: 0.875rem; color: #a1a1a6; text-transform: uppercase; }
        .motif-list-scroll { overflow-y: auto; flex: 1; }
        .motif-item { padding: 0.75rem 1rem; border-bottom: 1px solid #2a2a2e; cursor: pointer; display: flex; align-items: center; gap: 0.75rem; }
        .motif-item:hover { background: #1a1a1e; }
        .motif-item.active { background: #1a1a1e; border-left: 3px solid #0a84ff; }
        .motif-id { font-family: monospace; font-weight: 500; min-width: 45px; }
        .motif-status { width: 8px; height: 8px; border-radius: 50%; }
        .motif-preview { font-size: 0.875rem; color: #a1a1a6; flex: 1; white-space: nowrap; overflow: hidden; text-overflow: ellipsis; }
        .motif-preview.has-content { color: #f5f5f7; font-style: italic; }
        .motif-count { font-size: 0.75rem; color: #6e6e73; font-family: monospace; }

        .editor { background: #141416; border-radius: 12px; border: 1px solid #2a2a2e; }
        .editor-header { padding: 1.25rem; border-bottom: 1px solid #2a2a2e; background: #1a1a1e; border-radius: 12px 12px 0 0; }
        .editor-title { font-family: monospace; font-size: 1.5rem; font-weight: 600; }
        .editor-meta { font-size: 0.875rem; color: #a1a1a6; margin-top: 0.25rem; }
        .editor-body { padding: 1.25rem; }

        .section-label { font-size: 0.75rem; font-weight: 600; color: #6e6e73; text-transform: uppercase; letter-spacing: 0.05em; margin-bottom: 0.5rem; }
        .audio-section { margin-bottom: 1.5rem; }
        .audio-samples { display: flex; gap: 0.5rem; flex-wrap: wrap; }
        .audio-btn { display: flex; align-items: center; gap: 0.5rem; padding: 0.625rem 1rem; background: #1a1a1e; border: 1px solid #2a2a2e; border-radius: 8px; color: #f5f5f7; font-size: 0.875rem; cursor: pointer; }
        .audio-btn:hover { background: #242428; border-color: #0a84ff; }
        .audio-btn.playing { background: #0a84ff; border-color: #0a84ff; }

        .form-group { margin-bottom: 1.25rem; }
        .form-label { display: block; font-size: 0.875rem; font-weight: 500; margin-bottom: 0.375rem; color: #a1a1a6; }
        .form-input { width: 100%; padding: 0.75rem; background: #1a1a1e; border: 1px solid #2a2a2e; border-radius: 8px; color: #f5f5f7; font-size: 1rem; }
        .form-input:focus { outline: none; border-color: #0a84ff; }
        textarea.form-input { min-height: 80px; resize: vertical; }

        .category-grid { display: grid; grid-template-columns: repeat(2, 1fr); gap: 0.5rem; }
        .category-btn { display: flex; align-items: center; gap: 0.5rem; padding: 0.75rem; background: #1a1a1e; border: 2px solid #2a2a2e; border-radius: 8px; color: #f5f5f7; font-size: 0.875rem; cursor: pointer; text-align: left; }
        .category-btn:hover { background: #242428; }
        .category-btn.selected.morpheme { border-color: #34c759; background: rgba(52,199,89,0.1); }
        .category-btn.selected.syllable { border-color: #5e5ce6; background: rgba(94,92,230,0.1); }
        .category-btn.selected.noise { border-color: #ff6b6b; background: rgba(255,107,107,0.1); }
        .category-btn.selected.unclear { border-color: #ff9f0a; background: rgba(255,159,10,0.1); }
        .category-icon { width: 24px; height: 24px; border-radius: 6px; display: flex; align-items: center; justify-content: center; font-size: 0.875rem; }
        .category-btn.morpheme .category-icon { background: rgba(52,199,89,0.2); }
        .category-btn.syllable .category-icon { background: rgba(94,92,230,0.2); }
        .category-btn.noise .category-icon { background: rgba(255,107,107,0.2); }
        .category-btn.unclear .category-icon { background: rgba(255,159,10,0.2); }

        .action-bar { display: flex; gap: 0.75rem; margin-top: 1.5rem; padding-top: 1rem; border-top: 1px solid #2a2a2e; flex-wrap: wrap; }
        .btn { padding: 0.75rem 1.25rem; border-radius: 8px; font-size: 0.875rem; font-weight: 600; cursor: pointer; border: none; }
        .btn-primary { background: #0a84ff; color: white; }
        .btn-primary:hover { background: #0077ed; }
        .btn-secondary { background: #1a1a1e; color: #f5f5f7; border: 1px solid #2a2a2e; }
        .btn-secondary:hover { background: #242428; }

        .export-section { margin-top: 1.5rem; padding: 1.25rem; background: #141416; border-radius: 12px; border: 1px solid #2a2a2e; }
        .export-title { font-weight: 600; margin-bottom: 0.75rem; }
        .export-buttons { display: flex; gap: 0.5rem; flex-wrap: wrap; }

        .keyboard-hint { font-size: 0.75rem; color: #6e6e73; margin-top: 0.5rem; }
        kbd { padding: 0.125rem 0.375rem; background: #1a1a1e; border: 1px solid #2a2a2e; border-radius: 4px; font-family: monospace; font-size: 0.6875rem; }
    </style>
</head>
<body>
    <div class="container">
        <header>
            <h1>🔊 SoundScript Mapper</h1>
            <div class="subtitle">Sateré-Mawé · 100 motifs from 45 hours of audio</div>
            <div class="stats">
                <div class="stat"><div class="stat-dot morpheme"></div><span>Morphemes</span><span class="stat-count" id="statMorpheme">0</span></div>
                <div class="stat"><div class="stat-dot syllable"></div><span>Syllables</span><span class="stat-count" id="statSyllable">0</span></div>
                <div class="stat"><div class="stat-dot noise"></div><span>Noise</span><span class="stat-count" id="statNoise">0</span></div>
                <div class="stat"><div class="stat-dot unclear"></div><span>Unclear</span><span class="stat-count" id="statUnclear">0</span></div>
                <div class="stat"><div class="stat-dot pending"></div><span>Pending</span><span class="stat-count" id="statPending">0</span></div>
            </div>
            <div class="progress-bar"><div class="progress-fill" id="progressFill"></div></div>
        </header>

        <div class="main-grid">
            <div class="motif-list">
                <div class="motif-list-header">Motifs (by frequency)</div>
                <div class="motif-list-scroll" id="motifList"></div>
            </div>
            <div class="editor">
                <div class="editor-header">
                    <div class="editor-title" id="editorTitle">Select a motif</div>
                    <div class="editor-meta" id="editorMeta">Click a motif from the list to begin</div>
                </div>
                <div class="editor-body" id="editorBody">
                    <p style="color:#6e6e73;text-align:center;padding:3rem;">👈 Select a motif to start mapping</p>
                </div>
            </div>
        </div>

        <div class="export-section">
            <div class="export-title">Export Mapping Data</div>
            <div class="export-buttons">
                <button class="btn btn-secondary" onclick="exportJson()">📄 JSON</button>
                <button class="btn btn-secondary" onclick="exportCsv()">📊 CSV</button>
                <button class="btn btn-secondary" onclick="exportSummary()">📋 Summary</button>
            </div>
        </div>
    </div>

    <script src="soundscript_data.js"></script>
    <script>
        let motifs = MOTIF_DATA;
        let mappings = JSON.parse(localStorage.getItem('soundscript_mappings_v2') || '{}');
        let currentIndex = -1;
        let currentAudio = null;

        function init() {
            renderList();
            updateStats();
            document.addEventListener('keydown', handleKeys);
        }

        function renderList() {
            const list = document.getElementById('motifList');
            list.innerHTML = motifs.map((m, i) => {
                const map = mappings[m.id] || {};
                const status = map.category || 'pending';
                const preview = map.transcription || (map.category === 'noise' ? '[noise]' : '');
                return \`<div class="motif-item \${i === currentIndex ? 'active' : ''}" onclick="selectMotif(\${i})">
                    <div class="motif-status stat-dot \${status}"></div>
                    <div class="motif-id">\${m.symbol}</div>
                    <div class="motif-preview \${preview ? 'has-content' : ''}">\${preview || 'unmapped'}</div>
                    <div class="motif-count">\${(m.count/1000).toFixed(0)}k</div>
                </div>\`;
            }).join('');
        }

        function selectMotif(i) {
            currentIndex = i;
            const m = motifs[i];
            const map = mappings[m.id] || {};

            document.getElementById('editorTitle').textContent = m.symbol;
            document.getElementById('editorMeta').textContent = \`\${m.count.toLocaleString()} occurrences · \${m.samples.length} samples\`;

            document.getElementById('editorBody').innerHTML = \`
                <div class="audio-section">
                    <div class="section-label">Audio Samples</div>
                    <div class="audio-samples">
                        \${m.samples.map((s, j) => \`<button class="audio-btn" id="audio\${j}" onclick="playAudio('\${s}', \${j})">▶ Sample \${j+1}</button>\`).join('')}
                    </div>
                </div>
                <div class="form-group">
                    <label class="form-label">What do you hear? (Sateré orthography)</label>
                    <input type="text" class="form-input" id="transcription" placeholder="e.g., pe, -iatu, hin" value="\${map.transcription || ''}" onchange="saveField('transcription', this.value)">
                </div>
                <div class="form-group">
                    <div class="section-label">Category</div>
                    <div class="category-grid">
                        <button class="category-btn morpheme \${map.category==='morpheme'?'selected':''}" onclick="saveField('category','morpheme')">
                            <div class="category-icon">✓</div><div>Morpheme</div>
                        </button>
                        <button class="category-btn syllable \${map.category==='syllable'?'selected':''}" onclick="saveField('category','syllable')">
                            <div class="category-icon">◐</div><div>Syllable</div>
                        </button>
                        <button class="category-btn noise \${map.category==='noise'?'selected':''}" onclick="saveField('category','noise')">
                            <div class="category-icon">✗</div><div>Noise/Silence</div>
                        </button>
                        <button class="category-btn unclear \${map.category==='unclear'?'selected':''}" onclick="saveField('category','unclear')">
                            <div class="category-icon">?</div><div>Unclear</div>
                        </button>
                    </div>
                    <div class="keyboard-hint"><kbd>1</kbd> Morpheme <kbd>2</kbd> Syllable <kbd>3</kbd> Noise <kbd>4</kbd> Unclear</div>
                </div>
                <div class="form-group">
                    <label class="form-label">Gloss / Meaning</label>
                    <input type="text" class="form-input" id="gloss" placeholder="e.g., locative 'in, on, at'" value="\${map.gloss || ''}" onchange="saveField('gloss', this.value)">
                </div>
                <div class="form-group">
                    <label class="form-label">Notes</label>
                    <textarea class="form-input" id="notes" placeholder="Any observations..." onchange="saveField('notes', this.value)">\${map.notes || ''}</textarea>
                </div>
                <div class="action-bar">
                    <button class="btn btn-primary" onclick="saveAndNext()">Save & Next →</button>
                    <button class="btn btn-secondary" onclick="navigate(-1)" \${i===0?'disabled':''}>← Prev</button>
                    <button class="btn btn-secondary" onclick="navigate(1)" \${i===motifs.length-1?'disabled':''}>Next →</button>
                </div>
                <div class="keyboard-hint" style="text-align:center;margin-top:0.5rem;"><kbd>←</kbd><kbd>→</kbd> Navigate <kbd>Space</kbd> Play <kbd>Enter</kbd> Save & Next</div>
            \`;
            renderList();
        }

        function playAudio(file, idx) {
            if (currentAudio) { currentAudio.pause(); document.querySelectorAll('.audio-btn').forEach(b => b.classList.remove('playing')); }
            currentAudio = new Audio('motif_samples/' + file);
            document.getElementById('audio' + idx).classList.add('playing');
            currentAudio.onended = () => document.getElementById('audio' + idx).classList.remove('playing');
            currentAudio.play();
        }

        function saveField(field, value) {
            const m = motifs[currentIndex];
            if (!mappings[m.id]) mappings[m.id] = {};
            mappings[m.id][field] = value;
            localStorage.setItem('soundscript_mappings_v2', JSON.stringify(mappings));
            updateStats();
            if (field === 'category') selectMotif(currentIndex);
            else renderList();
        }

        function saveAndNext() {
            if (currentIndex < motifs.length - 1) selectMotif(currentIndex + 1);
        }

        function navigate(dir) {
            const next = currentIndex + dir;
            if (next >= 0 && next < motifs.length) selectMotif(next);
        }

        function handleKeys(e) {
            if (currentIndex < 0) return;
            if (e.target.tagName === 'INPUT' || e.target.tagName === 'TEXTAREA') {
                if (e.key === 'Enter') { e.preventDefault(); saveAndNext(); }
                return;
            }
            if (e.key === 'ArrowLeft') navigate(-1);
            if (e.key === 'ArrowRight') navigate(1);
            if (e.key === ' ') { e.preventDefault(); const m = motifs[currentIndex]; if (m.samples[0]) playAudio(m.samples[0], 0); }
            if (e.key === 'Enter') saveAndNext();
            if (e.key === '1') saveField('category', 'morpheme');
            if (e.key === '2') saveField('category', 'syllable');
            if (e.key === '3') saveField('category', 'noise');
            if (e.key === '4') saveField('category', 'unclear');
        }

        function updateStats() {
            const stats = {morpheme:0, syllable:0, noise:0, unclear:0, pending:0};
            motifs.forEach(m => {
                const cat = mappings[m.id]?.category || 'pending';
                stats[cat]++;
            });
            document.getElementById('statMorpheme').textContent = stats.morpheme;
            document.getElementById('statSyllable').textContent = stats.syllable;
            document.getElementById('statNoise').textContent = stats.noise;
            document.getElementById('statUnclear').textContent = stats.unclear;
            document.getElementById('statPending').textContent = stats.pending;
            document.getElementById('progressFill').style.width = ((100 - stats.pending) / 100 * 100) + '%';
        }

        function exportJson() {
            const data = { language: 'Sateré-Mawé', exported: new Date().toISOString(), mappings: Object.entries(mappings).map(([id, m]) => ({motifId: parseInt(id), ...m})) };
            download(JSON.stringify(data, null, 2), 'soundscript_mapping.json', 'application/json');
        }

        function exportCsv() {
            const rows = ['motif_id,symbol,transcription,category,gloss,notes'];
            motifs.forEach(m => {
                const map = mappings[m.id] || {};
                rows.push([m.id, m.symbol, \`"\${map.transcription||''}"\`, map.category||'', \`"\${map.gloss||''}"\`, \`"\${(map.notes||'').replace(/"/g,'""')}"\`].join(','));
            });
            download(rows.join('\\n'), 'soundscript_mapping.csv', 'text/csv');
        }

        function exportSummary() {
            const stats = {morpheme:[], syllable:[], noise:[], unclear:[]};
            motifs.forEach(m => { const cat = mappings[m.id]?.category; if (cat) stats[cat].push({...m, ...mappings[m.id]}); });
            let txt = \`SOUNDSCRIPT MAPPING SUMMARY\\nLanguage: Sateré-Mawé\\nDate: \${new Date().toLocaleDateString()}\\n\\n\`;
            txt += \`MORPHEMES (\${stats.morpheme.length})\\n\${'='.repeat(40)}\\n\${stats.morpheme.map(m => \`\${m.symbol}: \${m.transcription || '?'} — \${m.gloss || 'no gloss'}\`).join('\\n')||'None'}\\n\\n\`;
            txt += \`SYLLABLES (\${stats.syllable.length})\\n\${'='.repeat(40)}\\n\${stats.syllable.map(m => \`\${m.symbol}: \${m.transcription || '?'}\`).join('\\n')||'None'}\\n\\n\`;
            txt += \`NOISE (\${stats.noise.length})\\n\${'='.repeat(40)}\\n\${stats.noise.map(m => m.symbol).join(', ')||'None'}\\n\`;
            download(txt, 'soundscript_summary.txt', 'text/plain');
        }

        function download(content, filename, type) {
            const a = document.createElement('a');
            a.href = URL.createObjectURL(new Blob([content], {type}));
            a.download = filename;
            a.click();
        }

        init();
    </script>
</body>
</html>'''

# Save to deployment folder
DEPLOY_DIR = f"{PROJECT_ROOT}/soundscript_deploy"
with open(f"{DEPLOY_DIR}/index.html", "w") as f:
    f.write(html_content)

print("✓ Created index.html in soundscript_deploy/")
print("\nNow update the zip file...")

# Update zip
import shutil
zip_path = f"{PROJECT_ROOT}/soundscript_deploy.zip"
if os.path.exists(zip_path):
    os.remove(zip_path)
shutil.make_archive(zip_path.replace('.zip', ''), 'zip', DEPLOY_DIR)

print(f"✓ Updated: {zip_path}")
print("\nDownload soundscript_deploy.zip from Google Drive and deploy to GitHub!")

  return \`<div class="motif-item \${i === currentIndex ? 'active' : ''}" onclick="selectMotif(\${i})">


✓ Created index.html in soundscript_deploy/

Now update the zip file...
✓ Updated: /content/drive/MyDrive/satere_project/soundscript_deploy.zip

Download soundscript_deploy.zip from Google Drive and deploy to GitHub!


In [None]:
# =============================================================================
# VOCODER SETUP: Convert SoundScript motifs back to audio
# =============================================================================

!pip install -q fairseq speechbrain

import os
import json
import numpy as np
import scipy.io.wavfile as wav
from tqdm import tqdm

PROJECT_ROOT = "/content/drive/MyDrive/satere_project"
UNITS_DIR = f"{PROJECT_ROOT}/satere_units"
CONVERTED_DIR = f"{PROJECT_ROOT}/converted_audio"
OUTPUT_DIR = f"{PROJECT_ROOT}/soundscript_full"
VOCODER_DIR = f"{PROJECT_ROOT}/vocoder_output"

os.makedirs(VOCODER_DIR, exist_ok=True)

print("=" * 60)
print("VOCODER: Concatenative Synthesis")
print("=" * 60)
print("Building a simple vocoder from your audio corpus...")

# For each of the 100 acoustic units, extract a representative segment
# This creates a "codebook" of sounds we can concatenate

CODEBOOK_DIR = f"{VOCODER_DIR}/codebook"
os.makedirs(CODEBOOK_DIR, exist_ok=True)

# Load unit files to find examples of each unit
unit_files = [f for f in os.listdir(UNITS_DIR) if f.endswith('.units.txt')]
wav_files = {f.replace('.wav', ''): f for f in os.listdir(CONVERTED_DIR) if f.endswith('.wav')}

print(f"Building codebook from {len(unit_files)} files...")

# For each unit (0-99), collect multiple examples
unit_examples = {i: [] for i in range(100)}

for unit_file in tqdm(unit_files[:100], desc="Scanning for unit examples"):
    base_name = unit_file.replace('.units.txt', '')
    if base_name not in wav_files:
        continue

    # Read units
    with open(os.path.join(UNITS_DIR, unit_file), 'r') as f:
        units = [int(u) for u in f.read().strip().split()]

    # Find positions of each unit
    for pos, unit in enumerate(units):
        if len(unit_examples[unit]) < 10:  # Keep up to 10 examples per unit
            unit_examples[unit].append({
                'file': base_name,
                'position': pos,
                'context_before': units[max(0,pos-2):pos],
                'context_after': units[pos+1:pos+3]
            })

# Extract audio for each unit's best example
print("\nExtracting codebook audio...")

codebook = {}
for unit_id in tqdm(range(100), desc="Building codebook"):
    examples = unit_examples[unit_id]
    if not examples:
        continue

    # Use the first example
    ex = examples[0]
    wav_path = os.path.join(CONVERTED_DIR, wav_files[ex['file']])

    try:
        sr, audio = wav.read(wav_path)

        # Each unit is ~20ms, extract with small buffer
        start_ms = ex['position'] * 20
        end_ms = start_ms + 20

        start_sample = int((start_ms / 1000) * sr)
        end_sample = int((end_ms / 1000) * sr)

        # Add tiny buffer for smoother concatenation
        start_sample = max(0, start_sample - int(0.005 * sr))
        end_sample = min(len(audio), end_sample + int(0.005 * sr))

        segment = audio[start_sample:end_sample]

        if len(segment) > 0:
            codebook_path = f"{CODEBOOK_DIR}/unit_{unit_id:02d}.wav"
            wav.write(codebook_path, sr, segment)
            codebook[unit_id] = {
                'path': codebook_path,
                'duration_ms': len(segment) / sr * 1000
            }
    except Exception as e:
        continue

print(f"\n✓ Codebook built: {len(codebook)} units")

# Save codebook metadata
with open(f"{VOCODER_DIR}/codebook_meta.json", "w") as f:
    json.dump(codebook, f, indent=2)

print(f"  Saved to: {VOCODER_DIR}/codebook_meta.json")

[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/9.6 MB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m[90m━━━━━━━[0m [32m7.9/9.6 MB[0m [31m237.9 MB/s[0m eta [36m0:00:01[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m9.6/9.6 MB[0m [31m240.0 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m9.6/9.6 MB[0m [31m137.3 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Installing backend dependencies ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Requested omegaconf<2.1 from https://files.pythonhosted.org/packages/d0/eb/9d63ce09dd8aa85767c65668d5414958ea29648a0eec80a4a7d311ec2684/omegaconf-2.0.6-py3-none-any.whl (from fairseq) has invalid metadata: .* suffix can only be used with `==` or `!=` operators
    PyY

Scanning for unit examples: 100%|██████████| 100/100 [00:01<00:00, 54.11it/s]



Extracting codebook audio...


Building codebook: 100%|██████████| 100/100 [00:01<00:00, 53.71it/s]


✓ Codebook built: 100 units
  Saved to: /content/drive/MyDrive/satere_project/vocoder_output/codebook_meta.json





In [None]:
# =============================================================================
# VOCODER: Synthesize audio from unit sequences
# =============================================================================

import numpy as np
import scipy.io.wavfile as wav
from scipy import signal
import json
import os

PROJECT_ROOT = "/content/drive/MyDrive/satere_project"
VOCODER_DIR = f"{PROJECT_ROOT}/vocoder_output"
CODEBOOK_DIR = f"{VOCODER_DIR}/codebook"

# Load codebook
with open(f"{VOCODER_DIR}/codebook_meta.json", "r") as f:
    codebook_meta = json.load(f)

# Load all codebook audio into memory
print("Loading codebook into memory...")
codebook_audio = {}
sample_rate = None

for unit_id in range(100):
    unit_path = f"{CODEBOOK_DIR}/unit_{unit_id:02d}.wav"
    if os.path.exists(unit_path):
        sr, audio = wav.read(unit_path)
        sample_rate = sr
        codebook_audio[unit_id] = audio.astype(np.float32)

print(f"✓ Loaded {len(codebook_audio)} units at {sample_rate} Hz")

def synthesize_from_units(unit_sequence, crossfade_ms=5):
    """
    Convert a sequence of unit IDs to audio using concatenative synthesis.

    Args:
        unit_sequence: list of integers (0-99)
        crossfade_ms: crossfade duration for smoothing

    Returns:
        numpy array of audio samples
    """
    if not unit_sequence:
        return np.array([])

    crossfade_samples = int(crossfade_ms / 1000 * sample_rate)

    # Start with first unit
    first_unit = unit_sequence[0]
    if first_unit not in codebook_audio:
        output = np.zeros(int(0.02 * sample_rate))  # 20ms silence
    else:
        output = codebook_audio[first_unit].copy()

    # Concatenate remaining units with crossfade
    for unit_id in unit_sequence[1:]:
        if unit_id not in codebook_audio:
            # Insert silence for missing units
            segment = np.zeros(int(0.02 * sample_rate))
        else:
            segment = codebook_audio[unit_id].copy()

        if len(segment) == 0:
            continue

        # Apply crossfade if possible
        if crossfade_samples > 0 and len(output) > crossfade_samples and len(segment) > crossfade_samples:
            # Create fade out for end of output
            fade_out = np.linspace(1, 0, crossfade_samples)
            # Create fade in for start of segment
            fade_in = np.linspace(0, 1, crossfade_samples)

            # Apply fades
            output[-crossfade_samples:] *= fade_out
            segment[:crossfade_samples] *= fade_in

            # Overlap-add
            output[-crossfade_samples:] += segment[:crossfade_samples]
            output = np.concatenate([output, segment[crossfade_samples:]])
        else:
            output = np.concatenate([output, segment])

    # Normalize
    max_val = np.max(np.abs(output))
    if max_val > 0:
        output = output / max_val * 0.9

    return output.astype(np.int16 if output.dtype != np.int16 else output.dtype)

# Test the synthesizer with a sample from your corpus
print("\n" + "=" * 60)
print("TESTING SYNTHESIZER")
print("=" * 60)

# Load a test unit sequence
UNITS_DIR = f"{PROJECT_ROOT}/satere_units"
test_files = [f for f in os.listdir(UNITS_DIR) if f.endswith('.units.txt')][:1]

for test_file in test_files:
    with open(os.path.join(UNITS_DIR, test_file), 'r') as f:
        units = [int(u) for u in f.read().strip().split()[:500]]  # First 500 units (~10 seconds)

    print(f"\nSynthesizing from: {test_file}")
    print(f"  Units: {len(units)} (~{len(units)*20/1000:.1f} seconds)")

    # Synthesize
    audio = synthesize_from_units(units)

    # Save
    output_path = f"{VOCODER_DIR}/test_synthesis.wav"

    # Convert to int16 if needed
    if audio.dtype == np.float32:
        audio = (audio * 32767).astype(np.int16)

    wav.write(output_path, sample_rate, audio)

    print(f"  ✓ Saved: {output_path}")
    print(f"  Duration: {len(audio)/sample_rate:.1f} seconds")

print("\n" + "=" * 60)
print("✓ VOCODER READY")
print("=" * 60)
print("\nYou can now synthesize audio from any unit sequence!")
print("Download 'test_synthesis.wav' from vocoder_output/ to hear the result.")

Loading codebook into memory...
✓ Loaded 100 units at 16000 Hz

TESTING SYNTHESIZER

Synthesizing from: MAVWYIN1DA_B01_MAT_001.units.txt
  Units: 500 (~10.0 seconds)
  ✓ Saved: /content/drive/MyDrive/satere_project/vocoder_output/test_synthesis.wav
  Duration: 12.2 seconds

✓ VOCODER READY

You can now synthesize audio from any unit sequence!
Download 'test_synthesis.wav' from vocoder_output/ to hear the result.


In [None]:
# Debug: Check what's in the codebook
import os
import scipy.io.wavfile as wav
import numpy as np

VOCODER_DIR = f"{PROJECT_ROOT}/vocoder_output"
CODEBOOK_DIR = f"{VOCODER_DIR}/codebook"

print("Checking codebook files...")
print("=" * 60)

# Check a few codebook files
for unit_id in [0, 29, 82, 55]:  # Including some frequent units
    unit_path = f"{CODEBOOK_DIR}/unit_{unit_id:02d}.wav"
    if os.path.exists(unit_path):
        sr, audio = wav.read(unit_path)
        duration_ms = len(audio) / sr * 1000
        max_amplitude = np.max(np.abs(audio))
        print(f"Unit {unit_id:02d}: {duration_ms:.1f}ms, max amplitude: {max_amplitude}, samples: {len(audio)}")
    else:
        print(f"Unit {unit_id:02d}: FILE MISSING")

# Check the test synthesis file
test_path = f"{VOCODER_DIR}/test_synthesis.wav"
if os.path.exists(test_path):
    sr, audio = wav.read(test_path)
    print(f"\nTest synthesis file:")
    print(f"  Duration: {len(audio)/sr:.1f} seconds")
    print(f"  Max amplitude: {np.max(np.abs(audio))}")
    print(f"  Min amplitude: {np.min(audio)}")
    print(f"  Data type: {audio.dtype}")
    print(f"  Non-zero samples: {np.count_nonzero(audio)}")

# Check what the original audio looks like
CONVERTED_DIR = f"{PROJECT_ROOT}/converted_audio"
sample_wav = [f for f in os.listdir(CONVERTED_DIR) if f.endswith('.wav')][0]
sr, orig_audio = wav.read(os.path.join(CONVERTED_DIR, sample_wav))
print(f"\nOriginal audio sample ({sample_wav}):")
print(f"  Max amplitude: {np.max(np.abs(orig_audio))}")
print(f"  Data type: {orig_audio.dtype}")

Checking codebook files...
Unit 00: 30.0ms, max amplitude: 15010, samples: 480
Unit 29: 30.0ms, max amplitude: 2, samples: 480
Unit 82: 30.0ms, max amplitude: 8, samples: 480
Unit 55: 30.0ms, max amplitude: 10759, samples: 480

Test synthesis file:
  Duration: 12.2 seconds
  Max amplitude: 0
  Min amplitude: 0
  Data type: int16
  Non-zero samples: 0

Original audio sample (MAVWYIN1DA_B01_MAT_001.wav):
  Max amplitude: 23706
  Data type: int16


In [None]:
# =============================================================================
# FIXED VOCODER: Synthesize audio from unit sequences
# =============================================================================

import numpy as np
import scipy.io.wavfile as wav
import os

PROJECT_ROOT = "/content/drive/MyDrive/satere_project"
VOCODER_DIR = f"{PROJECT_ROOT}/vocoder_output"
CODEBOOK_DIR = f"{VOCODER_DIR}/codebook"

# Reload codebook as int16 (original format)
print("Reloading codebook (keeping original int16 format)...")
codebook_audio = {}
sample_rate = None

for unit_id in range(100):
    unit_path = f"{CODEBOOK_DIR}/unit_{unit_id:02d}.wav"
    if os.path.exists(unit_path):
        sr, audio = wav.read(unit_path)
        sample_rate = sr
        codebook_audio[unit_id] = audio  # Keep as int16

print(f"✓ Loaded {len(codebook_audio)} units")

# Check amplitude distribution
amplitudes = [(uid, np.max(np.abs(audio))) for uid, audio in codebook_audio.items()]
amplitudes.sort(key=lambda x: x[1], reverse=True)
print("\nTop 10 loudest units:")
for uid, amp in amplitudes[:10]:
    print(f"  Unit {uid:02d}: amplitude {amp}")

print("\nTop 10 quietest units:")
for uid, amp in amplitudes[-10:]:
    print(f"  Unit {uid:02d}: amplitude {amp}")

def synthesize_fixed(unit_sequence):
    """Simple concatenation without fancy processing"""
    if not unit_sequence:
        return np.array([], dtype=np.int16)

    segments = []
    for unit_id in unit_sequence:
        if unit_id in codebook_audio:
            segments.append(codebook_audio[unit_id])
        else:
            # 20ms of silence for missing units
            segments.append(np.zeros(int(0.02 * sample_rate), dtype=np.int16))

    # Simple concatenation
    output = np.concatenate(segments)
    return output

# Test with the same file
UNITS_DIR = f"{PROJECT_ROOT}/satere_units"
test_file = "MAVWYIN1DA_B01_MAT_001.units.txt"

with open(os.path.join(UNITS_DIR, test_file), 'r') as f:
    units = [int(u) for u in f.read().strip().split()[:500]]

print(f"\nSynthesizing {len(units)} units...")
print(f"First 20 units: {units[:20]}")

# Count how many units have good amplitude
good_units = sum(1 for u in units if u in codebook_audio and np.max(np.abs(codebook_audio[u])) > 100)
print(f"Units with amplitude > 100: {good_units}/{len(units)}")

# Synthesize
audio = synthesize_fixed(units)

print(f"\nOutput audio:")
print(f"  Samples: {len(audio)}")
print(f"  Max amplitude: {np.max(np.abs(audio))}")
print(f"  Duration: {len(audio)/sample_rate:.1f}s")

# Save
output_path = f"{VOCODER_DIR}/test_synthesis_fixed.wav"
wav.write(output_path, sample_rate, audio)
print(f"\n✓ Saved: {output_path}")

Reloading codebook (keeping original int16 format)...
✓ Loaded 100 units

Top 10 loudest units:
  Unit 42: amplitude 19580
  Unit 84: amplitude 19508
  Unit 06: amplitude 18949
  Unit 27: amplitude 18920
  Unit 62: amplitude 18329
  Unit 95: amplitude 18048
  Unit 31: amplitude 17993
  Unit 53: amplitude 17946
  Unit 90: amplitude 17907
  Unit 20: amplitude 17193

Top 10 quietest units:
  Unit 82: amplitude 8
  Unit 04: amplitude 7
  Unit 54: amplitude 3
  Unit 13: amplitude 2
  Unit 28: amplitude 2
  Unit 29: amplitude 2
  Unit 21: amplitude 1
  Unit 26: amplitude 1
  Unit 49: amplitude 1
  Unit 86: amplitude 1

Synthesizing 500 units...
First 20 units: [26, 26, 26, 26, 26, 49, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26]
Units with amplitude > 100: 291/500

Output audio:
  Samples: 235437
  Max amplitude: 19580
  Duration: 14.7s

✓ Saved: /content/drive/MyDrive/satere_project/vocoder_output/test_synthesis_fixed.wav


In [None]:
# =============================================================================
# REBUILD CODEBOOK: Find examples with actual audio content
# =============================================================================

import numpy as np
import scipy.io.wavfile as wav
import os
from tqdm import tqdm

PROJECT_ROOT = "/content/drive/MyDrive/satere_project"
UNITS_DIR = f"{PROJECT_ROOT}/satere_units"
CONVERTED_DIR = f"{PROJECT_ROOT}/converted_audio"
VOCODER_DIR = f"{PROJECT_ROOT}/vocoder_output"
CODEBOOK_DIR = f"{VOCODER_DIR}/codebook_v2"

os.makedirs(CODEBOOK_DIR, exist_ok=True)

print("=" * 60)
print("REBUILDING CODEBOOK WITH HIGH-AMPLITUDE EXAMPLES")
print("=" * 60)

# Get file lists
unit_files = [f for f in os.listdir(UNITS_DIR) if f.endswith('.units.txt')]
wav_files = {f.replace('.wav', ''): f for f in os.listdir(CONVERTED_DIR) if f.endswith('.wav')}

# For each unit, find the LOUDEST example (not just the first one)
print("Scanning for best examples of each unit...")

best_examples = {i: {'amplitude': 0, 'file': None, 'position': None} for i in range(100)}

for unit_file in tqdm(unit_files[:80], desc="Scanning files"):
    base_name = unit_file.replace('.units.txt', '')
    if base_name not in wav_files:
        continue

    # Load audio
    wav_path = os.path.join(CONVERTED_DIR, wav_files[base_name])
    try:
        sr, audio = wav.read(wav_path)
    except:
        continue

    # Load units
    with open(os.path.join(UNITS_DIR, unit_file), 'r') as f:
        units = [int(u) for u in f.read().strip().split()]

    # Check each unit position
    for pos, unit_id in enumerate(units):
        # Extract this unit's audio
        start_sample = int(pos * 0.02 * sr)
        end_sample = int((pos + 1) * 0.02 * sr)

        if end_sample > len(audio):
            continue

        segment = audio[start_sample:end_sample]
        amplitude = np.max(np.abs(segment))

        # Keep if this is the loudest example we've found
        if amplitude > best_examples[unit_id]['amplitude']:
            best_examples[unit_id] = {
                'amplitude': amplitude,
                'file': base_name,
                'position': pos
            }

# Now extract the best examples
print("\nExtracting best examples...")

codebook_audio = {}
sample_rate = 16000

for unit_id in tqdm(range(100), desc="Building codebook"):
    ex = best_examples[unit_id]
    if ex['file'] is None:
        continue

    wav_path = os.path.join(CONVERTED_DIR, wav_files[ex['file']])
    sr, audio = wav.read(wav_path)
    sample_rate = sr

    # Extract with small buffer for smoother transitions
    start_sample = int(ex['position'] * 0.02 * sr)
    end_sample = int((ex['position'] + 1) * 0.02 * sr)

    # Add 5ms buffer on each side
    buffer = int(0.005 * sr)
    start_sample = max(0, start_sample - buffer)
    end_sample = min(len(audio), end_sample + buffer)

    segment = audio[start_sample:end_sample]

    if len(segment) > 0:
        codebook_path = f"{CODEBOOK_DIR}/unit_{unit_id:02d}.wav"
        wav.write(codebook_path, sr, segment)
        codebook_audio[unit_id] = segment

# Report results
print("\n" + "=" * 60)
print("CODEBOOK V2 STATISTICS")
print("=" * 60)

amplitudes = [(uid, best_examples[uid]['amplitude']) for uid in range(100)]
amplitudes.sort(key=lambda x: x[1], reverse=True)

print("\nTop 10 units by amplitude:")
for uid, amp in amplitudes[:10]:
    print(f"  Unit {uid:02d}: {amp}")

print("\nBottom 10 units by amplitude:")
for uid, amp in amplitudes[-10:]:
    print(f"  Unit {uid:02d}: {amp}")

low_amplitude = sum(1 for _, amp in amplitudes if amp < 500)
print(f"\nUnits with low amplitude (<500): {low_amplitude}")

# Save metadata
import json
meta = {unit_id: {'amplitude': int(best_examples[unit_id]['amplitude'])} for unit_id in range(100)}
with open(f"{VOCODER_DIR}/codebook_v2_meta.json", "w") as f:
    json.dump(meta, f, indent=2)

print(f"\n✓ Codebook v2 saved to: {CODEBOOK_DIR}")

REBUILDING CODEBOOK WITH HIGH-AMPLITUDE EXAMPLES
Scanning for best examples of each unit...


Scanning files: 100%|██████████| 80/80 [00:51<00:00,  1.57it/s]



Extracting best examples...


Building codebook: 100%|██████████| 100/100 [00:03<00:00, 31.96it/s]


CODEBOOK V2 STATISTICS

Top 10 units by amplitude:
  Unit 46: 25361
  Unit 20: 25218
  Unit 10: 25194
  Unit 44: 25156
  Unit 90: 24929
  Unit 28: 24815
  Unit 00: 24801
  Unit 81: 24728
  Unit 76: 24624
  Unit 55: 24613

Bottom 10 units by amplitude:
  Unit 32: 1168
  Unit 99: 1089
  Unit 02: 780
  Unit 39: 391
  Unit 56: 265
  Unit 04: 264
  Unit 82: 218
  Unit 26: 214
  Unit 29: 104
  Unit 78: 91

Units with low amplitude (<500): 7

✓ Codebook v2 saved to: /content/drive/MyDrive/satere_project/vocoder_output/codebook_v2





In [None]:
# =============================================================================
# SYNTHESIZE WITH CODEBOOK V2
# =============================================================================

import numpy as np
import scipy.io.wavfile as wav
import os

PROJECT_ROOT = "/content/drive/MyDrive/satere_project"
VOCODER_DIR = f"{PROJECT_ROOT}/vocoder_output"
CODEBOOK_DIR = f"{VOCODER_DIR}/codebook_v2"
UNITS_DIR = f"{PROJECT_ROOT}/satere_units"

# Load codebook v2
print("Loading codebook v2...")
codebook_audio = {}
sample_rate = 16000

for unit_id in range(100):
    unit_path = f"{CODEBOOK_DIR}/unit_{unit_id:02d}.wav"
    if os.path.exists(unit_path):
        sr, audio = wav.read(unit_path)
        sample_rate = sr
        codebook_audio[unit_id] = audio

print(f"✓ Loaded {len(codebook_audio)} units")

def synthesize(unit_sequence):
    """Concatenate units to produce audio"""
    segments = []
    for unit_id in unit_sequence:
        if unit_id in codebook_audio:
            segments.append(codebook_audio[unit_id])
        else:
            segments.append(np.zeros(int(0.02 * sample_rate), dtype=np.int16))
    return np.concatenate(segments) if segments else np.array([], dtype=np.int16)

# Test 1: Synthesize from Matthew 1
test_file = "MAVWYIN1DA_B01_MAT_001.units.txt"
with open(os.path.join(UNITS_DIR, test_file), 'r') as f:
    units = [int(u) for u in f.read().strip().split()]

# Take first 1000 units (~20 seconds)
units_subset = units[:1000]

print(f"\nSynthesizing Matthew 1 ({len(units_subset)} units, ~{len(units_subset)*0.02:.0f}s)...")
audio = synthesize(units_subset)

output_path = f"{VOCODER_DIR}/matthew1_synthesis.wav"
wav.write(output_path, sample_rate, audio)
print(f"✓ Saved: {output_path}")
print(f"  Duration: {len(audio)/sample_rate:.1f}s")
print(f"  Max amplitude: {np.max(np.abs(audio))}")

# Test 2: Also synthesize from a different book for comparison
test_file2 = [f for f in os.listdir(UNITS_DIR) if 'JHN' in f or 'MRK' in f or 'LUK' in f]
if test_file2:
    with open(os.path.join(UNITS_DIR, test_file2[0]), 'r') as f:
        units2 = [int(u) for u in f.read().strip().split()[:1000]]

    print(f"\nSynthesizing {test_file2[0][:20]}... ({len(units2)} units)")
    audio2 = synthesize(units2)

    output_path2 = f"{VOCODER_DIR}/other_book_synthesis.wav"
    wav.write(output_path2, sample_rate, audio2)
    print(f"✓ Saved: {output_path2}")

print("\n" + "=" * 60)
print("Download these files from vocoder_output/ and listen:")
print("  - matthew1_synthesis.wav")
print("  - other_book_synthesis.wav (if created)")
print("=" * 60)

Loading codebook v2...
✓ Loaded 100 units

Synthesizing Matthew 1 (1000 units, ~20s)...
✓ Saved: /content/drive/MyDrive/satere_project/vocoder_output/matthew1_synthesis.wav
  Duration: 30.0s
  Max amplitude: 25361

Synthesizing MAVWYIN1DA_B02_MRK_0... (1000 units)
✓ Saved: /content/drive/MyDrive/satere_project/vocoder_output/other_book_synthesis.wav

Download these files from vocoder_output/ and listen:
  - matthew1_synthesis.wav
  - other_book_synthesis.wav (if created)


In [None]:
# =============================================================================
# IMPROVED SYNTHESIS WITH CROSSFADE AND FILTERING
# =============================================================================

import numpy as np
import scipy.io.wavfile as wav
from scipy import signal
import os

PROJECT_ROOT = "/content/drive/MyDrive/satere_project"
VOCODER_DIR = f"{PROJECT_ROOT}/vocoder_output"
CODEBOOK_DIR = f"{VOCODER_DIR}/codebook_v2"
UNITS_DIR = f"{PROJECT_ROOT}/satere_units"

# Load codebook v2
print("Loading codebook v2...")
codebook_audio = {}
sample_rate = 16000

for unit_id in range(100):
    unit_path = f"{CODEBOOK_DIR}/unit_{unit_id:02d}.wav"
    if os.path.exists(unit_path):
        sr, audio = wav.read(unit_path)
        sample_rate = sr
        # Convert to float for processing
        codebook_audio[unit_id] = audio.astype(np.float32) / 32768.0

print(f"✓ Loaded {len(codebook_audio)} units")

def synthesize_smooth(unit_sequence, crossfade_ms=10):
    """
    Concatenate units with crossfade to reduce metallic artifacts.
    """
    if not unit_sequence:
        return np.array([], dtype=np.int16)

    crossfade_samples = int(crossfade_ms / 1000 * sample_rate)

    # Get first segment
    if unit_sequence[0] in codebook_audio:
        output = codebook_audio[unit_sequence[0]].copy()
    else:
        output = np.zeros(int(0.02 * sample_rate), dtype=np.float32)

    # Concatenate with crossfade
    for unit_id in unit_sequence[1:]:
        if unit_id in codebook_audio:
            segment = codebook_audio[unit_id].copy()
        else:
            segment = np.zeros(int(0.02 * sample_rate), dtype=np.float32)

        if len(segment) < crossfade_samples or len(output) < crossfade_samples:
            output = np.concatenate([output, segment])
            continue

        # Create crossfade
        fade_out = np.linspace(1.0, 0.0, crossfade_samples)
        fade_in = np.linspace(0.0, 1.0, crossfade_samples)

        # Apply fades
        output[-crossfade_samples:] *= fade_out
        segment[:crossfade_samples] *= fade_in

        # Overlap-add
        output[-crossfade_samples:] += segment[:crossfade_samples]
        output = np.concatenate([output, segment[crossfade_samples:]])

    # Apply gentle low-pass filter to reduce harshness
    nyquist = sample_rate / 2
    cutoff = 7000  # Hz - preserve speech but reduce high-frequency artifacts
    b, a = signal.butter(4, cutoff / nyquist, btype='low')
    output = signal.filtfilt(b, a, output)

    # Normalize
    max_val = np.max(np.abs(output))
    if max_val > 0:
        output = output / max_val * 0.85

    # Convert back to int16
    return (output * 32767).astype(np.int16)

# Synthesize Matthew 1 with smoothing
test_file = "MAVWYIN1DA_B01_MAT_001.units.txt"
with open(os.path.join(UNITS_DIR, test_file), 'r') as f:
    units = [int(u) for u in f.read().strip().split()[:1000]]

print(f"\nSynthesizing with crossfade smoothing...")
audio_smooth = synthesize_smooth(units, crossfade_ms=10)

output_path = f"{VOCODER_DIR}/matthew1_smooth.wav"
wav.write(output_path, sample_rate, audio_smooth)
print(f"✓ Saved: {output_path}")
print(f"  Duration: {len(audio_smooth)/sample_rate:.1f}s")

# Also try with longer crossfade
print(f"\nSynthesizing with longer crossfade (15ms)...")
audio_smoother = synthesize_smooth(units, crossfade_ms=15)

output_path2 = f"{VOCODER_DIR}/matthew1_smoother.wav"
wav.write(output_path2, sample_rate, audio_smoother)
print(f"✓ Saved: {output_path2}")

print("\n" + "=" * 60)
print("Download and compare:")
print("  - matthew1_smooth.wav (10ms crossfade)")
print("  - matthew1_smoother.wav (15ms crossfade)")
print("These should have less metallic noise.")
print("=" * 60)

Loading codebook v2...
✓ Loaded 100 units

Synthesizing with crossfade smoothing...
✓ Saved: /content/drive/MyDrive/satere_project/vocoder_output/matthew1_smooth.wav
  Duration: 20.0s

Synthesizing with longer crossfade (15ms)...
✓ Saved: /content/drive/MyDrive/satere_project/vocoder_output/matthew1_smoother.wav

Download and compare:
  - matthew1_smooth.wav (10ms crossfade)
  - matthew1_smoother.wav (15ms crossfade)
These should have less metallic noise.


In [None]:
# Convert to MP3 for WhatsApp
import subprocess
import os

VOCODER_DIR = f"{PROJECT_ROOT}/vocoder_output"

# Convert the smoothest version to MP3
input_file = f"{VOCODER_DIR}/matthew1_smoother.wav"
output_file = f"{VOCODER_DIR}/matthew1_whatsapp.mp3"

subprocess.run([
    "ffmpeg", "-y", "-i", input_file,
    "-codec:a", "libmp3lame", "-qscale:a", "2",
    output_file
], capture_output=True)

print(f"✓ Created: {output_file}")

# Also make a shorter version (first 10 seconds) in case they want a quick sample
input_file2 = f"{VOCODER_DIR}/matthew1_smoother.wav"
output_file2 = f"{VOCODER_DIR}/matthew1_short_whatsapp.mp3"

subprocess.run([
    "ffmpeg", "-y", "-i", input_file2,
    "-t", "10",  # First 10 seconds only
    "-codec:a", "libmp3lame", "-qscale:a", "2",
    output_file2
], capture_output=True)

print(f"✓ Created: {output_file2} (10 seconds)")

print("\nDownload from vocoder_output/:")
print("  - matthew1_whatsapp.mp3 (full ~20 seconds)")
print("  - matthew1_short_whatsapp.mp3 (10 seconds)")

✓ Created: /content/drive/MyDrive/satere_project/vocoder_output/matthew1_whatsapp.mp3
✓ Created: /content/drive/MyDrive/satere_project/vocoder_output/matthew1_short_whatsapp.mp3 (10 seconds)

Download from vocoder_output/:
  - matthew1_whatsapp.mp3 (full ~20 seconds)
  - matthew1_short_whatsapp.mp3 (10 seconds)


In [None]:
# =============================================================================
# NEURAL VOCODER TRAINING: HiFi-GAN for Sateré-Mawé
# =============================================================================
# This will take 4-8 hours to train. Start it and let it run overnight.

# Step 1: Install dependencies
!pip install -q torch torchaudio tensorboard

# Clone HiFi-GAN repository
import os
if not os.path.exists('/content/hifi-gan'):
    !git clone https://github.com/jik876/hifi-gan.git /content/hifi-gan

print("✓ Dependencies installed")

Cloning into '/content/hifi-gan'...
remote: Enumerating objects: 48, done.[K
remote: Total 48 (delta 0), reused 0 (delta 0), pack-reused 48 (from 1)[K
Receiving objects: 100% (48/48), 620.94 KiB | 22.18 MiB/s, done.
Resolving deltas: 100% (20/20), done.
✓ Dependencies installed


In [None]:
# =============================================================================
# STEP 2: PREPARE TRAINING DATA FOR HiFi-GAN
# =============================================================================

import os
import json
import numpy as np
import scipy.io.wavfile as wav
from tqdm import tqdm
import random

PROJECT_ROOT = "/content/drive/MyDrive/satere_project"
UNITS_DIR = f"{PROJECT_ROOT}/satere_units"
CONVERTED_DIR = f"{PROJECT_ROOT}/converted_audio"
HIFIGAN_DATA = f"{PROJECT_ROOT}/hifigan_training"

os.makedirs(f"{HIFIGAN_DATA}/wavs", exist_ok=True)
os.makedirs(f"{HIFIGAN_DATA}/units", exist_ok=True)

print("=" * 60)
print("PREPARING HiFi-GAN TRAINING DATA")
print("=" * 60)

# Get all files
unit_files = sorted([f for f in os.listdir(UNITS_DIR) if f.endswith('.units.txt')])
print(f"Found {len(unit_files)} unit files")

# We'll create segments of ~3-10 seconds for training
# This gives the model enough context to learn prosody

SEGMENT_LENGTH_SEC = 5  # 5 second segments
SAMPLE_RATE = 16000
SEGMENT_SAMPLES = SEGMENT_LENGTH_SEC * SAMPLE_RATE
UNITS_PER_SECOND = 50  # 20ms per unit = 50 units/second
SEGMENT_UNITS = SEGMENT_LENGTH_SEC * UNITS_PER_SECOND

training_pairs = []
segment_id = 0

print(f"\nCreating {SEGMENT_LENGTH_SEC}-second segments...")

for unit_file in tqdm(unit_files[:150], desc="Processing files"):  # Use 150 files for training
    base_name = unit_file.replace('.units.txt', '')
    wav_path = os.path.join(CONVERTED_DIR, base_name + '.wav')

    if not os.path.exists(wav_path):
        continue

    # Load audio
    try:
        sr, audio = wav.read(wav_path)
    except:
        continue

    # Load units
    with open(os.path.join(UNITS_DIR, unit_file), 'r') as f:
        units = [int(u) for u in f.read().strip().split()]

    # Create segments
    num_segments = len(audio) // SEGMENT_SAMPLES

    for seg_idx in range(num_segments):
        # Audio segment
        start_sample = seg_idx * SEGMENT_SAMPLES
        end_sample = start_sample + SEGMENT_SAMPLES
        audio_segment = audio[start_sample:end_sample]

        # Corresponding units
        start_unit = seg_idx * SEGMENT_UNITS
        end_unit = start_unit + SEGMENT_UNITS

        if end_unit > len(units):
            continue

        unit_segment = units[start_unit:end_unit]

        # Skip silent segments
        if np.max(np.abs(audio_segment)) < 500:
            continue

        # Save audio segment
        seg_name = f"seg_{segment_id:05d}"
        wav_out = f"{HIFIGAN_DATA}/wavs/{seg_name}.wav"
        wav.write(wav_out, sr, audio_segment)

        # Save unit sequence
        unit_out = f"{HIFIGAN_DATA}/units/{seg_name}.txt"
        with open(unit_out, 'w') as f:
            f.write(' '.join(map(str, unit_segment)))

        training_pairs.append({
            'id': seg_name,
            'source_file': base_name,
            'duration_sec': SEGMENT_LENGTH_SEC
        })

        segment_id += 1

print(f"\n✓ Created {len(training_pairs)} training segments")
print(f"  Total duration: {len(training_pairs) * SEGMENT_LENGTH_SEC / 60:.1f} minutes")

# Split into train/validation
random.shuffle(training_pairs)
val_size = min(100, len(training_pairs) // 10)
train_pairs = training_pairs[val_size:]
val_pairs = training_pairs[:val_size]

# Save file lists
with open(f"{HIFIGAN_DATA}/train_files.txt", 'w') as f:
    for p in train_pairs:
        f.write(p['id'] + '\n')

with open(f"{HIFIGAN_DATA}/val_files.txt", 'w') as f:
    for p in val_pairs:
        f.write(p['id'] + '\n')

# Save metadata
metadata = {
    'num_train': len(train_pairs),
    'num_val': len(val_pairs),
    'segment_length_sec': SEGMENT_LENGTH_SEC,
    'sample_rate': SAMPLE_RATE,
    'num_units': 100
}

with open(f"{HIFIGAN_DATA}/metadata.json", 'w') as f:
    json.dump(metadata, f, indent=2)

print(f"\n✓ Training set: {len(train_pairs)} segments")
print(f"✓ Validation set: {len(val_pairs)} segments")
print(f"✓ Data saved to: {HIFIGAN_DATA}")

PREPARING HiFi-GAN TRAINING DATA
Found 229 unit files

Creating 5-second segments...


Processing files: 100%|██████████| 150/150 [05:09<00:00,  2.06s/it]


✓ Created 22934 training segments
  Total duration: 1911.2 minutes

✓ Training set: 22834 segments
✓ Validation set: 100 segments
✓ Data saved to: /content/drive/MyDrive/satere_project/hifigan_training





In [None]:
# =============================================================================
# STEP 3: CREATE UNIT-BASED HiFi-GAN MODEL
# =============================================================================

import os
import sys
sys.path.append('/content/hifi-gan')

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
import scipy.io.wavfile as wav
from tqdm import tqdm
import json

PROJECT_ROOT = "/content/drive/MyDrive/satere_project"
HIFIGAN_DATA = f"{PROJECT_ROOT}/hifigan_training"
MODEL_DIR = f"{PROJECT_ROOT}/hifigan_model"
os.makedirs(MODEL_DIR, exist_ok=True)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# =============================================================================
# Dataset
# =============================================================================

class UnitAudioDataset(Dataset):
    def __init__(self, data_dir, file_list, segment_length=16000*5):
        self.data_dir = data_dir
        self.segment_length = segment_length

        with open(file_list, 'r') as f:
            self.files = [line.strip() for line in f.readlines()]

        print(f"  Loaded {len(self.files)} files")

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        name = self.files[idx]

        # Load audio
        wav_path = f"{self.data_dir}/wavs/{name}.wav"
        sr, audio = wav.read(wav_path)
        audio = audio.astype(np.float32) / 32768.0

        # Load units
        unit_path = f"{self.data_dir}/units/{name}.txt"
        with open(unit_path, 'r') as f:
            units = [int(u) for u in f.read().strip().split()]

        return torch.FloatTensor(audio), torch.LongTensor(units)

# =============================================================================
# Simple but effective Unit-to-Audio Generator
# =============================================================================

class UnitVocoder(nn.Module):
    def __init__(self, num_units=100, unit_embed_dim=256, upsample_rates=[5, 5, 4, 4, 2]):
        super().__init__()

        # Unit embedding
        self.unit_embed = nn.Embedding(num_units, unit_embed_dim)

        # Initial convolution
        self.pre_conv = nn.Conv1d(unit_embed_dim, 512, kernel_size=7, padding=3)

        # Upsampling layers (20ms units -> 16kHz audio)
        # 50 units/sec -> 16000 samples/sec = 320x upsampling
        # 5 * 5 * 4 * 4 * 2 = 800, we'll trim
        self.upsample_rates = upsample_rates
        self.upsamples = nn.ModuleList()

        channels = 512
        for i, rate in enumerate(upsample_rates):
            out_channels = channels // 2 if i < len(upsample_rates) - 1 else 64
            self.upsamples.append(
                nn.Sequential(
                    nn.ConvTranspose1d(channels, out_channels,
                                       kernel_size=rate*2, stride=rate, padding=rate//2),
                    nn.LeakyReLU(0.1),
                    nn.Conv1d(out_channels, out_channels, kernel_size=7, padding=3),
                    nn.LeakyReLU(0.1),
                )
            )
            channels = out_channels

        # Output layer
        self.post_conv = nn.Conv1d(64, 1, kernel_size=7, padding=3)
        self.tanh = nn.Tanh()

    def forward(self, units):
        # units: (batch, seq_len)
        x = self.unit_embed(units)  # (batch, seq_len, embed_dim)
        x = x.transpose(1, 2)  # (batch, embed_dim, seq_len)

        x = self.pre_conv(x)
        x = F.leaky_relu(x, 0.1)

        for upsample in self.upsamples:
            x = upsample(x)

        x = self.post_conv(x)
        x = self.tanh(x)

        return x.squeeze(1)  # (batch, audio_len)

# =============================================================================
# Multi-Scale Discriminator
# =============================================================================

class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()

        self.discriminators = nn.ModuleList([
            self._make_discriminator(1),
            self._make_discriminator(2),
            self._make_discriminator(4),
        ])
        self.pools = nn.ModuleList([
            nn.Identity(),
            nn.AvgPool1d(2, 2),
            nn.AvgPool1d(4, 4),
        ])

    def _make_discriminator(self, scale):
        return nn.Sequential(
            nn.Conv1d(1, 64, kernel_size=15, padding=7),
            nn.LeakyReLU(0.1),
            nn.Conv1d(64, 128, kernel_size=41, stride=4, padding=20, groups=4),
            nn.LeakyReLU(0.1),
            nn.Conv1d(128, 256, kernel_size=41, stride=4, padding=20, groups=16),
            nn.LeakyReLU(0.1),
            nn.Conv1d(256, 512, kernel_size=41, stride=4, padding=20, groups=16),
            nn.LeakyReLU(0.1),
            nn.Conv1d(512, 512, kernel_size=5, padding=2),
            nn.LeakyReLU(0.1),
            nn.Conv1d(512, 1, kernel_size=3, padding=1),
        )

    def forward(self, x):
        outputs = []
        for pool, disc in zip(self.pools, self.discriminators):
            x_pooled = pool(x)
            outputs.append(disc(x_pooled))
        return outputs

# =============================================================================
# Initialize models
# =============================================================================

print("\nInitializing models...")
generator = UnitVocoder(num_units=100).to(device)
discriminator = Discriminator().to(device)

num_params_g = sum(p.numel() for p in generator.parameters())
num_params_d = sum(p.numel() for p in discriminator.parameters())
print(f"  Generator parameters: {num_params_g:,}")
print(f"  Discriminator parameters: {num_params_d:,}")

# =============================================================================
# Create data loaders
# =============================================================================

print("\nLoading datasets...")
train_dataset = UnitAudioDataset(HIFIGAN_DATA, f"{HIFIGAN_DATA}/train_files.txt")
val_dataset = UnitAudioDataset(HIFIGAN_DATA, f"{HIFIGAN_DATA}/val_files.txt")

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, num_workers=2)

print(f"  Train batches: {len(train_loader)}")
print(f"  Val batches: {len(val_loader)}")

print("\n✓ Models and data ready for training!")

Using device: cuda

Initializing models...
  Generator parameters: 3,311,617
  Discriminator parameters: 5,455,491

Loading datasets...
  Loaded 22834 files
  Loaded 100 files
  Train batches: 1428
  Val batches: 25

✓ Models and data ready for training!


In [None]:
# =============================================================================
# STEP 4: TRAIN THE VOCODER
# =============================================================================

import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
import time

# Training settings
NUM_EPOCHS = 50
LEARNING_RATE = 0.0002
SAVE_EVERY = 5  # Save checkpoint every N epochs

# Optimizers
optimizer_g = optim.AdamW(generator.parameters(), lr=LEARNING_RATE, betas=(0.8, 0.99))
optimizer_d = optim.AdamW(discriminator.parameters(), lr=LEARNING_RATE, betas=(0.8, 0.99))

# Learning rate schedulers
scheduler_g = optim.lr_scheduler.ExponentialLR(optimizer_g, gamma=0.99)
scheduler_d = optim.lr_scheduler.ExponentialLR(optimizer_d, gamma=0.99)

# Loss functions
l1_loss = nn.L1Loss()
mse_loss = nn.MSELoss()

# Tensorboard
writer = SummaryWriter(f"{MODEL_DIR}/logs")

print("=" * 60)
print("TRAINING UNIT VOCODER")
print("=" * 60)
print(f"  Epochs: {NUM_EPOCHS}")
print(f"  Batch size: 16")
print(f"  Learning rate: {LEARNING_RATE}")
print(f"  Device: {device}")
print("=" * 60)

global_step = 0
best_val_loss = float('inf')

for epoch in range(NUM_EPOCHS):
    epoch_start = time.time()
    generator.train()
    discriminator.train()

    total_loss_g = 0
    total_loss_d = 0
    num_batches = 0

    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS}")

    for batch_idx, (audio, units) in enumerate(progress_bar):
        audio = audio.to(device)
        units = units.to(device)

        # Generate audio
        audio_fake = generator(units)

        # Match lengths
        min_len = min(audio.shape[1], audio_fake.shape[1])
        audio = audio[:, :min_len]
        audio_fake = audio_fake[:, :min_len]

        # ===================
        # Train Discriminator
        # ===================
        optimizer_d.zero_grad()

        # Real audio
        real_out = discriminator(audio.unsqueeze(1))
        # Fake audio
        fake_out = discriminator(audio_fake.detach().unsqueeze(1))

        loss_d = 0
        for real, fake in zip(real_out, fake_out):
            loss_d += mse_loss(real, torch.ones_like(real))
            loss_d += mse_loss(fake, torch.zeros_like(fake))

        loss_d.backward()
        optimizer_d.step()

        # ===================
        # Train Generator
        # ===================
        optimizer_g.zero_grad()

        # Adversarial loss
        fake_out = discriminator(audio_fake.unsqueeze(1))
        loss_adv = 0
        for fake in fake_out:
            loss_adv += mse_loss(fake, torch.ones_like(fake))

        # Reconstruction loss (L1)
        loss_recon = l1_loss(audio_fake, audio) * 45  # Weight for L1 loss

        # Total generator loss
        loss_g = loss_adv + loss_recon

        loss_g.backward()
        optimizer_g.step()

        total_loss_g += loss_g.item()
        total_loss_d += loss_d.item()
        num_batches += 1
        global_step += 1

        # Update progress bar
        progress_bar.set_postfix({
            'G_loss': f'{loss_g.item():.3f}',
            'D_loss': f'{loss_d.item():.3f}'
        })

        # Log to tensorboard
        if global_step % 100 == 0:
            writer.add_scalar('Loss/Generator', loss_g.item(), global_step)
            writer.add_scalar('Loss/Discriminator', loss_d.item(), global_step)

    # Update learning rates
    scheduler_g.step()
    scheduler_d.step()

    # Epoch summary
    avg_loss_g = total_loss_g / num_batches
    avg_loss_d = total_loss_d / num_batches
    epoch_time = time.time() - epoch_start

    print(f"\n  Epoch {epoch+1} complete in {epoch_time:.1f}s")
    print(f"  Avg G Loss: {avg_loss_g:.4f}, Avg D Loss: {avg_loss_d:.4f}")

    # ===================
    # Validation
    # ===================
    if (epoch + 1) % 2 == 0:  # Validate every 2 epochs
        generator.eval()
        val_loss = 0

        with torch.no_grad():
            for audio, units in val_loader:
                audio = audio.to(device)
                units = units.to(device)

                audio_fake = generator(units)
                min_len = min(audio.shape[1], audio_fake.shape[1])
                val_loss += l1_loss(audio_fake[:, :min_len], audio[:, :min_len]).item()

        val_loss /= len(val_loader)
        print(f"  Validation L1 Loss: {val_loss:.4f}")
        writer.add_scalar('Loss/Validation', val_loss, epoch)

        # Save best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save({
                'epoch': epoch,
                'generator_state_dict': generator.state_dict(),
                'discriminator_state_dict': discriminator.state_dict(),
                'optimizer_g_state_dict': optimizer_g.state_dict(),
                'optimizer_d_state_dict': optimizer_d.state_dict(),
                'val_loss': val_loss,
            }, f"{MODEL_DIR}/best_model.pt")
            print(f"  ✓ Saved best model (val_loss: {val_loss:.4f})")

    # Save checkpoint
    if (epoch + 1) % SAVE_EVERY == 0:
        torch.save({
            'epoch': epoch,
            'generator_state_dict': generator.state_dict(),
            'discriminator_state_dict': discriminator.state_dict(),
            'optimizer_g_state_dict': optimizer_g.state_dict(),
            'optimizer_d_state_dict': optimizer_d.state_dict(),
        }, f"{MODEL_DIR}/checkpoint_epoch{epoch+1}.pt")
        print(f"  ✓ Saved checkpoint")

    # Generate sample audio every 10 epochs
    if (epoch + 1) % 10 == 0:
        generator.eval()
        with torch.no_grad():
            sample_audio, sample_units = val_dataset[0]
            sample_units = sample_units.unsqueeze(0).to(device)
            generated = generator(sample_units).squeeze().cpu().numpy()

            # Save sample
            generated_int16 = (generated * 32767).astype(np.int16)
            sample_path = f"{MODEL_DIR}/sample_epoch{epoch+1}.wav"
            wav.write(sample_path, 16000, generated_int16)
            print(f"  ✓ Saved sample audio: {sample_path}")

print("\n" + "=" * 60)
print("✓ TRAINING COMPLETE")
print("=" * 60)
print(f"Best validation loss: {best_val_loss:.4f}")
print(f"Model saved to: {MODEL_DIR}")

writer.close()

TRAINING UNIT VOCODER
  Epochs: 50
  Batch size: 16
  Learning rate: 0.0002
  Device: cuda


Epoch 1/50: 100%|██████████| 1428/1428 [40:09<00:00,  1.69s/it, G_loss=7.432, D_loss=0.247]



  Epoch 1 complete in 2409.8s
  Avg G Loss: 6.2638, Avg D Loss: 0.5480


Epoch 2/50: 100%|██████████| 1428/1428 [40:07<00:00,  1.69s/it, G_loss=6.476, D_loss=0.449]


  Epoch 2 complete in 2407.1s
  Avg G Loss: 6.4641, Avg D Loss: 0.4347





  Validation L1 Loss: 0.0916
  ✓ Saved best model (val_loss: 0.0916)


Epoch 3/50: 100%|██████████| 1428/1428 [40:08<00:00,  1.69s/it, G_loss=7.151, D_loss=0.392]



  Epoch 3 complete in 2408.5s
  Avg G Loss: 6.4193, Avg D Loss: 0.4168


Epoch 4/50: 100%|██████████| 1428/1428 [40:14<00:00,  1.69s/it, G_loss=6.001, D_loss=0.870]


  Epoch 4 complete in 2414.1s
  Avg G Loss: 6.1879, Avg D Loss: 0.6943





  Validation L1 Loss: 0.0954


Epoch 5/50: 100%|██████████| 1428/1428 [40:15<00:00,  1.69s/it, G_loss=6.767, D_loss=0.558]



  Epoch 5 complete in 2415.0s
  Avg G Loss: 6.2297, Avg D Loss: 0.6593
  ✓ Saved checkpoint


Epoch 6/50: 100%|██████████| 1428/1428 [40:11<00:00,  1.69s/it, G_loss=5.346, D_loss=0.798]


  Epoch 6 complete in 2411.9s
  Avg G Loss: 6.2239, Avg D Loss: 0.6581





  Validation L1 Loss: 0.0989


Epoch 7/50: 100%|██████████| 1428/1428 [40:12<00:00,  1.69s/it, G_loss=6.523, D_loss=0.576]



  Epoch 7 complete in 2412.2s
  Avg G Loss: 6.2421, Avg D Loss: 0.6541


Epoch 8/50: 100%|██████████| 1428/1428 [40:13<00:00,  1.69s/it, G_loss=7.782, D_loss=0.497]


  Epoch 8 complete in 2413.4s
  Avg G Loss: 6.2512, Avg D Loss: 0.6502





  Validation L1 Loss: 0.0967


Epoch 9/50: 100%|██████████| 1428/1428 [40:13<00:00,  1.69s/it, G_loss=7.112, D_loss=0.572]



  Epoch 9 complete in 2413.0s
  Avg G Loss: 6.3734, Avg D Loss: 0.6357


Epoch 10/50: 100%|██████████| 1428/1428 [40:14<00:00,  1.69s/it, G_loss=5.246, D_loss=0.948]


  Epoch 10 complete in 2415.0s
  Avg G Loss: 6.3499, Avg D Loss: 0.6462





  Validation L1 Loss: 0.0981
  ✓ Saved checkpoint
  ✓ Saved sample audio: /content/drive/MyDrive/satere_project/hifigan_model/sample_epoch10.wav


Epoch 11/50: 100%|██████████| 1428/1428 [40:15<00:00,  1.69s/it, G_loss=5.311, D_loss=0.739]



  Epoch 11 complete in 2415.7s
  Avg G Loss: 6.3602, Avg D Loss: 0.6355


Epoch 12/50: 100%|██████████| 1428/1428 [40:12<00:00,  1.69s/it, G_loss=7.143, D_loss=0.562]


  Epoch 12 complete in 2412.5s
  Avg G Loss: 6.3228, Avg D Loss: 0.6362





  Validation L1 Loss: 0.0971


Epoch 13/50: 100%|██████████| 1428/1428 [40:14<00:00,  1.69s/it, G_loss=6.430, D_loss=0.780]



  Epoch 13 complete in 2414.2s
  Avg G Loss: 6.4224, Avg D Loss: 0.6193


Epoch 14/50: 100%|██████████| 1428/1428 [40:15<00:00,  1.69s/it, G_loss=6.454, D_loss=0.535]


  Epoch 14 complete in 2415.2s
  Avg G Loss: 6.4475, Avg D Loss: 0.6059





  Validation L1 Loss: 0.0980


Epoch 15/50: 100%|██████████| 1428/1428 [40:17<00:00,  1.69s/it, G_loss=6.418, D_loss=0.655]



  Epoch 15 complete in 2417.1s
  Avg G Loss: 6.4663, Avg D Loss: 0.5943
  ✓ Saved checkpoint


Epoch 16/50: 100%|██████████| 1428/1428 [40:18<00:00,  1.69s/it, G_loss=7.285, D_loss=0.565]


  Epoch 16 complete in 2418.6s
  Avg G Loss: 6.5010, Avg D Loss: 0.5799





  Validation L1 Loss: 0.0980


Epoch 17/50: 100%|██████████| 1428/1428 [40:17<00:00,  1.69s/it, G_loss=5.465, D_loss=0.692]



  Epoch 17 complete in 2417.5s
  Avg G Loss: 6.4730, Avg D Loss: 0.5859


Epoch 18/50: 100%|██████████| 1428/1428 [40:17<00:00,  1.69s/it, G_loss=6.164, D_loss=0.669]


  Epoch 18 complete in 2417.1s
  Avg G Loss: 6.5268, Avg D Loss: 0.5709





  Validation L1 Loss: 0.0985


Epoch 19/50: 100%|██████████| 1428/1428 [40:16<00:00,  1.69s/it, G_loss=6.723, D_loss=0.584]



  Epoch 19 complete in 2416.8s
  Avg G Loss: 6.5942, Avg D Loss: 0.5276


Epoch 20/50: 100%|██████████| 1428/1428 [40:19<00:00,  1.69s/it, G_loss=6.692, D_loss=0.557]


  Epoch 20 complete in 2419.0s
  Avg G Loss: 6.6849, Avg D Loss: 0.4644





  Validation L1 Loss: 0.0965
  ✓ Saved checkpoint
  ✓ Saved sample audio: /content/drive/MyDrive/satere_project/hifigan_model/sample_epoch20.wav


Epoch 21/50: 100%|██████████| 1428/1428 [40:17<00:00,  1.69s/it, G_loss=6.741, D_loss=0.236]



  Epoch 21 complete in 2417.9s
  Avg G Loss: 6.7695, Avg D Loss: 0.4220


Epoch 22/50: 100%|██████████| 1428/1428 [40:19<00:00,  1.69s/it, G_loss=7.123, D_loss=0.284]


  Epoch 22 complete in 2419.1s
  Avg G Loss: 6.8773, Avg D Loss: 0.3488





  Validation L1 Loss: 0.0979


Epoch 23/50: 100%|██████████| 1428/1428 [40:19<00:00,  1.69s/it, G_loss=7.362, D_loss=0.423]



  Epoch 23 complete in 2419.4s
  Avg G Loss: 6.8554, Avg D Loss: 0.3593


Epoch 24/50: 100%|██████████| 1428/1428 [40:19<00:00,  1.69s/it, G_loss=6.539, D_loss=0.164]


  Epoch 24 complete in 2419.5s
  Avg G Loss: 6.9048, Avg D Loss: 0.3414





  Validation L1 Loss: 0.0990


Epoch 25/50: 100%|██████████| 1428/1428 [40:19<00:00,  1.69s/it, G_loss=6.589, D_loss=0.105]



  Epoch 25 complete in 2419.8s
  Avg G Loss: 6.8876, Avg D Loss: 0.3214
  ✓ Saved checkpoint


Epoch 26/50:  29%|██▊       | 408/1428 [11:31<28:49,  1.70s/it, G_loss=6.759, D_loss=0.319]

In [None]:
import os

PROJECT_ROOT = "/content/drive/MyDrive/satere_project"

# Check for existing codebook files
possible_files = [
    "satere_soundscript_codebook_v1.npy",
    "kmeans_model.pkl",
    "kmeans_model.joblib",
    "codebook.npy",
    "cluster_centers.npy"
]

print("Searching for codebook files...\n")

for root, dirs, files in os.walk(PROJECT_ROOT):
    for f in files:
        if "kmeans" in f.lower() or "codebook" in f.lower() or "cluster" in f.lower():
            path = os.path.join(root, f)
            print(f"Found: {path}")

# Also check satere_units folder
units_dir = f"{PROJECT_ROOT}/satere_units"
if os.path.exists(units_dir):
    print(f"\nFiles in satere_units/: {os.listdir(units_dir)[:10]}...")

In [None]:
import os
print(os.listdir("/content/drive/MyDrive/satere_project"))

In [None]:
import os
MODEL_DIR = "/content/drive/MyDrive/satere_project/hifigan_model"
print("Saved checkpoints:")
for f in sorted(os.listdir(MODEL_DIR)):
    print(f"  {f}")

Saved checkpoints:


FileNotFoundError: [Errno 2] No such file or directory: '/content/drive/MyDrive/satere_project/hifigan_model'

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import os
MODEL_DIR = "/content/drive/MyDrive/satere_project/hifigan_model"
print("Saved checkpoints:")
for f in sorted(os.listdir(MODEL_DIR)):
    print(f"  {f}")

Saved checkpoints:
  best_model.pt
  checkpoint_epoch10.pt
  checkpoint_epoch15.pt
  checkpoint_epoch20.pt
  checkpoint_epoch25.pt
  checkpoint_epoch5.pt
  logs
  sample_epoch10.wav
  sample_epoch20.wav


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
import os
import json
from tqdm import tqdm

# Paths
PROJECT_ROOT = "/content/drive/MyDrive/satere_project"
MODEL_DIR = f"{PROJECT_ROOT}/hifigan_model"
TRAINING_DIR = f"{PROJECT_ROOT}/hifigan_training"

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Load training config
with open(f"{TRAINING_DIR}/training_manifest.json", "r") as f:
    manifest = json.load(f)

print(f"Training samples: {len(manifest['files'])}")

# ============ MODEL DEFINITIONS ============
class ResBlock(nn.Module):
    def __init__(self, channels, kernel_size, dilations):
        super().__init__()
        self.convs1 = nn.ModuleList()
        self.convs2 = nn.ModuleList()
        for d in dilations:
            self.convs1.append(nn.Conv1d(channels, channels, kernel_size, dilation=d, padding=(kernel_size*d-d)//2))
            self.convs2.append(nn.Conv1d(channels, channels, kernel_size, dilation=1, padding=(kernel_size-1)//2))

    def forward(self, x):
        for c1, c2 in zip(self.convs1, self.convs2):
            xt = F.leaky_relu(x, 0.1)
            xt = F.leaky_relu(c1(xt), 0.1)
            xt = c2(xt)
            x = x + xt
        return x

class Generator(nn.Module):
    def __init__(self, num_units=100, unit_embed_dim=256, upsample_rates=[8,8,4,2],
                 upsample_kernels=[16,16,8,4], resblock_kernels=[3,7,11], resblock_dilations=[[1,3,5],[1,3,5],[1,3,5]]):
        super().__init__()
        self.embedding = nn.Embedding(num_units, unit_embed_dim)
        self.conv_pre = nn.Conv1d(unit_embed_dim, 256, 7, padding=3)

        self.ups = nn.ModuleList()
        self.resblocks = nn.ModuleList()

        ch = 256
        for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernels)):
            self.ups.append(nn.ConvTranspose1d(ch, ch//2, k, u, padding=(k-u)//2))
            ch = ch//2
            for j, (rk, rd) in enumerate(zip(resblock_kernels, resblock_dilations)):
                self.resblocks.append(ResBlock(ch, rk, rd))

        self.conv_post = nn.Conv1d(ch, 1, 7, padding=3)
        self.num_upsamples = len(upsample_rates)
        self.num_kernels = len(resblock_kernels)

    def forward(self, x):
        x = self.embedding(x).transpose(1, 2)
        x = self.conv_pre(x)
        for i, up in enumerate(self.ups):
            x = F.leaky_relu(x, 0.1)
            x = up(x)
            xs = None
            for j in range(self.num_kernels):
                if xs is None:
                    xs = self.resblocks[i*self.num_kernels+j](x)
                else:
                    xs += self.resblocks[i*self.num_kernels+j](x)
            x = xs / self.num_kernels
        x = F.leaky_relu(x, 0.1)
        x = self.conv_post(x)
        x = torch.tanh(x)
        return x.squeeze(1)

class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.convs = nn.ModuleList([
            nn.Conv1d(1, 16, 15, padding=7),
            nn.Conv1d(16, 64, 41, stride=4, padding=20, groups=4),
            nn.Conv1d(64, 256, 41, stride=4, padding=20, groups=16),
            nn.Conv1d(256, 1024, 41, stride=4, padding=20, groups=64),
            nn.Conv1d(1024, 1024, 41, stride=4, padding=20, groups=256),
            nn.Conv1d(1024, 1024, 5, padding=2),
        ])
        self.conv_post = nn.Conv1d(1024, 1, 3, padding=1)

    def forward(self, x):
        fmaps = []
        for conv in self.convs:
            x = F.leaky_relu(conv(x), 0.1)
            fmaps.append(x)
        x = self.conv_post(x)
        fmaps.append(x)
        return x.flatten(1, -1), fmaps

# ============ DATASET ============
class UnitAudioDataset(Dataset):
    def __init__(self, manifest, segment_length=16384):
        self.files = manifest['files']
        self.segment_length = segment_length

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        item = self.files[idx]
        units = np.load(item['units_path'])
        audio = np.load(item['audio_path'])

        unit_len = self.segment_length // 512
        if len(units) > unit_len:
            start = np.random.randint(0, len(units) - unit_len)
            units = units[start:start+unit_len]
            audio_start = start * 512
            audio = audio[audio_start:audio_start+self.segment_length]

        if len(audio) < self.segment_length:
            audio = np.pad(audio, (0, self.segment_length - len(audio)))
        if len(units) < unit_len:
            units = np.pad(units, (0, unit_len - len(units)))

        return torch.LongTensor(units), torch.FloatTensor(audio)

# ============ LOAD MODELS AND RESUME ============
generator = Generator().to(device)
discriminator = Discriminator().to(device)

optimizer_g = torch.optim.AdamW(generator.parameters(), lr=0.0002, betas=(0.8, 0.99))
optimizer_d = torch.optim.AdamW(discriminator.parameters(), lr=0.0002, betas=(0.8, 0.99))

# Load checkpoint
checkpoint_path = f"{MODEL_DIR}/checkpoint_epoch25.pt"
checkpoint = torch.load(checkpoint_path, map_location=device)

generator.load_state_dict(checkpoint['generator_state_dict'])
discriminator.load_state_dict(checkpoint['discriminator_state_dict'])
optimizer_g.load_state_dict(checkpoint['optimizer_g_state_dict'])
optimizer_d.load_state_dict(checkpoint['optimizer_d_state_dict'])
start_epoch = checkpoint['epoch'] + 1

print(f"✓ Resumed from epoch {checkpoint['epoch']}")
print(f"  Starting at epoch {start_epoch}")

# ============ DATALOADER ============
dataset = UnitAudioDataset(manifest)
dataloader = DataLoader(dataset, batch_size=16, shuffle=True, num_workers=2, pin_memory=True)

# Validation set
val_dataset = UnitAudioDataset(manifest)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, num_workers=2)

print(f"✓ DataLoader ready: {len(dataloader)} batches")

# ============ TRAINING LOOP ============
NUM_EPOCHS = 50

for epoch in range(start_epoch, NUM_EPOCHS + 1):
    generator.train()
    discriminator.train()

    epoch_g_loss = 0
    epoch_d_loss = 0

    pbar = tqdm(dataloader, desc=f"Epoch {epoch}/{NUM_EPOCHS}")

    for units, audio in pbar:
        units = units.to(device)
        audio = audio.to(device)

        # Generate
        audio_fake = generator(units)

        # Match lengths
        min_len = min(audio.shape[1], audio_fake.shape[1])
        audio = audio[:, :min_len]
        audio_fake = audio_fake[:, :min_len]

        # Discriminator
        optimizer_d.zero_grad()
        real_out, _ = discriminator(audio.unsqueeze(1))
        fake_out, _ = discriminator(audio_fake.detach().unsqueeze(1))
        d_loss = torch.mean((real_out - 1)**2) + torch.mean(fake_out**2)
        d_loss.backward()
        optimizer_d.step()

        # Generator
        optimizer_g.zero_grad()
        fake_out, fake_fmaps = discriminator(audio_fake.unsqueeze(1))
        _, real_fmaps = discriminator(audio.unsqueeze(1))

        g_loss_adv = torch.mean((fake_out - 1)**2)
        g_loss_fm = sum(torch.mean(torch.abs(rf - ff)) for rf, ff in zip(real_fmaps, fake_fmaps))
        g_loss_l1 = F.l1_loss(audio_fake, audio) * 45
        g_loss = g_loss_adv + g_loss_fm * 2 + g_loss_l1
        g_loss.backward()
        optimizer_g.step()

        epoch_g_loss += g_loss.item()
        epoch_d_loss += d_loss.item()
        pbar.set_postfix({'G_loss': f'{g_loss.item():.3f}', 'D_loss': f'{d_loss.item():.3f}'})

    avg_g = epoch_g_loss / len(dataloader)
    avg_d = epoch_d_loss / len(dataloader)
    print(f"\n  Epoch {epoch} complete in {pbar.format_dict['elapsed']:.1f}s")
    print(f"  Avg G Loss: {avg_g:.4f}, Avg D Loss: {avg_d:.4f}")

    # Validation
    if epoch % 2 == 0:
        generator.eval()
        val_loss = 0
        with torch.no_grad():
            for units, audio in val_loader:
                units, audio = units.to(device), audio.to(device)
                fake = generator(units)
                min_len = min(audio.shape[1], fake.shape[1])
                val_loss += F.l1_loss(fake[:,:min_len], audio[:,:min_len]).item()
        val_loss /= len(val_loader)
        print(f"\n  Validation L1 Loss: {val_loss:.4f}")

    # Save checkpoint every 5 epochs
    if epoch % 5 == 0:
        torch.save({
            'epoch': epoch,
            'generator_state_dict': generator.state_dict(),
            'discriminator_state_dict': discriminator.state_dict(),
            'optimizer_g_state_dict': optimizer_g.state_dict(),
            'optimizer_d_state_dict': optimizer_d.state_dict(),
        }, f"{MODEL_DIR}/checkpoint_epoch{epoch}.pt")
        print(f"  ✓ Saved checkpoint")

    # Save sample every 10 epochs
    if epoch % 10 == 0:
        generator.eval()
        with torch.no_grad():
            sample_units = torch.LongTensor(dataset[0][0]).unsqueeze(0).to(device)
            sample_audio = generator(sample_units).squeeze().cpu().numpy()
            import scipy.io.wavfile as wav
            wav.write(f"{MODEL_DIR}/sample_epoch{epoch}.wav", 16000, (sample_audio * 32767).astype(np.int16))
            print(f"  ✓ Saved sample audio: {MODEL_DIR}/sample_epoch{epoch}.wav")

print("\n" + "="*60)
print("TRAINING COMPLETE")
print("="*60)

Using device: cuda


FileNotFoundError: [Errno 2] No such file or directory: '/content/drive/MyDrive/satere_project/hifigan_training/training_manifest.json'

In [None]:
import os

PROJECT_ROOT = "/content/drive/MyDrive/satere_project"

# Check hifigan_training folder
TRAINING_DIR = f"{PROJECT_ROOT}/hifigan_training"
if os.path.exists(TRAINING_DIR):
    print("Files in hifigan_training:")
    for f in sorted(os.listdir(TRAINING_DIR))[:20]:
        print(f"  {f}")
else:
    print("hifigan_training folder not found")

# Check what folders exist
print("\nMain project folders:")
for f in sorted(os.listdir(PROJECT_ROOT)):
    print(f"  {f}")

Files in hifigan_training:
  metadata.json
  train_files.txt
  units
  val_files.txt
  wavs

Main project folders:
  acts12.mp3
  converted_audio
  hifigan_model
  hifigan_training
  phase2_output
  raw_audio
  satere_motifs
  satere_units
  soundscript_deploy
  soundscript_deploy.zip
  soundscript_full
  vocoder_output


In [None]:
import json

with open(f"{PROJECT_ROOT}/hifigan_training/metadata.json", "r") as f:
    metadata = json.load(f)

print("Keys:", metadata.keys())
print("\nFirst few entries:")
if isinstance(metadata, dict):
    for k, v in list(metadata.items())[:3]:
        print(f"  {k}: {v}")
elif isinstance(metadata, list):
    for item in metadata[:3]:
        print(f"  {item}")

Keys: dict_keys(['num_train', 'num_val', 'segment_length_sec', 'sample_rate', 'num_units'])

First few entries:
  num_train: 22834
  num_val: 100
  segment_length_sec: 5


In [None]:
# Check train_files.txt format
with open(f"{PROJECT_ROOT}/hifigan_training/train_files.txt", "r") as f:
    lines = f.readlines()[:5]

print("First 5 lines of train_files.txt:")
for line in lines:
    print(f"  {line.strip()}")

# Check the units and wavs folders
print("\nSample files in units/:")
units_dir = f"{PROJECT_ROOT}/hifigan_training/units"
for f in sorted(os.listdir(units_dir))[:3]:
    print(f"  {f}")

print("\nSample files in wavs/:")
wavs_dir = f"{PROJECT_ROOT}/hifigan_training/wavs"
for f in sorted(os.listdir(wavs_dir))[:3]:
    print(f"  {f}")

First 5 lines of train_files.txt:
  seg_18430
  seg_02412
  seg_17923
  seg_01742
  seg_17771

Sample files in units/:
  seg_00000.txt
  seg_00001.txt
  seg_00002.txt

Sample files in wavs/:
  seg_00000.wav
  seg_00001.wav
  seg_00002.wav


In [None]:
# Check unit file format
with open(f"{PROJECT_ROOT}/hifigan_training/units/seg_00000.txt", "r") as f:
    content = f.read().strip()
print(f"Unit file content (first 200 chars):\n{content[:200]}")
print(f"\nTotal length: {len(content)}")

Unit file content (first 200 chars):
26 26 26 26 26 49 26 26 26 26 26 26 26 26 26 26 26 26 26 26 26 26 26 49 30 98 98 88 88 88 91 91 91 69 46 46 46 66 8 72 23 58 49 99 56 93 18 16 16 10 10 10 51 90 6 6 6 6 6 84 27 62 62 62 62 62 62 52 52

Total length: 734


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
import os
import scipy.io.wavfile as wav
from tqdm import tqdm

# Paths
PROJECT_ROOT = "/content/drive/MyDrive/satere_project"
MODEL_DIR = f"{PROJECT_ROOT}/hifigan_model"
TRAINING_DIR = f"{PROJECT_ROOT}/hifigan_training"

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Load file lists
with open(f"{TRAINING_DIR}/train_files.txt", "r") as f:
    train_files = [line.strip() for line in f.readlines()]

with open(f"{TRAINING_DIR}/val_files.txt", "r") as f:
    val_files = [line.strip() for line in f.readlines()]

print(f"Training files: {len(train_files)}")
print(f"Validation files: {len(val_files)}")

# ============ MODEL DEFINITIONS ============
class ResBlock(nn.Module):
    def __init__(self, channels, kernel_size, dilations):
        super().__init__()
        self.convs1 = nn.ModuleList()
        self.convs2 = nn.ModuleList()
        for d in dilations:
            self.convs1.append(nn.Conv1d(channels, channels, kernel_size, dilation=d, padding=(kernel_size*d-d)//2))
            self.convs2.append(nn.Conv1d(channels, channels, kernel_size, dilation=1, padding=(kernel_size-1)//2))

    def forward(self, x):
        for c1, c2 in zip(self.convs1, self.convs2):
            xt = F.leaky_relu(x, 0.1)
            xt = F.leaky_relu(c1(xt), 0.1)
            xt = c2(xt)
            x = x + xt
        return x

class Generator(nn.Module):
    def __init__(self, num_units=100, unit_embed_dim=256, upsample_rates=[8,8,4,2],
                 upsample_kernels=[16,16,8,4], resblock_kernels=[3,7,11], resblock_dilations=[[1,3,5],[1,3,5],[1,3,5]]):
        super().__init__()
        self.embedding = nn.Embedding(num_units, unit_embed_dim)
        self.conv_pre = nn.Conv1d(unit_embed_dim, 256, 7, padding=3)

        self.ups = nn.ModuleList()
        self.resblocks = nn.ModuleList()

        ch = 256
        for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernels)):
            self.ups.append(nn.ConvTranspose1d(ch, ch//2, k, u, padding=(k-u)//2))
            ch = ch//2
            for j, (rk, rd) in enumerate(zip(resblock_kernels, resblock_dilations)):
                self.resblocks.append(ResBlock(ch, rk, rd))

        self.conv_post = nn.Conv1d(ch, 1, 7, padding=3)
        self.num_upsamples = len(upsample_rates)
        self.num_kernels = len(resblock_kernels)

    def forward(self, x):
        x = self.embedding(x).transpose(1, 2)
        x = self.conv_pre(x)
        for i, up in enumerate(self.ups):
            x = F.leaky_relu(x, 0.1)
            x = up(x)
            xs = None
            for j in range(self.num_kernels):
                if xs is None:
                    xs = self.resblocks[i*self.num_kernels+j](x)
                else:
                    xs += self.resblocks[i*self.num_kernels+j](x)
            x = xs / self.num_kernels
        x = F.leaky_relu(x, 0.1)
        x = self.conv_post(x)
        x = torch.tanh(x)
        return x.squeeze(1)

class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.convs = nn.ModuleList([
            nn.Conv1d(1, 16, 15, padding=7),
            nn.Conv1d(16, 64, 41, stride=4, padding=20, groups=4),
            nn.Conv1d(64, 256, 41, stride=4, padding=20, groups=16),
            nn.Conv1d(256, 1024, 41, stride=4, padding=20, groups=64),
            nn.Conv1d(1024, 1024, 41, stride=4, padding=20, groups=256),
            nn.Conv1d(1024, 1024, 5, padding=2),
        ])
        self.conv_post = nn.Conv1d(1024, 1, 3, padding=1)

    def forward(self, x):
        fmaps = []
        for conv in self.convs:
            x = F.leaky_relu(conv(x), 0.1)
            fmaps.append(x)
        x = self.conv_post(x)
        fmaps.append(x)
        return x.flatten(1, -1), fmaps

# ============ DATASET ============
class UnitAudioDataset(Dataset):
    def __init__(self, file_list, training_dir, segment_length=16384):
        self.file_list = file_list
        self.training_dir = training_dir
        self.segment_length = segment_length
        self.sample_rate = 16000

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        file_id = self.file_list[idx]

        # Load units from text file
        with open(f"{self.training_dir}/units/{file_id}.txt", "r") as f:
            units = np.array([int(x) for x in f.read().strip().split()], dtype=np.int64)

        # Load audio from wav file
        sr, audio = wav.read(f"{self.training_dir}/wavs/{file_id}.wav")
        audio = audio.astype(np.float32) / 32768.0

        # Segment
        unit_len = self.segment_length // 512
        if len(units) > unit_len:
            start = np.random.randint(0, len(units) - unit_len)
            units = units[start:start+unit_len]
            audio_start = start * 512
            audio = audio[audio_start:audio_start+self.segment_length]

        if len(audio) < self.segment_length:
            audio = np.pad(audio, (0, self.segment_length - len(audio)))
        if len(units) < unit_len:
            units = np.pad(units, (0, unit_len - len(units)))

        return torch.LongTensor(units), torch.FloatTensor(audio)

# ============ LOAD MODELS AND RESUME ============
generator = Generator().to(device)
discriminator = Discriminator().to(device)

optimizer_g = torch.optim.AdamW(generator.parameters(), lr=0.0002, betas=(0.8, 0.99))
optimizer_d = torch.optim.AdamW(discriminator.parameters(), lr=0.0002, betas=(0.8, 0.99))

# Load checkpoint
checkpoint_path = f"{MODEL_DIR}/checkpoint_epoch25.pt"
checkpoint = torch.load(checkpoint_path, map_location=device)

generator.load_state_dict(checkpoint['generator_state_dict'])
discriminator.load_state_dict(checkpoint['discriminator_state_dict'])
optimizer_g.load_state_dict(checkpoint['optimizer_g_state_dict'])
optimizer_d.load_state_dict(checkpoint['optimizer_d_state_dict'])
start_epoch = checkpoint['epoch'] + 1

print(f"\n✓ Resumed from epoch {checkpoint['epoch']}")
print(f"  Starting at epoch {start_epoch}")

# ============ DATALOADER ============
dataset = UnitAudioDataset(train_files, TRAINING_DIR)
dataloader = DataLoader(dataset, batch_size=16, shuffle=True, num_workers=2, pin_memory=True)

val_dataset = UnitAudioDataset(val_files, TRAINING_DIR)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, num_workers=2)

print(f"✓ DataLoader ready: {len(dataloader)} batches")

# ============ TRAINING LOOP ============
NUM_EPOCHS = 50

print("\n" + "="*60)
print("RESUMING TRAINING")
print("="*60)

for epoch in range(start_epoch, NUM_EPOCHS + 1):
    generator.train()
    discriminator.train()

    epoch_g_loss = 0
    epoch_d_loss = 0

    pbar = tqdm(dataloader, desc=f"Epoch {epoch}/{NUM_EPOCHS}")

    for units, audio in pbar:
        units = units.to(device)
        audio = audio.to(device)

        # Generate
        audio_fake = generator(units)

        # Match lengths
        min_len = min(audio.shape[1], audio_fake.shape[1])
        audio = audio[:, :min_len]
        audio_fake = audio_fake[:, :min_len]

        # Discriminator
        optimizer_d.zero_grad()
        real_out, _ = discriminator(audio.unsqueeze(1))
        fake_out, _ = discriminator(audio_fake.detach().unsqueeze(1))
        d_loss = torch.mean((real_out - 1)**2) + torch.mean(fake_out**2)
        d_loss.backward()
        optimizer_d.step()

        # Generator
        optimizer_g.zero_grad()
        fake_out, fake_fmaps = discriminator(audio_fake.unsqueeze(1))
        _, real_fmaps = discriminator(audio.unsqueeze(1))

        g_loss_adv = torch.mean((fake_out - 1)**2)
        g_loss_fm = sum(torch.mean(torch.abs(rf - ff)) for rf, ff in zip(real_fmaps, fake_fmaps))
        g_loss_l1 = F.l1_loss(audio_fake, audio) * 45
        g_loss = g_loss_adv + g_loss_fm * 2 + g_loss_l1
        g_loss.backward()
        optimizer_g.step()

        epoch_g_loss += g_loss.item()
        epoch_d_loss += d_loss.item()
        pbar.set_postfix({'G_loss': f'{g_loss.item():.3f}', 'D_loss': f'{d_loss.item():.3f}'})

    avg_g = epoch_g_loss / len(dataloader)
    avg_d = epoch_d_loss / len(dataloader)
    print(f"\n  Epoch {epoch} complete")
    print(f"  Avg G Loss: {avg_g:.4f}, Avg D Loss: {avg_d:.4f}")

    # Validation
    if epoch % 2 == 0:
        generator.eval()
        val_loss = 0
        with torch.no_grad():
            for units, audio in val_loader:
                units, audio = units.to(device), audio.to(device)
                fake = generator(units)
                min_len = min(audio.shape[1], fake.shape[1])
                val_loss += F.l1_loss(fake[:,:min_len], audio[:,:min_len]).item()
        val_loss /= len(val_loader)
        print(f"\n  Validation L1 Loss: {val_loss:.4f}")

    # Save checkpoint every 5 epochs
    if epoch % 5 == 0:
        torch.save({
            'epoch': epoch,
            'generator_state_dict': generator.state_dict(),
            'discriminator_state_dict': discriminator.state_dict(),
            'optimizer_g_state_dict': optimizer_g.state_dict(),
            'optimizer_d_state_dict': optimizer_d.state_dict(),
        }, f"{MODEL_DIR}/checkpoint_epoch{epoch}.pt")
        print(f"  ✓ Saved checkpoint")

    # Save sample every 10 epochs
    if epoch % 10 == 0:
        generator.eval()
        with torch.no_grad():
            sample_units = torch.LongTensor(dataset[0][0]).unsqueeze(0).to(device)
            sample_audio = generator(sample_units).squeeze().cpu().numpy()
            wav.write(f"{MODEL_DIR}/sample_epoch{epoch}.wav", 16000, (sample_audio * 32767).astype(np.int16))
            print(f"  ✓ Saved sample audio: {MODEL_DIR}/sample_epoch{epoch}.wav")

print("\n" + "="*60)
print("TRAINING COMPLETE")
print("="*60)

Using device: cuda
Training files: 22834
Validation files: 100


RuntimeError: Error(s) in loading state_dict for Generator:
	Missing key(s) in state_dict: "embedding.weight", "conv_pre.weight", "conv_pre.bias", "ups.0.weight", "ups.0.bias", "ups.1.weight", "ups.1.bias", "ups.2.weight", "ups.2.bias", "ups.3.weight", "ups.3.bias", "resblocks.0.convs1.0.weight", "resblocks.0.convs1.0.bias", "resblocks.0.convs1.1.weight", "resblocks.0.convs1.1.bias", "resblocks.0.convs1.2.weight", "resblocks.0.convs1.2.bias", "resblocks.0.convs2.0.weight", "resblocks.0.convs2.0.bias", "resblocks.0.convs2.1.weight", "resblocks.0.convs2.1.bias", "resblocks.0.convs2.2.weight", "resblocks.0.convs2.2.bias", "resblocks.1.convs1.0.weight", "resblocks.1.convs1.0.bias", "resblocks.1.convs1.1.weight", "resblocks.1.convs1.1.bias", "resblocks.1.convs1.2.weight", "resblocks.1.convs1.2.bias", "resblocks.1.convs2.0.weight", "resblocks.1.convs2.0.bias", "resblocks.1.convs2.1.weight", "resblocks.1.convs2.1.bias", "resblocks.1.convs2.2.weight", "resblocks.1.convs2.2.bias", "resblocks.2.convs1.0.weight", "resblocks.2.convs1.0.bias", "resblocks.2.convs1.1.weight", "resblocks.2.convs1.1.bias", "resblocks.2.convs1.2.weight", "resblocks.2.convs1.2.bias", "resblocks.2.convs2.0.weight", "resblocks.2.convs2.0.bias", "resblocks.2.convs2.1.weight", "resblocks.2.convs2.1.bias", "resblocks.2.convs2.2.weight", "resblocks.2.convs2.2.bias", "resblocks.3.convs1.0.weight", "resblocks.3.convs1.0.bias", "resblocks.3.convs1.1.weight", "resblocks.3.convs1.1.bias", "resblocks.3.convs1.2.weight", "resblocks.3.convs1.2.bias", "resblocks.3.convs2.0.weight", "resblocks.3.convs2.0.bias", "resblocks.3.convs2.1.weight", "resblocks.3.convs2.1.bias", "resblocks.3.convs2.2.weight", "resblocks.3.convs2.2.bias", "resblocks.4.convs1.0.weight", "resblocks.4.convs1.0.bias", "resblocks.4.convs1.1.weight", "resblocks.4.convs1.1.bias", "resblocks.4.convs1.2.weight", "resblocks.4.convs1.2.bias", "resblocks.4.convs2.0.weight", "resblocks.4.convs2.0.bias", "resblocks.4.convs2.1.weight", "resblocks.4.convs2.1.bias", "resblocks.4.convs2.2.weight", "resblocks.4.convs2.2.bias", "resblocks.5.convs1.0.weight", "resblocks.5.convs1.0.bias", "resblocks.5.convs1.1.weight", "resblocks.5.convs1.1.bias", "resblocks.5.convs1.2.weight", "resblocks.5.convs1.2.bias", "resblocks.5.convs2.0.weight", "resblocks.5.convs2.0.bias", "resblocks.5.convs2.1.weight", "resblocks.5.convs2.1.bias", "resblocks.5.convs2.2.weight", "resblocks.5.convs2.2.bias", "resblocks.6.convs1.0.weight", "resblocks.6.convs1.0.bias", "resblocks.6.convs1.1.weight", "resblocks.6.convs1.1.bias", "resblocks.6.convs1.2.weight", "resblocks.6.convs1.2.bias", "resblocks.6.convs2.0.weight", "resblocks.6.convs2.0.bias", "resblocks.6.convs2.1.weight", "resblocks.6.convs2.1.bias", "resblocks.6.convs2.2.weight", "resblocks.6.convs2.2.bias", "resblocks.7.convs1.0.weight", "resblocks.7.convs1.0.bias", "resblocks.7.convs1.1.weight", "resblocks.7.convs1.1.bias", "resblocks.7.convs1.2.weight", "resblocks.7.convs1.2.bias", "resblocks.7.convs2.0.weight", "resblocks.7.convs2.0.bias", "resblocks.7.convs2.1.weight", "resblocks.7.convs2.1.bias", "resblocks.7.convs2.2.weight", "resblocks.7.convs2.2.bias", "resblocks.8.convs1.0.weight", "resblocks.8.convs1.0.bias", "resblocks.8.convs1.1.weight", "resblocks.8.convs1.1.bias", "resblocks.8.convs1.2.weight", "resblocks.8.convs1.2.bias", "resblocks.8.convs2.0.weight", "resblocks.8.convs2.0.bias", "resblocks.8.convs2.1.weight", "resblocks.8.convs2.1.bias", "resblocks.8.convs2.2.weight", "resblocks.8.convs2.2.bias", "resblocks.9.convs1.0.weight", "resblocks.9.convs1.0.bias", "resblocks.9.convs1.1.weight", "resblocks.9.convs1.1.bias", "resblocks.9.convs1.2.weight", "resblocks.9.convs1.2.bias", "resblocks.9.convs2.0.weight", "resblocks.9.convs2.0.bias", "resblocks.9.convs2.1.weight", "resblocks.9.convs2.1.bias", "resblocks.9.convs2.2.weight", "resblocks.9.convs2.2.bias", "resblocks.10.convs1.0.weight", "resblocks.10.convs1.0.bias", "resblocks.10.convs1.1.weight", "resblocks.10.convs1.1.bias", "resblocks.10.convs1.2.weight", "resblocks.10.convs1.2.bias", "resblocks.10.convs2.0.weight", "resblocks.10.convs2.0.bias", "resblocks.10.convs2.1.weight", "resblocks.10.convs2.1.bias", "resblocks.10.convs2.2.weight", "resblocks.10.convs2.2.bias", "resblocks.11.convs1.0.weight", "resblocks.11.convs1.0.bias", "resblocks.11.convs1.1.weight", "resblocks.11.convs1.1.bias", "resblocks.11.convs1.2.weight", "resblocks.11.convs1.2.bias", "resblocks.11.convs2.0.weight", "resblocks.11.convs2.0.bias", "resblocks.11.convs2.1.weight", "resblocks.11.convs2.1.bias", "resblocks.11.convs2.2.weight", "resblocks.11.convs2.2.bias", "conv_post.weight", "conv_post.bias". 
	Unexpected key(s) in state_dict: "unit_embed.weight", "pre_conv.weight", "pre_conv.bias", "upsamples.0.0.weight", "upsamples.0.0.bias", "upsamples.0.2.weight", "upsamples.0.2.bias", "upsamples.1.0.weight", "upsamples.1.0.bias", "upsamples.1.2.weight", "upsamples.1.2.bias", "upsamples.2.0.weight", "upsamples.2.0.bias", "upsamples.2.2.weight", "upsamples.2.2.bias", "upsamples.3.0.weight", "upsamples.3.0.bias", "upsamples.3.2.weight", "upsamples.3.2.bias", "upsamples.4.0.weight", "upsamples.4.0.bias", "upsamples.4.2.weight", "upsamples.4.2.bias", "post_conv.weight", "post_conv.bias". 

In [None]:
# Check the checkpoint structure
checkpoint_path = f"{MODEL_DIR}/checkpoint_epoch25.pt"
checkpoint = torch.load(checkpoint_path, map_location='cpu')

print("Checkpoint keys:")
for key in checkpoint.keys():
    print(f"  {key}")

print("\nGenerator state_dict keys (first 20):")
gen_keys = list(checkpoint['generator_state_dict'].keys())
for key in gen_keys[:20]:
    print(f"  {key}")

print(f"\n... total {len(gen_keys)} keys")

Checkpoint keys:
  epoch
  generator_state_dict
  discriminator_state_dict
  optimizer_g_state_dict
  optimizer_d_state_dict

Generator state_dict keys (first 20):
  unit_embed.weight
  pre_conv.weight
  pre_conv.bias
  upsamples.0.0.weight
  upsamples.0.0.bias
  upsamples.0.2.weight
  upsamples.0.2.bias
  upsamples.1.0.weight
  upsamples.1.0.bias
  upsamples.1.2.weight
  upsamples.1.2.bias
  upsamples.2.0.weight
  upsamples.2.0.bias
  upsamples.2.2.weight
  upsamples.2.2.bias
  upsamples.3.0.weight
  upsamples.3.0.bias
  upsamples.3.2.weight
  upsamples.3.2.bias
  upsamples.4.0.weight

... total 25 keys


In [None]:
# Check tensor shapes to understand the architecture
print("Generator weights shapes:")
for key, value in checkpoint['generator_state_dict'].items():
    print(f"  {key}: {value.shape}")

print("\nDiscriminator weights shapes:")
for key, value in checkpoint['discriminator_state_dict'].items():
    print(f"  {key}: {value.shape}")

Generator weights shapes:
  unit_embed.weight: torch.Size([100, 256])
  pre_conv.weight: torch.Size([512, 256, 7])
  pre_conv.bias: torch.Size([512])
  upsamples.0.0.weight: torch.Size([512, 256, 10])
  upsamples.0.0.bias: torch.Size([256])
  upsamples.0.2.weight: torch.Size([256, 256, 7])
  upsamples.0.2.bias: torch.Size([256])
  upsamples.1.0.weight: torch.Size([256, 128, 10])
  upsamples.1.0.bias: torch.Size([128])
  upsamples.1.2.weight: torch.Size([128, 128, 7])
  upsamples.1.2.bias: torch.Size([128])
  upsamples.2.0.weight: torch.Size([128, 64, 8])
  upsamples.2.0.bias: torch.Size([64])
  upsamples.2.2.weight: torch.Size([64, 64, 7])
  upsamples.2.2.bias: torch.Size([64])
  upsamples.3.0.weight: torch.Size([64, 32, 8])
  upsamples.3.0.bias: torch.Size([32])
  upsamples.3.2.weight: torch.Size([32, 32, 7])
  upsamples.3.2.bias: torch.Size([32])
  upsamples.4.0.weight: torch.Size([32, 64, 4])
  upsamples.4.0.bias: torch.Size([64])
  upsamples.4.2.weight: torch.Size([64, 64, 7])
  up

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
import os
import scipy.io.wavfile as wav
from tqdm import tqdm

# Paths
PROJECT_ROOT = "/content/drive/MyDrive/satere_project"
MODEL_DIR = f"{PROJECT_ROOT}/hifigan_model"
TRAINING_DIR = f"{PROJECT_ROOT}/hifigan_training"

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Load file lists
with open(f"{TRAINING_DIR}/train_files.txt", "r") as f:
    train_files = [line.strip() for line in f.readlines()]

with open(f"{TRAINING_DIR}/val_files.txt", "r") as f:
    val_files = [line.strip() for line in f.readlines()]

print(f"Training files: {len(train_files)}")
print(f"Validation files: {len(val_files)}")

# ============ MODEL DEFINITIONS (matching checkpoint) ============
class Generator(nn.Module):
    def __init__(self, num_units=100, embed_dim=256):
        super().__init__()
        self.unit_embed = nn.Embedding(num_units, embed_dim)
        self.pre_conv = nn.Conv1d(embed_dim, 512, 7, padding=3)

        # 5 upsample blocks: each has ConvTranspose + LeakyReLU + Conv + LeakyReLU
        self.upsamples = nn.ModuleList([
            nn.Sequential(
                nn.ConvTranspose1d(512, 256, 10, stride=5, padding=2, output_padding=0),
                nn.LeakyReLU(0.1),
                nn.Conv1d(256, 256, 7, padding=3),
                nn.LeakyReLU(0.1)
            ),
            nn.Sequential(
                nn.ConvTranspose1d(256, 128, 10, stride=5, padding=2, output_padding=0),
                nn.LeakyReLU(0.1),
                nn.Conv1d(128, 128, 7, padding=3),
                nn.LeakyReLU(0.1)
            ),
            nn.Sequential(
                nn.ConvTranspose1d(128, 64, 8, stride=4, padding=2, output_padding=0),
                nn.LeakyReLU(0.1),
                nn.Conv1d(64, 64, 7, padding=3),
                nn.LeakyReLU(0.1)
            ),
            nn.Sequential(
                nn.ConvTranspose1d(64, 32, 8, stride=4, padding=2, output_padding=0),
                nn.LeakyReLU(0.1),
                nn.Conv1d(32, 32, 7, padding=3),
                nn.LeakyReLU(0.1)
            ),
            nn.Sequential(
                nn.ConvTranspose1d(32, 64, 4, stride=2, padding=1, output_padding=0),
                nn.LeakyReLU(0.1),
                nn.Conv1d(64, 64, 7, padding=3),
                nn.LeakyReLU(0.1)
            ),
        ])

        self.post_conv = nn.Conv1d(64, 1, 7, padding=3)

    def forward(self, x):
        x = self.unit_embed(x).transpose(1, 2)  # (B, embed_dim, T)
        x = self.pre_conv(x)
        x = F.leaky_relu(x, 0.1)

        for upsample in self.upsamples:
            x = upsample(x)

        x = self.post_conv(x)
        x = torch.tanh(x)
        return x.squeeze(1)

class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        # 3 sub-discriminators
        self.discriminators = nn.ModuleList([
            self._make_discriminator(),
            self._make_discriminator(),
            self._make_discriminator()
        ])
        self.pools = nn.ModuleList([
            nn.Identity(),
            nn.AvgPool1d(4, 2, padding=2),
            nn.AvgPool1d(16, 8, padding=8)
        ])

    def _make_discriminator(self):
        return nn.Sequential(
            nn.Conv1d(1, 64, 15, padding=7),
            nn.LeakyReLU(0.1),
            nn.Conv1d(64, 128, 41, stride=4, padding=20, groups=4),
            nn.LeakyReLU(0.1),
            nn.Conv1d(128, 256, 41, stride=4, padding=20, groups=8),
            nn.LeakyReLU(0.1),
            nn.Conv1d(256, 512, 41, stride=4, padding=20, groups=16),
            nn.LeakyReLU(0.1),
            nn.Conv1d(512, 512, 5, padding=2),
            nn.LeakyReLU(0.1),
            nn.Conv1d(512, 1, 3, padding=1)
        )

    def forward(self, x):
        outputs = []
        for pool, disc in zip(self.pools, self.discriminators):
            x_pooled = pool(x)
            outputs.append(disc(x_pooled))
        return outputs

# ============ DATASET ============
class UnitAudioDataset(Dataset):
    def __init__(self, file_list, training_dir, segment_length=16384):
        self.file_list = file_list
        self.training_dir = training_dir
        self.segment_length = segment_length

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        file_id = self.file_list[idx]

        with open(f"{self.training_dir}/units/{file_id}.txt", "r") as f:
            units = np.array([int(x) for x in f.read().strip().split()], dtype=np.int64)

        sr, audio = wav.read(f"{self.training_dir}/wavs/{file_id}.wav")
        audio = audio.astype(np.float32) / 32768.0

        unit_len = self.segment_length // 512
        if len(units) > unit_len:
            start = np.random.randint(0, len(units) - unit_len)
            units = units[start:start+unit_len]
            audio_start = start * 512
            audio = audio[audio_start:audio_start+self.segment_length]

        if len(audio) < self.segment_length:
            audio = np.pad(audio, (0, self.segment_length - len(audio)))
        if len(units) < unit_len:
            units = np.pad(units, (0, unit_len - len(units)))

        return torch.LongTensor(units), torch.FloatTensor(audio)

# ============ LOAD MODELS AND RESUME ============
generator = Generator().to(device)
discriminator = Discriminator().to(device)

optimizer_g = torch.optim.AdamW(generator.parameters(), lr=0.0002, betas=(0.8, 0.99))
optimizer_d = torch.optim.AdamW(discriminator.parameters(), lr=0.0002, betas=(0.8, 0.99))

# Load checkpoint
checkpoint_path = f"{MODEL_DIR}/checkpoint_epoch25.pt"
checkpoint = torch.load(checkpoint_path, map_location=device)

generator.load_state_dict(checkpoint['generator_state_dict'])
discriminator.load_state_dict(checkpoint['discriminator_state_dict'])
optimizer_g.load_state_dict(checkpoint['optimizer_g_state_dict'])
optimizer_d.load_state_dict(checkpoint['optimizer_d_state_dict'])
start_epoch = checkpoint['epoch'] + 1

print(f"\n✓ Resumed from epoch {checkpoint['epoch']}")
print(f"  Starting at epoch {start_epoch}")

# ============ DATALOADER ============
dataset = UnitAudioDataset(train_files, TRAINING_DIR)
dataloader = DataLoader(dataset, batch_size=16, shuffle=True, num_workers=2, pin_memory=True)

val_dataset = UnitAudioDataset(val_files, TRAINING_DIR)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, num_workers=2)

print(f"✓ DataLoader ready: {len(dataloader)} batches")

# ============ TRAINING LOOP ============
NUM_EPOCHS = 50

print("\n" + "="*60)
print("RESUMING TRAINING")
print("="*60)

for epoch in range(start_epoch, NUM_EPOCHS + 1):
    generator.train()
    discriminator.train()

    epoch_g_loss = 0
    epoch_d_loss = 0

    pbar = tqdm(dataloader, desc=f"Epoch {epoch}/{NUM_EPOCHS}")

    for units, audio in pbar:
        units = units.to(device)
        audio = audio.to(device)

        # Generate
        audio_fake = generator(units)

        # Match lengths
        min_len = min(audio.shape[1], audio_fake.shape[1])
        audio = audio[:, :min_len]
        audio_fake = audio_fake[:, :min_len]

        # Discriminator update
        optimizer_d.zero_grad()
        real_outs = discriminator(audio.unsqueeze(1))
        fake_outs = discriminator(audio_fake.detach().unsqueeze(1))

        d_loss = 0
        for real_out, fake_out in zip(real_outs, fake_outs):
            d_loss += torch.mean((real_out - 1)**2) + torch.mean(fake_out**2)
        d_loss.backward()
        optimizer_d.step()

        # Generator update
        optimizer_g.zero_grad()
        fake_outs = discriminator(audio_fake.unsqueeze(1))

        g_loss_adv = 0
        for fake_out in fake_outs:
            g_loss_adv += torch.mean((fake_out - 1)**2)

        g_loss_l1 = F.l1_loss(audio_fake, audio) * 45
        g_loss = g_loss_adv + g_loss_l1
        g_loss.backward()
        optimizer_g.step()

        epoch_g_loss += g_loss.item()
        epoch_d_loss += d_loss.item()
        pbar.set_postfix({'G_loss': f'{g_loss.item():.3f}', 'D_loss': f'{d_loss.item():.3f}'})

    avg_g = epoch_g_loss / len(dataloader)
    avg_d = epoch_d_loss / len(dataloader)
    print(f"\n  Epoch {epoch} complete")
    print(f"  Avg G Loss: {avg_g:.4f}, Avg D Loss: {avg_d:.4f}")

    # Validation
    if epoch % 2 == 0:
        generator.eval()
        val_loss = 0
        with torch.no_grad():
            for units, audio in val_loader:
                units, audio = units.to(device), audio.to(device)
                fake = generator(units)
                min_len = min(audio.shape[1], fake.shape[1])
                val_loss += F.l1_loss(fake[:,:min_len], audio[:,:min_len]).item()
        val_loss /= len(val_loader)
        print(f"\n  Validation L1 Loss: {val_loss:.4f}")

    # Save checkpoint every 5 epochs
    if epoch % 5 == 0:
        torch.save({
            'epoch': epoch,
            'generator_state_dict': generator.state_dict(),
            'discriminator_state_dict': discriminator.state_dict(),
            'optimizer_g_state_dict': optimizer_g.state_dict(),
            'optimizer_d_state_dict': optimizer_d.state_dict(),
        }, f"{MODEL_DIR}/checkpoint_epoch{epoch}.pt")
        print(f"  ✓ Saved checkpoint")

    # Save sample every 10 epochs
    if epoch % 10 == 0:
        generator.eval()
        with torch.no_grad():
            sample_units = torch.LongTensor(dataset[0][0]).unsqueeze(0).to(device)
            sample_audio = generator(sample_units).squeeze().cpu().numpy()
            wav.write(f"{MODEL_DIR}/sample_epoch{epoch}.wav", 16000, (sample_audio * 32767).astype(np.int16))
            print(f"  ✓ Saved sample audio: {MODEL_DIR}/sample_epoch{epoch}.wav")

print("\n" + "="*60)
print("TRAINING COMPLETE")
print("="*60)

Using device: cuda
Training files: 22834
Validation files: 100


RuntimeError: Error(s) in loading state_dict for Discriminator:
	size mismatch for discriminators.0.4.weight: copying a param with shape torch.Size([256, 8, 41]) from checkpoint, the shape in current model is torch.Size([256, 16, 41]).
	size mismatch for discriminators.1.4.weight: copying a param with shape torch.Size([256, 8, 41]) from checkpoint, the shape in current model is torch.Size([256, 16, 41]).
	size mismatch for discriminators.2.4.weight: copying a param with shape torch.Size([256, 8, 41]) from checkpoint, the shape in current model is torch.Size([256, 16, 41]).

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
import os
import scipy.io.wavfile as wav
from tqdm import tqdm

# Paths
PROJECT_ROOT = "/content/drive/MyDrive/satere_project"
MODEL_DIR = f"{PROJECT_ROOT}/hifigan_model"
TRAINING_DIR = f"{PROJECT_ROOT}/hifigan_training"

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Load file lists
with open(f"{TRAINING_DIR}/train_files.txt", "r") as f:
    train_files = [line.strip() for line in f.readlines()]

with open(f"{TRAINING_DIR}/val_files.txt", "r") as f:
    val_files = [line.strip() for line in f.readlines()]

print(f"Training files: {len(train_files)}")
print(f"Validation files: {len(val_files)}")

# ============ MODEL DEFINITIONS (matching checkpoint) ============
class Generator(nn.Module):
    def __init__(self, num_units=100, embed_dim=256):
        super().__init__()
        self.unit_embed = nn.Embedding(num_units, embed_dim)
        self.pre_conv = nn.Conv1d(embed_dim, 512, 7, padding=3)

        self.upsamples = nn.ModuleList([
            nn.Sequential(
                nn.ConvTranspose1d(512, 256, 10, stride=5, padding=2, output_padding=0),
                nn.LeakyReLU(0.1),
                nn.Conv1d(256, 256, 7, padding=3),
                nn.LeakyReLU(0.1)
            ),
            nn.Sequential(
                nn.ConvTranspose1d(256, 128, 10, stride=5, padding=2, output_padding=0),
                nn.LeakyReLU(0.1),
                nn.Conv1d(128, 128, 7, padding=3),
                nn.LeakyReLU(0.1)
            ),
            nn.Sequential(
                nn.ConvTranspose1d(128, 64, 8, stride=4, padding=2, output_padding=0),
                nn.LeakyReLU(0.1),
                nn.Conv1d(64, 64, 7, padding=3),
                nn.LeakyReLU(0.1)
            ),
            nn.Sequential(
                nn.ConvTranspose1d(64, 32, 8, stride=4, padding=2, output_padding=0),
                nn.LeakyReLU(0.1),
                nn.Conv1d(32, 32, 7, padding=3),
                nn.LeakyReLU(0.1)
            ),
            nn.Sequential(
                nn.ConvTranspose1d(32, 64, 4, stride=2, padding=1, output_padding=0),
                nn.LeakyReLU(0.1),
                nn.Conv1d(64, 64, 7, padding=3),
                nn.LeakyReLU(0.1)
            ),
        ])

        self.post_conv = nn.Conv1d(64, 1, 7, padding=3)

    def forward(self, x):
        x = self.unit_embed(x).transpose(1, 2)
        x = self.pre_conv(x)
        x = F.leaky_relu(x, 0.1)

        for upsample in self.upsamples:
            x = upsample(x)

        x = self.post_conv(x)
        x = torch.tanh(x)
        return x.squeeze(1)

class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        # 3 sub-discriminators with correct group sizes from checkpoint
        self.discriminators = nn.ModuleList([
            self._make_discriminator(),
            self._make_discriminator(),
            self._make_discriminator()
        ])
        self.pools = nn.ModuleList([
            nn.Identity(),
            nn.AvgPool1d(4, 2, padding=2),
            nn.AvgPool1d(16, 8, padding=8)
        ])

    def _make_discriminator(self):
        # Groups: 64->128 (groups=4, 64/4=16 in, 128/4=32 out per group)
        # But checkpoint shows: [128, 16, 41] meaning in_ch=16*groups
        # 64 in, 128 out, groups=4 means each group: 16 in -> 32 out
        # [256, 8, 41] means 8 input channels per group
        # 128 in, groups where 128/groups=8, so groups=16
        return nn.Sequential(
            nn.Conv1d(1, 64, 15, padding=7),
            nn.LeakyReLU(0.1),
            nn.Conv1d(64, 128, 41, stride=4, padding=20, groups=4),
            nn.LeakyReLU(0.1),
            nn.Conv1d(128, 256, 41, stride=4, padding=20, groups=16),  # Fixed: groups=16
            nn.LeakyReLU(0.1),
            nn.Conv1d(256, 512, 41, stride=4, padding=20, groups=16),
            nn.LeakyReLU(0.1),
            nn.Conv1d(512, 512, 5, padding=2),
            nn.LeakyReLU(0.1),
            nn.Conv1d(512, 1, 3, padding=1)
        )

    def forward(self, x):
        outputs = []
        for pool, disc in zip(self.pools, self.discriminators):
            x_pooled = pool(x)
            outputs.append(disc(x_pooled))
        return outputs

# ============ DATASET ============
class UnitAudioDataset(Dataset):
    def __init__(self, file_list, training_dir, segment_length=16384):
        self.file_list = file_list
        self.training_dir = training_dir
        self.segment_length = segment_length

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        file_id = self.file_list[idx]

        with open(f"{self.training_dir}/units/{file_id}.txt", "r") as f:
            units = np.array([int(x) for x in f.read().strip().split()], dtype=np.int64)

        sr, audio = wav.read(f"{self.training_dir}/wavs/{file_id}.wav")
        audio = audio.astype(np.float32) / 32768.0

        unit_len = self.segment_length // 512
        if len(units) > unit_len:
            start = np.random.randint(0, len(units) - unit_len)
            units = units[start:start+unit_len]
            audio_start = start * 512
            audio = audio[audio_start:audio_start+self.segment_length]

        if len(audio) < self.segment_length:
            audio = np.pad(audio, (0, self.segment_length - len(audio)))
        if len(units) < unit_len:
            units = np.pad(units, (0, unit_len - len(units)))

        return torch.LongTensor(units), torch.FloatTensor(audio)

# ============ LOAD MODELS AND RESUME ============
generator = Generator().to(device)
discriminator = Discriminator().to(device)

optimizer_g = torch.optim.AdamW(generator.parameters(), lr=0.0002, betas=(0.8, 0.99))
optimizer_d = torch.optim.AdamW(discriminator.parameters(), lr=0.0002, betas=(0.8, 0.99))

# Load checkpoint
checkpoint_path = f"{MODEL_DIR}/checkpoint_epoch25.pt"
checkpoint = torch.load(checkpoint_path, map_location=device)

generator.load_state_dict(checkpoint['generator_state_dict'])
discriminator.load_state_dict(checkpoint['discriminator_state_dict'])
optimizer_g.load_state_dict(checkpoint['optimizer_g_state_dict'])
optimizer_d.load_state_dict(checkpoint['optimizer_d_state_dict'])
start_epoch = checkpoint['epoch'] + 1

print(f"\n✓ Resumed from epoch {checkpoint['epoch']}")
print(f"  Starting at epoch {start_epoch}")

# ============ DATALOADER ============
dataset = UnitAudioDataset(train_files, TRAINING_DIR)
dataloader = DataLoader(dataset, batch_size=16, shuffle=True, num_workers=2, pin_memory=True)

val_dataset = UnitAudioDataset(val_files, TRAINING_DIR)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, num_workers=2)

print(f"✓ DataLoader ready: {len(dataloader)} batches")

# ============ TRAINING LOOP ============
NUM_EPOCHS = 50

print("\n" + "="*60)
print("RESUMING TRAINING")
print("="*60)

for epoch in range(start_epoch, NUM_EPOCHS + 1):
    generator.train()
    discriminator.train()

    epoch_g_loss = 0
    epoch_d_loss = 0

    pbar = tqdm(dataloader, desc=f"Epoch {epoch}/{NUM_EPOCHS}")

    for units, audio in pbar:
        units = units.to(device)
        audio = audio.to(device)

        audio_fake = generator(units)

        min_len = min(audio.shape[1], audio_fake.shape[1])
        audio = audio[:, :min_len]
        audio_fake = audio_fake[:, :min_len]

        # Discriminator
        optimizer_d.zero_grad()
        real_outs = discriminator(audio.unsqueeze(1))
        fake_outs = discriminator(audio_fake.detach().unsqueeze(1))

        d_loss = 0
        for real_out, fake_out in zip(real_outs, fake_outs):
            d_loss += torch.mean((real_out - 1)**2) + torch.mean(fake_out**2)
        d_loss.backward()
        optimizer_d.step()

        # Generator
        optimizer_g.zero_grad()
        fake_outs = discriminator(audio_fake.unsqueeze(1))

        g_loss_adv = 0
        for fake_out in fake_outs:
            g_loss_adv += torch.mean((fake_out - 1)**2)

        g_loss_l1 = F.l1_loss(audio_fake, audio) * 45
        g_loss = g_loss_adv + g_loss_l1
        g_loss.backward()
        optimizer_g.step()

        epoch_g_loss += g_loss.item()
        epoch_d_loss += d_loss.item()
        pbar.set_postfix({'G_loss': f'{g_loss.item():.3f}', 'D_loss': f'{d_loss.item():.3f}'})

    avg_g = epoch_g_loss / len(dataloader)
    avg_d = epoch_d_loss / len(dataloader)
    print(f"\n  Epoch {epoch} complete")
    print(f"  Avg G Loss: {avg_g:.4f}, Avg D Loss: {avg_d:.4f}")

    if epoch % 2 == 0:
        generator.eval()
        val_loss = 0
        with torch.no_grad():
            for units, audio in val_loader:
                units, audio = units.to(device), audio.to(device)
                fake = generator(units)
                min_len = min(audio.shape[1], fake.shape[1])
                val_loss += F.l1_loss(fake[:,:min_len], audio[:,:min_len]).item()
        val_loss /= len(val_loader)
        print(f"\n  Validation L1 Loss: {val_loss:.4f}")

    if epoch % 5 == 0:
        torch.save({
            'epoch': epoch,
            'generator_state_dict': generator.state_dict(),
            'discriminator_state_dict': discriminator.state_dict(),
            'optimizer_g_state_dict': optimizer_g.state_dict(),
            'optimizer_d_state_dict': optimizer_d.state_dict(),
        }, f"{MODEL_DIR}/checkpoint_epoch{epoch}.pt")
        print(f"  ✓ Saved checkpoint")

    if epoch % 10 == 0:
        generator.eval()
        with torch.no_grad():
            sample_units = torch.LongTensor(dataset[0][0]).unsqueeze(0).to(device)
            sample_audio = generator(sample_units).squeeze().cpu().numpy()
            wav.write(f"{MODEL_DIR}/sample_epoch{epoch}.wav", 16000, (sample_audio * 32767).astype(np.int16))
            print(f"  ✓ Saved sample audio: {MODEL_DIR}/sample_epoch{epoch}.wav")

print("\n" + "="*60)
print("TRAINING COMPLETE")
print("="*60)

Using device: cuda
Training files: 22834
Validation files: 100

✓ Resumed from epoch 24
  Starting at epoch 25
✓ DataLoader ready: 1428 batches

RESUMING TRAINING


Epoch 25/50:  54%|█████▍    | 768/1428 [2:07:24<1:49:29,  9.95s/it, G_loss=5.567, D_loss=0.089]


KeyboardInterrupt: 

In [None]:
!nvidia-smi

In [None]:
!nvidia-smi

Mon Jan 19 16:52:39 2026       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  Tesla T4                       Off |   00000000:00:04.0 Off |                    0 |
| N/A   76C    P0             32W /   70W |    2188MiB /  15360MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
import os
import scipy.io.wavfile as wav
from tqdm import tqdm

# Paths
PROJECT_ROOT = "/content/drive/MyDrive/satere_project"
MODEL_DIR = f"{PROJECT_ROOT}/hifigan_model"
TRAINING_DIR = f"{PROJECT_ROOT}/hifigan_training"

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Load file lists
with open(f"{TRAINING_DIR}/train_files.txt", "r") as f:
    train_files = [line.strip() for line in f.readlines()]

with open(f"{TRAINING_DIR}/val_files.txt", "r") as f:
    val_files = [line.strip() for line in f.readlines()]

print(f"Training files: {len(train_files)}")
print(f"Validation files: {len(val_files)}")

# ============ MODEL DEFINITIONS ============
class Generator(nn.Module):
    def __init__(self, num_units=100, embed_dim=256):
        super().__init__()
        self.unit_embed = nn.Embedding(num_units, embed_dim)
        self.pre_conv = nn.Conv1d(embed_dim, 512, 7, padding=3)

        self.upsamples = nn.ModuleList([
            nn.Sequential(
                nn.ConvTranspose1d(512, 256, 10, stride=5, padding=2, output_padding=0),
                nn.LeakyReLU(0.1),
                nn.Conv1d(256, 256, 7, padding=3),
                nn.LeakyReLU(0.1)
            ),
            nn.Sequential(
                nn.ConvTranspose1d(256, 128, 10, stride=5, padding=2, output_padding=0),
                nn.LeakyReLU(0.1),
                nn.Conv1d(128, 128, 7, padding=3),
                nn.LeakyReLU(0.1)
            ),
            nn.Sequential(
                nn.ConvTranspose1d(128, 64, 8, stride=4, padding=2, output_padding=0),
                nn.LeakyReLU(0.1),
                nn.Conv1d(64, 64, 7, padding=3),
                nn.LeakyReLU(0.1)
            ),
            nn.Sequential(
                nn.ConvTranspose1d(64, 32, 8, stride=4, padding=2, output_padding=0),
                nn.LeakyReLU(0.1),
                nn.Conv1d(32, 32, 7, padding=3),
                nn.LeakyReLU(0.1)
            ),
            nn.Sequential(
                nn.ConvTranspose1d(32, 64, 4, stride=2, padding=1, output_padding=0),
                nn.LeakyReLU(0.1),
                nn.Conv1d(64, 64, 7, padding=3),
                nn.LeakyReLU(0.1)
            ),
        ])

        self.post_conv = nn.Conv1d(64, 1, 7, padding=3)

    def forward(self, x):
        x = self.unit_embed(x).transpose(1, 2)
        x = self.pre_conv(x)
        x = F.leaky_relu(x, 0.1)

        for upsample in self.upsamples:
            x = upsample(x)

        x = self.post_conv(x)
        x = torch.tanh(x)
        return x.squeeze(1)

class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.discriminators = nn.ModuleList([
            self._make_discriminator(),
            self._make_discriminator(),
            self._make_discriminator()
        ])
        self.pools = nn.ModuleList([
            nn.Identity(),
            nn.AvgPool1d(4, 2, padding=2),
            nn.AvgPool1d(16, 8, padding=8)
        ])

    def _make_discriminator(self):
        return nn.Sequential(
            nn.Conv1d(1, 64, 15, padding=7),
            nn.LeakyReLU(0.1),
            nn.Conv1d(64, 128, 41, stride=4, padding=20, groups=4),
            nn.LeakyReLU(0.1),
            nn.Conv1d(128, 256, 41, stride=4, padding=20, groups=16),
            nn.LeakyReLU(0.1),
            nn.Conv1d(256, 512, 41, stride=4, padding=20, groups=16),
            nn.LeakyReLU(0.1),
            nn.Conv1d(512, 512, 5, padding=2),
            nn.LeakyReLU(0.1),
            nn.Conv1d(512, 1, 3, padding=1)
        )

    def forward(self, x):
        outputs = []
        for pool, disc in zip(self.pools, self.discriminators):
            x_pooled = pool(x)
            outputs.append(disc(x_pooled))
        return outputs

# ============ DATASET WITH CACHING ============
class UnitAudioDataset(Dataset):
    def __init__(self, file_list, training_dir, segment_length=16384, cache_in_memory=False):
        self.file_list = file_list
        self.training_dir = training_dir
        self.segment_length = segment_length
        self.cache = {}
        self.cache_in_memory = cache_in_memory

        # Pre-load a subset to warm up
        if cache_in_memory:
            print("Pre-loading data into memory...")
            for i, file_id in enumerate(tqdm(file_list[:5000], desc="Caching")):
                self._load_and_cache(i, file_id)

    def _load_and_cache(self, idx, file_id):
        with open(f"{self.training_dir}/units/{file_id}.txt", "r") as f:
            units = np.array([int(x) for x in f.read().strip().split()], dtype=np.int64)
        sr, audio = wav.read(f"{self.training_dir}/wavs/{file_id}.wav")
        audio = audio.astype(np.float32) / 32768.0
        self.cache[idx] = (units, audio)
        return units, audio

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        if idx in self.cache:
            units, audio = self.cache[idx]
        else:
            file_id = self.file_list[idx]
            with open(f"{self.training_dir}/units/{file_id}.txt", "r") as f:
                units = np.array([int(x) for x in f.read().strip().split()], dtype=np.int64)
            sr, audio = wav.read(f"{self.training_dir}/wavs/{file_id}.wav")
            audio = audio.astype(np.float32) / 32768.0

        unit_len = self.segment_length // 512
        if len(units) > unit_len:
            start = np.random.randint(0, len(units) - unit_len)
            units = units[start:start+unit_len]
            audio_start = start * 512
            audio = audio[audio_start:audio_start+self.segment_length]

        if len(audio) < self.segment_length:
            audio = np.pad(audio, (0, self.segment_length - len(audio)))
        if len(units) < unit_len:
            units = np.pad(units, (0, unit_len - len(units)))

        return torch.LongTensor(units), torch.FloatTensor(audio)

# ============ LOAD MODELS AND RESUME ============
generator = Generator().to(device)
discriminator = Discriminator().to(device)

optimizer_g = torch.optim.AdamW(generator.parameters(), lr=0.0002, betas=(0.8, 0.99))
optimizer_d = torch.optim.AdamW(discriminator.parameters(), lr=0.0002, betas=(0.8, 0.99))

# Find best checkpoint
checkpoints = [f for f in os.listdir(MODEL_DIR) if f.startswith('checkpoint_epoch')]
if checkpoints:
    epochs = [int(f.split('epoch')[1].split('.')[0]) for f in checkpoints]
    best_epoch = max(epochs)
    checkpoint_path = f"{MODEL_DIR}/checkpoint_epoch{best_epoch}.pt"
    print(f"Found checkpoint: epoch {best_epoch}")
else:
    checkpoint_path = None
    print("No checkpoint found, starting fresh")

if checkpoint_path:
    checkpoint = torch.load(checkpoint_path, map_location=device)
    generator.load_state_dict(checkpoint['generator_state_dict'])
    discriminator.load_state_dict(checkpoint['discriminator_state_dict'])
    optimizer_g.load_state_dict(checkpoint['optimizer_g_state_dict'])
    optimizer_d.load_state_dict(checkpoint['optimizer_d_state_dict'])
    start_epoch = checkpoint['epoch'] + 1
    print(f"✓ Resumed from epoch {checkpoint['epoch']}")
else:
    start_epoch = 1

print(f"  Starting at epoch {start_epoch}")

# ============ DATALOADER - OPTIMIZED ============
dataset = UnitAudioDataset(train_files, TRAINING_DIR, cache_in_memory=False)
dataloader = DataLoader(
    dataset,
    batch_size=16,
    shuffle=True,
    num_workers=0,  # Avoid multiprocessing overhead with Drive
    pin_memory=True
)

val_dataset = UnitAudioDataset(val_files, TRAINING_DIR)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, num_workers=0)

print(f"✓ DataLoader ready: {len(dataloader)} batches")

# Warm up GPU
print("Warming up GPU...")
dummy = torch.randn(16, 32).long().to(device)
_ = generator(dummy)
print("✓ GPU ready")

# ============ TRAINING LOOP ============
NUM_EPOCHS = 50

print("\n" + "="*60)
print("RESUMING TRAINING")
print("="*60)

for epoch in range(start_epoch, NUM_EPOCHS + 1):
    generator.train()
    discriminator.train()

    epoch_g_loss = 0
    epoch_d_loss = 0

    pbar = tqdm(dataloader, desc=f"Epoch {epoch}/{NUM_EPOCHS}")

    for units, audio in pbar:
        units = units.to(device, non_blocking=True)
        audio = audio.to(device, non_blocking=True)

        audio_fake = generator(units)

        min_len = min(audio.shape[1], audio_fake.shape[1])
        audio = audio[:, :min_len]
        audio_fake = audio_fake[:, :min_len]

        # Discriminator
        optimizer_d.zero_grad()
        real_outs = discriminator(audio.unsqueeze(1))
        fake_outs = discriminator(audio_fake.detach().unsqueeze(1))

        d_loss = 0
        for real_out, fake_out in zip(real_outs, fake_outs):
            d_loss += torch.mean((real_out - 1)**2) + torch.mean(fake_out**2)
        d_loss.backward()
        optimizer_d.step()

        # Generator
        optimizer_g.zero_grad()
        fake_outs = discriminator(audio_fake.unsqueeze(1))

        g_loss_adv = 0
        for fake_out in fake_outs:
            g_loss_adv += torch.mean((fake_out - 1)**2)

        g_loss_l1 = F.l1_loss(audio_fake, audio) * 45
        g_loss = g_loss_adv + g_loss_l1
        g_loss.backward()
        optimizer_g.step()

        epoch_g_loss += g_loss.item()
        epoch_d_loss += d_loss.item()
        pbar.set_postfix({'G_loss': f'{g_loss.item():.3f}', 'D_loss': f'{d_loss.item():.3f}'})

    avg_g = epoch_g_loss / len(dataloader)
    avg_d = epoch_d_loss / len(dataloader)
    print(f"\n  Epoch {epoch} complete")
    print(f"  Avg G Loss: {avg_g:.4f}, Avg D Loss: {avg_d:.4f}")

    if epoch % 2 == 0:
        generator.eval()
        val_loss = 0
        with torch.no_grad():
            for units, audio in val_loader:
                units, audio = units.to(device), audio.to(device)
                fake = generator(units)
                min_len = min(audio.shape[1], fake.shape[1])
                val_loss += F.l1_loss(fake[:,:min_len], audio[:,:min_len]).item()
        val_loss /= len(val_loader)
        print(f"\n  Validation L1 Loss: {val_loss:.4f}")

    if epoch % 5 == 0:
        torch.save({
            'epoch': epoch,
            'generator_state_dict': generator.state_dict(),
            'discriminator_state_dict': discriminator.state_dict(),
            'optimizer_g_state_dict': optimizer_g.state_dict(),
            'optimizer_d_state_dict': optimizer_d.state_dict(),
        }, f"{MODEL_DIR}/checkpoint_epoch{epoch}.pt")
        print(f"  ✓ Saved checkpoint")

    if epoch % 10 == 0:
        generator.eval()
        with torch.no_grad():
            sample_units = torch.LongTensor(dataset[0][0]).unsqueeze(0).to(device)
            sample_audio = generator(sample_units).squeeze().cpu().numpy()
            wav.write(f"{MODEL_DIR}/sample_epoch{epoch}.wav", 16000, (sample_audio * 32767).astype(np.int16))
            print(f"  ✓ Saved sample audio: {MODEL_DIR}/sample_epoch{epoch}.wav")

print("\n" + "="*60)
print("TRAINING COMPLETE")
print("="*60)

Using device: cuda
Training files: 22834
Validation files: 100
Found checkpoint: epoch 25
✓ Resumed from epoch 24
  Starting at epoch 25
✓ DataLoader ready: 1428 batches
Warming up GPU...
✓ GPU ready

RESUMING TRAINING


Epoch 25/50:   0%|          | 0/1428 [00:10<?, ?it/s]


AcceleratorError: CUDA error: device-side assert triggered
Search for `cudaErrorAssert' in https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html for more information.
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
import os
import scipy.io.wavfile as wav
from tqdm import tqdm

# Clear any previous CUDA errors
torch.cuda.empty_cache()

# Paths
PROJECT_ROOT = "/content/drive/MyDrive/satere_project"
MODEL_DIR = f"{PROJECT_ROOT}/hifigan_model"
TRAINING_DIR = f"{PROJECT_ROOT}/hifigan_training"

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Load file lists
with open(f"{TRAINING_DIR}/train_files.txt", "r") as f:
    train_files = [line.strip() for line in f.readlines()]

with open(f"{TRAINING_DIR}/val_files.txt", "r") as f:
    val_files = [line.strip() for line in f.readlines()]

print(f"Training files: {len(train_files)}")
print(f"Validation files: {len(val_files)}")

# ============ MODEL DEFINITIONS ============
class Generator(nn.Module):
    def __init__(self, num_units=100, embed_dim=256):
        super().__init__()
        self.unit_embed = nn.Embedding(num_units, embed_dim)
        self.pre_conv = nn.Conv1d(embed_dim, 512, 7, padding=3)

        self.upsamples = nn.ModuleList([
            nn.Sequential(
                nn.ConvTranspose1d(512, 256, 10, stride=5, padding=2, output_padding=0),
                nn.LeakyReLU(0.1),
                nn.Conv1d(256, 256, 7, padding=3),
                nn.LeakyReLU(0.1)
            ),
            nn.Sequential(
                nn.ConvTranspose1d(256, 128, 10, stride=5, padding=2, output_padding=0),
                nn.LeakyReLU(0.1),
                nn.Conv1d(128, 128, 7, padding=3),
                nn.LeakyReLU(0.1)
            ),
            nn.Sequential(
                nn.ConvTranspose1d(128, 64, 8, stride=4, padding=2, output_padding=0),
                nn.LeakyReLU(0.1),
                nn.Conv1d(64, 64, 7, padding=3),
                nn.LeakyReLU(0.1)
            ),
            nn.Sequential(
                nn.ConvTranspose1d(64, 32, 8, stride=4, padding=2, output_padding=0),
                nn.LeakyReLU(0.1),
                nn.Conv1d(32, 32, 7, padding=3),
                nn.LeakyReLU(0.1)
            ),
            nn.Sequential(
                nn.ConvTranspose1d(32, 64, 4, stride=2, padding=1, output_padding=0),
                nn.LeakyReLU(0.1),
                nn.Conv1d(64, 64, 7, padding=3),
                nn.LeakyReLU(0.1)
            ),
        ])

        self.post_conv = nn.Conv1d(64, 1, 7, padding=3)

    def forward(self, x):
        x = self.unit_embed(x).transpose(1, 2)
        x = self.pre_conv(x)
        x = F.leaky_relu(x, 0.1)

        for upsample in self.upsamples:
            x = upsample(x)

        x = self.post_conv(x)
        x = torch.tanh(x)
        return x.squeeze(1)

class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.discriminators = nn.ModuleList([
            self._make_discriminator(),
            self._make_discriminator(),
            self._make_discriminator()
        ])
        self.pools = nn.ModuleList([
            nn.Identity(),
            nn.AvgPool1d(4, 2, padding=2),
            nn.AvgPool1d(16, 8, padding=8)
        ])

    def _make_discriminator(self):
        return nn.Sequential(
            nn.Conv1d(1, 64, 15, padding=7),
            nn.LeakyReLU(0.1),
            nn.Conv1d(64, 128, 41, stride=4, padding=20, groups=4),
            nn.LeakyReLU(0.1),
            nn.Conv1d(128, 256, 41, stride=4, padding=20, groups=16),
            nn.LeakyReLU(0.1),
            nn.Conv1d(256, 512, 41, stride=4, padding=20, groups=16),
            nn.LeakyReLU(0.1),
            nn.Conv1d(512, 512, 5, padding=2),
            nn.LeakyReLU(0.1),
            nn.Conv1d(512, 1, 3, padding=1)
        )

    def forward(self, x):
        outputs = []
        for pool, disc in zip(self.pools, self.discriminators):
            x_pooled = pool(x)
            outputs.append(disc(x_pooled))
        return outputs

# ============ DATASET ============
class UnitAudioDataset(Dataset):
    def __init__(self, file_list, training_dir, segment_length=16384):
        self.file_list = file_list
        self.training_dir = training_dir
        self.segment_length = segment_length

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        file_id = self.file_list[idx]

        with open(f"{self.training_dir}/units/{file_id}.txt", "r") as f:
            units = np.array([int(x) for x in f.read().strip().split()], dtype=np.int64)

        sr, audio = wav.read(f"{self.training_dir}/wavs/{file_id}.wav")
        audio = audio.astype(np.float32) / 32768.0

        unit_len = self.segment_length // 512
        if len(units) > unit_len:
            start = np.random.randint(0, len(units) - unit_len)
            units = units[start:start+unit_len]
            audio_start = start * 512
            audio = audio[audio_start:audio_start+self.segment_length]

        if len(audio) < self.segment_length:
            audio = np.pad(audio, (0, self.segment_length - len(audio)))
        if len(units) < unit_len:
            units = np.pad(units, (0, unit_len - len(units)))

        return torch.LongTensor(units), torch.FloatTensor(audio)

# ============ LOAD MODELS AND RESUME ============
generator = Generator().to(device)
discriminator = Discriminator().to(device)

optimizer_g = torch.optim.AdamW(generator.parameters(), lr=0.0002, betas=(0.8, 0.99))
optimizer_d = torch.optim.AdamW(discriminator.parameters(), lr=0.0002, betas=(0.8, 0.99))

# Find best checkpoint
checkpoints = [f for f in os.listdir(MODEL_DIR) if f.startswith('checkpoint_epoch')]
if checkpoints:
    epochs = [int(f.split('epoch')[1].split('.')[0]) for f in checkpoints]
    best_epoch = max(epochs)
    checkpoint_path = f"{MODEL_DIR}/checkpoint_epoch{best_epoch}.pt"
    print(f"Found checkpoint: epoch {best_epoch}")
else:
    checkpoint_path = None
    print("No checkpoint found")

if checkpoint_path:
    checkpoint = torch.load(checkpoint_path, map_location=device)
    generator.load_state_dict(checkpoint['generator_state_dict'])
    discriminator.load_state_dict(checkpoint['discriminator_state_dict'])
    optimizer_g.load_state_dict(checkpoint['optimizer_g_state_dict'])
    optimizer_d.load_state_dict(checkpoint['optimizer_d_state_dict'])
    start_epoch = checkpoint['epoch'] + 1
    print(f"✓ Resumed from epoch {checkpoint['epoch']}")
else:
    start_epoch = 1

print(f"  Starting at epoch {start_epoch}")

# ============ DATALOADER ============
dataset = UnitAudioDataset(train_files, TRAINING_DIR)
dataloader = DataLoader(
    dataset,
    batch_size=16,
    shuffle=True,
    num_workers=0,
    pin_memory=True
)

val_dataset = UnitAudioDataset(val_files, TRAINING_DIR)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, num_workers=0)

print(f"✓ DataLoader ready: {len(dataloader)} batches")

# ============ TRAINING LOOP ============
NUM_EPOCHS = 50

print("\n" + "="*60)
print("RESUMING TRAINING")
print("="*60)

for epoch in range(start_epoch, NUM_EPOCHS + 1):
    generator.train()
    discriminator.train()

    epoch_g_loss = 0
    epoch_d_loss = 0

    pbar = tqdm(dataloader, desc=f"Epoch {epoch}/{NUM_EPOCHS}")

    for units, audio in pbar:
        units = units.to(device)
        audio = audio.to(device)

        audio_fake = generator(units)

        min_len = min(audio.shape[1], audio_fake.shape[1])
        audio = audio[:, :min_len]
        audio_fake = audio_fake[:, :min_len]

        # Discriminator
        optimizer_d.zero_grad()
        real_outs = discriminator(audio.unsqueeze(1))
        fake_outs = discriminator(audio_fake.detach().unsqueeze(1))

        d_loss = 0
        for real_out, fake_out in zip(real_outs, fake_outs):
            d_loss += torch.mean((real_out - 1)**2) + torch.mean(fake_out**2)
        d_loss.backward()
        optimizer_d.step()

        # Generator
        optimizer_g.zero_grad()
        fake_outs = discriminator(audio_fake.unsqueeze(1))

        g_loss_adv = 0
        for fake_out in fake_outs:
            g_loss_adv += torch.mean((fake_out - 1)**2)

        g_loss_l1 = F.l1_loss(audio_fake, audio) * 45
        g_loss = g_loss_adv + g_loss_l1
        g_loss.backward()
        optimizer_g.step()

        epoch_g_loss += g_loss.item()
        epoch_d_loss += d_loss.item()
        pbar.set_postfix({'G_loss': f'{g_loss.item():.3f}', 'D_loss': f'{d_loss.item():.3f}'})

    avg_g = epoch_g_loss / len(dataloader)
    avg_d = epoch_d_loss / len(dataloader)
    print(f"\n  Epoch {epoch} complete")
    print(f"  Avg G Loss: {avg_g:.4f}, Avg D Loss: {avg_d:.4f}")

    if epoch % 2 == 0:
        generator.eval()
        val_loss = 0
        with torch.no_grad():
            for units, audio in val_loader:
                units, audio = units.to(device), audio.to(device)
                fake = generator(units)
                min_len = min(audio.shape[1], fake.shape[1])
                val_loss += F.l1_loss(fake[:,:min_len], audio[:,:min_len]).item()
        val_loss /= len(val_loader)
        print(f"\n  Validation L1 Loss: {val_loss:.4f}")

    if epoch % 5 == 0:
        torch.save({
            'epoch': epoch,
            'generator_state_dict': generator.state_dict(),
            'discriminator_state_dict': discriminator.state_dict(),
            'optimizer_g_state_dict': optimizer_g.state_dict(),
            'optimizer_d_state_dict': optimizer_d.state_dict(),
        }, f"{MODEL_DIR}/checkpoint_epoch{epoch}.pt")
        print(f"  ✓ Saved checkpoint")

    if epoch % 10 == 0:
        generator.eval()
        with torch.no_grad():
            sample_units = torch.LongTensor(dataset[0][0]).unsqueeze(0).to(device)
            sample_audio = generator(sample_units).squeeze().cpu().numpy()
            wav.write(f"{MODEL_DIR}/sample_epoch{epoch}.wav", 16000, (sample_audio * 32767).astype(np.int16))
            print(f"  ✓ Saved sample audio: {MODEL_DIR}/sample_epoch{epoch}.wav")

print("\n" + "="*60)
print("TRAINING COMPLETE")
print("="*60)

Using device: cuda
Training files: 22834
Validation files: 100
Found checkpoint: epoch 25
✓ Resumed from epoch 24
  Starting at epoch 25
✓ DataLoader ready: 1428 batches

RESUMING TRAINING


Epoch 25/50:  23%|██▎       | 328/1428 [50:08<2:48:09,  9.17s/it, G_loss=4.624, D_loss=0.063]


KeyboardInterrupt: 

In [None]:
import shutil
import os

# Copy training data to local Colab storage (much faster)
LOCAL_DIR = "/content/hifigan_local"
DRIVE_DIR = "/content/drive/MyDrive/satere_project/hifigan_training"

if not os.path.exists(LOCAL_DIR):
    print("Copying training data to local storage...")
    print("This will take 5-10 minutes but training will be MUCH faster.")
    shutil.copytree(DRIVE_DIR, LOCAL_DIR)
    print(f"✓ Copied to {LOCAL_DIR}")
else:
    print(f"✓ Local data already exists at {LOCAL_DIR}")

print(f"\nLocal files: {len(os.listdir(f'{LOCAL_DIR}/units'))} unit files")

Copying training data to local storage...
This will take 5-10 minutes but training will be MUCH faster.


FileNotFoundError: [Errno 2] No such file or directory: '/content/drive/MyDrive/satere_project/hifigan_training'

In [None]:
from google.colab import drive
drive.mount('/content/drive')

ValueError: Mountpoint must not already contain files

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
import os
import scipy.io.wavfile as wav
from tqdm import tqdm

# Paths - LOCAL storage for training data, DRIVE for saving checkpoints
LOCAL_DIR = "/content/hifigan_local"  # Fast local storage
PROJECT_ROOT = "/content/drive/MyDrive/satere_project"
MODEL_DIR = f"{PROJECT_ROOT}/hifigan_model"  # Save checkpoints to Drive

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Load file lists from local
with open(f"{LOCAL_DIR}/train_files.txt", "r") as f:
    train_files = [line.strip() for line in f.readlines()]

with open(f"{LOCAL_DIR}/val_files.txt", "r") as f:
    val_files = [line.strip() for line in f.readlines()]

print(f"Training files: {len(train_files)}")
print(f"Validation files: {len(val_files)}")

# ============ MODEL DEFINITIONS ============
class Generator(nn.Module):
    def __init__(self, num_units=100, embed_dim=256):
        super().__init__()
        self.unit_embed = nn.Embedding(num_units, embed_dim)
        self.pre_conv = nn.Conv1d(embed_dim, 512, 7, padding=3)

        self.upsamples = nn.ModuleList([
            nn.Sequential(
                nn.ConvTranspose1d(512, 256, 10, stride=5, padding=2, output_padding=0),
                nn.LeakyReLU(0.1),
                nn.Conv1d(256, 256, 7, padding=3),
                nn.LeakyReLU(0.1)
            ),
            nn.Sequential(
                nn.ConvTranspose1d(256, 128, 10, stride=5, padding=2, output_padding=0),
                nn.LeakyReLU(0.1),
                nn.Conv1d(128, 128, 7, padding=3),
                nn.LeakyReLU(0.1)
            ),
            nn.Sequential(
                nn.ConvTranspose1d(128, 64, 8, stride=4, padding=2, output_padding=0),
                nn.LeakyReLU(0.1),
                nn.Conv1d(64, 64, 7, padding=3),
                nn.LeakyReLU(0.1)
            ),
            nn.Sequential(
                nn.ConvTranspose1d(64, 32, 8, stride=4, padding=2, output_padding=0),
                nn.LeakyReLU(0.1),
                nn.Conv1d(32, 32, 7, padding=3),
                nn.LeakyReLU(0.1)
            ),
            nn.Sequential(
                nn.ConvTranspose1d(32, 64, 4, stride=2, padding=1, output_padding=0),
                nn.LeakyReLU(0.1),
                nn.Conv1d(64, 64, 7, padding=3),
                nn.LeakyReLU(0.1)
            ),
        ])

        self.post_conv = nn.Conv1d(64, 1, 7, padding=3)

    def forward(self, x):
        x = self.unit_embed(x).transpose(1, 2)
        x = self.pre_conv(x)
        x = F.leaky_relu(x, 0.1)

        for upsample in self.upsamples:
            x = upsample(x)

        x = self.post_conv(x)
        x = torch.tanh(x)
        return x.squeeze(1)

class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.discriminators = nn.ModuleList([
            self._make_discriminator(),
            self._make_discriminator(),
            self._make_discriminator()
        ])
        self.pools = nn.ModuleList([
            nn.Identity(),
            nn.AvgPool1d(4, 2, padding=2),
            nn.AvgPool1d(16, 8, padding=8)
        ])

    def _make_discriminator(self):
        return nn.Sequential(
            nn.Conv1d(1, 64, 15, padding=7),
            nn.LeakyReLU(0.1),
            nn.Conv1d(64, 128, 41, stride=4, padding=20, groups=4),
            nn.LeakyReLU(0.1),
            nn.Conv1d(128, 256, 41, stride=4, padding=20, groups=16),
            nn.LeakyReLU(0.1),
            nn.Conv1d(256, 512, 41, stride=4, padding=20, groups=16),
            nn.LeakyReLU(0.1),
            nn.Conv1d(512, 512, 5, padding=2),
            nn.LeakyReLU(0.1),
            nn.Conv1d(512, 1, 3, padding=1)
        )

    def forward(self, x):
        outputs = []
        for pool, disc in zip(self.pools, self.discriminators):
            x_pooled = pool(x)
            outputs.append(disc(x_pooled))
        return outputs

# ============ DATASET ============
class UnitAudioDataset(Dataset):
    def __init__(self, file_list, training_dir, segment_length=16384):
        self.file_list = file_list
        self.training_dir = training_dir
        self.segment_length = segment_length

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        file_id = self.file_list[idx]

        with open(f"{self.training_dir}/units/{file_id}.txt", "r") as f:
            units = np.array([int(x) for x in f.read().strip().split()], dtype=np.int64)

        sr, audio = wav.read(f"{self.training_dir}/wavs/{file_id}.wav")
        audio = audio.astype(np.float32) / 32768.0

        unit_len = self.segment_length // 512
        if len(units) > unit_len:
            start = np.random.randint(0, len(units) - unit_len)
            units = units[start:start+unit_len]
            audio_start = start * 512
            audio = audio[audio_start:audio_start+self.segment_length]

        if len(audio) < self.segment_length:
            audio = np.pad(audio, (0, self.segment_length - len(audio)))
        if len(units) < unit_len:
            units = np.pad(units, (0, unit_len - len(units)))

        return torch.LongTensor(units), torch.FloatTensor(audio)

# ============ LOAD MODELS AND RESUME ============
generator = Generator().to(device)
discriminator = Discriminator().to(device)

optimizer_g = torch.optim.AdamW(generator.parameters(), lr=0.0002, betas=(0.8, 0.99))
optimizer_d = torch.optim.AdamW(discriminator.parameters(), lr=0.0002, betas=(0.8, 0.99))

# Find best checkpoint from Drive
checkpoints = [f for f in os.listdir(MODEL_DIR) if f.startswith('checkpoint_epoch')]
if checkpoints:
    epochs = [int(f.split('epoch')[1].split('.')[0]) for f in checkpoints]
    best_epoch = max(epochs)
    checkpoint_path = f"{MODEL_DIR}/checkpoint_epoch{best_epoch}.pt"
    print(f"Found checkpoint: epoch {best_epoch}")
else:
    checkpoint_path = None
    print("No checkpoint found")

if checkpoint_path:
    checkpoint = torch.load(checkpoint_path, map_location=device)
    generator.load_state_dict(checkpoint['generator_state_dict'])
    discriminator.load_state_dict(checkpoint['discriminator_state_dict'])
    optimizer_g.load_state_dict(checkpoint['optimizer_g_state_dict'])
    optimizer_d.load_state_dict(checkpoint['optimizer_d_state_dict'])
    start_epoch = checkpoint['epoch'] + 1
    print(f"✓ Resumed from epoch {checkpoint['epoch']}")
else:
    start_epoch = 1

print(f"  Starting at epoch {start_epoch}")

# ============ DATALOADER - using LOCAL storage ============
dataset = UnitAudioDataset(train_files, LOCAL_DIR)
dataloader = DataLoader(
    dataset,
    batch_size=16,
    shuffle=True,
    num_workers=2,
    pin_memory=True
)

val_dataset = UnitAudioDataset(val_files, LOCAL_DIR)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, num_workers=2)

print(f"✓ DataLoader ready: {len(dataloader)} batches")

# ============ TRAINING LOOP ============
NUM_EPOCHS = 50

print("\n" + "="*60)
print("RESUMING TRAINING (using local storage)")
print("="*60)

for epoch in range(start_epoch, NUM_EPOCHS + 1):
    generator.train()
    discriminator.train()

    epoch_g_loss = 0
    epoch_d_loss = 0

    pbar = tqdm(dataloader, desc=f"Epoch {epoch}/{NUM_EPOCHS}")

    for units, audio in pbar:
        units = units.to(device)
        audio = audio.to(device)

        audio_fake = generator(units)

        min_len = min(audio.shape[1], audio_fake.shape[1])
        audio = audio[:, :min_len]
        audio_fake = audio_fake[:, :min_len]

        # Discriminator
        optimizer_d.zero_grad()
        real_outs = discriminator(audio.unsqueeze(1))
        fake_outs = discriminator(audio_fake.detach().unsqueeze(1))

        d_loss = 0
        for real_out, fake_out in zip(real_outs, fake_outs):
            d_loss += torch.mean((real_out - 1)**2) + torch.mean(fake_out**2)
        d_loss.backward()
        optimizer_d.step()

        # Generator
        optimizer_g.zero_grad()
        fake_outs = discriminator(audio_fake.unsqueeze(1))

        g_loss_adv = 0
        for fake_out in fake_outs:
            g_loss_adv += torch.mean((fake_out - 1)**2)

        g_loss_l1 = F.l1_loss(audio_fake, audio) * 45
        g_loss = g_loss_adv + g_loss_l1
        g_loss.backward()
        optimizer_g.step()

        epoch_g_loss += g_loss.item()
        epoch_d_loss += d_loss.item()
        pbar.set_postfix({'G_loss': f'{g_loss.item():.3f}', 'D_loss': f'{d_loss.item():.3f}'})

    avg_g = epoch_g_loss / len(dataloader)
    avg_d = epoch_d_loss / len(dataloader)
    print(f"\n  Epoch {epoch} complete")
    print(f"  Avg G Loss: {avg_g:.4f}, Avg D Loss: {avg_d:.4f}")

    if epoch % 2 == 0:
        generator.eval()
        val_loss = 0
        with torch.no_grad():
            for units, audio in val_loader:
                units, audio = units.to(device), audio.to(device)
                fake = generator(units)
                min_len = min(audio.shape[1], fake.shape[1])
                val_loss += F.l1_loss(fake[:,:min_len], audio[:,:min_len]).item()
        val_loss /= len(val_loader)
        print(f"\n  Validation L1 Loss: {val_loss:.4f}")

    # Save checkpoint to DRIVE every 5 epochs
    if epoch % 5 == 0:
        torch.save({
            'epoch': epoch,
            'generator_state_dict': generator.state_dict(),
            'discriminator_state_dict': discriminator.state_dict(),
            'optimizer_g_state_dict': optimizer_g.state_dict(),
            'optimizer_d_state_dict': optimizer_d.state_dict(),
        }, f"{MODEL_DIR}/checkpoint_epoch{epoch}.pt")
        print(f"  ✓ Saved checkpoint to Drive")

    # Save sample to DRIVE every 10 epochs
    if epoch % 10 == 0:
        generator.eval()
        with torch.no_grad():
            sample_units = torch.LongTensor(dataset[0][0]).unsqueeze(0).to(device)
            sample_audio = generator(sample_units).squeeze().cpu().numpy()
            wav.write(f"{MODEL_DIR}/sample_epoch{epoch}.wav", 16000, (sample_audio * 32767).astype(np.int16))
            print(f"  ✓ Saved sample audio to Drive")

print("\n" + "="*60)
print("TRAINING COMPLETE")
print("="*60)

Using device: cuda
Training files: 22834
Validation files: 100
Found checkpoint: epoch 25
✓ Resumed from epoch 24
  Starting at epoch 25
✓ DataLoader ready: 1428 batches

RESUMING TRAINING (using local storage)


Epoch 25/50: 100%|██████████| 1428/1428 [07:37<00:00,  3.12it/s, G_loss=7.787, D_loss=0.012]



  Epoch 25 complete
  Avg G Loss: 5.5681, Avg D Loss: 0.1430
  ✓ Saved checkpoint to Drive


Epoch 26/50: 100%|██████████| 1428/1428 [07:36<00:00,  3.13it/s, G_loss=7.085, D_loss=0.036]


  Epoch 26 complete
  Avg G Loss: 5.6244, Avg D Loss: 0.0634






  Validation L1 Loss: 0.0600


Epoch 27/50: 100%|██████████| 1428/1428 [07:36<00:00,  3.12it/s, G_loss=8.098, D_loss=0.019]



  Epoch 27 complete
  Avg G Loss: 5.6361, Avg D Loss: 0.0563


Epoch 28/50: 100%|██████████| 1428/1428 [07:37<00:00,  3.12it/s, G_loss=5.139, D_loss=0.059]


  Epoch 28 complete
  Avg G Loss: 5.5948, Avg D Loss: 0.0505






  Validation L1 Loss: 0.0602


Epoch 29/50: 100%|██████████| 1428/1428 [07:36<00:00,  3.13it/s, G_loss=4.438, D_loss=0.020]



  Epoch 29 complete
  Avg G Loss: 5.5855, Avg D Loss: 0.0866


Epoch 30/50: 100%|██████████| 1428/1428 [07:35<00:00,  3.13it/s, G_loss=6.537, D_loss=0.011]


  Epoch 30 complete
  Avg G Loss: 5.6172, Avg D Loss: 0.0471






  Validation L1 Loss: 0.0735
  ✓ Saved checkpoint to Drive
  ✓ Saved sample audio to Drive


Epoch 31/50: 100%|██████████| 1428/1428 [07:35<00:00,  3.14it/s, G_loss=5.447, D_loss=0.015]



  Epoch 31 complete
  Avg G Loss: 5.6211, Avg D Loss: 0.0393


Epoch 32/50: 100%|██████████| 1428/1428 [07:36<00:00,  3.13it/s, G_loss=8.033, D_loss=0.027]


  Epoch 32 complete
  Avg G Loss: 5.6402, Avg D Loss: 0.0363






  Validation L1 Loss: 0.0535


Epoch 33/50: 100%|██████████| 1428/1428 [07:36<00:00,  3.13it/s, G_loss=7.720, D_loss=0.004]



  Epoch 33 complete
  Avg G Loss: 5.6443, Avg D Loss: 0.0238


Epoch 34/50: 100%|██████████| 1428/1428 [07:36<00:00,  3.13it/s, G_loss=2.747, D_loss=0.143]


  Epoch 34 complete
  Avg G Loss: 5.6419, Avg D Loss: 0.0231






  Validation L1 Loss: 0.0576


Epoch 35/50: 100%|██████████| 1428/1428 [07:36<00:00,  3.13it/s, G_loss=4.742, D_loss=0.003]



  Epoch 35 complete
  Avg G Loss: 5.6765, Avg D Loss: 0.0233
  ✓ Saved checkpoint to Drive


Epoch 36/50: 100%|██████████| 1428/1428 [07:37<00:00,  3.12it/s, G_loss=6.640, D_loss=0.001]


  Epoch 36 complete
  Avg G Loss: 5.6775, Avg D Loss: 0.0222






  Validation L1 Loss: 0.0672


Epoch 37/50: 100%|██████████| 1428/1428 [07:36<00:00,  3.13it/s, G_loss=5.126, D_loss=0.001]



  Epoch 37 complete
  Avg G Loss: 5.6520, Avg D Loss: 0.0206


Epoch 38/50: 100%|██████████| 1428/1428 [07:36<00:00,  3.13it/s, G_loss=8.141, D_loss=0.009]


  Epoch 38 complete
  Avg G Loss: 5.6752, Avg D Loss: 0.0188






  Validation L1 Loss: 0.0533


Epoch 39/50: 100%|██████████| 1428/1428 [07:37<00:00,  3.12it/s, G_loss=6.552, D_loss=0.006]



  Epoch 39 complete
  Avg G Loss: 5.6421, Avg D Loss: 0.0227


Epoch 40/50: 100%|██████████| 1428/1428 [07:37<00:00,  3.12it/s, G_loss=6.919, D_loss=0.066]


  Epoch 40 complete
  Avg G Loss: 5.6288, Avg D Loss: 0.0295






  Validation L1 Loss: 0.0684
  ✓ Saved checkpoint to Drive
  ✓ Saved sample audio to Drive


Epoch 41/50: 100%|██████████| 1428/1428 [07:36<00:00,  3.12it/s, G_loss=7.316, D_loss=0.022]



  Epoch 41 complete
  Avg G Loss: 5.6646, Avg D Loss: 0.0192


Epoch 42/50: 100%|██████████| 1428/1428 [07:37<00:00,  3.12it/s, G_loss=5.362, D_loss=0.001]


  Epoch 42 complete
  Avg G Loss: 5.6936, Avg D Loss: 0.0141






  Validation L1 Loss: 0.0470


Epoch 43/50: 100%|██████████| 1428/1428 [07:37<00:00,  3.12it/s, G_loss=6.630, D_loss=0.006]



  Epoch 43 complete
  Avg G Loss: 5.6573, Avg D Loss: 0.0132


Epoch 44/50: 100%|██████████| 1428/1428 [07:37<00:00,  3.12it/s, G_loss=6.683, D_loss=0.017]


  Epoch 44 complete
  Avg G Loss: 5.6638, Avg D Loss: 0.0203






  Validation L1 Loss: 0.0626


Epoch 45/50: 100%|██████████| 1428/1428 [07:36<00:00,  3.13it/s, G_loss=4.240, D_loss=0.002]



  Epoch 45 complete
  Avg G Loss: 5.6578, Avg D Loss: 0.0158
  ✓ Saved checkpoint to Drive


Epoch 46/50: 100%|██████████| 1428/1428 [07:37<00:00,  3.12it/s, G_loss=5.284, D_loss=0.001]


  Epoch 46 complete
  Avg G Loss: 5.6778, Avg D Loss: 0.0174






  Validation L1 Loss: 0.0571


Epoch 47/50: 100%|██████████| 1428/1428 [07:36<00:00,  3.13it/s, G_loss=2.965, D_loss=0.032]



  Epoch 47 complete
  Avg G Loss: 5.6775, Avg D Loss: 0.0178


Epoch 48/50: 100%|██████████| 1428/1428 [07:36<00:00,  3.13it/s, G_loss=7.646, D_loss=0.021]


  Epoch 48 complete
  Avg G Loss: 5.6587, Avg D Loss: 0.0151






  Validation L1 Loss: 0.0557


Epoch 49/50: 100%|██████████| 1428/1428 [07:36<00:00,  3.12it/s, G_loss=4.180, D_loss=0.003]



  Epoch 49 complete
  Avg G Loss: 5.6572, Avg D Loss: 0.0159


Epoch 50/50: 100%|██████████| 1428/1428 [07:36<00:00,  3.13it/s, G_loss=7.558, D_loss=0.001]


  Epoch 50 complete
  Avg G Loss: 5.6738, Avg D Loss: 0.0163






  Validation L1 Loss: 0.0649
  ✓ Saved checkpoint to Drive
  ✓ Saved sample audio to Drive

TRAINING COMPLETE


In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
import os
import scipy.io.wavfile as wav
from tqdm import tqdm

# Clear GPU memory
torch.cuda.empty_cache()

# Paths
LOCAL_DIR = "/content/hifigan_local"
PROJECT_ROOT = "/content/drive/MyDrive/satere_project"
MODEL_DIR = f"{PROJECT_ROOT}/hifigan_model_v2"  # New folder for v2
os.makedirs(MODEL_DIR, exist_ok=True)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Load file lists
with open(f"{LOCAL_DIR}/train_files.txt", "r") as f:
    train_files = [line.strip() for line in f.readlines()]

with open(f"{LOCAL_DIR}/val_files.txt", "r") as f:
    val_files = [line.strip() for line in f.readlines()]

print(f"Training files: {len(train_files)}")
print(f"Validation files: {len(val_files)}")

# ============ DATASET ============
class UnitAudioDataset(Dataset):
    def __init__(self, data_dir, file_list):
        self.data_dir = data_dir
        self.files = file_list

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        name = self.files[idx]

        # Load audio
        wav_path = f"{self.data_dir}/wavs/{name}.wav"
        sr, audio = wav.read(wav_path)
        audio = audio.astype(np.float32) / 32768.0

        # Load units
        unit_path = f"{self.data_dir}/units/{name}.txt"
        with open(unit_path, 'r') as f:
            units = [int(u) for u in f.read().strip().split()]

        return torch.FloatTensor(audio), torch.LongTensor(units)

# ============ GENERATOR (same architecture) ============
class Generator(nn.Module):
    def __init__(self, num_units=100, embed_dim=256):
        super().__init__()
        self.unit_embed = nn.Embedding(num_units, embed_dim)
        self.pre_conv = nn.Conv1d(embed_dim, 512, 7, padding=3)

        self.upsamples = nn.ModuleList([
            nn.Sequential(
                nn.ConvTranspose1d(512, 256, 10, stride=5, padding=2),
                nn.LeakyReLU(0.1),
                nn.Conv1d(256, 256, 7, padding=3),
                nn.LeakyReLU(0.1)
            ),
            nn.Sequential(
                nn.ConvTranspose1d(256, 128, 10, stride=5, padding=2),
                nn.LeakyReLU(0.1),
                nn.Conv1d(128, 128, 7, padding=3),
                nn.LeakyReLU(0.1)
            ),
            nn.Sequential(
                nn.ConvTranspose1d(128, 64, 8, stride=4, padding=2),
                nn.LeakyReLU(0.1),
                nn.Conv1d(64, 64, 7, padding=3),
                nn.LeakyReLU(0.1)
            ),
            nn.Sequential(
                nn.ConvTranspose1d(64, 32, 8, stride=4, padding=2),
                nn.LeakyReLU(0.1),
                nn.Conv1d(32, 32, 7, padding=3),
                nn.LeakyReLU(0.1)
            ),
            nn.Sequential(
                nn.ConvTranspose1d(32, 64, 4, stride=2, padding=1),
                nn.LeakyReLU(0.1),
                nn.Conv1d(64, 64, 7, padding=3),
                nn.LeakyReLU(0.1)
            ),
        ])

        self.post_conv = nn.Conv1d(64, 1, 7, padding=3)

    def forward(self, x):
        x = self.unit_embed(x).transpose(1, 2)
        x = self.pre_conv(x)
        x = F.leaky_relu(x, 0.1)

        for upsample in self.upsamples:
            x = upsample(x)

        x = self.post_conv(x)
        x = torch.tanh(x)
        return x.squeeze(1)

# ============ DISCRIMINATOR (simplified) ============
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.convs = nn.ModuleList([
            nn.Conv1d(1, 16, 15, stride=1, padding=7),
            nn.Conv1d(16, 64, 41, stride=4, padding=20, groups=4),
            nn.Conv1d(64, 256, 41, stride=4, padding=20, groups=16),
            nn.Conv1d(256, 512, 41, stride=4, padding=20, groups=64),
            nn.Conv1d(512, 512, 41, stride=4, padding=20, groups=64),
            nn.Conv1d(512, 512, 5, stride=1, padding=2),
        ])
        self.conv_post = nn.Conv1d(512, 1, 3, padding=1)

    def forward(self, x):
        x = x.unsqueeze(1)
        fmaps = []
        for conv in self.convs:
            x = conv(x)
            x = F.leaky_relu(x, 0.1)
            fmaps.append(x)
        x = self.conv_post(x)
        return x, fmaps

# ============ TRAINING SETUP ============
generator = Generator().to(device)
discriminator = Discriminator().to(device)

# KEY FIX 1: Lower discriminator learning rate
optimizer_g = torch.optim.AdamW(generator.parameters(), lr=0.0002, betas=(0.8, 0.99))
optimizer_d = torch.optim.AdamW(discriminator.parameters(), lr=0.00005, betas=(0.8, 0.99))  # 4x lower

# Learning rate schedulers
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optimizer_g, gamma=0.999)
scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optimizer_d, gamma=0.999)

# Dataset and DataLoader
train_dataset = UnitAudioDataset(LOCAL_DIR, train_files)
val_dataset = UnitAudioDataset(LOCAL_DIR, val_files)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, num_workers=2)

print(f"✓ DataLoader ready: {len(train_loader)} batches")

# ============ LOSS FUNCTIONS ============
def feature_loss(real_fmaps, fake_fmaps):
    loss = 0
    for real_fm, fake_fm in zip(real_fmaps, fake_fmaps):
        loss += F.l1_loss(fake_fm, real_fm.detach())
    return loss

def discriminator_loss(real_out, fake_out):
    real_loss = torch.mean((1 - real_out) ** 2)
    fake_loss = torch.mean(fake_out ** 2)
    return real_loss + fake_loss

def generator_adv_loss(fake_out):
    return torch.mean((1 - fake_out) ** 2)

# ============ TRAINING LOOP ============
NUM_EPOCHS = 100  # More epochs
SAVE_EVERY = 10
SAMPLE_EVERY = 10

# KEY FIX 2: L1 loss weight (high at start, decreases over time)
L1_WEIGHT_START = 45.0
L1_WEIGHT_END = 10.0
FM_WEIGHT = 2.0

# KEY FIX 3: Discriminator warmup - only L1 loss for first N epochs
D_WARMUP_EPOCHS = 10

print("\n" + "="*60)
print("STARTING TRAINING (v2 - Fixed Hyperparameters)")
print("="*60)
print(f"  Discriminator warmup: {D_WARMUP_EPOCHS} epochs")
print(f"  L1 weight: {L1_WEIGHT_START} -> {L1_WEIGHT_END}")
print(f"  D learning rate: 0.00005 (4x lower than G)")
print("="*60 + "\n")

for epoch in range(1, NUM_EPOCHS + 1):
    generator.train()
    discriminator.train()

    # Compute L1 weight for this epoch (linear decay)
    l1_weight = L1_WEIGHT_START - (L1_WEIGHT_START - L1_WEIGHT_END) * (epoch - 1) / (NUM_EPOCHS - 1)

    epoch_g_loss = 0
    epoch_d_loss = 0
    epoch_l1_loss = 0

    pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{NUM_EPOCHS}")

    for batch_idx, (real_audio, units) in enumerate(pbar):
        real_audio = real_audio.to(device)
        units = units.to(device)

        # Generate fake audio
        fake_audio = generator(units)

        # Match lengths
        min_len = min(real_audio.shape[1], fake_audio.shape[1])
        real_audio = real_audio[:, :min_len]
        fake_audio = fake_audio[:, :min_len]

        # L1 loss (always computed)
        l1_loss = F.l1_loss(fake_audio, real_audio)

        # ============ DISCRIMINATOR UPDATE ============
        if epoch > D_WARMUP_EPOCHS:
            optimizer_d.zero_grad()

            real_out, real_fmaps = discriminator(real_audio)
            fake_out, _ = discriminator(fake_audio.detach())

            d_loss = discriminator_loss(real_out, fake_out)
            d_loss.backward()
            optimizer_d.step()

            epoch_d_loss += d_loss.item()
        else:
            d_loss = torch.tensor(0.0)

        # ============ GENERATOR UPDATE ============
        optimizer_g.zero_grad()

        if epoch > D_WARMUP_EPOCHS:
            fake_out, fake_fmaps = discriminator(fake_audio)
            _, real_fmaps = discriminator(real_audio)

            g_adv_loss = generator_adv_loss(fake_out)
            fm_loss = feature_loss(real_fmaps, fake_fmaps)

            g_loss = g_adv_loss + FM_WEIGHT * fm_loss + l1_weight * l1_loss
        else:
            # Warmup: only L1 loss
            g_loss = l1_weight * l1_loss

        g_loss.backward()
        optimizer_g.step()

        epoch_g_loss += g_loss.item()
        epoch_l1_loss += l1_loss.item()

        pbar.set_postfix({
            'G': f'{g_loss.item():.3f}',
            'D': f'{d_loss.item():.3f}',
            'L1': f'{l1_loss.item():.4f}'
        })

    # Update learning rates
    scheduler_g.step()
    scheduler_d.step()

    # Epoch stats
    avg_g = epoch_g_loss / len(train_loader)
    avg_d = epoch_d_loss / len(train_loader) if epoch > D_WARMUP_EPOCHS else 0
    avg_l1 = epoch_l1_loss / len(train_loader)

    print(f"\n  Epoch {epoch}: G={avg_g:.4f}, D={avg_d:.4f}, L1={avg_l1:.4f}, L1_weight={l1_weight:.1f}")

    # Validation
    if epoch % 2 == 0:
        generator.eval()
        val_l1 = 0
        with torch.no_grad():
            for real_audio, units in val_loader:
                real_audio = real_audio.to(device)
                units = units.to(device)
                fake_audio = generator(units)
                min_len = min(real_audio.shape[1], fake_audio.shape[1])
                val_l1 += F.l1_loss(fake_audio[:, :min_len], real_audio[:, :min_len]).item()
        val_l1 /= len(val_loader)
        print(f"  Validation L1: {val_l1:.4f}")

    # Save checkpoint
    if epoch % SAVE_EVERY == 0:
        checkpoint = {
            'epoch': epoch,
            'generator_state_dict': generator.state_dict(),
            'discriminator_state_dict': discriminator.state_dict(),
            'optimizer_g_state_dict': optimizer_g.state_dict(),
            'optimizer_d_state_dict': optimizer_d.state_dict(),
        }
        torch.save(checkpoint, f"{MODEL_DIR}/checkpoint_epoch{epoch}.pt")
        print(f"  ✓ Saved checkpoint")

    # Save sample audio
    if epoch % SAMPLE_EVERY == 0:
        generator.eval()
        with torch.no_grad():
            sample_audio, sample_units = val_dataset[0]
            sample_units = sample_units.unsqueeze(0).to(device)
            generated = generator(sample_units).squeeze().cpu().numpy()
            generated = (generated * 32767).astype(np.int16)
            wav.write(f"{MODEL_DIR}/sample_epoch{epoch}.wav", 16000, generated)
        print(f"  ✓ Saved sample audio")

print("\n" + "="*60)
print("TRAINING COMPLETE (v2)")
print("="*60)

Using device: cuda


FileNotFoundError: [Errno 2] No such file or directory: '/content/hifigan_local/train_files.txt'

In [None]:
import os

# Check if drive is mounted
if os.path.exists('/content/drive/MyDrive'):
    print("Drive is mounted")

    # Check satere_project folder
    satere_path = '/content/drive/MyDrive/satere_project'
    if os.path.exists(satere_path):
        print(f"\nContents of {satere_path}:")
        for item in os.listdir(satere_path):
            print(f"  {item}")
    else:
        print(f"\n{satere_path} does not exist")
        print("\nContents of MyDrive:")
        for item in os.listdir('/content/drive/MyDrive')[:20]:
            print(f"  {item}")
else:
    print("Drive is NOT mounted. Run: drive.mount('/content/drive')")

Drive is mounted

Contents of /content/drive/MyDrive/satere_project:
  hifigan_model_v2


In [None]:
import os

# Search for hifigan_training
for root, dirs, files in os.walk('/content/drive/MyDrive'):
    if 'hifigan_training' in dirs:
        full_path = os.path.join(root, 'hifigan_training')
        print(f"Found: {full_path}")
        break

In [None]:
import os

# List everything in satere_project
path = '/content/drive/MyDrive/satere_project'
print("Contents of satere_project:")
for item in os.listdir(path):
    print(f"  {item}")

Contents of satere_project:
  hifigan_model_v2


In [None]:
from google.colab import drive
drive.flush_and_unmount()
drive.mount('/content/drive')

Drive not mounted, so nothing to flush and unmount.


ValueError: Mountpoint must not already contain files

In [None]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

ValueError: Mountpoint must not already contain files

In [None]:
from google.colab import drive
drive.mount('/content/drive')

ValueError: Mountpoint must not already contain files

In [None]:
import os
import shutil

# Remove the existing mount point
if os.path.exists('/content/drive'):
    shutil.rmtree('/content/drive')

# Now mount
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import os
print(os.listdir('/content/drive/MyDrive/satere_project'))

['raw_audio', 'converted_audio', 'satere_units', 'satere_motifs', 'acts12.mp3', 'phase2_output', 'soundscript_full', 'soundscript_deploy', 'soundscript_deploy.zip', 'vocoder_output', 'hifigan_training', 'hifigan_model']


In [None]:
import shutil
import os

LOCAL_DIR = "/content/hifigan_local"
DRIVE_DIR = "/content/drive/MyDrive/satere_project/hifigan_training"

if os.path.exists(LOCAL_DIR):
    shutil.rmtree(LOCAL_DIR)

print("Copying training data to local storage...")
shutil.copytree(DRIVE_DIR, LOCAL_DIR)
print(f"✓ Copied to {LOCAL_DIR}")

print(f"\nLocal files: {len(os.listdir(f'{LOCAL_DIR}/units'))} unit files")

Copying training data to local storage...


KeyboardInterrupt: 

In [None]:
import os
import shutil
from concurrent.futures import ThreadPoolExecutor

LOCAL_DIR = "/content/hifigan_local"
DRIVE_DIR = "/content/drive/MyDrive/satere_project/hifigan_training"

# Create directories
os.makedirs(f"{LOCAL_DIR}/wavs", exist_ok=True)
os.makedirs(f"{LOCAL_DIR}/units", exist_ok=True)

# Copy text files first
shutil.copy(f"{DRIVE_DIR}/train_files.txt", LOCAL_DIR)
shutil.copy(f"{DRIVE_DIR}/val_files.txt", LOCAL_DIR)
print("✓ Copied file lists")

# Get file lists
wav_files = os.listdir(f"{DRIVE_DIR}/wavs")
unit_files = os.listdir(f"{DRIVE_DIR}/units")
print(f"Files to copy: {len(wav_files)} wavs, {len(unit_files)} units")

# Copy function
def copy_file(args):
    src, dst = args
    shutil.copy(src, dst)

# Copy wavs in parallel
print("Copying wavs...")
wav_args = [(f"{DRIVE_DIR}/wavs/{f}", f"{LOCAL_DIR}/wavs/{f}") for f in wav_files]
with ThreadPoolExecutor(max_workers=8) as executor:
    list(executor.map(copy_file, wav_args))
print(f"✓ Copied {len(wav_files)} wav files")

# Copy units in parallel
print("Copying units...")
unit_args = [(f"{DRIVE_DIR}/units/{f}", f"{LOCAL_DIR}/units/{f}") for f in unit_files]
with ThreadPoolExecutor(max_workers=8) as executor:
    list(executor.map(copy_file, unit_args))
print(f"✓ Copied {len(unit_files)} unit files")

print("\n✓ Done!")

✓ Copied file lists
Files to copy: 22934 wavs, 22934 units
Copying wavs...
✓ Copied 22934 wav files
Copying units...
✓ Copied 22934 unit files

✓ Done!


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
import os
import scipy.io.wavfile as wav
from tqdm import tqdm

# Clear GPU memory
torch.cuda.empty_cache()

# Paths
LOCAL_DIR = "/content/hifigan_local"
PROJECT_ROOT = "/content/drive/MyDrive/satere_project"
MODEL_DIR = f"{PROJECT_ROOT}/hifigan_model_v2"  # New folder for v2
os.makedirs(MODEL_DIR, exist_ok=True)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Load file lists
with open(f"{LOCAL_DIR}/train_files.txt", "r") as f:
    train_files = [line.strip() for line in f.readlines()]

with open(f"{LOCAL_DIR}/val_files.txt", "r") as f:
    val_files = [line.strip() for line in f.readlines()]

print(f"Training files: {len(train_files)}")
print(f"Validation files: {len(val_files)}")

# ============ DATASET ============
class UnitAudioDataset(Dataset):
    def __init__(self, data_dir, file_list):
        self.data_dir = data_dir
        self.files = file_list

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        name = self.files[idx]

        # Load audio
        wav_path = f"{self.data_dir}/wavs/{name}.wav"
        sr, audio = wav.read(wav_path)
        audio = audio.astype(np.float32) / 32768.0

        # Load units
        unit_path = f"{self.data_dir}/units/{name}.txt"
        with open(unit_path, 'r') as f:
            units = [int(u) for u in f.read().strip().split()]

        return torch.FloatTensor(audio), torch.LongTensor(units)

# ============ GENERATOR (same architecture) ============
class Generator(nn.Module):
    def __init__(self, num_units=100, embed_dim=256):
        super().__init__()
        self.unit_embed = nn.Embedding(num_units, embed_dim)
        self.pre_conv = nn.Conv1d(embed_dim, 512, 7, padding=3)

        self.upsamples = nn.ModuleList([
            nn.Sequential(
                nn.ConvTranspose1d(512, 256, 10, stride=5, padding=2),
                nn.LeakyReLU(0.1),
                nn.Conv1d(256, 256, 7, padding=3),
                nn.LeakyReLU(0.1)
            ),
            nn.Sequential(
                nn.ConvTranspose1d(256, 128, 10, stride=5, padding=2),
                nn.LeakyReLU(0.1),
                nn.Conv1d(128, 128, 7, padding=3),
                nn.LeakyReLU(0.1)
            ),
            nn.Sequential(
                nn.ConvTranspose1d(128, 64, 8, stride=4, padding=2),
                nn.LeakyReLU(0.1),
                nn.Conv1d(64, 64, 7, padding=3),
                nn.LeakyReLU(0.1)
            ),
            nn.Sequential(
                nn.ConvTranspose1d(64, 32, 8, stride=4, padding=2),
                nn.LeakyReLU(0.1),
                nn.Conv1d(32, 32, 7, padding=3),
                nn.LeakyReLU(0.1)
            ),
            nn.Sequential(
                nn.ConvTranspose1d(32, 64, 4, stride=2, padding=1),
                nn.LeakyReLU(0.1),
                nn.Conv1d(64, 64, 7, padding=3),
                nn.LeakyReLU(0.1)
            ),
        ])

        self.post_conv = nn.Conv1d(64, 1, 7, padding=3)

    def forward(self, x):
        x = self.unit_embed(x).transpose(1, 2)
        x = self.pre_conv(x)
        x = F.leaky_relu(x, 0.1)

        for upsample in self.upsamples:
            x = upsample(x)

        x = self.post_conv(x)
        x = torch.tanh(x)
        return x.squeeze(1)

# ============ DISCRIMINATOR (simplified) ============
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.convs = nn.ModuleList([
            nn.Conv1d(1, 16, 15, stride=1, padding=7),
            nn.Conv1d(16, 64, 41, stride=4, padding=20, groups=4),
            nn.Conv1d(64, 256, 41, stride=4, padding=20, groups=16),
            nn.Conv1d(256, 512, 41, stride=4, padding=20, groups=64),
            nn.Conv1d(512, 512, 41, stride=4, padding=20, groups=64),
            nn.Conv1d(512, 512, 5, stride=1, padding=2),
        ])
        self.conv_post = nn.Conv1d(512, 1, 3, padding=1)

    def forward(self, x):
        x = x.unsqueeze(1)
        fmaps = []
        for conv in self.convs:
            x = conv(x)
            x = F.leaky_relu(x, 0.1)
            fmaps.append(x)
        x = self.conv_post(x)
        return x, fmaps

# ============ TRAINING SETUP ============
generator = Generator().to(device)
discriminator = Discriminator().to(device)

# KEY FIX 1: Lower discriminator learning rate
optimizer_g = torch.optim.AdamW(generator.parameters(), lr=0.0002, betas=(0.8, 0.99))
optimizer_d = torch.optim.AdamW(discriminator.parameters(), lr=0.00005, betas=(0.8, 0.99))  # 4x lower

# Learning rate schedulers
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optimizer_g, gamma=0.999)
scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optimizer_d, gamma=0.999)

# Dataset and DataLoader
train_dataset = UnitAudioDataset(LOCAL_DIR, train_files)
val_dataset = UnitAudioDataset(LOCAL_DIR, val_files)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, num_workers=2)

print(f"✓ DataLoader ready: {len(train_loader)} batches")

# ============ LOSS FUNCTIONS ============
def feature_loss(real_fmaps, fake_fmaps):
    loss = 0
    for real_fm, fake_fm in zip(real_fmaps, fake_fmaps):
        loss += F.l1_loss(fake_fm, real_fm.detach())
    return loss

def discriminator_loss(real_out, fake_out):
    real_loss = torch.mean((1 - real_out) ** 2)
    fake_loss = torch.mean(fake_out ** 2)
    return real_loss + fake_loss

def generator_adv_loss(fake_out):
    return torch.mean((1 - fake_out) ** 2)

# ============ TRAINING LOOP ============
NUM_EPOCHS = 100  # More epochs
SAVE_EVERY = 10
SAMPLE_EVERY = 10

# KEY FIX 2: L1 loss weight (high at start, decreases over time)
L1_WEIGHT_START = 45.0
L1_WEIGHT_END = 10.0
FM_WEIGHT = 2.0

# KEY FIX 3: Discriminator warmup - only L1 loss for first N epochs
D_WARMUP_EPOCHS = 10

print("\n" + "="*60)
print("STARTING TRAINING (v2 - Fixed Hyperparameters)")
print("="*60)
print(f"  Discriminator warmup: {D_WARMUP_EPOCHS} epochs")
print(f"  L1 weight: {L1_WEIGHT_START} -> {L1_WEIGHT_END}")
print(f"  D learning rate: 0.00005 (4x lower than G)")
print("="*60 + "\n")

for epoch in range(1, NUM_EPOCHS + 1):
    generator.train()
    discriminator.train()

    # Compute L1 weight for this epoch (linear decay)
    l1_weight = L1_WEIGHT_START - (L1_WEIGHT_START - L1_WEIGHT_END) * (epoch - 1) / (NUM_EPOCHS - 1)

    epoch_g_loss = 0
    epoch_d_loss = 0
    epoch_l1_loss = 0

    pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{NUM_EPOCHS}")

    for batch_idx, (real_audio, units) in enumerate(pbar):
        real_audio = real_audio.to(device)
        units = units.to(device)

        # Generate fake audio
        fake_audio = generator(units)

        # Match lengths
        min_len = min(real_audio.shape[1], fake_audio.shape[1])
        real_audio = real_audio[:, :min_len]
        fake_audio = fake_audio[:, :min_len]

        # L1 loss (always computed)
        l1_loss = F.l1_loss(fake_audio, real_audio)

        # ============ DISCRIMINATOR UPDATE ============
        if epoch > D_WARMUP_EPOCHS:
            optimizer_d.zero_grad()

            real_out, real_fmaps = discriminator(real_audio)
            fake_out, _ = discriminator(fake_audio.detach())

            d_loss = discriminator_loss(real_out, fake_out)
            d_loss.backward()
            optimizer_d.step()

            epoch_d_loss += d_loss.item()
        else:
            d_loss = torch.tensor(0.0)

        # ============ GENERATOR UPDATE ============
        optimizer_g.zero_grad()

        if epoch > D_WARMUP_EPOCHS:
            fake_out, fake_fmaps = discriminator(fake_audio)
            _, real_fmaps = discriminator(real_audio)

            g_adv_loss = generator_adv_loss(fake_out)
            fm_loss = feature_loss(real_fmaps, fake_fmaps)

            g_loss = g_adv_loss + FM_WEIGHT * fm_loss + l1_weight * l1_loss
        else:
            # Warmup: only L1 loss
            g_loss = l1_weight * l1_loss

        g_loss.backward()
        optimizer_g.step()

        epoch_g_loss += g_loss.item()
        epoch_l1_loss += l1_loss.item()

        pbar.set_postfix({
            'G': f'{g_loss.item():.3f}',
            'D': f'{d_loss.item():.3f}',
            'L1': f'{l1_loss.item():.4f}'
        })

    # Update learning rates
    scheduler_g.step()
    scheduler_d.step()

    # Epoch stats
    avg_g = epoch_g_loss / len(train_loader)
    avg_d = epoch_d_loss / len(train_loader) if epoch > D_WARMUP_EPOCHS else 0
    avg_l1 = epoch_l1_loss / len(train_loader)

    print(f"\n  Epoch {epoch}: G={avg_g:.4f}, D={avg_d:.4f}, L1={avg_l1:.4f}, L1_weight={l1_weight:.1f}")

    # Validation
    if epoch % 2 == 0:
        generator.eval()
        val_l1 = 0
        with torch.no_grad():
            for real_audio, units in val_loader:
                real_audio = real_audio.to(device)
                units = units.to(device)
                fake_audio = generator(units)
                min_len = min(real_audio.shape[1], fake_audio.shape[1])
                val_l1 += F.l1_loss(fake_audio[:, :min_len], real_audio[:, :min_len]).item()
        val_l1 /= len(val_loader)
        print(f"  Validation L1: {val_l1:.4f}")

    # Save checkpoint
    if epoch % SAVE_EVERY == 0:
        checkpoint = {
            'epoch': epoch,
            'generator_state_dict': generator.state_dict(),
            'discriminator_state_dict': discriminator.state_dict(),
            'optimizer_g_state_dict': optimizer_g.state_dict(),
            'optimizer_d_state_dict': optimizer_d.state_dict(),
        }
        torch.save(checkpoint, f"{MODEL_DIR}/checkpoint_epoch{epoch}.pt")
        print(f"  ✓ Saved checkpoint")

    # Save sample audio
    if epoch % SAMPLE_EVERY == 0:
        generator.eval()
        with torch.no_grad():
            sample_audio, sample_units = val_dataset[0]
            sample_units = sample_units.unsqueeze(0).to(device)
            generated = generator(sample_units).squeeze().cpu().numpy()
            generated = (generated * 32767).astype(np.int16)
            wav.write(f"{MODEL_DIR}/sample_epoch{epoch}.wav", 16000, generated)
        print(f"  ✓ Saved sample audio")

print("\n" + "="*60)
print("TRAINING COMPLETE (v2)")
print("="*60)

Using device: cuda


FileNotFoundError: [Errno 2] No such file or directory: '/content/hifigan_local/train_files.txt'

In [None]:
import os

# Check what's in the local folder
print(os.listdir('/content/hifigan_local'))

FileNotFoundError: [Errno 2] No such file or directory: '/content/hifigan_local'