In [1]:
import numpy as np

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim

from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
plt.style.use('ggplot')
plt.style.use('seaborn-colorblind')

In [14]:
def draw_edge_pair(x):
    plt.plot([x[0], x[2]], [x[1], x[3]])
    plt.plot([x[4], x[6]], [x[5], x[7]])
    
    
def are_edge_pairs_crossed(p):
    p1, p2, p3, p4 = p[:,:2], p[:,2:4], p[:,4:6], p[:,6:]
    a = p2 - p1
    b = p3 - p4
    c = p1 - p3
    ax, ay = a[:,0], a[:,1]
    bx, by = b[:,0], b[:,1]
    cx, cy = c[:,0], c[:,1]
    
    denom = ay*bx - ax*by
    numer_alpha = by*cx-bx*cy
    numer_beta = ax*cy-ay*cx
    alpha = numer_alpha / denom
    beta = numer_beta / denom
    return torch.logical_and(
        torch.logical_and(0<alpha, alpha<1),
        torch.logical_and(0<beta, beta<1),
    ).float()



class EdgePairDataset():
    def __init__(self, n=10000):
        super().__init__()
        self.n = n
        self.data = torch.rand(n, 8)
        self.label = are_edge_pairs_crossed(self.data)
        
    def __len__(self):
        return self.n
    
    def __getitem__(self, i):
        return self.data[i], self.label[i]
    
    
    
class CrossingDetector(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer_dims = [8,96,256,96,1]
        self.layers = []
        for i, (in_dim, out_dim) in enumerate(zip(self.layer_dims[:-1], self.layer_dims[1:])):
            self.layers.append(nn.Linear(in_dim, out_dim))
            if i < len(self.layer_dims)-2:
                self.layers.append(nn.LeakyReLU())
            else:
                self.layers.append(nn.Sigmoid())
        self.main = nn.Sequential(*self.layers)
        
    def forward(self, x):
        return self.main(x)
    
    



In [15]:
dataset = EdgePairDataset(n=int(1e6))
dataloader = DataLoader(dataset, batch_size=1024, shuffle=True)

device = 'cuda'
model = CrossingDetector().to(device)
bce = nn.BCELoss()
optmizer = optim.SGD(model.parameters(), lr=1)

In [None]:


loss_curve = []
for epoch in tqdm(range(100)):
    for edge_pairs, targets in dataloader:
        edge_pairs, targets = edge_pairs.to(device), targets.to(device)
        pred = model(edge_pairs)
        loss = bce(pred, targets.view(-1,1))

        optmizer.zero_grad()
        loss.backward()
        optmizer.step()
        
    loss_curve.append(loss.item())
        
    if epoch % 20 == 19:
        plt.plot(loss_curve)
        plt.show()


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

## Acurracy

In [None]:
test_loader = DataLoader(EdgePairDataset(n=int(1e6)), batch_size=1024, shuffle=True)

correct = 0
total = 0
with torch.no_grad():
    for edge_pairs, targets in tqdm(test_loader):
        edge_pairs, targets = edge_pairs.to(device), targets.to(device)
        pred = model(edge_pairs)
        correct += ((pred>0.5) == targets.view(-1,1)).sum().item()
        total += len(targets)
        
#         ## vis
#         draw_edge_pair(edge_pairs[0])
#         plt.title(f'{pred[0].item()}/{targets[0].item()}')
#         plt.xlim([0,1])
#         plt.ylim([0,1])
#         plt.show()
print(f'{correct}/{total} {correct/total}')

## Test: Optimziation on a single pair of crossed edges