In [24]:
import random as rd
import torch
import torch.nn as nn

def channel(chanInput):
    chanInput = torch.clip(chanInput,-1,1)
    erasedIndex = rd.randint(0,2)
    chanInput[erasedIndex:len(chanInput):3] = 0

    return chanInput + torch.empty(chanInput.shape).normal_(std=torch.sqrt(torch.tensor(5)))

class System_MLP(nn.Module):
    
    def __init__(self):
        super().__init__()
        
        self.encoder = nn.Sequential(
        nn.Linear(7, 64),
        nn.ReLU(),
        nn.BatchNorm1d(64),
            
        nn.Linear(64,256),
        nn.ReLU(),
        nn.BatchNorm1d(256),
            
        nn.Linear(256, 700)
        )
        
        self.decoder = nn.Sequential(
        nn.Linear(700, 256),
        nn.ReLU(),
        nn.BatchNorm1d(256),
            
        nn.Linear(256, 128)
        )
        
    def forward(self, char):
        encoded_char = self.encoder(char)
        with torch.no_grad():
            channel_output = channel(encoded_char)
        decoded_char = self.decoder(channel_output)
        
        return decoded_char, encoded_char
    
class System_CNN(nn.Module):
    
    def __init__(self):
        super().__init__()
        
        self.encoder = nn.Sequential(
        nn.ConvTranspose1d(1, 20, 6),
        nn.ReLU(),
        nn.BatchNorm1d(20),
        
        nn.ConvTranspose1d(20, 35, 9),
        nn.Flatten()
        )
        
        self.decoder = nn.Sequential(
        nn.Conv1d(35, 16, 11),
        nn.ReLU(),
        nn.BatchNorm1d(16),
            
        nn.Flatten(),
        nn.Linear(160, 128)
        )
        
    def forward(self, char):
        char = char.reshape(char.size(0),1,-1)
        encoded_char = self.encoder(char)
        with torch.no_grad():
            channel_output = channel(encoded_char)
        channel_output = channel_output.reshape(channel_output.size(0), 35, 20)
        decoded_char = self.decoder(channel_output)
        
        return decoded_char, encoded_char

In [27]:
def generate_data(nb_samples):
    return torch.empty(nb_samples,7).random_(2)

def create_targets(bin_list):
    bin_list = bin_list.flip(-1)
    int_target = torch.sum(bin_list*torch.tensor([2**k for k in range(7)]), -1)
    targets = torch.zeros((bin_list.size(0),128))
    for k in range(bin_list.size(0)):
        targets[k][int(int_target[k])-1] = 1.
    return targets

def train(model, data, data_targets, batch_size=100, nb_epochs=100):
    
    criterion = nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters())
    
    for e in range(nb_epochs):
        for b in range(0, data.size(0), batch_size):
            decoded_char, _ = model(data.narrow(0, b, batch_size))
            loss = criterion(decoded_char, data_targets.narrow(0, b, batch_size))
            model.zero_grad()
            loss.backward()
            optimizer.step()

In [29]:
data = generate_data(51000)
data_targets = create_targets(data)

In [30]:
system = System_MLP()
train(system, data, data_targets, nb_epochs=500)

In [31]:
data = generate_data(10000)
targets = create_targets(data).argmax(-1)
output = system(data)[0].argmax(-1)
#print(targets, output)
print("nb_errors:", torch.sum(targets != output).item())

nb_errors: 3671
