In [1]:
import torch
import torch.nn as nn
from torchvision.transforms import ToTensor
from torch.utils.data import Dataset, DataLoader
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
import os

from src.UrbanSound import UrbanSoundDataset, UrbanSoundExposureGenerator
from src.replay import ReplayExposureBlender

def normalize_tensor_wav(x, eps=1e-10, std=None):
    mean = x.mean(-1, keepdim=True)
    if std is None:
        std = x.std(-1, keepdim=True)
    return (x - mean) / (std + eps)

In [2]:
exposure_generator = UrbanSoundExposureGenerator(
    'UrbanSound8K', range(1, 10), sr=16000, 
    exposure_size=300, exposure_val_size=50, initial_K=4
)

initial_tr, initial_val, seen_classes = exposure_generator.get_initial_set()

exposure_tr_list = []
exposure_val_list = []
exposure_label_list = []

for i in range(len(exposure_generator)):
    exposure_tr, exposure_val, label = exposure_generator[i]  
    exposure_tr_list.append(exposure_tr)
    exposure_val_list.append(exposure_val)
    exposure_label_list.append(label)

In [3]:
initial_tr_loader = DataLoader(initial_tr, batch_size=4, shuffle=True, num_workers=4)
initial_val_loader = DataLoader(initial_tr, batch_size=4, shuffle=True, num_workers=4)

In [4]:
for i, label in enumerate(exposure_label_list):
    if label in seen_classes:
        exposure_tr = exposure_tr_list[i]
        exposure_val = exposure_val_list[i]
        break
        
new_tr = ReplayExposureBlender(initial_tr, exposure_tr, seen_classes, label)
#new_val = ReplayExposureBlender(initial_val, exposure_val, seen_classes, label)

In [10]:
new_tr_loader = DataLoader(new_tr, batch_size=4, shuffle=True, num_workers=4)
#new_val_loader = DataLoader(new_val, batch_size=4, shuffle=True, num_workers=4)