In [1]:
import numpy as np
import torch
import math
from math import log
import torch.nn as nn
import torch.nn.functional as F

In [2]:

from typing import List
def create_binary_list_from_int(number: int) -> List[int]:
    if number < 0:
        raise ValueError(f"Only Positive integers are allowed, actual number: {number}, type: {type(number)}")

    res = [int(x) for x in list(bin(number))[2:]]
    # res = [0] * (7 - len(res)) + res
    # res[4] = res[5]
    # res[3] = 1 - res[4]
    return res

In [3]:
# def correct_int(number: int) -> int:
#     lst = create_binary_list_from_int(number)
#     res = int("".join(str(x) for x in lst), 2)
#     return res

In [4]:
# num = 127
# print(create_binary_list_from_int(num))
# print(correct_int(num))

In [5]:
# correct_list = set([correct_int(x) for x in range(128)])
# correct_list

In [6]:




def gen_even_data(max_int: int, batch_sz: int):
    max_len = int(log(max_int - 1,2)) + 1

    sampled_ints = np.random.randint(0,max_int/2,batch_sz)
    data = [create_binary_list_from_int(x*2) for x in sampled_ints]
    data = torch.tensor([([0] * (max_len - len(x))) + x for x in data]).float()
    return data

def gen_noise(max_int: int, batch_sz: int):
    max_len = int(log(max_int - 1,2)) + 1
    sampled_ints = np.random.randint(0,int(max_int),batch_sz)
    data = torch.tensor([create_binary_list_from_int(x) for x in sampled_ints]).float()
    return data

In [7]:
gen_even_data(128,16)

tensor([[1., 0., 0., 0., 1., 1., 0.],
        [1., 1., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 1., 1., 0.],
        [0., 1., 0., 0., 1., 0., 0.],
        [0., 1., 0., 0., 1., 0., 0.],
        [1., 0., 0., 0., 1., 0., 0.],
        [0., 0., 0., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 1., 0.],
        [1., 0., 0., 1., 0., 0., 0.],
        [0., 0., 1., 1., 1., 1., 0.],
        [0., 1., 1., 0., 0., 1., 0.],
        [1., 0., 0., 1., 1., 0., 0.],
        [0., 1., 1., 1., 1., 0., 0.],
        [0., 1., 1., 0., 0., 1., 0.],
        [0., 0., 1., 0., 1., 1., 0.],
        [1., 0., 1., 0., 1., 1., 0.]])

In [8]:


class Generator(nn.Module):

    def __init__(self, input_length: int):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            nn.Linear(int(input_length),256),
            nn.BatchNorm1d(256),
            nn.ReLU(),

            nn.Linear(256,512),
            nn.BatchNorm1d(512),
            nn.ReLU(),

            nn.Linear(512,int(input_length))
        )

    def forward(self, x):
        b = torch.full((int(input_len),),1.0, dtype=torch.float)
        x = self.main(x)
        x = F.relu(b - F.relu(b - x))
        return x

In [9]:
class Discriminator(nn.Module):
    def __init__(self, input_length: int):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Linear(int(input_length),512),
            nn.BatchNorm1d(512),
            nn.LeakyReLU(0.1),

            nn.Linear(512,256),
            nn.BatchNorm1d(256),
            nn.LeakyReLU(0.1),

            nn.Linear(256,1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.main(x)



In [10]:
def float_list_to_int(lst: list[float]):
    bool_lst = [(1 if x>0.5 else 0) for x in lst]
    res = int("".join(str(x) for x in bool_lst), 2)
    return res

In [11]:


max_int = 128
input_len = int(math.log(max_int,2))
generator = Generator(input_len)
discriminator = Discriminator(input_len)
generator_optim = torch.optim.Adam(generator.parameters(), lr = 0.0004, betas=(0.5,0.999))
discriminator_optim = torch.optim.Adam(discriminator.parameters(), lr = 0.0004, betas=(0.5,0.999))
loss = nn.BCELoss()

def train(batch_size: int = 16, training_steps: int = 5000):

    generator.train()
    discriminator.train()

    for i in range(training_steps):
        generator.zero_grad()

        noise = torch.randn((batch_size,int(input_len)), dtype=torch.float)
        generated_data = generator(noise)

        true_data = gen_even_data(max_int, batch_size)
        true_labels = torch.ones(size=(batch_size,1))

        true_discriminator_out = discriminator(true_data)
        true_discriminator_loss = loss(true_discriminator_out, true_labels)

        generator_discriminator_out = discriminator(generated_data.detach())
        generator_discriminator_loss = loss(generator_discriminator_out, torch.zeros(batch_size,1))
        discriminator_loss = (true_discriminator_loss + 2*generator_discriminator_loss)
        discriminator_loss.backward()
        discriminator_optim.step()

        generator_discriminator_out = discriminator(generated_data)
        generator_loss = loss(generator_discriminator_out,true_labels)
        generator_loss.backward()
        generator_optim.step()

        if(i%100 == 0):
            print(f"\n\niter: {i}")
            for j, x in enumerate(generated_data.detach()):

                print(f"data: {float_list_to_int(x)}, discriminator out: {float(generator_discriminator_out.detach()[j])}")
            # converted_generated_data = [float_list_to_int(x) for x in generated_data]
            # print(f"iter: {i}, data: {converted_generated_data[:5]} \n")
            # print(f"iter: {i}, true data: {true_data[:5]} \n")


In [12]:
train()



iter: 0
data: 0, discriminator out: 0.36002740263938904
data: 8, discriminator out: 0.34836745262145996
data: 36, discriminator out: 0.3635422885417938
data: 36, discriminator out: 0.3369635045528412
data: 32, discriminator out: 0.38845038414001465
data: 6, discriminator out: 0.3734356164932251
data: 28, discriminator out: 0.3440278172492981
data: 4, discriminator out: 0.42969340085983276
data: 0, discriminator out: 0.30539435148239136
data: 2, discriminator out: 0.3507276475429535
data: 8, discriminator out: 0.37897780537605286
data: 0, discriminator out: 0.4425602853298187
data: 34, discriminator out: 0.37691086530685425
data: 56, discriminator out: 0.4008733332157135
data: 10, discriminator out: 0.39745157957077026
data: 0, discriminator out: 0.385061651468277


iter: 100
data: 50, discriminator out: 0.43741804361343384
data: 50, discriminator out: 0.43741804361343384
data: 50, discriminator out: 0.43741804361343384
data: 50, discriminator out: 0.43741804361343384
data: 50, discri

In [13]:
print((1 if -0.9>0.5 else 0))

0


In [14]:
a = [ 1.0000, -1.0000, -1.0000, -1.0000,  0.9991, -0.9989, -1.0000]
print(float_list_to_int(a))

68
