In [None]:
import torch
import torch.nn.functional as F

class Net(torch.nn.Module):
    def __init__(self, w=None):
        super().__init__()
        if w is not None:
            self.w = torch.nn.Parameter(torch.tensor(w).float().view(6, 1))
        else:
            self.w = torch.nn.Parameter(torch.rand(6, 1))

    def forward(self, x):
        return torch.matmul(self.w, x).view(3, 2)


In [None]:
from plothelper import PlotHelper


net = Net()
opt = torch.optim.SGD(net.parameters(), lr=0.1)
plot = PlotHelper()

x = torch.tensor([1.])
y = torch.tensor([0, 0, 1])

for _ in range(100):
    opt.zero_grad()
    y_logit = net(x)
    loss = F.cross_entropy(y_logit[2:], y[2:])
    loss.backward()
    y_prob = torch.softmax(y_logit, dim=-1)
    plot.add(y0=y_prob[0,1].data, y1=y_prob[1,1].data, y2=y_prob[2,1].data, loss=loss.data)
    opt.step()

plot.show()


In [None]:
import torch
import itertools

import sys
sys.path.append("..")

from plothelper import PlotHelper
from pytorch_constraints.constraint import constraint
from pytorch_constraints.brute_force_solver import *
from pytorch_constraints.sampling_solver import *
from pytorch_constraints.tnorm_solver import ProductTNormLogicSolver

net = Net()
opt = torch.optim.SGD(net.parameters(), lr=0.1)
plot = PlotHelper()
plot_loss = PlotHelper()


# y: 3, Bx3
def xor(y):
    # return (y[0] and not y[2]) or (not y[0] and y[2])
    return y[0]!=y[1] and y[1]!=y[2]
    # return any([y[i]==y[i-1] for i in range(10)])

# xor_cons = constraint(xor)
num_samples = 100
xor_cons = constraint(xor, ViolationBruteForceSolver())
# xor_cons = constraint(xor, SamplingSolver(num_samples))
# xor_cons = constraint(xor, WeightedSamplingSolver(num_samples))
#xor_cons = constraint(xor, ProductTNormLogicSolver())
# x: 1 -> Bx1
for _ in range(500):
    opt.zero_grad()
    y_logit = net(x) # y: 3x2 -> Bx3x2
    oloss = F.cross_entropy(y_logit[2:], y[2:])
    closs = xor_cons(y_logit)
    loss = oloss + closs
    loss.backward()
    y_prob = torch.softmax(y_logit, dim=-1)
    plot.add(y0=y_prob[0,1].data, y1=y_prob[1,1].data, y2=y_prob[2,1].data)
    plot_loss.add(oloss=oloss.data, closs=closs.data, loss=loss.data)
    opt.step()

plot.show()
plot_loss.show()

In [None]:
import torch
import torch.nn.functional as F

class BatchNet(torch.nn.Module):
    def __init__(self, w=None):
        super().__init__()
        if w is not None:
            self.w = torch.nn.Parameter(torch.tensor(w).float().view(6, 1))
        else:
            self.w = torch.nn.Parameter(torch.rand(6, 1))

    def forward(self, x):
        return torch.matmul(x, self.w.T).view(-1, 3, 2)


In [None]:
from plothelper import PlotHelper


net = BatchNet()
opt = torch.optim.SGD(net.parameters(), lr=0.1)
plot0 = PlotHelper()
plot1 = PlotHelper()

x = torch.tensor([[1.],[1.]])
y = torch.tensor([[0, 0, 1],[0, 0, 1]])

for _ in range(100):
    opt.zero_grad()
    y_logit = net(x)
    loss = F.cross_entropy(y_logit[:, 2:, :].view(-1, 2), y[:, 2:].view(-1,))
    loss.backward()
    y_prob = torch.softmax(y_logit, dim=-1)
    plot0.add(y0=y_prob[0,0,1].data, y1=y_prob[0,1,1].data, y2=y_prob[0,2,1].data, loss=loss.data)
    plot1.add(y0=y_prob[1,0,1].data, y1=y_prob[1,1,1].data, y2=y_prob[1,2,1].data, loss=loss.data)
    opt.step()

plot0.show()
plot1.show()


In [None]:
import sys
sys.path.append("..")

from plothelper import PlotHelper

from pytorch_constraints.constraint import constraint
from pytorch_constraints.brute_force_solver import ViolationBruteForceSolver


net = BatchNet()
opt = torch.optim.SGD(net.parameters(), lr=0.1)
plot0 = PlotHelper()
plot1 = PlotHelper()

x = torch.tensor([[1.],[1.]])
y = torch.tensor([[0, 0, 1],[0, 0, 1]])

def xor(y):
    return (y[:,0] and not y[:,2]) or (not y[:,0] and y[:,2])
xor_cons = constraint(xor, ViolationBruteForceSolver())

for _ in range(100):
    opt.zero_grad()
    y_logit = net(x)
    oloss = F.cross_entropy(y_logit[:, 2:, :].view(-1, 2), y[:, 2:].view(-1,))
    closs = xor_cons(y_logit, y_logit)
    loss = oloss + closs
    loss.backward()
    y_prob = torch.softmax(y_logit, dim=-1)
    plot0.add(y0=y_prob[0,0,1].data, y1=y_prob[0,1,1].data, y2=y_prob[0,2,1].data, loss=loss.data)
    plot1.add(y0=y_prob[1,0,1].data, y1=y_prob[1,1,1].data, y2=y_prob[1,2,1].data, loss=loss.data)
    opt.step()

plot0.show()
plot1.show()
