In [1]:
import torch
from torch.utils.data import DataLoader
import pandas as pd
import random
from transformer import Transformer
import os

device = 'cuda'
def rboolf(N, width, deg):
    coefficients = torch.randn(width).to(device)
    coefficients = (coefficients-coefficients.mean())/coefficients.pow(2).sum().sqrt()
    combs = torch.combinations(torch.arange(N+1), r=deg, with_replacement=True)
    combs = combs[torch.randperm(combs.size()[0])][:width].to(device) # Shuffled
    def func(x):
        binary = f"{x:0{N}b}"+"0"
        comps = []
        for elem in combs:
            res = 1
            for e in elem:
                bit = 1 if int(binary[e]) else -1
                res *= bit
            comps.append(res)
        return torch.dot(coefficients, torch.tensor(comps, dtype=torch.float32).to(device))
    return func, (coefficients, combs)

def generate_dataset(num_samples, N, batch_size):
    num_samples = 10000
    inputs = torch.tensor([random.randint(0, 2**N-1) for _ in range(num_samples)]).to(device)
    train_loader = DataLoader(inputs, shuffle=True, batch_size=batch_size)
    return train_loader 

def validate(model, func, N, num_samples=1000):
      model.eval()
      inputs = torch.tensor([random.randint(0, 2**N-1) for _ in range(num_samples)]).to(device)
      targets = torch.FloatTensor([float(func(x)) for x in inputs]).to(device)
      result = model(inputs).to(device)
      loss = (result - targets).pow(2).mean()
      return loss.detach()

def fitNetwork(function, loader, N, epochs, dir_name):
    model = Transformer(N, 120, 1, 1, 128, 1e-5).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=0.0000006, weight_decay=0.1)
    # model.train()
    
    movAvg = 0
    summary = pd.DataFrame(columns=["iter", "loss"])
    # dir_name = f"{args.N}_{args.dim}_{args.l}_{args.h}_{args.f}"

    for epoch in range(epochs):    
        for idx, inputs in enumerate(loader):
          model.train()
        
          targets = torch.FloatTensor([float(function(x)) for x in inputs]).to(device)
    
          result = model(inputs)
          
          loss = (result - targets).pow(2).mean()
          movAvg = 0.99 * movAvg + (1-0.99) * (float(loss.detach()))

          (loss).backward()
          optimizer.step()
          optimizer.zero_grad()
        
          iteration = epoch*len(loader)+idx+1
          if (iteration) % 10 == 0:
            summary.loc[len(summary)] = {"iter":iteration, "loss":movAvg}
            summary.to_csv(f"{dir_name}/curr_func.csv")

          if (iteration) % 100 == 0:
            val_loss = validate(model, function, N, num_samples=10000)
            print(f"Iteration: {iteration}, Loss: {loss.detach():.3f}, Validation Loss: {val_loss:.3f}")
            path = os.path.join(dir_name, f"modelx_{iteration}.pt")
            torch.save(model.state_dict(), path)
          
          if movAvg < 0.01:
            break
    return model, summary

In [2]:
N           = 30
width       = 5
deg         = 2 
num_samples = 10000
batch_size  = 64
epochs      = 5
dir_name    = "/dartfs/rc/lab/C/CybenkoG/bool_sens/test"

dataloader  = generate_dataset(num_samples, N, batch_size)
function, _ = rboolf(N, width, deg)

model, _    = fitNetwork(function, dataloader, N, epochs, dir_name)
loss        = validate(model, function, N, num_samples=10000)

iters = [100*i for i in range(1,8)]
for iter in iters:
    model = Transformer(N, 120, 1, 1, 128, 1e-5).to(device)
    state_dict = torch.load(f"{dir_name}/modelx_{iter}.pt", weights_only=True)
    model.load_state_dict(state_dict)
    model.eval()
    val_loss = validate(model, function, N, num_samples=10000)
    print(f"Iteration: {iter}, Validation Loss: {val_loss:.3f}")



Iteration: 100, Loss: 1.140, Validation Loss: 1.477
Iteration: 100, Loss: 1.140, Validation Loss: 1.453
embeddings.weight
transformer.0.attn.in_proj_weight
transformer.0.attn.out_proj.weight
transformer.0.norm1.weight
transformer.0.norm1.bias
transformer.0.norm2.weight
transformer.0.norm2.bias
transformer.0.linear.0.weight
transformer.0.linear.0.bias
transformer.0.linear.2.weight
transformer.0.linear.2.bias
output_proj.weight


In [4]:
val = torch.tensor([random.randint(0, 2**N-1) for _ in range(5)]).to(device)

print(model(val))
print(model2(val))

tensor([[-0.2711],
        [-0.6370],
        [-0.3238],
        [ 0.0585],
        [-0.5553]], device='cuda:0', grad_fn=<MmBackward0>)
tensor([[-0.2711],
        [-0.6370],
        [-0.3238],
        [ 0.0585],
        [-0.5553]], device='cuda:0', grad_fn=<MmBackward0>)


In [45]:
import torch
device = 'cpu'
def rboolf(N, width, deg):
    coefficients = torch.randn(width).to(device)
    coefficients = (coefficients-coefficients.mean())/coefficients.pow(2).sum().sqrt()
    combs = torch.combinations(torch.arange(N+1), r=deg, with_replacement=True)
    combs = combs[torch.randperm(combs.size()[0])][:width].to(device) # Shuffled
    def func(x):
        binary = f"{x:0{N}b}"+"0"
        comps = []
        for elem in combs:
            res = 1
            for e in elem:
                bit = 1 if int(binary[e]) else -1
                res *= bit
            comps.append(res)
        return torch.dot(coefficients, torch.tensor(comps, dtype=torch.float32).to(device))
    return func, (coefficients, combs)

def reconstructboolf(N, coefficients, combs):
    def func(x):
        binary = f"{x:0{N}b}"+"0"
        comps = []
        for elem in combs:
            res = 1
            for e in elem:
                bit = 1 if int(binary[e]) else -1
                res *= bit
            comps.append(res)
        return torch.dot(coefficients, torch.tensor(comps, dtype=torch.float32).to(device))
    return func, (coefficients, combs)

In [None]:
f, (coeffs, combs) = rboolf(30, 5, 3);

torch.save(coeffs, f"test_func_coeffs.pt")
torch.save(combs, f"test_func_combs.pt")

combs = torch.load(f"test_func_combs.pt")
coeffs = torch.load(f"test_func_coeffs.pt")

f2, _ = reconstructboolf(30, coeffs, combs);

for i in range(100):
    print(f"Results: {f(i)}, {f2(i)}")




