In [None]:
!pip install -q timm torchaudio librosa

from google.colab import drive
drive.mount('/content/drive', force_remount=True)

import os, random, warnings
import numpy as np, pandas as pd
import torch, torchaudio, torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from torch.cuda.amp import autocast, GradScaler
from glob import glob
from tqdm.notebook import tqdm
from sklearn.model_selection import StratifiedKFold
import timm
warnings.filterwarnings('ignore')

# Configuration
ROOT = '/content/drive/MyDrive/the-frequency-quest'
TRAIN_DIR, TEST_DIR = f'{ROOT}/train/train', f'{ROOT}/test/test'
CACHE_DIR = '/content/mel_cache_optimized'
os.makedirs(CACHE_DIR, exist_ok=True)

CLASSES = ['dog_bark','drilling','engine_idling','siren','street_music']
C2I = {c:i for i,c in enumerate(CLASSES)}
I2C = {i:c for c,i in C2I.items()}

# Optimized audio parameters
SR, DURATION = 32000, 5.0
N_MELS, N_FFT, HOP = 256, 2048, 320
SAMPLES = int(SR * DURATION)

# Training parameters
BATCH_SIZE, IMG_SIZE = 24, 384
MODEL_NAME = 'tf_efficientnetv2_m.in21k_ft_in1k'
EPOCHS, LR = 22, 1.8e-4
MIXUP_ALPHA, CUTMIX_ALPHA = 0.7, 0.6
LABEL_SMOOTH = 0.12
TTA_COUNT = 12
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
SEED = 42

def seed_everything(seed=SEED):
    random.seed(seed); np.random.seed(seed)
    torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic, torch.backends.cudnn.benchmark = True, False
seed_everything()

# Enhanced mel extraction
def extract_mel(path):
    wav, sr = torchaudio.load(path)
    if wav.shape[0] > 1: wav = wav.mean(0, keepdim=True)
    if sr != SR: wav = torchaudio.transforms.Resample(sr, SR)(wav)
    if wav.shape[1] < SAMPLES: wav = torch.nn.functional.pad(wav, (0, SAMPLES-wav.shape[1]))
    else: wav = wav[:, :SAMPLES]

    mel = torchaudio.transforms.MelSpectrogram(SR, N_FFT, hop_length=HOP, n_mels=N_MELS,
                                                f_min=30, f_max=SR//2, power=2.0)(wav)
    mel_db = torchaudio.transforms.AmplitudeToDB(top_db=80)(mel)
    mel_db = (mel_db - mel_db.mean()) / (mel_db.std() + 1e-6)
    return torch.clamp(mel_db, -8, 8).squeeze(0).numpy()

def cache_mels(files, prefix):
    cached = []
    for f in tqdm(files, desc=f'Cache {prefix}'):
        out = os.path.join(CACHE_DIR, os.path.basename(f).replace('.wav','.npy'))
        if not os.path.exists(out): np.save(out, extract_mel(f))
        cached.append(out)
    return cached

# Load data
train_files, train_labels = [], []
for cls in CLASSES:
    files = sorted(glob(f"{TRAIN_DIR}/{cls}/*.wav"))
    train_files += files
    train_labels += [C2I[cls]] * len(files)

# Stratified split
from sklearn.model_selection import train_test_split
tr_files, vl_files, tr_labels, vl_labels = train_test_split(
    train_files, train_labels, test_size=0.075, stratify=train_labels, random_state=SEED
)
ts_files = sorted(glob(f"{TEST_DIR}/*.wav"))

print(f"Train: {len(tr_files)} | Val: {len(vl_files)}")

# Cache spectrograms
tr_cache = cache_mels(tr_files, 'train')
vl_cache = cache_mels(vl_files, 'val')
ts_cache = cache_mels(ts_files, 'test')

# Efficient augmentation
def spec_aug(mel, freq_m=3, time_m=3):
    h, w = mel.shape
    mel = mel.copy()
    for _ in range(freq_m):
        f, f0 = random.randint(12, 40), random.randint(0, max(0, h-40))
        mel[f0:f0+f, :] = mel.mean()
    for _ in range(time_m):
        t, t0 = random.randint(15, 50), random.randint(0, max(0, w-50))
        mel[:, t0:t0+t] = mel.mean()
    return mel

def aug_pipeline(mel):
    if random.random() < 0.8: mel = spec_aug(mel, random.randint(2,4), random.randint(2,4))
    if random.random() < 0.5: mel = mel + np.random.randn(*mel.shape) * 0.018
    if random.random() < 0.5: mel = mel * np.random.uniform(0.75, 1.25)
    if random.random() < 0.45: mel = mel[:, np.random.permutation(mel.shape[1])] if random.random()<0.3 else mel
    return mel

class AudioDataset(Dataset):
    def __init__(self, paths, labels=None, augment=False):
        self.paths, self.labels, self.augment = paths, labels, augment
    def __len__(self): return len(self.paths)
    def __getitem__(self, idx):
        mel = np.load(self.paths[idx])
        if self.augment: mel = aug_pipeline(mel)
        mel3 = torch.tensor(np.stack([mel]*3).astype(np.float32))
        mel3 = torch.nn.functional.interpolate(mel3.unsqueeze(0), (IMG_SIZE, IMG_SIZE),
                                               mode='bicubic', align_corners=False).squeeze(0)
        if self.labels is not None: return mel3, torch.tensor(self.labels[idx], dtype=torch.long)
        return mel3, os.path.basename(self.paths[idx]).replace('.npy','.wav')

def mixup(x, y, alpha=MIXUP_ALPHA):
    lam = np.random.beta(alpha, alpha) if alpha > 0 else 1
    idx = torch.randperm(x.size(0)).to(x.device)
    return lam*x + (1-lam)*x[idx], y, y[idx], lam

def cutmix(x, y, alpha=CUTMIX_ALPHA):
    lam = np.random.beta(alpha, alpha) if alpha > 0 else 1
    idx = torch.randperm(x.size(0)).to(x.device)
    _, _, H, W = x.shape
    cut_h, cut_w = int(H*np.sqrt(1-lam)), int(W*np.sqrt(1-lam))
    cx, cy = np.random.randint(W), np.random.randint(H)
    x1, y1 = np.clip(cx-cut_w//2, 0, W), np.clip(cy-cut_h//2, 0, H)
    x2, y2 = np.clip(cx+cut_w//2, 0, W), np.clip(cy+cut_h//2, 0, H)
    x[:,:,y1:y2,x1:x2] = x[idx,:,y1:y2,x1:x2]
    lam = 1 - ((x2-x1)*(y2-y1)/(W*H))
    return x, y, y[idx], lam

# Weighted sampling
class_counts = np.bincount(tr_labels, minlength=len(CLASSES))
sample_weights = [1.0/(class_counts[l]+1e-9) for l in tr_labels]
sampler = torch.utils.data.WeightedRandomSampler(sample_weights, len(sample_weights), True)

# DataLoaders
train_ds = AudioDataset(tr_cache, tr_labels, augment=True)
val_ds = AudioDataset(vl_cache, vl_labels, augment=False)
test_ds = AudioDataset(ts_cache, augment=False)

train_dl = DataLoader(train_ds, BATCH_SIZE, sampler=sampler, num_workers=4, pin_memory=True)
val_dl = DataLoader(val_ds, BATCH_SIZE, shuffle=False, num_workers=4, pin_memory=True)
test_dl = DataLoader(test_ds, BATCH_SIZE, shuffle=False, num_workers=4, pin_memory=True)

# Model
model = timm.create_model(MODEL_NAME, pretrained=True, num_classes=len(CLASSES),
                          drop_rate=0.35, drop_path_rate=0.25).to(DEVICE)

optimizer = AdamW(model.parameters(), lr=LR, weight_decay=4e-3)
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, LR, epochs=EPOCHS,
                                                 steps_per_epoch=len(train_dl), pct_start=0.1)
criterion = nn.CrossEntropyLoss(label_smoothing=LABEL_SMOOTH)
scaler = GradScaler()

best_acc, best_path = 0.0, 'best_model_opt.pth'

print(f"\n{'='*60}\nTraining Started\n{'='*60}\n")

for epoch in range(1, EPOCHS+1):
    model.train()
    total_loss, correct, total = 0, 0, 0

    for xb, yb in tqdm(train_dl, desc=f"Epoch {epoch}/{EPOCHS}"):
        xb, yb = xb.to(DEVICE), yb.to(DEVICE)

        # Apply augmentation
        if random.random() < 0.5: xb, ya, yb, lam = mixup(xb, yb)
        elif random.random() < 0.75: xb, ya, yb, lam = cutmix(xb, yb)
        else: ya, lam = yb, 1.0

        optimizer.zero_grad()
        with autocast():
            pred = model(xb)
            loss = lam*criterion(pred, ya) + (1-lam)*criterion(pred, yb)

        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        scaler.step(optimizer)
        scaler.update()
        scheduler.step()

        total_loss += loss.item()*xb.size(0)
        correct += (pred.argmax(1) == yb).sum().item()
        total += xb.size(0)

    train_acc = correct/total

    # Validation
    model.eval()
    val_correct, val_total = 0, 0
    with torch.no_grad():
        for xb, yb in tqdm(val_dl, desc='Validating'):
            xb, yb = xb.to(DEVICE), yb.to(DEVICE)
            with autocast(): pred = model(xb)
            val_correct += (pred.argmax(1) == yb).sum().item()
            val_total += xb.size(0)

    val_acc = val_correct/val_total
    print(f"Epoch {epoch} | Train: {train_acc*100:.2f}% | Val: {val_acc*100:.4f}%")

    if val_acc > best_acc:
        best_acc = val_acc
        torch.save(model.state_dict(), best_path)
        print(f"✓ Best: {val_acc*100:.4f}%\n")

print(f"\n{'='*60}\nBest Val: {best_acc*100:.4f}%\n{'='*60}\n")

# Multi-level TTA inference
model.load_state_dict(torch.load(best_path))
model.eval()

all_preds = []
with torch.no_grad():
    for xb, names in tqdm(test_dl, desc='Predicting'):
        xb = xb.to(DEVICE)
        tta_logits = torch.zeros(xb.size(0), len(CLASSES)).to(DEVICE)

        for i in range(TTA_COUNT):
            xb_aug = xb.clone()
            if i > 0:
                if random.random() < 0.6: xb_aug += torch.randn_like(xb_aug) * 0.01
                if random.random() < 0.5: xb_aug *= np.random.uniform(0.92, 1.08)
            with autocast():
                tta_logits += torch.softmax(model(xb_aug), dim=1)

        tta_logits /= TTA_COUNT
        preds = tta_logits.argmax(1).cpu().numpy()

        for name, pred in zip(names, preds):
            all_preds.append((name, I2C[int(pred)]))

submission = pd.DataFrame(all_preds, columns=['ID','Class'])
submission.to_csv('submission_optimized.csv', index=False)

print(f"\n✓ Submission saved | Total: {len(submission)}")
print(f"\nDistribution:\n{submission['Class'].value_counts()}")

from google.colab import files
files.download('submission_optimized.csv')


Mounted at /content/drive
Train: 3191 | Val: 259


Cache train:   0%|          | 0/3191 [00:00<?, ?it/s]

Cache val:   0%|          | 0/259 [00:00<?, ?it/s]

Cache test:   0%|          | 0/740 [00:00<?, ?it/s]

model.safetensors:   0%|          | 0.00/218M [00:00<?, ?B/s]


Training Started



Epoch 1/22:   0%|          | 0/133 [00:00<?, ?it/s]

Validating:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 1 | Train: 27.08% | Val: 69.4981%
✓ Best: 69.4981%



Epoch 2/22:   0%|          | 0/133 [00:00<?, ?it/s]

Validating:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 2 | Train: 38.42% | Val: 80.3089%
✓ Best: 80.3089%



Epoch 3/22:   0%|          | 0/133 [00:00<?, ?it/s]

Validating:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 3 | Train: 46.73% | Val: 93.4363%
✓ Best: 93.4363%



Epoch 4/22:   0%|          | 0/133 [00:00<?, ?it/s]

Validating:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 4 | Train: 51.11% | Val: 92.6641%


Epoch 5/22:   0%|          | 0/133 [00:00<?, ?it/s]

Validating:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 5 | Train: 49.86% | Val: 96.1390%
✓ Best: 96.1390%



Epoch 6/22:   0%|          | 0/133 [00:00<?, ?it/s]

Validating:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 6 | Train: 51.61% | Val: 95.7529%


Epoch 7/22:   0%|          | 0/133 [00:00<?, ?it/s]

Validating:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 7 | Train: 55.72% | Val: 98.8417%
✓ Best: 98.8417%



Epoch 8/22:   0%|          | 0/133 [00:00<?, ?it/s]

Validating:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 8 | Train: 52.90% | Val: 97.2973%


Epoch 9/22:   0%|          | 0/133 [00:00<?, ?it/s]

Validating:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 9 | Train: 53.81% | Val: 98.0695%


Epoch 10/22:   0%|          | 0/133 [00:00<?, ?it/s]

Validating:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 10 | Train: 56.69% | Val: 99.6139%
✓ Best: 99.6139%



Epoch 11/22:   0%|          | 0/133 [00:00<?, ?it/s]

Validating:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 11 | Train: 54.47% | Val: 99.6139%


Epoch 12/22:   0%|          | 0/133 [00:00<?, ?it/s]

Validating:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 12 | Train: 59.26% | Val: 99.6139%


Epoch 13/22:   0%|          | 0/133 [00:00<?, ?it/s]

Validating:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 13 | Train: 55.78% | Val: 99.2278%


Epoch 14/22:   0%|          | 0/133 [00:00<?, ?it/s]

Validating:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 14 | Train: 61.17% | Val: 99.2278%


Epoch 15/22:   0%|          | 0/133 [00:00<?, ?it/s]

Validating:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 15 | Train: 64.62% | Val: 99.6139%


Epoch 16/22:   0%|          | 0/133 [00:00<?, ?it/s]

Validating:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 16 | Train: 60.26% | Val: 99.6139%


Epoch 17/22:   0%|          | 0/133 [00:00<?, ?it/s]

Validating:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 17 | Train: 57.16% | Val: 100.0000%
✓ Best: 100.0000%



Epoch 18/22:   0%|          | 0/133 [00:00<?, ?it/s]

Validating:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 18 | Train: 56.44% | Val: 99.6139%


Epoch 19/22:   0%|          | 0/133 [00:00<?, ?it/s]

Validating:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 19 | Train: 58.60% | Val: 99.6139%


Epoch 20/22:   0%|          | 0/133 [00:00<?, ?it/s]

Validating:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 20 | Train: 55.12% | Val: 99.6139%


Epoch 21/22:   0%|          | 0/133 [00:00<?, ?it/s]

Validating:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 21 | Train: 53.84% | Val: 99.6139%


Epoch 22/22:   0%|          | 0/133 [00:00<?, ?it/s]

Validating:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 22 | Train: 58.54% | Val: 99.6139%

Best Val: 100.0000%



Predicting:   0%|          | 0/31 [00:00<?, ?it/s]


✓ Submission saved | Total: 740

Distribution:
Class
street_music     152
drilling         151
engine_idling    151
dog_bark         148
siren            138
Name: count, dtype: int64


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>