In [1]:
data_dir="/mnt/d/Programs/Python/PW/projects/asteroid/zip-hindi-2k"

In [79]:
from glob import glob
import torch
import torch.nn as nn   
from torch.utils.data import Dataset,DataLoader
import librosa
import soundfile as sf
from numpy.random import choice
import numpy as np  

In [3]:
files=glob(f"{data_dir}/**/*.wav")
len(files)

2000

In [4]:
SAMPLE_RATE=16_000
SEG_LENGTH=0.63
NUM_FBANK=64
WINDOW_LENGTH=0.025
OVERLAP=0.010

In [21]:
class Prologue(nn.Module):
    def __init__(self,
                 in_channels=1,
                 out_channels=128,
                 kernel_size=11,):
        super(Prologue,self).__init__()
        self.prolog=nn.Conv2d(in_channels=in_channels,
                              out_channels=out_channels,
                              kernel_size=kernel_size,
                              padding='same')
        self.norm1=nn.BatchNorm2d(num_features=out_channels)
        self.relu=nn.ReLU()
    
    def forward(self,x):
        x=self.prolog(x)
        x=self.norm1(x)
        x=self.relu(x)
        return x


In [10]:
class QuartzSubBlock(nn.Module):
    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size) -> None:
        super(QuartzSubBlock,self).__init__()
        self.depthwise_conv = nn.Conv2d(out_channels,
                                        out_channels,
                                        kernel_size=kernel_size, 
                                        padding=kernel_size//2, 
                                        groups=out_channels)
        self.pointwise_conv = nn.Conv2d(in_channels, 
                                        out_channels, 
                                        kernel_size=1)
        self.norm=nn.BatchNorm2d(out_channels)
        self.relu=nn.ReLU()
        self.dropout=nn.Dropout()
    
    def forward(self,x):
        x=self.pointwise_conv(x)
        x=self.depthwise_conv(x)
        x=self.norm(x)
        x=self.relu(x)
        x=self.dropout(x)
        return x


In [24]:
class QuartzBlock(nn.Module):
    def __init__(self,
                 out_channels,
                 kernel_size,
                 num_sub_blocks=2):
        super(QuartzBlock,self).__init__()
        self.sub_block_list=[QuartzSubBlock(
            out_channels,
            out_channels,
            kernel_size
        ) for _ in range(num_sub_blocks)]
        self.sub_blocks=nn.Sequential(*self.sub_block_list)
        self.depthwise_conv1 = nn.Conv2d(out_channels,
                                        out_channels,
                                        kernel_size=kernel_size, 
                                        padding=kernel_size//2, 
                                        groups=out_channels)
        self.pointwise_conv1 = nn.Conv2d(out_channels, 
                                        out_channels, 
                                        kernel_size=1)
        self.norm1=nn.BatchNorm2d(out_channels)
        self.relu1=nn.ReLU()
        self.dropout1=nn.Dropout()
        self.pointwise_conv2 = nn.Conv2d(out_channels, 
                                        out_channels, 
                                        kernel_size=1)
        self.norm2=nn.BatchNorm2d(out_channels)
    
    def forward(self,x):
        y=self.sub_blocks(x)
        y=self.pointwise_conv1(x)
        y=self.depthwise_conv1(x)
        y=self.norm1(y)
        x=self.pointwise_conv2(x)
        x=self.norm2(x)
        out=self.relu1(x+y)
        out=self.dropout1(out)
        return out


In [26]:
class Epilogue(nn.Module):
    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
                 dilation=1):
        super(Epilogue,self).__init__()
        self.conv=nn.Conv2d(in_channels=in_channels,
                            out_channels=out_channels,
                            kernel_size=kernel_size,
                            dilation=dilation)
        self.norm=nn.BatchNorm2d(out_channels)
        self.relu=nn.ReLU()
    
    def forward(self,x):
        x=self.conv(x)
        x=self.norm(x)
        x=self.relu(x)
        return x

In [30]:
class MarbleNet(nn.Module):
    def __init__(self) -> None:
        super(MarbleNet,self).__init__()
        self.prolog=Prologue()
        self.resizer=nn.Conv2d(in_channels=128,out_channels=64,kernel_size=1)
        self.block_b1=QuartzBlock(out_channels=64,
                                  kernel_size=13,
                                  num_sub_blocks=2)
        self.block_b2=QuartzBlock(out_channels=64,
                                  kernel_size=15,
                                  num_sub_blocks=2)
        self.block_b3=QuartzBlock(out_channels=64,
                                  kernel_size=17,
                                  num_sub_blocks=2)
        self.epilogue1=Epilogue(in_channels=64,
                                out_channels=128,
                                kernel_size=29,
                                dilation=2)
        self.epilogue2=Epilogue(in_channels=128,
                                out_channels=128,
                                kernel_size=1)
        self.conv1x1=nn.Conv2d(in_channels=128,
                            out_channels=2,
                            kernel_size=1)
        self.linear=nn.Linear(in_features=128,out_features=2)

    def forward(self,x):
        x=self.prolog(x)
        x=self.resizer(x)
        x=self.block_b1(x)
        x=self.block_b2(x)
        x=self.block_b3(x)
        x=self.epilogue1(x)
        x=self.epilogue2(x)
        x=self.conv1x1(x)
        batch=x.shape[0]
        x=torch.reshape(x,shape=(batch,-1))
        x=self.linear(x)
        return x

In [81]:
class MarbleNetDataset(Dataset):
    def __init__(self,audio_files,
                 noise_files,
                 sample_rate=16_000,
                 seg_len=0.63,
                 num_filts=64,
                 win_len=0.025,
                 overlap=0.01):
        self.audio_files=audio_files
        self.noise_files=noise_files
        self.sample_rate=sample_rate
        self.seg_len=int(seg_len*sample_rate)
        self.num_filts=num_filts
        self.win_len=int(win_len*sample_rate)
        self.overlap=int(overlap*sample_rate)
    
    def __len__(self):
        return len(self.audio_files)+len(self.noise_files)
    
    def __getitem__(self):
        chance=torch.rand(1).item()
        file=None
        label=None
        if chance > 0.5:
            file=choice(self.audio_files,1).item()
            label=1
        else:
            file=choice(self.noise_files,1).item()
            label=0
        data,_=librosa.load(file,sr=self.sample_rate,mono=True)
        mel_spectrogram = librosa.feature.melspectrogram(y=data, sr=self.sample_rate, 
                                                         n_fft=512,
                                                         hop_length=self.overlap, 
                                                         win_length=self.win_len, 
                                                         n_mels=self.num_filts)
        mel_spectrogram_db = librosa.power_to_db(mel_spectrogram, ref=np.max)
        if mel_spectrogram_db.shape[1] > 64:
            mel_spectrogram_db = mel_spectrogram_db[:, :64]
        elif mel_spectrogram_db.shape[1] < 64:
            mel_spectrogram_db = np.pad(mel_spectrogram_db, ((0, 0), (0, 64 - mel_spectrogram_db.shape[1])), mode='constant')
        return torch.tensor(mel_spectrogram_db).unsqueeze(0),torch.tensor(label)

In [90]:
marblenet_model=MarbleNet()

In [82]:
noise_files=glob('/mnt/d/Programs/Python/PW/projects/asteroid/noise-2k/**/*.wav')
len(noise_files)

2000

In [83]:
dataset=MarbleNetDataset(
    audio_files=files,
    noise_files=noise_files
)

In [87]:
dataloader=DataLoader(dataset,batch_size=3,shuffle=True)

In [91]:
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


In [92]:
optimizer=torch.optim.SGD(marblenet_model.parameters(),lr=0.01,momentum=0.9,weight_decay=0.001)
criterion=nn.CrossEntropyLoss()

In [93]:
marblenet_model=marblenet_model.to(device)

In [94]:
for epoch in range(1):
    marblenet_model.train()
    train_loss=0
    for x,y in dataloader:
        x=x.to(device)
        y=y.to(device)
        optimizer.zero_grad()
        output=marblenet_model(x)
        loss=criterion(output,y)
        train_loss+=loss.item()
        loss.backward()
        optimizer.step()
    print(f"epoch: {epoch} loss: {train_loss/len(dataset)}")

In [None]:
torch.save(marblenet_model,'./marble_net.pt')