# Define LayeredSegments module with n layers of m segments each

In [2]:
import math
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

import torch
import torch.nn as nn
import torch.nn.functional as F

class Segment(torch.nn.Module):
__doc__ = r""" Creates segment_features segments for each input feature. 
Output feature size is always 1

Args:
        in_features (int): Number of channels in the input image
        segment_features (int): Number of channels produced by the convolution

"""
    def __init__(self, in_features, segment_features, alpha=0.0):
        super(OneLayerSegmentNetv3, self).__init__()
        self.segment_features = segment_features
        self.in_features = in_features
        self.weight = nn.Parameter(torch.Tensor(segment_features, in_features))
        nn.init.normal_(self.weight, -10, 10)
        self.bias = nn.Parameter(torch.Tensor(segment_features))
        self.alpha = alpha
        self.intersections = torch.zeros(segment_features)

    def forward(self, x):
        # Stack features - we get yall[:,i] = yn(i) from for loop
        yall = x*model.weight.T + model.bias 
        
        #intersections are the X value of points of intersection of two lines
        # (this gives us starting and ending X values of a segment ie. projection on segment on x axis)
        intersections = (model.bias[1:] - model.bias[:-1])/((model.weight[:-1] - model.weight[1:]).squeeze(-1))
        
        #intersections are one less in shape than x. this ones matrix is added to mask later
        ones = torch.ones_like(x)
        
        # Create masks
        upper_mask = torch.ones((x.shape[0],intersections.shape[0])).masked_fill((x >= intersections), self.alpha)
        # append extra layer of ones on upper_mask so it matches segment size
        upper_mask = torch.cat((upper_mask, ones), dim=1)

        lower_mask = torch.ones((x.shape[0],intersections.shape[0])).masked_fill((x < intersections), self.alpha)
        # prepend extra layer of ones on lower_mask so it matches segment size
        lower_mask = torch.cat((ones, lower_mask), dim=1)
        
        # Apply masks
        yall = yall * upper_mask * lower_mask

        # Sum features
        y = yall.sum(dim=1).unsqueeze(-1)  # (N)

        return y