Skip to content

Commit

Permalink
Introduce symbol linking for mlir
Browse files Browse the repository at this point in the history
Fixes #84
  • Loading branch information
alexarice committed May 30, 2024
1 parent d49e413 commit 5fd8984
Show file tree
Hide file tree
Showing 16 changed files with 10,872 additions and 12,647 deletions.
3 changes: 3 additions & 0 deletions examples/mlir/sym.mlir
Git LFS file not shown
14 changes: 9 additions & 5 deletions sd-core/src/free_vars.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@ impl<T: Language> Expr<T> {
to_remove: &mut IndexSet<T::Var>,
) {
for bind in &self.binds {
bind.value.free_vars(vars);
bind.value.free_vars(vars, to_remove);
}

for value in &self.values {
value.free_vars(vars);
value.free_vars(vars, to_remove);
}

to_remove.extend(
Expand All @@ -43,18 +43,22 @@ impl<T: Language> Expr<T> {
}

impl<T: Language> Value<T> {
pub(crate) fn free_vars(&self, vars: &mut IndexSet<T::Var>) {
pub(crate) fn free_vars(&self, vars: &mut IndexSet<T::Var>, to_remove: &mut IndexSet<T::Var>) {
match self {
Value::Variable(v) => {
vars.insert(v.clone());
}
Value::Thunk(thunk) => {
thunk.free_vars(vars);
}
Value::Op { args, .. } => {
Value::Op { op, args } => {
for arg in args {
arg.free_vars(vars);
arg.free_vars(vars, to_remove);
}
if let Some(s) = op.sym_name() {
to_remove.insert(s.into());
}
vars.extend(op.symbols_used().map(|s| s.into()))
}
}
}
Expand Down
51 changes: 42 additions & 9 deletions sd-core/src/graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use std::{

use derivative::Derivative;
use either::Either::{self, Left, Right};
use itertools::Itertools;
#[cfg(test)]
use serde::Serialize;
use thiserror::Error;
Expand Down Expand Up @@ -56,10 +57,11 @@ pub enum Name<T: Language> {

impl<T: Language> WithType for Name<T> {
fn get_type(&self) -> WireType {
if matches!(self, Name::CF(_)) {
WireType::ControlFlow
} else {
WireType::Data
match self {
Name::CF(_) => WireType::ControlFlow,
Name::Nil => WireType::Data,
Name::FreeVar(v) => v.get_type(),
Name::BoundVar(v) => v.var().get_type(),
}
}
}
Expand Down Expand Up @@ -111,6 +113,8 @@ pub enum ConvertError<T: Language> {
Shadowed(T::Var),
#[error("Fragment did not have output")]
NoOutputError,
#[error("Uninitialised Inports for variables: {0:?}")]
UnitialisedInput(Vec<T::Var>),
}

/// Environments capture the local information needed to build a hypergraph from an AST
Expand Down Expand Up @@ -317,6 +321,10 @@ where
ProcessInput::InPort(_) => vec![Name::Nil],
};

if let Some(symbol) = op.sym_name() {
output_weights.push(Name::FreeVar(symbol.into()))
}

let cf = op.get_cf();

match &cf {
Expand All @@ -329,10 +337,22 @@ where
None => {}
}

let operation_node =
self.fragment
.add_operation(args.len(), output_weights, op.clone());
for (arg, in_port) in args.iter().rev().zip(operation_node.inputs().rev()) {
let symbol: Vec<_> = op.symbols_used().collect();

let len = args.len() + symbol.len();

let operation_node = self.fragment.add_operation(len, output_weights, op.clone());

let mut inputs = operation_node.inputs().rev();
self.inputs.extend(
symbol
.into_iter()
.rev()
.map_into()
.zip(inputs.by_ref())
.map(|(x, y)| (y, x)),
);
for (arg, in_port) in args.iter().rev().zip(inputs) {
self.process_value(arg, ProcessInput::InPort(in_port))?;
}

Expand All @@ -350,14 +370,21 @@ where

match input {
ProcessInput::Variables(inputs) => {
for (input, out_port) in inputs.into_iter().zip(out_ports) {
for (input, out_port) in inputs.into_iter().zip(out_ports.by_ref()) {
let var = input.var();
self.outputs
.insert(var.clone(), out_port)
.is_none()
.then_some(())
.ok_or(ConvertError::Shadowed(var.clone()))?;
}
if let Some(symbol) = op.sym_name() {
self.outputs
.insert(symbol.clone().into(), out_ports.next().unwrap())
.is_none()
.then_some(())
.ok_or(ConvertError::Shadowed(symbol.into()))?;
}
}
ProcessInput::InPort(in_port) => {
self.fragment.link(out_ports.next().unwrap(), in_port)?;
Expand Down Expand Up @@ -444,6 +471,12 @@ where

debug!("Expression processed");

if !env.inputs.is_empty() {
return Err(ConvertError::UnitialisedInput(
env.inputs.into_iter().map(|x| x.1).collect(),
));
}

Ok(env.fragment.build()?)
}
}
Expand Down
18 changes: 17 additions & 1 deletion sd-core/src/language/chil.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@ use pest_derive::Parser;
use serde::Serialize;

use super::{span_into_str, Fresh, GetVar, OpInfo};
use crate::common::{Empty, Matchable};
use crate::{
common::{Empty, Matchable},
hypergraph::traits::{WireType, WithType},
};

pub struct Chil;

Expand All @@ -23,6 +26,7 @@ impl super::Language for Chil {
type Addr = Addr;
type VarDef = VariableDef;
type BlockAddr = Empty;
type Symbol = Empty;
}

pub type Expr = super::Expr<Chil>;
Expand Down Expand Up @@ -126,6 +130,18 @@ pub struct Variable {
pub addr: Addr,
}

impl WithType for Variable {
fn get_type(&self) -> WireType {
WireType::Data
}
}

impl From<Empty> for Variable {
fn from(value: Empty) -> Self {
match value {}
}
}

impl Matchable for Variable {
fn is_match(&self, query: &str) -> bool {
// If a variable is "foo(id: %0)", then we match "foo(id: %0)", "foo", and "%0".
Expand Down
2 changes: 1 addition & 1 deletion sd-core/src/language/mlir.pest
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ toplevelitem = { operation | attribute_alias_def | type_alias_def }
/// https://mlir.llvm.org/docs/LangRef/#identifiers-and-keywords
bare_id = @{ (ASCII_ALPHA | "_") ~ (ASCII_ALPHANUMERIC | "_" | "$" | ".")* }
bare_id_list = { bare_id ~ ("," ~ bare_id)* }
symbol_ref_id = @{ "@" ~ (suffix_id | string_literal) ~ ("::" ~ symbol_ref_id)? }
symbol_ref_id = @{ "@" ~ (string_literal | suffix_id) ~ ("::" ~ symbol_ref_id)? }
value_id = @{ "%" ~ suffix_id }
alias_name = { bare_id }
suffix_id = @{ ASCII_DIGIT+ | ((ASCII_ALPHA | id_punct) ~ (ASCII_ALPHANUMERIC | id_punct)*) }
Expand Down
2 changes: 1 addition & 1 deletion sd-core/src/language/mlir/internal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ impl Attribute {
self.1.as_ref().and_then(|x| {
let mut chars = x.0.chars();
if let Some('@') = chars.next() {
Some(chars.filter(|c| *c != '@').collect())
Some(chars.filter(|c| *c != '@' && *c != '\"').collect())
} else {
None
}
Expand Down
101 changes: 73 additions & 28 deletions sd-core/src/language/mlir/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,17 @@ use itertools::Itertools;

pub mod internal;

use pretty::RcDoc;
#[cfg(test)]
use serde::Serialize;

use self::internal::Attribute;
use super::{Fresh, Language, OpInfo, CF};
use crate::common::{Matchable, Unit};
use crate::{
common::{Matchable, Unit},
hypergraph::traits::{WireType, WithType},
prettyprinter::PrettyPrint,
};

pub struct Mlir;

Expand All @@ -21,6 +26,7 @@ impl Language for Mlir {
type Addr = Unit;
type VarDef = Var;
type BlockAddr = BlockAddr;
type Symbol = Symbol;
}

pub type Expr = super::Expr<Mlir>;
Expand Down Expand Up @@ -63,43 +69,62 @@ impl OpInfo<Mlir> for Op {
}
}

fn symbol_use(&self) -> impl Iterator<Item = &str> {
self.symbols.iter().map(|x| x.as_str())
fn symbols_used(&self) -> impl Iterator<Item = Symbol> {
self.symbols.iter().map(|x| Symbol(x.clone()))
}

fn sym_name(&self) -> Option<&str> {
self.sym_name.as_deref()
fn sym_name(&self) -> Option<Symbol> {
self.sym_name.as_ref().map(|x| Symbol(x.clone()))
}
}

#[derive(Clone, Eq, PartialEq, Hash, Debug)]
#[cfg_attr(test, derive(Serialize))]
pub struct Var {
pub id: String,
pub index: Option<usize>,
pub enum Var {
Var { id: String },
VarIdx { id: String, index: usize },
Symbol(Symbol),
}

impl WithType for Var {
fn get_type(&self) -> WireType {
match self {
Var::Symbol(_) => WireType::SymName,
_ => WireType::Data,
}
}
}

impl From<Symbol> for Var {
fn from(value: Symbol) -> Self {
Var::Symbol(value)
}
}

impl Display for Var {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
if let Some(idx) = self.index {
write!(f, "{}#{idx}", self.id)
} else {
f.write_str(&self.id)
match self {
Var::Var { id } => f.write_str(id),
Var::VarIdx { id, index } => write!(f, "{id}#{index}"),
Var::Symbol(s) => s.fmt(f),
}
}
}

impl Matchable for Var {
fn is_match(&self, query: &str) -> bool {
self.id == query
match self {
Var::Var { id } => id == query,
Var::VarIdx { id, .. } => id == query,
Var::Symbol(s) => s.is_match(query),
}
}
}

impl Fresh for Var {
fn fresh(number: usize) -> Self {
Var {
Var::Var {
id: format!("?{number}"),
index: None,
}
}
}
Expand All @@ -120,40 +145,60 @@ impl Matchable for BlockAddr {
}
}

#[derive(Clone, Eq, PartialEq, Hash, Debug)]
#[cfg_attr(test, derive(Serialize))]
pub struct Symbol(pub String);

impl Display for Symbol {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(&self.0)
}
}

impl Matchable for Symbol {
fn is_match(&self, query: &str) -> bool {
self.0 == query || self.0 == query.chars().filter(|c| c != &'@').collect::<String>()
}
}

impl PrettyPrint for Symbol {
fn to_doc(&self) -> pretty::RcDoc<'_, ()> {
RcDoc::text(&self.0)
}
}

// Conversion from internal AST.

impl From<internal::Value> for Var {
fn from(value: internal::Value) -> Self {
Var {
id: value.id,
index: value.index.map(|idx| idx.0),
if let Some(i) = value.index {
Var::VarIdx {
id: value.id,
index: i.0,
}
} else {
Var::Var { id: value.id }
}
}
}

impl From<internal::TypedArg> for Var {
fn from(arg: internal::TypedArg) -> Self {
Var {
id: arg.id,
index: None,
}
Var::Var { id: arg.id }
}
}

impl From<internal::OpResult> for Vec<Var> {
fn from(op_result: internal::OpResult) -> Vec<Var> {
if let Some(idx) = op_result.index {
(0..idx.0)
.map(|x| Var {
.map(|x| Var::VarIdx {
id: op_result.id.clone(),
index: Some(x),
index: x,
})
.collect()
} else {
vec![Var {
id: op_result.id,
index: None,
}]
vec![Var::Var { id: op_result.id }]
}
}
}
Expand Down
Loading

0 comments on commit 5fd8984

Please sign in to comment.