In [54]:
import numpy as np
from mnist1d.data import make_dataset, get_dataset_args

class MNIST1D:
    
    '''
    Input Dims (X): (B, C, W)
    ''' 
    
    def __init__(self, seed):
        self.seed = seed
        self._set_seed() 
        
    def forward(self, X, output_channels, kernel_size, padding = 0, stride = 1, dilation_rate = 1):
        self.padding = padding
        self.X = self._pad(X) # -> (Samples, Channels, Padded Features)
        self.output_channels = output_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.dilation_rate = dilation_rate
        
        Y = self._forward_util()
        
        return Y
        
    def _forward_util(self):
       
        self._create_kernel()
        
        batch_size = self.X.shape[0]
        
        
        if len(self.kernel.shape) == 1 and self.X.shape[1] == 1:
            
            '''if there is only 1 output and only 1 input channel''' 
            
            output_size = int(((self.X.shape[2] - self.kernel.size) / self.stride) + 1)  
            Y = np.zeros(shape = (output_size))
            Y_all = np.zeros(shape = (batch_size, 1, output_size))

            for sample in range(batch_size):
                for i in range(Y.size):
                    current_slice = self.X[sample, 0, i * self.stride : (i * self.stride + self.kernel.size)]
                    
                    if current_slice.size != self.kernel.size:
                        break
                    
                    Y[i] = np.sum(current_slice * self.kernel)
           
                Y_all[sample, 0] = Y 
            
        elif len(self.kernel.shape) == 3 and self.X.shape[1] == 1:
            
            '''
            If we have multiple output channels but a single input channel.
            '''
            
            output_size = int(((self.X.shape[2] - self.kernel.shape[2]) / self.stride) + 1)
            Y = np.zeros(shape = (self.output_channels, output_size)) 
            Y_all = np.zeros(shape = (batch_size, self.output_channels, output_size))

            for sample in range(batch_size):
                for out_ch in range(Y.shape[0]):
                    for i in range(Y.shape[1]):
                        
                        slice_idx = i * self.stride
                        current_slice = self.X[sample, 0, slice_idx:(slice_idx+self.kernel.shape[2])]
                        
                        if current_slice.size != self.kernel.shape[2]:
                            break
                        
                        Y[out_ch, i] = np.sum(current_slice * self.kernel[out_ch])
                
                Y_all[sample] = Y
 
        return Y_all
        
    def _create_kernel(self):
        
        if self.dilation_rate != 1 and self.output_channels > 1:
            kernel_mask = np.random.random_sample(size = (self.output_channels, self.X.shape[0], self.kernel_size))
            self.kernel = self._dilate_multiple_channels(kernel_mask)

        elif self.dilation_rate != 1 and self.output_channels == 1:
            kernel_mask = np.random.random_sample(size = (self.kernel_size))
            self.kernel = self._dilate(kernel_mask)
         
        elif self.dilation_rate == 1 and self.output_channels > 1: 
            self.kernel = np.random.random_sample(size = (self._output_channels, self.X.shape[0], self.kernel_size))
           
        elif self.dilation_rate == 1 and self.output_channels == 1 and self.X.shape[1] != 1:
            self.kernel = np.random.random_sample(size = (1, self.X.shape[0], self.kernel_size))
            
        elif self.dilation_rate == 1 and self.output_channels == 1:
            self.kernel = np.random.random_sample(size = (self.kernel_size)) 
            
          
    def _dilate(self, kernel_mask):
       
        if len(kernel_mask) == 1:
            return kernel_mask
       
        dilation_rate = self.dilation_rate - 1  
        i = 0
        
        while i < len(kernel_mask):
            if kernel_mask[i] != 0:
                kernel_mask = np.concatenate((kernel_mask[:i+1], [0 for _ in range(dilation_rate)], kernel_mask[i+1:]))
                i += dilation_rate
            i+= 1
            if i == (len(kernel_mask) - 1):
                return kernel_mask
            
    def _dilate_multiple_channels(self, kernel_mask):
        
        if kernel_mask.shape[2] == 1:
            return kernel_mask
        
        out_kernel = np.zeros(shape = (kernel_mask.shape[0], kernel_mask.shape[1], (kernel_mask.shape[2] * self.dilation_rate - (self.dilation_rate - 1))))
        
        dilation_rate = self.dilation_rate - 1
        
        for out_ch in range(kernel_mask.shape[0]):
            for in_ch in range(kernel_mask.shape[1]):
                i = 0
                dilated_row = kernel_mask[out_ch, in_ch, :]
                while i < len(dilated_row):
                    if dilated_row[i] != 0:
                        dilated_row = np.concatenate((dilated_row[:i+1], [0 for _ in range(dilation_rate)], dilated_row[i+1:]))
                        i += dilation_rate
                        
                    i += 1
                    
                    if i == (len(dilated_row) - 1):
                        out_kernel[out_ch, in_ch, :] = dilated_row
                        break
      
        return out_kernel 
          
    def _pad(self, X):
        X = np.pad(X, pad_width = ((0, 0), (0, 0), (self.padding, self.padding)))
        return X
           
    def _set_seed(self):
        if self.seed is not None:
            np.random.seed(self.seed)
            
    @property
    def dilation_rate(self):
        return self._dilation_rate
    
    @dilation_rate.setter
    def dilation_rate(self, dilation_rate):
        assert dilation_rate >= 1, ValueError('Dilation cannot be less than 1 for the Kernel!')
        self._dilation_rate = dilation_rate
    
    @property
    def output_channels(self):
        return self._output_channels
    
    @output_channels.setter
    def output_channels(self, output_channels):
        assert output_channels >= 1, ValueError('Output Channels cannot be less than 1!')
        self._output_channels = output_channels
        
    @property
    def kernel_size(self):
        return self._kernel_size     
   
    @kernel_size.setter
    def kernel_size(self, kernel_size):
        assert self.X.size >= kernel_size, ValueError('Kernel cannot be greater than input_vector!')
        assert isinstance(kernel_size, int), ValueError('kernel_size must be int for 1D Conv')
        self._kernel_size = kernel_size

In [55]:
defaults = get_dataset_args()
data = make_dataset(defaults)
x = data['x']

x = x.reshape(4000, 1, 40)

op = MNIST1D(seed=1)
y = op.forward(x, output_channels = 16, kernel_size = 5, dilation_rate = 1, stride = 2)
print(f"Output Shape: {y.shape}")


Output Shape: (4000, 16, 18)
