In [2]:
import math
import time

import torch
from torch import nn
from torch.nn import functional as F

import math
import time

import torch
from torch import nn
from torch.nn import functional as F


class BLSTM(nn.Module):
    def __init__(self, dim, layers=2):
        super().__init__()
        self.lstm = nn.LSTM(bidirectional=False, num_layers=layers, hidden_size=dim, input_size=dim)

    def forward(self, x, hidden=None):
        x, hidden = self.lstm(x, hidden)
        return x, hidden

class Zmucs(nn.Module):

    def __init__(self,
                 depth=5,
                 hidden=48,
                 kernel_size=8,
                 stride=4,
                 resample=1) -> None:
        super().__init__()
        self.depth=depth
        self.hidden=hidden
        self.kernel_size=kernel_size
        self.stride=stride
        self.resample=resample
        
        self.encoder = nn.ModuleList()
        self.decoder = nn.ModuleList()
        self.chin=1
        self.chout=1 
        
        chin = self.chin
        chout = self.chout

        glu = True 
        growth = 2
        activation = nn.GLU(1) if glu else nn.ReLU()
        ch_scale = 2 if glu else 1
        max_hidden = 10000
        
        for idx in range(self.depth):
            encode = [
                nn.Conv1d(in_channels=chin,
                          out_channels=hidden,
                          kernel_size=kernel_size,
                          stride=stride),
                nn.ReLU(),
                nn.Conv1d(in_channels=hidden,
                          out_channels=hidden * ch_scale,
                          kernel_size=1),
                activation
            ]
            self.encoder.append(nn.Sequential(*encode))
            
            decode = [
                nn.Conv1d(in_channels=hidden,
                          out_channels=hidden*ch_scale,
                          kernel_size=1),
                activation,
                nn.ConvTranspose1d(in_channels=hidden,
                                   out_channels=chout,
                                   kernel_size=kernel_size,
                                   stride=stride,
                                   bias=True)
            ]
            if idx > 0:                                
                decode.append(nn.ReLU())
            self.decoder.insert(0,nn.Sequential(*decode))
            
            chin=hidden
            chout=hidden
            hidden = min(int(growth * hidden), max_hidden)

        self.lstm = BLSTM(chin,1)
    
    def valid_length(self, length):
        length = math.ceil(length * self.resample)
        for idx in range(self.depth):
            length = math.ceil((length - self.kernel_size) / self.stride) + 1
            length = max(length, 1)
        for idx in range(self.depth):
            length = (length - 1) * self.stride + self.kernel_size
        length = int(math.ceil(length / self.resample))
        return int(length)
    
    @property
    def total_stride(self):
        return self.stride ** self.depth // self.resample

    def forward(self,signal):
        if signal.dim() ==2 :
            signal = signal.unsqueeze(1)
        
        x = signal
        length = signal.shape[-1]
        x = F.pad(x, (0, self.valid_length(length) - length))
        skips = []
        for encode in self.encoder:
            x=encode(x)
            skips.append(x)
        
        x = x.permute(2, 0, 1)
        x, _ = self.lstm(x)
        x = x.permute(1, 2, 0)
        
        for idx, decode in enumerate(self.decoder):
            skip = skips.pop(-1)
            x = x + skip[..., :x.shape[-1]]
            x = decode(x)
        x = x[...,:signal.shape[-1]]
        return x

class ZmucsSteamer():
    def __init__(self,
                 zmucs,
                 num_frames=1,
                 resample_lookahead=0,
                 resample_buffer=0
                 ) -> None:
        self.zmucs=zmucs
        self.num_frames=num_frames
        self.resample_lookahead=resample_lookahead
        self.resample_buffer=resample_buffer
        self.pending = torch.zeros(zmucs.chin, 0)

        self.frame_length = zmucs.valid_length(1) + zmucs.total_stride * (num_frames - 1)
        
        self.total_length = self.frame_length + resample_lookahead
        
        self.stride = zmucs.total_stride * num_frames
        self.lstm_state = None
        self.conv_state = None
        
    
    def flush(self):
        """
        Flush remaining audio by padding it with zero. Call this
        when you have no more input and want to get back the last chunk of audio.
        """
        pending_length = self.pending.shape[1]
        padding = torch.zeros(self.zmucs.chin, self.total_length)
        out = self.feed(padding)
        return out[:, :pending_length]
        self.conv_state = None 
    
    def feed(self, chunk, flush_it=False):
        if chunk.dim() !=2:
            raise ValueError("Wav should be 2d")
        zmucs = self.zmucs
        self.pending=torch.cat([self.pending,chunk],dim=1)
        outs = []
        while self.pending.shape[-1] >= self.total_length:
            frame = self.pending[:,:self.total_length]
            out = self._separate_frame(frame)
            outs.append(out[:,:self.stride])
            self.pending = self.pending[:,self.stride:]
            
            # for idx, each in enumerate(self.conv_state):
            #     print(f"Level {idx}: {each.shape}")
            # break;
            
        return torch.cat(outs,1)

    def _separate_frame(self,frame):
        x = frame[None]
        
        skips = []
        next_state = []
        
        zmucs = self.zmucs
        first = self.conv_state is None
        # first = True
        stride = self.stride * zmucs.resample
        
        for idx,encode in enumerate(self.zmucs.encoder):
                # print(f"Shape x:{x.shape}")
            length = x.shape[2]
            stride //= zmucs.stride
            if not first:
                prev = self.conv_state.pop(0)
               
                # below temp
                prev = prev[..., stride:]
                
                tgt = (length - zmucs.kernel_size) // zmucs.stride + 1
                missing = tgt - prev.shape[-1]
                offset = length - zmucs.kernel_size - zmucs.stride * (missing - 1)
                x = x[..., offset:]
            
            x = encode[3](encode[2](encode[1](encode[0](x))))
            if not first:
                x = torch.cat([prev, x], -1)
            next_state.append(x)
            skips.append(x)

        x = x.permute(2, 0, 1)
        x, self.lstm_state = zmucs.lstm(x, self.lstm_state)
        x = x.permute(1, 2, 0)

        try: 
            for idx, decode in enumerate(self.zmucs.decoder):
                skip = skips.pop(-1)

                x += skip[..., :x.shape[-1]]
                x = decode[2](decode[1](decode[0](x)))
                
                next_state.append(x[..., -zmucs.stride:] - decode[2].bias.view(-1, 1))
                
                # next_state.append(x[..., -zmucs.stride:])
                x = x[..., :-zmucs.stride]

                if not first:
                    prev = self.conv_state.pop(0)
                    x[..., :zmucs.stride] += prev
                
                if idx != self.zmucs.depth - 1:
                    x = decode[3](x)
        except Exception as e:
            print(e)
            print(f"exception at idx {idx}")
            print(f"Shapes: x: {x.shape}")
            print(f"encoder {encode}")
            raise ValueError
        # return x[..., :-self.zmucs.stride]
        self.conv_state = next_state
        return x[0]

def get_norm(ta, tb):
       print(f"delta batch/streaming: {torch.norm(ta - tb) / torch.norm(ta):.2%}")          

if __name__ == '__main__':
    for j in range(0,1): 
        zmucs = Zmucs(depth=5)   
        sig = torch.randn(1,160000)
        denoised = zmucs(sig)
        denoised = denoised[0]
        # print(denoised[0].shape)
        
        # zlen = zmucs.valid_length(1)
        # print('zlen ' + str(zlen))
        streamer = ZmucsSteamer(zmucs)

        infer_start = time.time()
        outs = streamer.feed(sig)
        out_rt = torch.cat([outs[0],streamer.flush()[0]]).unsqueeze(0)
        infer_end = time.time()-infer_start
        print(f"Total infer time RTF: {infer_end:.2f}")
        # get_norm(denoised,out_rt)
        print(len(outs))


Total infer time RTF: 76.01
1


In [3]:
ta =torch.randn(1,100)
tb =torch.randn(1,100)
print(f"streaming {torch.norm(ta-tb) / torch.norm(ta):.2%}")
linalgnorm = torch.linalg.vector_norm(ta - tb) / torch.linalg.vector_norm(ta)
print(f"linalg norm {linalgnorm}")

streaming 134.09%
linalg norm 1.3408938646316528


In [4]:
zmucs = Zmucs()
list_shape = lambda tenslist: [print(each.shape) for each in tenslist]

In [5]:
import math
import time

import torch
from torch import nn
from torch.nn import functional as F

import math
import time

import torch
from torch import nn
from torch.nn import functional as F

def diff(ta, tb):
       print(f"delta batch/streaming: {torch.norm(ta - tb) / torch.norm(ta):.2%}")

class Umucs:
    def __init__(self) -> None:            
        self.depth=3
        self.stride_inp=4
        
        self.c1 = nn.Conv1d(in_channels=1, out_channels=48, kernel_size=8, stride=4)
        self.c2 = nn.Conv1d(in_channels=48, out_channels=96, kernel_size=8, stride=4)
        self.c3 = nn.Conv1d(in_channels=96, out_channels=192, kernel_size=8, stride=4)
        self.encoder=nn.ModuleList()
        
        self.encoder.append(self.c1)
        self.encoder.append(self.c2)
        self.encoder.append(self.c3)
        
        # self.inp1= torch.randn(1,296)
        
        self.stride = self.stride_inp ** self.depth
        self.frame_length= self.valid_length(1)

    def get_out(self, length,depth):
        kernel_size=8
        stride=4
        resample=1
        length = math.ceil(length * resample)
        for idx in range(depth):
            length = math.ceil((length - kernel_size) / stride) + 1
            length = max(length, 1)
        return length

    def get_in(self,length,depth):
        """
        Determine the input_length given that we got `length` in the output.
        """
        kernel_size=8
        stride=4
        resample=1
        length = math.ceil(length * resample)
        for idx in range(depth):
            length = (length - 1) * stride + kernel_size
        length = int(math.ceil(length / resample))
        return int(length)

    def valid_length(self,length):
        len = self.get_out(length,self.depth)
        return self.get_in(len,self.depth)
    
    def frm_zmucs(self,inp):
        self.pending = torch.zeros(1, 0)
        self.pending=torch.cat([self.pending,inp],dim=1)
        inp_frames = []
        while self.pending.shape[-1] >= self.frame_length:
            frame = self.pending[:,:self.frame_length]
            inp_frames.append(frame)
            self.pending = self.pending[:,self.stride:] 
        return inp_frames           
        
        
    def feed(self,inp:torch.Tensor,conv_state):
        
        expected_length = self.valid_length(1)
        assert inp.shape[-1] == expected_length
        
        # do_predict(inp)
        
        return_values = [
            inp[:,self.stride:],
        ]
        return return_values

    def framed_inp(self,inp):
        return inp.unfold(1,self.valid_length(1),self.stride).squeeze(0)
    
    def main(self,inp):
        framed_inp = self.framed_inp(inp)
        # zms_fr = torch.stack(self.frm_zmucs(inp)).transpose(0,1)
        # assert torch.allclose(zms_fr,framed_inp) 
        
        self.conv_state = []
        
        out = []
        for each in framed_inp:
            out.append(self.pred_frame(each.unsqueeze(0)))
        
        output = torch.stack(out).transpose(0,3).squeeze(0)
        return output
    
    def pred_frame(self,frame):
        x = frame[None]
        
        first = len(self.conv_state) == 0
        
        stride = self.stride
        
        next_state = []
        for idx, encode in enumerate(self.encoder):
            
            stride = stride // self.stride_inp
            length = x.shape[2]
            
            if not first:
                prev = self.conv_state.pop(0)
                prev = prev[..., stride:]
                
                tgt = (length - 8) // self.stride_inp + 1
                missing = tgt - prev.shape[-1]
                offset = length - 8 - self.stride_inp * (missing - 1)
                x = x[..., offset:]
            
            x = encode(x)
            if not first:
                x = torch.cat([prev, x], -1)

            next_state.append(x)
        
        self.conv_state = next_state
        return x
            
umucs = Umucs()
inp = torch.randn(1,1600)
zmucs_frames = umucs.frm_zmucs(inp)

In [6]:
online_op = umucs.main(inp)
online_op.shape

torch.Size([1, 192, 23])

In [7]:
offl_op = umucs.c3(umucs.c2(umucs.c1(inp[None])))
offl_op.shape

torch.Size([1, 192, 23])

In [8]:
diff(offl_op,online_op)
assert torch.allclose(offl_op,online_op,1e-6,1e-6)

delta batch/streaming: 0.00%


In [9]:
d1 = nn.ConvTranspose1d(in_channels=48,
                        out_channels=1,
                        kernel_size=8,
                        stride=4,
                        bias=True)

In [15]:
inp1 = torch.randn(1,48,400)
d1(inp1).shape

torch.Size([1, 1, 1604])

In [18]:
d1(inp1[...,-1].unsqueeze(-1))

torch.Size([1, 1, 8])