In [1]:
import torch
from torch_geometric.data import Data, DataLoader
from cnf import BipartiteData
from loss import LossCompute, AccuracyCompute, LossMetric
import time
import numpy as np
import matplotlib.pyplot as plt
from utils import load_checkpoint

In [2]:
# add a folder raw under dataset manually before execute this chunk
from data import SATDataset
ds = SATDataset('../dataset', 'RND3SAT/uf50-218', False)
last_trn, last_val = int(len(ds)), int(len(ds))
train_ds = ds[: last_trn]
valid_ds = ds[last_trn: last_val]
test_ds = ds[last_val:]

In [4]:
test_data = train_ds[1]
edge_index_pos = test_data.edge_index_pos
edge_index_neg = test_data.edge_index_neg
variable_count = max(max(edge_index_pos[1]), max(edge_index_neg[1])) + 1
clause_count = len(edge_index_pos[1])

# Debug for loss.py

In [5]:
def test_loss(iter_num, par_sm, par_sg, var_num, plot=False):
    loss_func = LossCompute(par_sm, par_sg, metric=LossMetric.linear_loss, debug=True)
    sat_rate = np.zeros(iter_num)
    loss_v = np.zeros(iter_num)
    for i in range(iter_num):
        x_s = LossCompute.push_to_side(torch.rand(var_num, 1), par_sg)
        loss, sm = loss_func(x_s, edge_index_pos, edge_index_neg, clause_count=218, is_train=False)
        satisfied_percentage = sum(sm > 0.5).numpy() / clause_count
        loss_v[i] = loss
        sat_rate[i] = satisfied_percentage
    if plot:
        plt.plot(sat_rate, loss_v, "ro")
        plt.xlabel("Satisfied Clauses / Numer of Clauses")
        plt.ylabel("Loss")
    return sat_rate, loss_v

In [5]:
start = time.time()
sat_rate, loss_v = test_loss(5000, 30, 50, variable_count, plot=True)
span = time.time() - start
print(f"Time for loss1 and loss2 to compute loss of 5000 FG respectively takes {span}s")

TypeError: __call__() missing 1 required positional argument: 'gr_idx_cls'

# Debug for models.py

In [6]:
import models
from args import make_args

In [8]:
class Args():
    def __init__(self):
        self.dataset = 'RND3SAT/uf50-218'
        self.dataset_root = '../dataset'
        self.loss = 'l2'
        self.use_gpu = True
        self.cuda = '0'
        self.graph_valid_ratio = 0.1
        self.graph_test_ratio = 0.1
        self.feature_transform = False
        self.drop_rate = 0.5
        self.speedip = False
        self.load_model = False
        self.batch_size = 1
        self.num_layers = 2
        self.num_encoder_layers = 4
        self.num_decoder_layers = 2
        self.num_meta_paths = 4
        self.encoder_channels = '1,16,32,32,32'
        self.decoder_channels = '32,16,16'
        self.self_att_heads = 8
        self.cross_att_heads = 8
        self.lr = 1e-6
        self.sm_par = 1e-3
        self.sig_par = 1e-3
        self.warmup_steps = 200
        self.opt_train_factor = 4
        self.epoch_num = 201
        self.epoch_log = 50
        self.epoch_save = 50
        self.save_root = 'saved_model'
        self.root = '../dataset'
        self.activation = 'relu'

## sample a test data

In [9]:
args = Args()
device = torch.device('cuda:0') if args.use_gpu and torch.cuda.is_available() else torch.device('cpu')
# device = 'cuda' if torch.cuda.is_available() else 'cpu
# download and save the dataset
dataset = SATDataset(args.root, args.dataset, use_negative=False)
dataset, perm = dataset.shuffle(return_perm=True)
num_clauses = dataset.num_clauses[perm][900:]
test_loader = DataLoader(dataset[900:],
                             batch_size=1,
                             shuffle=True)

## Initialize a model and to make it run through

In [16]:
model = models.make_model(args)
path = '../saved_model/check_point_500.pickle'
checkpoint = torch.load(path)
model.load_state_dict(checkpoint['model_state_dict'])
loss_func = LossCompute(30, 50, metric=LossMetric.linear_loss, debug=True)
accuracy = AccuracyCompute()
for i, batch in enumerate(test_loader):
    xv = model(batch, args)
    num_cls = num_clauses[i * 1 : (i + 1) * 1]
    gr_idx_cls = torch.cat([torch.tensor([i] * num_cls[i]) for i in range(num_cls.size(0))])
    loss, sm = loss_func(xv, batch.edge_index_pos, batch.edge_index_neg, clause_count=218, gr_idx_cls=gr_idx_cls, is_train=False)
    acc = accuracy(xv, batch.edge_index_pos, batch.edge_index_neg).item()
    print("new accuracy: ", flip_search(xv, sm, batch.edge_index_pos, batch.edge_index_neg, acc))
    print(acc)
    print(torch.sum(sm) / 218)

new accuracy:  None
0.0
tensor(0.9260, grad_fn=<DivBackward0>)
new accuracy:  None
0.0
tensor(0.9555, grad_fn=<DivBackward0>)
new accuracy:  None
0.0
tensor(0.9545, grad_fn=<DivBackward0>)
new accuracy:  None
0.0
tensor(0.9608, grad_fn=<DivBackward0>)
new accuracy:  None
0.0
tensor(0.9465, grad_fn=<DivBackward0>)
new accuracy:  None
0.0
tensor(0.9622, grad_fn=<DivBackward0>)
new accuracy:  None
0.0
tensor(0.9674, grad_fn=<DivBackward0>)
new accuracy:  None
0.0
tensor(0.9365, grad_fn=<DivBackward0>)
new accuracy:  None
0.0
tensor(0.9337, grad_fn=<DivBackward0>)
new accuracy:  None
0.0
tensor(0.9608, grad_fn=<DivBackward0>)
new accuracy:  None
0.0
tensor(0.9474, grad_fn=<DivBackward0>)
new accuracy:  None
0.0
tensor(0.9385, grad_fn=<DivBackward0>)
new accuracy:  None
0.0
tensor(0.9751, grad_fn=<DivBackward0>)
new accuracy:  None
0.0
tensor(0.9760, grad_fn=<DivBackward0>)
new accuracy:  None
0.0
tensor(0.9791, grad_fn=<DivBackward0>)
new accuracy:  None
0.0
tensor(0.9460, grad_fn=<DivBack

In [12]:
from loss import flip_search
x = torch.rand(4)
edge_index_pos = torch.tensor([
    [0, 0, 1, 2, 2],
    [0, 1, 1, 2, 3],
])
edge_index_neg = torch.tensor([
    [0, 1, 1, 2],
    [2, 0, 2, 1],
])
accuracy_compute = AccuracyCompute()
accuracy = accuracy_compute(x, edge_index_pos, edge_index_neg)
sm = LossCompute.get_sm(x, edge_index_pos, edge_index_neg, 10, 10)
print(x)
print("accuracy: ", accuracy)
print(sm)
print("new accuracy: ", flip_search(x, sm, edge_index_pos, edge_index_neg, accuracy))
print(x)

tensor([0.2695, 0.7240, 0.8077, 0.8801])
accuracy:  tensor(1.)
tensor([0.7167, 0.7261, 0.8555])
new accuracy:  None
tensor([0.2695, 0.7240, 0.8077, 0.8801])


vx = model(vx, vc, graph)   # graph is concatenated batch of 32 graphs  
vx.shape = (sum(lit_num_of_each_graph), 1)  
vc = sm = LossCompute(vx, graph)  
batch_index.shape = (sum(cls_num_of_each_graph), 1)  

# After finish model.py, debug train.py