In [2]:
import torch
import torch.nn as nn
from tqdm import tqdm
from torch.optim import Adam
from torchvision.transforms import Compose, ToTensor, Normalize, Lambda
import numpy as np
import random

  from .autonotebook import tqdm as notebook_tqdm


In [3]:

class Net(torch.nn.Module):
    def __init__(self, dims, device=None):
        super().__init__()
        self.layers = []
        for d in range(len(dims) - 1): 
            self.layers += [Layer(dims[d], dims[d + 1], id=d, device=device)]

    def predict(self, x):
        goodness_per_label = []
        for label in range(16):
            h = x
            goodness = []
            for layer in self.layers:
                h = layer(h)
                goodness += [h.pow(2).mean(1)]
            goodness_per_label += [sum(goodness).unsqueeze(1)]
        goodness_per_label = torch.cat(goodness_per_label, 1)
        return goodness_per_label.argmax(1)
    
    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

    def train(self, x_pos, x_neg):
        h_pos, h_neg = x_pos, x_neg
        for i, layer in enumerate(self.layers):
            print(f"training layer {i}")
            h_pos, h_neg = layer.train(h_pos, h_neg)

class Layer(nn.Linear):
    def __init__(self, in_features, out_features, id,
                 bias=True, device=None, dtype=None):
        super().__init__(in_features, out_features, bias, device, dtype)
        self.id = id
        self.relu = torch.nn.ReLU()
        self.opt = Adam(self.parameters(), lr=0.03)
        self.threshold = 2.0
        self.num_epochs = 1000
        self.device = device

    def forward(self, x):
        x_direction = torch.nn.functional.normalize(x)
        x_direction = torch.tensor(x_direction, dtype=torch.float32)
        return self.relu(
            torch.matmul(x_direction, self.weight.T) + self.bias.unsqueeze(0))

    def train(self, x_pos, x_neg):
        if self.id == 2:
            for _ in tqdm(range(self.num_epochs)):
                g_pos = self.Channel(self.forward(x_pos))
                g_neg = self.Channel(self.forward(x_neg))
                g_pos_sum = g_pos[:, :128].pow(2) + g_pos[:, 128:].pow(2)
                g_neg_sum = g_neg[:, :128].pow(2) + g_neg[:, 128:].pow(2)
                g_pos_sum = torch.cat([g_pos_sum, g_pos_sum], 1)
                g_neg_sum = torch.cat([g_neg_sum, g_neg_sum], 1)
                g_pos = torch.div(g_pos, g_pos_sum)
                g_neg = torch.div(g_neg, g_neg_sum)
                loss = torch.relu(-(g_pos - g_neg).pow(2)).mean()
                self.opt.zero_grad()
                loss.backward()
                self.opt.step()
        else:
            for _ in tqdm(range(self.num_epochs)):
                g_pos = self.forward(x_pos)
                g_neg = self.forward(x_neg)
                g_pos_sum = g_pos[:, :128].pow(2) + g_pos[:, 128:].pow(2)
                g_neg_sum = g_neg[:, :128].pow(2) + g_neg[:, 128:].pow(2)
                g_pos_sum = torch.cat([g_pos_sum, g_pos_sum], 1)
                g_neg_sum = torch.cat([g_neg_sum, g_neg_sum], 1)
                g_pos = torch.div(g_pos, g_pos_sum)
                g_neg = torch.div(g_neg, g_neg_sum)
                loss = torch.relu(-(g_pos - g_neg).pow(2)).mean()
                self.opt.zero_grad()
                loss.backward()
                self.opt.step()
        return self.forward(x_pos).detach(), self.forward(x_neg).detach()

    def Channel(self, x): 
        # add some noise
        stddev = np.sqrt(1 / (10 ** (4)))
        noise = torch.normal(mean=0, std=stddev, size=x.shape).to(device)
        return x + noise
    

class decoder(nn.Module):
    def __init__(self, block_length):
        super().__init__()
        self.conv1 = nn.Linear(block_length, 512)
        self.conv2 = nn.Linear(512, 256)
        self.conv3 = nn.Linear(256, 128)

    
    def forward(self, x):
        x = torch.sigmoid(self.conv1(x))
        x = torch.sigmoid(self.conv2(x))
        x = torch.sigmoid(self.conv3(x))
        return x

def Channel(x): 
    # add some noise
    stddev = np.sqrt(1 / (10 ** (4)))
    noise = torch.normal(mean=0, std=stddev, size=x.shape).to(device)
    return x + noise


In [4]:
batch_size = 512
block_length = 128

train_data = np.random.binomial(1, 0.5, [100000, block_length])
label_true, label_false = train_data, train_data
np.random.shuffle(label_false)

test_data = train_data
np.random.shuffle(test_data)
test_data = test_data[:100, :]
test_label = test_data

print(test_data.shape)
print(train_data.shape)
print(label_true.shape)

device = torch.device("cuda:3" if torch.cuda.is_available() else "cpu")
    
net_in = Net([block_length, 256, 256, 256], device)
# net_out = Net([block_length, 256, 128, 64, 32, 1], device)

x_pos = torch.Tensor(label_true).to(device)
x_neg = torch.Tensor(label_false).to(device)
            
net_in.train(x_pos, x_neg)

(100, 128)
(100000, 128)
(100000, 128)
training layer 0


 11%|█         | 110/1000 [00:02<00:22, 39.19it/s]


KeyboardInterrupt: 

In [7]:
# imm_data = torch.nn.functional.normalize(net_in.forward(x_pos)).cpu().detach()
# imm_label = x_pos.cpu().detach()
from copy import deepcopy

# data, label = train_data, train_data
# data = torch.Tensor(data).to(device)
# label = torch.Tensor(label).to(device)
# print(data.shape)
# print(label.shape)
# data = Channel(net_in.forward(x_pos)).detach()
data = torch.cat([x_pos, x_pos], 1)
label = x_pos.detach()

# train_data = np.random.binomial(1, 0.5, [100000, block_length])
# train_label = train_data


deco = decoder(256).to(device)
optimizer = Adam(deco.parameters(),lr=1e-3)
crit = nn.MSELoss()

for epoch in tqdm(range(5000)):
    optimizer.zero_grad()
    output = deco(data)
    loss = crit(output, label)
    loss.backward()
    optimizer.step()

    if epoch % 500 == 0:
        print(f"current loss: {loss.item()}")


  0%|          | 20/5000 [00:00<00:27, 182.07it/s]

current loss: 0.25545522570610046


 10%|█         | 521/5000 [00:06<00:52, 85.80it/s]

current loss: 0.017220621928572655


 20%|██        | 1022/5000 [00:12<00:46, 84.78it/s]

current loss: 0.0056803859770298


 30%|███       | 1522/5000 [00:18<00:40, 85.38it/s]

current loss: 0.0026116734370589256


 40%|████      | 2022/5000 [00:24<00:35, 84.88it/s]

current loss: 0.0015556573634967208


 50%|█████     | 2522/5000 [00:30<00:29, 84.56it/s]

current loss: 0.0010220810072496533


 60%|██████    | 3022/5000 [00:37<00:23, 84.65it/s]

current loss: 0.0007070524152368307


 70%|███████   | 3522/5000 [00:43<00:17, 84.41it/s]

current loss: 0.0005055690417066216


 80%|████████  | 4022/5000 [00:49<00:11, 84.48it/s]

current loss: 0.00037101603811606765


 90%|█████████ | 4522/5000 [00:55<00:05, 84.28it/s]

current loss: 0.00027591927209869027


100%|██████████| 5000/5000 [01:01<00:00, 81.24it/s]


In [8]:
# test_data = np.random.binomial(1, 0.5, [100, block_length])
test_label = test_data

x_pos = torch.Tensor(test_data).to(device)

encoded_msg = net_in.forward(x_pos)

print(encoded_msg)

channeled_msg = Channel(encoded_msg).to(device)

print(channeled_msg)

channeled_msg = torch.cat([x_pos, x_pos], 1)
output = deco(channeled_msg)
output = torch.Tensor(np.where(output.cpu().detach().numpy() >= 0.5, 1, 0)).to(device)

print(output)
print(test_label)
    
print(test_label - output.cpu().detach().numpy())
print(np.absolute(test_label - output.cpu().detach().numpy()).sum(1))

tensor([[0.0000, 0.0003, 0.0000,  ..., 0.0597, 0.0000, 0.0993],
        [0.0000, 0.0135, 0.0036,  ..., 0.0610, 0.0000, 0.0809],
        [0.0000, 0.0000, 0.0000,  ..., 0.0667, 0.0000, 0.1000],
        ...,
        [0.0000, 0.0083, 0.0000,  ..., 0.0593, 0.0000, 0.0730],
        [0.0000, 0.0005, 0.0000,  ..., 0.0436, 0.0000, 0.0745],
        [0.0000, 0.0169, 0.0000,  ..., 0.0743, 0.0000, 0.0655]],
       device='cuda:3', grad_fn=<ReluBackward0>)
tensor([[-0.0045, -0.0053,  0.0050,  ...,  0.0629, -0.0082,  0.1026],
        [-0.0053,  0.0139,  0.0091,  ...,  0.0682, -0.0062,  0.0835],
        [-0.0222, -0.0023,  0.0035,  ...,  0.0698,  0.0067,  0.0860],
        ...,
        [-0.0096,  0.0015,  0.0056,  ...,  0.0637,  0.0086,  0.0809],
        [-0.0143, -0.0017, -0.0074,  ...,  0.0301, -0.0089,  0.0808],
        [-0.0122,  0.0184, -0.0016,  ...,  0.0741,  0.0043,  0.0718]],
       device='cuda:3', grad_fn=<AddBackward0>)
tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 

