In [1]:
%matplotlib notebook

In [2]:
import numpy as np
import pandas as pd
import torch, IPython, itertools, string
import random, time, warnings
import matplotlib.pyplot as plt
from matplotlib import animation

In [3]:
warnings.simplefilter('ignore')

In [4]:
class Point(torch.nn.Module):
    def __init__(self, linkage, name, at=None, on=None, anchored_at=None):
        super(Point, self).__init__()
        self.linkage = linkage
        self.name = name
        if at is not None:
            self.mode = 'at'
            self.type = 'free_point'
            self.parent = None
            self.r = torch.nn.Parameter(torch.tensor(at).to(torch.float))
        elif on is not None:
            self.mode = 'on'
            self.type = 'co_point'
            self.parent = on
            if self.parent.__class__.__name__ is 'Point':
                self.r = self.parent.r
            elif self.parent.__class__.__name__ is 'Line':
                raise Exception()
        elif anchored_at is not None:
            self.mode = 'anchored_at'
            self.type = 'anchored_point'
            self.parent = None
            self.r = torch.tensor(anchored_at).to(torch.float)
        else:
            self.mode = None
            self.type = 'null_point'
            self.parent = None
            self.r = None
    
    def __repr__(self):
        if self.mode is 'at':
            return('Point_{}(at={})'.format(self.name, str(self.r.tolist())))
        elif self.mode is 'on':
            return('Point_{}(on={})'.format(self.name, str(self.parent)))
        elif self.mode is 'anchored_at':
            return('Point_{}(anchored_at={})'.format(self.name, str(self.r.tolist())))
        else:
            return('Point_{}(None)'.format(self.name))
        
    def root(self):
        if self.parent is None:
            return(self)
        return(self.parent.root())
    
    def constrain_on(self, parent):
        self.mode = 'on'
        self.type = 'co_point'
        self.parent = parent
        if self.parent.__class__.__name__ is 'Point':
            self.r = self.parent.r
        elif self.parent.__class__.__name__ is 'Line':
            raise Exception()

In [5]:
class Line(torch.nn.Module):
    def __init__(self, linkage, name, r1, r2):
        super(Line, self).__init__()
        self.linkage = linkage
        self.name = name
        self.target_length = None
        if r1.__class__.__name__ is 'Point':
            self.p1 = Point(self.linkage, '{}{}'.format(self.name, '1'), on=r1)
        else:
            self.p1 = Point(self.linkage, '{}{}'.format(self.name, '1'), at=r1)
        if r2.__class__.__name__ is 'Point':
            self.p2 = Point(self.linkage, '{}{}'.format(self.name, '2'), on=r2)
        else:
            self.p2 = Point(self.linkage, '{}{}'.format(self.name, '2'), at=r2)
        
    def __repr__(self):
        return('Line_{}(p1={}, p2={})'.format(self.name, str(self.p1), str(self.p2)))
    
    @property
    def r(self):
        r = self.p2.r - self.p1.r
        return(r)
    
    def length(self):
        L = (self.p2.r-self.p1.r).pow(2).sum().pow(0.5)
        return(L)
    
    def constrain_length(self, L=None):
        if L is None:
            self.target_length = self.length().item()
        else:
            self.target_length = L
        self.linkage.update()
        
    def is_constrained(self):
        return(self.target_length is not None)
        
    def soft_abs(self, x):
        a = 10.0
        y = a*(x.abs()).pow(3)
        y /= (1.0+a*x.pow(2))
        return(y)
        
    def energy(self):
        E = 0.0
        if self.target_length is not None:
            E += self.soft_abs((self.length()-self.target_length))
        return(E)

In [6]:
class Angle(torch.nn.Module):
    def __init__(self, linkage, line1, line2):
        super(Angle, self).__init__()
        self.linkage = linkage
        self.line1 = line1
        self.line2 = line2
        self.target_theta = None
        self.ccw = False
        
    def __repr__(self):
        return('Angle({},{})'.format(self.line1.name, self.line2.name))
    
    def theta(self, ccw=False):
        r1 = self.line1.r
        r2 = self.line2.r
        if self.line1.p1.root() is not self.line2.p1.root():
            r2 = -r2
        u1 = r1/(r1.pow(2).sum().pow(0.5))
        u2 = r2/(r2.pow(2).sum().pow(0.5))
        cos_theta = (u1*u2).sum()
        theta = torch.arccos(cos_theta)
        if ccw is True:
            s = torch.sign(torch.cross(u1, u2)[2])
            theta = np.pi*(1-s)+s*theta
        theta *= 180/np.pi
        return(theta)
    
    def constrain_angle(self, theta=None, ccw=False):
        self.ccw = ccw
        if theta is None:
            self.target_theta = self.theta().item()
        else:
            theta %= 360
            theta *= np.pi/180
            self.target_theta = theta
        self.linkage.update()
        
    def is_constrained(self):
        return(self.target_theta is not None)
        
    def soft_abs(self, x):
        a = 10.0
        y = a*(x.abs()).pow(3)
        y /= (1.0+a*x.pow(2))
        return(y)
        
    def energy(self):
        E = 0.0
        if self.target_theta is not None:
            #E += (self.theta(self.ccw)*np.pi/180-self.target_theta).abs()
            E += self.soft_abs((self.theta(self.ccw)*np.pi/180-self.target_theta))
        return(E)

In [7]:
class Linkage():
    def __init__(self):       
        self.points = torch.nn.ModuleDict({})
        self.anchors = torch.nn.ModuleDict({})
        self.lines = torch.nn.ModuleDict({})
        self.angles = torch.nn.ModuleDict({})
        self.names = {}
        for _type in ['point', 'line']:
            self.names[_type] = []
            letters = string.ascii_letters[-26:]
            if _type is 'line':
                letters = letters.lower()
            for n in range(3):
                for t in itertools.product(letters, repeat=n):
                    self.names[_type].append(''.join(t))
            self.names[_type] = iter(self.names[_type][1:])
        self.names['anchor'] = self.names['point']
        self.plot = LinkagePlot(self)
        self.tolerance = 0.1
    
    def add_point_at(self, r):
        name = next(self.names['point'])
        self.points[name] = Point(self, name, at=r)
        self.plot.update()
        return(str(self.points[name]))
        
    def add_point_on(self, p):
        name = next(self.names['point'])
        self.points[name] = Point(self, name, on=p)
        self.plot.update()
        return(str(self.points[name]))
    
    def add_anchor_at(self, r):
        name = next(self.names['anchor'])
        self.anchors[name] = Point(self, name, anchored_at=r)
        self.plot.update()
        return(str(self.anchors[name]))
    
    def add_line(self, r1, r2):
        name = next(self.names['line'])
        self.lines[name] = Line(self, name, r1, r2)
        self.plot.update()
        return(str(self.lines[name]))
        
    def add_angle(self, line1, line2):
        name = '{}_{}'.format(line1.name, line2.name)
        self.angles[name] = Angle(self, line1, line2)
        self.plot.update()
        return(str(self.angles[name]))
        
    @property
    def N(self):
        N = 0
        N += len(self.points)
        N += len(self.anchors)
        N += 2*len(self.lines)
        return(N)
    
    @property
    def M(self):
        return(len(self.lines))
     
    def parameters(self):
        parameters = []
        for point in self.points.values():
            for param in point.parameters():
                parameters.append(param)
        for anchor in self.anchors.values():
            for param in anchor.parameters():
                parameters.append(param)
        for line in self.lines.values():
            for param in line.parameters():
                parameters.append(param)
        return(parameters)
        
    def xyz(self):
        xyz = torch.zeros((self.N, 3))
        n = 0
        for point in self.points.values():
            xyz[n] += point.r
            n += 1
        for anchor in self.anchors.values():
            xyz[n] += anchor.r
            n += 1
        for line in self.lines.values():
            xyz[n] += line.p1.r
            xyz[n+1] += line.p2.r
            n += 2
        return(xyz)
        
    def energy(self):
        E = 0.0
        for line in self.lines.values():
            E += line.energy()
        for angle in self.angles.values():
            E += angle.energy()
        return(E)
            
    def update(self, max_num_epochs=1000):
        optimizer = torch.optim.SGD(self.parameters(), lr=0.001)
        for epoch in range(max_num_epochs):
            optimizer.zero_grad()
            E = self.energy()
            E.backward()
            optimizer.step()
            self.plot.E_list.append(E.item())
            if E <= self.tolerance:
                break
        #if (E > self.tolerance or E.isnan()):
        #    raise Exception('Could not solve all constraints.')
        self.plot.update()
        time.sleep(0.01)

In [8]:
class LinkagePlot():
    def __init__(self, linkage):
        self.linkage = linkage
        self.origin = torch.tensor([0,0,0])
        self.E_list = [10000.0]
        
        # Set up figure and axis
        self.size = 5
        self.lim = 5
        self.fig = plt.figure(figsize=(2*self.size,self.size))
        self.ax1 = self.fig.add_subplot(121, autoscale_on=False,
            xlim=(-self.lim,self.lim),
            ylim=(-self.lim,self.lim))
        self.ax2 = self.fig.add_subplot(122, autoscale_on=False,
            xlim=(0,1),
            ylim=(0,1))
        self.ax1.set_title('Configuration')
        self.ax2.set_title('log10(E)')
        
        self.ax1.scatter(
            [self.origin[0]], [self.origin[1]],
            marker='o', s=20, c='red', alpha=1, label='origin')
        
        self.points, self.anchors, self.lines = {}, {}, {}
            
        self.lnE_line, = self.ax2.plot([], [], 'b-', markersize=3, lw=0.5, label='log10(E)')
        time_template = ' t={:.0f}\n E={:.2f}\n T={:.5f}\n theta={:.0f}\n'
        self.time_text = self.ax1.text(0.05, 0.7, '', transform=self.ax1.transAxes)
        
    def update(self):
        
        for point_name in self.linkage.points.keys():
            if point_name not in self.points.keys():
                point = self.ax1.scatter([], [], s=20, c='limegreen',
                    zorder=2, label=point_name)
                self.points[point_name] = point
            point = self.linkage.points[point_name]
            self.points[point_name].set_offsets(
                [[point.r[0],point.r[1]]])
                
        for anchor_name in self.linkage.anchors.keys():
            if anchor_name not in self.anchors.keys():
                anchor = self.ax1.scatter([], [], s=20, c='orange',
                    zorder=1, label=anchor_name)
                self.anchors[anchor_name] = anchor
            anchor = self.linkage.anchors[anchor_name]
            self.anchors[anchor_name].set_offsets(
                [[anchor.r[0],anchor.r[1]]])
                
        for line_name in self.linkage.lines.keys():
            ls, lw = ':', 1
            if self.linkage.lines[line_name].is_constrained():
                ls, lw = '-', 1
            if line_name not in self.lines.keys():
                line, = self.ax1.plot([], [], linestyle=ls, markersize=3, lw=lw, c='black',
                    zorder=0, label=line_name)
                self.lines[line_name] = line
            line = self.linkage.lines[line_name]
            self.lines[line_name].set_data(
                [line.p1.r[0],line.p2.r[0]],
                [line.p1.r[1],line.p2.r[1]])
            self.lines[line_name].set_linestyle(ls)
            self.lines[line_name].set_linewidth(lw)
            
        self.lnE_line.set_xdata(torch.arange(0,len(self.E_list)))
        self.lnE_line.set_ydata(torch.log10(torch.tensor(self.E_list)))
        self.ax2.set_xlabel('Epoch')
        self.ax2.set_xlim(0,len(self.E_list))
        self.ax2.set_ylim(-10,10)
        self.time_text.set_text('')
        self.fig.canvas.draw()

In [9]:
# Turn crank while simultaneously increasing crank length
# Implement Newton solver?

In [10]:
linkage = Linkage()

<IPython.core.display.Javascript object>

In [11]:
linkage.add_anchor_at([-1,1,0])

'Point_A(anchored_at=[-1.0, 1.0, 0.0])'

In [12]:
linkage.add_point_at([0,1.25,0])

'Point_B(at=[0.0, 1.25, 0.0])'

In [13]:
#linkage.add_point_at([3,4,0])

In [14]:
linkage.add_anchor_at([3,1,0])

'Point_C(anchored_at=[3.0, 1.0, 0.0])'

In [15]:
linkage.add_line(linkage.anchors['A'],linkage.points['B'])

'Line_a(p1=Point_a1(on=Point_A(anchored_at=[-1.0, 1.0, 0.0])), p2=Point_a2(on=Point_B(at=[0.0, 1.25, 0.0])))'

In [16]:
#linkage.add_line(linkage.points['B'],linkage.points['C'])

In [17]:
#linkage.add_line(linkage.points['C'],linkage.anchors['D'])

In [18]:
#linkage.add_line(linkage.anchors['D'],linkage.anchors['A'])

In [19]:
linkage.add_line(linkage.anchors['C'],linkage.anchors['A']) ############

'Line_b(p1=Point_b1(on=Point_C(anchored_at=[3.0, 1.0, 0.0])), p2=Point_b2(on=Point_A(anchored_at=[-1.0, 1.0, 0.0])))'

In [20]:
linkage.lines['a'].constrain_length(2)

In [21]:
#linkage.lines['b'].constrain_length(5)

In [22]:
#linkage.lines['c'].constrain_length(4)

In [23]:
linkage.add_angle(linkage.lines['b'], linkage.lines['a'])

'Angle(b,a)'

In [24]:
for theta in np.arange(0,720+10,10):
    linkage.angles['b_a'].constrain_angle(theta, ccw=True) # 'd_a'
    if theta in [0,180,360,540,720]:
        time.sleep(3)