<a href="https://colab.research.google.com/github/suhasineejain/SEMINAR--ConvTasNet/blob/main/Seminar.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

**IMPORTING LIBRARIES**

In [None]:
!pip install torchaudio

import torch
import torchaudio
import torch.utils.data


import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

import numpy as np
import scipy,time,numpy
import itertools
from itertools import permutations

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


**DATA LOADING**

In [None]:
class Timitdataset():

  def __init__(self,base_dir,transform=None):

    self.base_dir = base_dir
    self.transform = transform

    f = open(self.base_dir + '/mix_2_spk_mix.txt','r')
    self.lines =[]
    for line in f.readlines():
      self.lines.append(line)

  def __len__(self):
    return(len(self.lines))

  def __getitem__(self,idx):

    mixline = self.lines[idx][:-1]
    filenamemix = self.base_dir + '/mix/'+ mixline +'.wav'
    filenames1 =  self.base_dir + '/s1/'+ mixline + '.wav'
    filenames2 = self.base_dir + '/s2/' + mixline + '.wav'

    waveform_mix, sample_rate = torchaudio.load(filenamemix)
    waveform_s1, sample_rate = torchaudio.load(filenames1)
    waveform_s2, sample_rate = torchaudio.load(filenames2)

    if(waveform_s1.shape[1]<=32000):
      pad = 32000 - waveform_s1.shape[1]
      c = torch.zeros((1,pad))
      waveform_s1 = torch.cat((waveform_s1,c),1)
    else:
      waveform_s1 = waveform_s1[:,0:32000]
    
    if(waveform_s2.shape[1]<=32000):
      pad = 32000 - waveform_s2.shape[1]
      c = torch.zeros((1,pad))
      waveform_s2 = torch.cat((waveform_s2,c),1)
    else:
      waveform_s2 = waveform_s2[:,0:32000]

    target = torch.cat((waveform_s1,waveform_s2),dim=0)

    return waveform_mix , target 


**MODEL**

In [None]:
class DepthConv1d(nn.Module):

    def __init__(self, input_channel, hidden_channel, kernel, padding, dilation=1, skip=True, causal=False):
        super(DepthConv1d, self).__init__()
        
        self.causal = causal
        self.skip = skip
        
        self.conv1d = nn.Conv1d(input_channel, hidden_channel, 1)
        if self.causal:
            self.padding = (kernel - 1) * dilation
        else:
            self.padding = padding
        self.dconv1d = nn.Conv1d(hidden_channel, hidden_channel, kernel, dilation=dilation,
          groups=hidden_channel,
          padding=self.padding)
        self.res_out = nn.Conv1d(hidden_channel, input_channel, 1)
        self.nonlinearity1 = nn.PReLU()
        self.nonlinearity2 = nn.PReLU()
        if self.causal:
            self.reg1 = cLN(hidden_channel, eps=1e-08)
            self.reg2 = cLN(hidden_channel, eps=1e-08)
        else:
            self.reg1 = nn.GroupNorm(1, hidden_channel, eps=1e-08)
            self.reg2 = nn.GroupNorm(1, hidden_channel, eps=1e-08)
        
        if self.skip:
            self.skip_out = nn.Conv1d(hidden_channel, input_channel, 1)

    def forward(self, input):
        output = self.reg1(self.nonlinearity1(self.conv1d(input)))
        if self.causal:
            output = self.reg2(self.nonlinearity2(self.dconv1d(output)[:,:,:-self.padding]))
        else:
            output = self.reg2(self.nonlinearity2(self.dconv1d(output)))
        residual = self.res_out(output)
        if self.skip:
            skip = self.skip_out(output)
            return residual, skip
        else:
            return residual

In [None]:
class TCN(nn.Module):
    def __init__(self, input_dim, output_dim, BN_dim, hidden_dim,
                 layer, stack, kernel=3, skip=True, 
                 causal=False, dilated=True):
        super(TCN, self).__init__()
        
        # input is a sequence of features of shape (B, N, L)
        
        # normalization
        if not causal:
            self.LN = nn.GroupNorm(1, input_dim, eps=1e-8)
        else:
            self.LN = cLN(input_dim, eps=1e-8)

        self.BN = nn.Conv1d(input_dim, BN_dim, 1)
        
        # TCN for feature extraction
        self.receptive_field = 0
        self.dilated = dilated
        
        self.TCN = nn.ModuleList([])
        for s in range(stack):
            for i in range(layer):
                if self.dilated:
                    self.TCN.append(DepthConv1d(BN_dim, hidden_dim, kernel, dilation=2**i, padding=2**i, skip=skip, causal=causal)) 
                else:
                    self.TCN.append(DepthConv1d(BN_dim, hidden_dim, kernel, dilation=1, padding=1, skip=skip, causal=causal))   
                if i == 0 and s == 0:
                    self.receptive_field += kernel
                else:
                    if self.dilated:
                        self.receptive_field += (kernel - 1) * 2**i
                    else:
                        self.receptive_field += (kernel - 1)
                    
        #print("Receptive field: {:3d} frames.".format(self.receptive_field))
        
        # output layer
        
        self.output = nn.Sequential(nn.PReLU(),
                                    nn.Conv1d(BN_dim, output_dim, 1)
                                   )
        
        self.skip = skip
        
    def forward(self, input):
        
        # input shape: (B, N, L)
        
        # normalization
        output = self.BN(self.LN(input))
        
        # pass to TCN
        if self.skip:
            skip_connection = 0.
            for i in range(len(self.TCN)):
                residual, skip = self.TCN[i](output)
                output = output + residual
                skip_connection = skip_connection + skip
        else:
            for i in range(len(self.TCN)):
                residual = self.TCN[i](output)
                output = output + residual
            
        # output layer
        if self.skip:
            output = self.output(skip_connection)
        else:
            output = self.output(output)
        
        return output

In [None]:
class TasNet(nn.Module):
    def __init__(self, enc_dim=512, feature_dim=128, sr=16000, win=2, layer=8, stack=3,        #changed sr from 16000 to 8000 kHz
                 kernel=3, num_spk=2, causal=False):
        super(TasNet, self).__init__()
        
        # hyper parameters
        self.num_spk = num_spk

        self.enc_dim = enc_dim
        self.feature_dim = feature_dim
        
        self.win = int(sr*win/1000)
        self.stride = self.win // 2
        
        self.layer = layer
        self.stack = stack
        self.kernel = kernel

        self.causal = causal
        
        # input encoder
        self.encoder = nn.Conv1d(1, self.enc_dim, self.win, bias=False, stride=self.stride)
        
        # TCN separator
        self.TCN = TCN(self.enc_dim, self.enc_dim*self.num_spk, self.feature_dim, self.feature_dim*4,
                              self.layer, self.stack, self.kernel, causal=self.causal)

        self.receptive_field = self.TCN.receptive_field
        
        # output decoder
        self.decoder = nn.ConvTranspose1d(self.enc_dim, 1, self.win, bias=False, stride=self.stride)

    def pad_signal(self, input):

        # input is the waveforms: (B, T) or (B, 1, T)
        # reshape and padding
        if input.dim() not in [2, 3]:
            raise RuntimeError("Input can only be 2 or 3 dimensional.")
        
        if input.dim() == 2:
            input = input.unsqueeze(1)
        batch_size = input.size(0)
        nsample = input.size(2)
        rest = self.win - (self.stride + nsample % self.win) % self.win
        if rest > 0:
            pad = Variable(torch.zeros(batch_size, 1, rest)).type(input.type())
            input = torch.cat([input, pad], 2)
        
        pad_aux = Variable(torch.zeros(batch_size, 1, self.stride)).type(input.type())
        input = torch.cat([pad_aux, input, pad_aux], 2)

        return input, rest
        
    def forward(self, input):
        
        # padding
        output, rest = self.pad_signal(input)
        batch_size = output.size(0)
        
        # waveform encoder
        enc_output = self.encoder(output)  # B, N, L

        # generate masks
        masks = torch.sigmoid(self.TCN(enc_output)).view(batch_size, self.num_spk, self.enc_dim, -1)  # B, C, N, L
        masked_output = enc_output.unsqueeze(1) * masks  # B, C, N, L
        
        # waveform decoder
        output = self.decoder(masked_output.view(batch_size*self.num_spk, self.enc_dim, -1))  # B*C, 1, L
        output = output[:,:,self.stride:-(rest+self.stride)].contiguous()  # B*C, 1, L
        output = output.view(batch_size, self.num_spk, -1)  # B, C, T
        
        return output

**LOSS FUNCTION**

In [None]:
def calc_sdr_torch(estimation, origin, mask=None):
    """
    batch-wise SDR caculation for one audio file on pytorch Variables.
    estimation: (batch, nsample)
    origin: (batch, nsample)
    mask: optional, (batch, nsample), binary
    """
    
    if mask is not None:
        origin = origin * mask
        estimation = estimation * mask
    
    origin_power = torch.pow(origin, 2).sum(1, keepdim=True) + 1e-8  # (batch, 1)
    
    scale = torch.sum(origin*estimation, 1, keepdim=True) / origin_power  # (batch, 1)
    
    est_true = scale * origin  # (batch, nsample)
    est_res = estimation - est_true  # (batch, nsample)
    
    true_power = torch.pow(est_true, 2).sum(1)
    res_power = torch.pow(est_res, 2).sum(1)
    
    return 10*torch.log10(true_power) - 10*torch.log10(res_power)  # (batch, 1)


In [None]:
def batch_SDR_torch(estimation, origin, mask=None):
    """
    batch-wise SDR caculation for multiple audio files.
    estimation: (batch, nsource, nsample)
    origin: (batch, nsource, nsample)
    mask: optional, (batch, nsample), binary
    """
    
    batch_size_est, nsource_est, nsample_est = estimation.size()
    batch_size_ori, nsource_ori, nsample_ori = origin.size()
    
    assert batch_size_est == batch_size_ori, "Estimation and original sources should have same shape."
    assert nsource_est == nsource_ori, "Estimation and original sources should have same shape."
    assert nsample_est == nsample_ori, "Estimation and original sources should have same shape."
    
    assert nsource_est < nsample_est, "Axis 1 should be the number of sources, and axis 2 should be the signal."
    
    batch_size = batch_size_est
    nsource = nsource_est
    nsample = nsample_est
    
    # zero mean signals
    estimation = estimation - torch.mean(estimation, 2, keepdim=True).expand_as(estimation)
    origin = origin - torch.mean(origin, 2, keepdim=True).expand_as(estimation)
    
    # possible permutations
    perm = list(set(permutations(np.arange(nsource))))
    
    # pair-wise SDR
    SDR = torch.zeros((batch_size, nsource, nsource)).type(estimation.type())
    for i in range(nsource):
        for j in range(nsource):
            SDR[:,i,j] = calc_sdr_torch(estimation[:,i], origin[:,j], mask)
    
    # choose the best permutation
    SDR_max = []
    SDR_perm = []
    for permute in perm:
        sdr = []
        for idx in range(len(permute)):
            sdr.append(SDR[:,idx,permute[idx]].view(batch_size,-1))
        sdr = torch.sum(torch.cat(sdr, 1), 1)
        SDR_perm.append(sdr.view(batch_size, 1))
    
    SDR_perm = torch.cat(SDR_perm, 1)
    SDR_perm = torch.sum(SDR_perm,0,keepdim=True)                         #added by me
    SDR_max, _ = torch.max(SDR_perm, dim=1)


    return SDR_max / nsource


**DEVICE**

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'


**INSTANCE OF MODEL**

In [None]:
model = TasNet()
model = model.to(device)
for param in model.parameters():

   print(type(param), param.size())

In [None]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'The model has {count_parameters(model):,} trainable parameters')

The model has 5,066,929 trainable parameters


**OPTIMIZATION**

In [None]:
from torch import optim

learning_rate = 1e-3
T_max = 10 
optimizer = optim.Adam(model.parameters(),lr = learning_rate, weight_decay=1e-3)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max, eta_min=10e-6, last_epoch=-1)
#scheduler = optim.lr_scheduler.StepLR(optimizer,10,last_epoch=-1)
#scheduler = custscheduler(optimizer)
print('done')

done


**TRAINING**

In [None]:
from torch.utils.data import random_split
from torch.utils.data import DataLoader

base_dir = '/content/drive/My Drive/output_dir8k_new'

dataset = Timitdataset(base_dir) 
 

n_val = int(len(dataset) * 0.1)
n_train = len(dataset) - n_val
train_dataset, valid_dataset = random_split(dataset,[n_train, n_val])
b_size = 16          #16


train_loader = DataLoader(train_dataset,batch_size=b_size,shuffle=True, drop_last=True, pin_memory= True,num_workers=8)
valid_loader = DataLoader(valid_dataset,batch_size=1,shuffle=True, drop_last=True, pin_memory=True,num_workers=8)



In [None]:
def train(model,train_loader,optimizer):
  epoch_loss = 0

  model.train()

  for input,target in train_loader:
    
    input = input.to(device)
    target = target.to(device)

    predicted = model(input)
    
    loss = - batch_SDR_torch(predicted,target)
    

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    #if diff1+diff2<1:
    #scheduler.step()

    epoch_loss += loss.item()

  return epoch_loss / len(train_loader)

In [None]:
def evaluation(model,valid_loader,optimizer):
  epoch_loss = 0

  model.eval()

  with torch.no_grad():

    for input,target in valid_loader:

      input = input.to(device)
      target = target.to(device)
      
      prediction = model(input)
      loss = batch_SDR_torch(prediction,target)

      epoch_loss += loss.item()
  
  return epoch_loss / len(valid_loader)

In [None]:
import time

def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

In [None]:
import os
epoch =0
loss = 0
run_no = 1

if not os.path.isfile(os.path.join(base_dir,'model_parameters',f'checkpoint{run_no}.pth') ):
        torch.save({'model_state_dict':model.state_dict(),
                  'epoch':epoch,
                  'optimizer_state_dict':optimizer.state_dict(),
                  'loss':loss
                  },
                  os.path.join(base_dir,'model_parameters',f'checkpoint{run_no}.pth'))
        
checkpoint=torch.load(os.path.join(base_dir,'model_parameters',f'checkpoint{run_no}.pth'))
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

In [None]:
N_EPOCHS = 20

best_valid_loss = float('inf')
for epoch in range(N_EPOCHS):

    start_time = time.time()
  
    train_loss = train(model, train_loader, optimizer)
    #train_loss = train(model, train_loader, optimizer,diff1,diff2)
    valid_loss = evaluation(model, valid_loader,optimizer)

    end_time = time.time()

    epoch_mins, epoch_secs = epoch_time(start_time, end_time)
    
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
    
    print(f'Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s')
    print(f'\tTrain Loss: {-train_loss:.3f}')
    print(f'\t Val. Loss: {valid_loss:.3f}')

    if (epoch + 1) % 5 == 0:
      torch.save({'model_state_dict':model.state_dict(),
                  'epoch':epoch,
                  'optimizer_state_dict':optimizer.state_dict(),
                  'loss':loss
                 },
                 os.path.join(base_dir,'model_parameters',f'checkpoint{run_no}.pth'))

torch.save(model.state_dict(),
               os.path.join(base_dir,'model_parameters', f'final{run_no}.pth'))


**DEMO OUTPUT**

In [None]:
count = 0
for input,target in valid_loader:
    
      input = input.to(device)
      target = target.to(device)
        
      prediction = model(input)
      if count == 100:
        break
      count += 1

In [None]:
print(type(prediction),type(target),type(input))
print(prediction[0].shape,target.shape,input.shape)

<class 'torch.Tensor'> <class 'torch.Tensor'> <class 'torch.Tensor'>
torch.Size([2, 32000]) torch.Size([1, 2, 32000]) torch.Size([1, 1, 32000])


In [None]:
import IPython.display as ipd
ipd.Audio(data = prediction.detach().cpu().numpy()[0][0], rate = 8000)

In [None]:
ipd.Audio(data = target.detach().cpu().numpy()[0][0], rate = 8000)

In [None]:
ipd.Audio(data = target.detach().cpu().numpy()[0][1], rate = 8000)

In [None]:
ipd.Audio(data = prediction.detach().cpu().numpy()[0][1], rate = 8000)

In [None]:
ipd.Audio(data = input.detach().cpu().numpy()[0][0], rate = 8000)