In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [31]:
N = 4
x = torch.tensor([0,1,2,3], dtype=torch.float32).reshape((N, 1))
y = torch.tensor([2,4,6,8], dtype=torch.float32).reshape((N, 1))
print(x)
print(y)

tensor([[0.],
        [1.],
        [2.],
        [3.]])
tensor([[2.],
        [4.],
        [6.],
        [8.]])


In [50]:
from itertools import permutations

variables = ["x", "w", "b"]
binary_ops = ["@", "+", "*", "-"]

def permutations_binary_op(variables, binary_ops):
	combos = []
	for b in binary_ops:
		for v in permutations(variables, 2):
			combos.append(f"{v[0]}{b}{v[1]}")
	return combos

op1_combos = permutations_binary_op(variables, binary_ops)
print(op1_combos)
op2_combos = permutations_binary_op([*variables, "op1"], binary_ops)
print(op2_combos)

def gen_programs(op1_combos, op2_combos):
	programs = []
	for op1 in op1_combos:
		for op2 in op2_combos:
			programs.append((op1, op2))

	return programs

programs = gen_programs(op1_combos, op2_combos)
print(programs[:10])
len(programs)

['x@w', 'x@b', 'w@x', 'w@b', 'b@x', 'b@w', 'x+w', 'x+b', 'w+x', 'w+b', 'b+x', 'b+w', 'x*w', 'x*b', 'w*x', 'w*b', 'b*x', 'b*w', 'x-w', 'x-b', 'w-x', 'w-b', 'b-x', 'b-w']
['x@w', 'x@b', 'x@op1', 'w@x', 'w@b', 'w@op1', 'b@x', 'b@w', 'b@op1', 'op1@x', 'op1@w', 'op1@b', 'x+w', 'x+b', 'x+op1', 'w+x', 'w+b', 'w+op1', 'b+x', 'b+w', 'b+op1', 'op1+x', 'op1+w', 'op1+b', 'x*w', 'x*b', 'x*op1', 'w*x', 'w*b', 'w*op1', 'b*x', 'b*w', 'b*op1', 'op1*x', 'op1*w', 'op1*b', 'x-w', 'x-b', 'x-op1', 'w-x', 'w-b', 'w-op1', 'b-x', 'b-w', 'b-op1', 'op1-x', 'op1-w', 'op1-b']
[('x@w', 'x@w'), ('x@w', 'x@b'), ('x@w', 'x@op1'), ('x@w', 'w@x'), ('x@w', 'w@b'), ('x@w', 'w@op1'), ('x@w', 'b@x'), ('x@w', 'b@w'), ('x@w', 'b@op1'), ('x@w', 'op1@x')]


1152

In [101]:
class Model(nn.Module):
	def __init__(self, op1: str, op2: str):
		super().__init__()
		self.w = nn.Parameter(torch.tensor([[0]], dtype=torch.float32), requires_grad=True)
		self.b = nn.Parameter(torch.tensor([[0]], dtype=torch.float32), requires_grad=True)
		self.op1 = op1
		self.op2 = op2

	def forward(self, x, y):
		w = self.w
		b = self.b
		op1 = eval(self.op1)
		op2 = eval(self.op2)
		preds = op2
		valid_shape = preds.shape[0] == N and preds.shape[1] == 1
		if not valid_shape:
			raise Exception() 
		loss = F.mse_loss(y, preds)
		return loss

In [102]:
def train(model: nn.Module, x, y):
	iterations = 50
	lr = 0.1
	optim = torch.optim.SGD(model.parameters(), lr)

	best_loss = float("inf")
	for i in range(iterations):
		loss = model(x, y)

		# store best loss from this model
		if loss.item() < best_loss:
			best_loss = loss.item()

		optim.zero_grad()
		loss.backward()
		optim.step()
	
	return best_loss

In [103]:
from tqdm import tqdm

def program_search(programs, x, y):
	results = []
	INVALID_PROGRAM = -1
	for p in tqdm(programs):
		model = Model(*p)
		try:
			best_loss = train(model, x, y)
			results.append(
				(best_loss, p)
			)
		except Exception:
			results.append(
				(INVALID_PROGRAM, p)
			)
	return results

results = program_search(programs, x, y)

100%|██████████| 1152/1152 [00:10<00:00, 112.32it/s]


In [104]:
valid_results = [r for r in results if r[0] != -1]

In [111]:
valid_results.sort(key=lambda x: x[0])
valid_results

[(0.0, ('x@w', 'w+op1')),
 (0.0, ('x@w', 'op1+w')),
 (0.0, ('x@b', 'b+op1')),
 (0.0, ('x@b', 'op1+b')),
 (0.0, ('x*w', 'w+op1')),
 (0.0, ('x*w', 'op1+w')),
 (0.0, ('x*b', 'b+op1')),
 (0.0, ('x*b', 'op1+b')),
 (0.0, ('w*x', 'w+op1')),
 (0.0, ('w*x', 'op1+w')),
 (0.0, ('b*x', 'b+op1')),
 (0.0, ('b*x', 'op1+b')),
 (1.2768488488745788e-09, ('x+w', 'x+op1')),
 (1.2768488488745788e-09, ('x+w', 'op1+x')),
 (1.2768488488745788e-09, ('x+b', 'x+op1')),
 (1.2768488488745788e-09, ('x+b', 'op1+x')),
 (1.2768488488745788e-09, ('w+x', 'x+op1')),
 (1.2768488488745788e-09, ('w+x', 'op1+x')),
 (1.2768488488745788e-09, ('b+x', 'x+op1')),
 (1.2768488488745788e-09, ('b+x', 'op1+x')),
 (1.2768488488745788e-09, ('x-w', 'x+op1')),
 (1.2768488488745788e-09, ('x-w', 'op1+x')),
 (1.2768488488745788e-09, ('x-b', 'x+op1')),
 (1.2768488488745788e-09, ('x-b', 'op1+x')),
 (1.2768488488745788e-09, ('w-x', 'x-op1')),
 (1.2768488488745788e-09, ('b-x', 'x-op1')),
 (0.000676481518894434, ('x@w', 'b+op1')),
 (0.00067648151

In [114]:
invalid_results = [r for r in results if r[0] == -1]
invalid_results

[(-1, ('x@w', 'x@op1')),
 (-1, ('x@w', 'w@x')),
 (-1, ('x@w', 'w@b')),
 (-1, ('x@w', 'w@op1')),
 (-1, ('x@w', 'b@x')),
 (-1, ('x@w', 'b@w')),
 (-1, ('x@w', 'b@op1')),
 (-1, ('x@w', 'op1@x')),
 (-1, ('x@w', 'w+b')),
 (-1, ('x@w', 'b+w')),
 (-1, ('x@w', 'w*b')),
 (-1, ('x@w', 'b*w')),
 (-1, ('x@w', 'w-b')),
 (-1, ('x@w', 'b-w')),
 (-1, ('x@b', 'x@op1')),
 (-1, ('x@b', 'w@x')),
 (-1, ('x@b', 'w@b')),
 (-1, ('x@b', 'w@op1')),
 (-1, ('x@b', 'b@x')),
 (-1, ('x@b', 'b@w')),
 (-1, ('x@b', 'b@op1')),
 (-1, ('x@b', 'op1@x')),
 (-1, ('x@b', 'w+b')),
 (-1, ('x@b', 'b+w')),
 (-1, ('x@b', 'w*b')),
 (-1, ('x@b', 'b*w')),
 (-1, ('x@b', 'w-b')),
 (-1, ('x@b', 'b-w')),
 (-1, ('w@x', 'x@w')),
 (-1, ('w@x', 'x@b')),
 (-1, ('w@x', 'x@op1')),
 (-1, ('w@x', 'w@x')),
 (-1, ('w@x', 'w@b')),
 (-1, ('w@x', 'w@op1')),
 (-1, ('w@x', 'b@x')),
 (-1, ('w@x', 'b@w')),
 (-1, ('w@x', 'b@op1')),
 (-1, ('w@x', 'op1@x')),
 (-1, ('w@x', 'op1@w')),
 (-1, ('w@x', 'op1@b')),
 (-1, ('w@x', 'x+w')),
 (-1, ('w@x', 'x+b')),
 (-1, 