In [None]:
import pyomo.environ as pyo
from pyomo.opt import SolverFactory


In [None]:
model = pyo.ConcreteModel()
model.nVars = pyo.Param(initialize=4)
model.N = pyo.RangeSet(model.nVars)
model.x = pyo.Var(model.N, within=pyo.Binary)
model.obj = pyo.Objective(expr=pyo.summation(model.x))
model.cuts = pyo.ConstraintList()
opt = SolverFactory("glpk")
opt.solve(model)

# Iterate, adding a cut to exclude the previously found solution
for i in range(5):
    expr = 0
    for j in model.x:
        if pyo.value(model.x[j]) < 0.5:
            expr += model.x[j]
        else:
            expr += 1 - model.x[j]
    model.cuts.add(expr >= 1)
    results = opt.solve(model)
    print("\n===== iteration", i)
    model.display()

In [None]:
import equinox as eqx
import jax
import jax.numpy as jnp
import jax.random as jr
import jax.tree_util as jtu
import sympy

import sympy2jax

def assert_equal(x, y):
    x_leaves, x_tree = jtu.tree_flatten(x)
    y_leaves, y_tree = jtu.tree_flatten(y)
    assert x_tree == y_tree
    for xi, yi in zip(x_leaves, y_leaves):
        assert type(xi) is type(yi)
        if isinstance(xi, jnp.ndarray):
            assert xi.shape == yi.shape
            assert xi.dtype == yi.dtype
            assert jnp.all(xi == yi)
        else:
            assert xi == yi

def assert_sympy_allclose(x, y):
    assert isinstance(x, sympy.Expr)
    assert isinstance(y, sympy.Expr)
    assert x.func is y.func
    if isinstance(x, sympy.Float):
        assert abs(float(x) - float(y)) < 1e-5
    elif isinstance(x, sympy.Integer):
        assert x == y
    elif isinstance(x, sympy.Rational):
        assert x.numerator == y.numerator  # pyright: ignore
        assert x.denominator == y.denominator  # pyright: ignore
    elif isinstance(x, sympy.Symbol):
        assert x.name == y.name  # pyright: ignore
    else:
        assert len(x.args) == len(y.args)
        for xarg, yarg in zip(x.args, y.args):
            assert_sympy_allclose(xarg, yarg)

x_sym = sympy.symbols("x_sym")
y = 2.1 * x_sym**2
mod = sympy2jax.SymbolicModule(y)
x = jnp.array(1.1)

grad_m = eqx.filter_grad(lambda m, z: m(x_sym=z))(mod, x)
#print(grad_m)
grad_z = eqx.filter_grad(lambda z, m: m(x_sym=z))(x, mod)

true_grad_m = eqx.filter(
    sympy2jax.SymbolicModule(1.21 * x_sym**2), eqx.is_inexact_array
)
true_grad_z = jnp.array(4.2 * x)
print(grad_z)

assert_equal(grad_m, true_grad_m)
assert_equal(grad_z, true_grad_z)

mod2 = eqx.apply_updates(mod, grad_m)
expr = mod2.sympy()

assert_sympy_allclose(expr, 3.31 * x_sym**2)

R = sympy.symbols("R")
C = sympy.symbols("C")
edp = R*C + 2*C**2
mod = sympy2jax.SymbolicModule(edp)
R_i = jnp.array(2.0)
A_i = jnp.array(2.0)
C_i = jnp.array(3.0)
grad_C = eqx.filter_grad(lambda z, y, m: m(C=z, R=y))(C_i, R_i, mod)
print(grad_C)
grad_R = eqx.filter_grad(lambda a, z, y, m: m(A=z, R=a, C=y))(R_i, A_i, C_i, mod)
print(grad_R)

In [1]:
import equinox as eqx
import jax
import jax.numpy as jnp
import jax.random as jr
import jax.tree_util as jtu
import sympy
import hw_symbols
import sympy2jax
from codesign import Codesign
import copy


def rotate_arr(args_arr):
    next_val = args_arr[0]
    for i in range(len(args_arr))[::-1]:    
        tmp = next_val
        next_val = args_arr[(i)%len(args_arr)]
        args_arr[(i)%len(args_arr)] = tmp
    return args_arr

def get_grad(grad_var, args_arr, jmod):
    return eqx.filter_grad(lambda gvar, arr0, arr, a: a(grad_var = gvar,
                                                V_dd=arr0,
                                                f=arr[1],
                                                MemReadL=arr[2],
                                                MemWriteL=arr[3],
                                                MemReadPact=arr[4],
                                                MemWritePact=arr[5],
                                                MemPpass=arr[6],
                                                Reff_And=arr[7],
                                                Reff_Or=arr[8],
                                                Reff_Add=arr[9],
                                                Reff_Sub=arr[10],
                                                Reff_Mult=arr[11],
                                                Reff_FloorDiv=arr[12],
                                                Reff_Mod=arr[13],
                                                Reff_LShift=arr[14],
                                                Reff_RShift=arr[15],
                                                Reff_BitOr=arr[16],
                                                Reff_BitXor=arr[17],
                                                Reff_BitAnd=arr[18],
                                                Reff_Eq=arr[19],
                                                Reff_NotEq=arr[20],
                                                Reff_Lt=arr[21],
                                                Reff_LtE=arr[22],
                                                Reff_Gt=arr[23],
                                                Reff_GtE=arr[24],
                                                Reff_USub=arr[25],
                                                Reff_UAdd=arr[26],
                                                Reff_IsNot=arr[27],
                                                Reff_Not=arr[28],
                                                Reff_Invert=arr[29],
                                                Reff_Regs=arr[30],
                                                Ceff_And=arr[31],
                                                Ceff_Or=arr[32],
                                                Ceff_Add=arr[33],
                                                Ceff_Sub=arr[34],
                                                Ceff_Mult=arr[35],
                                                Ceff_FloorDiv=arr[36],
                                                Ceff_Mod=arr[37],
                                                Ceff_LShift=arr[38],
                                                Ceff_RShift=arr[39],
                                                Ceff_BitOr=arr[40],
                                                Ceff_BitXor=arr[41],
                                                Ceff_BitAnd=arr[42],
                                                Ceff_Eq=arr[43],
                                                Ceff_NotEq=arr[44],
                                                Ceff_Lt=arr[45],
                                                Ceff_LtE=arr[46],
                                                Ceff_Gt=arr[47],
                                                Ceff_GtE=arr[48],
                                                Ceff_USub=arr[49],
                                                Ceff_UAdd=arr[50],
                                                Ceff_IsNot=arr[51],
                                                Ceff_Not=arr[52],
                                                Ceff_Invert=arr[53],
                                                Ceff_Regs=arr[54]))(grad_var, args_arr[0], args_arr, jmod)

mod = Codesign("testme.py", "codesign_log_dir")
edp = 2*hw_symbols.symbol_table["Ceff_Add"] + 4*hw_symbols.symbol_table["Ceff_Regs"] + 3*hw_symbols.symbol_table["V_dd"]
#print(edp)
args_arr = []
starting_vals = []
for name in hw_symbols.symbol_table:
    args_arr.append(jnp.array(mod.tech_params[hw_symbols.symbol_table[name]]))
    starting_vals.append(mod.tech_params[hw_symbols.symbol_table[name]])
original_arr = copy.deepcopy(args_arr)

#print(args_arr)
grad_map = {}
grad_var = sympy.symbols("grad_var")
j = 0
for name in hw_symbols.symbol_table:
    m = {}
    i = 0
    for other_name in hw_symbols.symbol_table:
        m[hw_symbols.symbol_table[other_name]] = list(hw_symbols.symbol_table)[(i+j)%len(hw_symbols.symbol_table)]
        i += 1
    edp_cur = edp.subs({
        hw_symbols.symbol_table[name]: grad_var
    })
    print(edp_cur)
    jmod = sympy2jax.SymbolicModule(edp_cur)
    grad_map[name] = get_grad(args_arr[j], args_arr, jmod)
    print(name, grad_map[name])
    print(args_arr)
    j += 1
    #raise Exception()
raise Exception()

memory needed: 116 bytes
nvm memory needed: 0 bytes
memory needed: 116 bytes
nvm memory needed: 0 bytes
2*Ceff_Add + 4*Ceff_Regs + 3*grad_var
V_dd 3.0
[Array(1.1, dtype=float32, weak_type=True), Array(2.e+09, dtype=float32, weak_type=True), Array(2.e-09, dtype=float32, weak_type=True), Array(2.e-09, dtype=float32, weak_type=True), Array(1.e-07, dtype=float32, weak_type=True), Array(1.e-07, dtype=float32, weak_type=True), Array(8.448e-08, dtype=float32, weak_type=True), Array(2010.4232, dtype=float32, weak_type=True), Array(2010.4232, dtype=float32, weak_type=True), Array(2004.8353, dtype=float32, weak_type=True), Array(2004.8353, dtype=float32, weak_type=True), Array(46.949585, dtype=float32, weak_type=True), Array(46.949585, dtype=float32, weak_type=True), Array(2004.8353, dtype=float32, weak_type=True), Array(1014.97614, dtype=float32, weak_type=True), Array(1014.97614, dtype=float32, weak_type=True), Array(2010.4232, dtype=float32, weak_type=True), Array(2010.4232, dtype=float32, we

Exception: 

In [None]:
import pandas as pd

# For now, only reading from asap7 data
# Will update in the future with more
f = open("tech_node_data/asap7data.txt", "r")
fl = f.readlines()
d = []
name = []
technode = []
for i in range(len(fl)):
    if fl[i].startswith("cell") and len(fl[i+2])>1:
        name.append(fl[i][fl[i].find("(")+1:fl[i].find("_")])
        technode.append(7)
        n = name[-1]
        d.append([float(fl[i+1].split(" ")[1])*1e-6, float(fl[i+5].split(" ")[1])*1e-15, float(fl[i+6].split(" ")[1])])
ind = pd.MultiIndex.from_arrays([technode, name], names=("tech node", "standard cell"))
df = pd.DataFrame(data=d, index=ind, columns=["area", "R_eff", "C_eff"])
df.to_csv("params/std_cell_data.csv")