In [1]:
import os
import matplotlib.pyplot as plt

from spmmsim.utils.simulate.all_layers_sim import all_layers_sim as sim
# from spmmsim.utils.report.report import report as rpt
from spmmsim.utils.config.architecture import architecture as arch
from spmmsim.utils.config.workload import workload 


class spmm_simulator:
	def __init__(self,
				arch_obj='', 
				workload_obj='',
				rpt_path='', 
				rst_path='', 
				verbosity=True):
		self.output_path = "./"
		self.verbose = True

		# Data structures
		self.arch = arch()
		self.workload = workload()

		self.architecture_file = ''
		self.workload_file = ''
		
		self.set_params(arch_filename=arch_obj, workload_filename=workload_obj)
		
	def set_params(self, arch_filename='', workload_filename='' ):
    	# 1. check if the user provided a valid workload file
		if not workload_filename == '':
			if not os.path.exists(workload_filename):
				print("ERROR: workload file not found")
				print("Input file:" + workload_filename)
				print('Exiting')
				exit()
			else:
				self.workload_file = workload_filename
			
			if not os.path.exists(arch_filename):
				print("ERROR: scalesim.scale.py: Config file not found") 
				print("Input file:" + arch_filename)
				print('Exiting')
				exit()
			else: 
				self.architecture_file = arch_filename

        # 2. Parse the architecture
		self.arch.read_arch_file(self.architecture_file)

        # 3. Parse the workload
		self.arch.set_workload_file(self.workload_file)
		self.workload.load_arrays(workload=self.workload_file, sparsity_inputs=True)


	def simulate(self):
		sim(rst_path, rpt_path)
		sim.run()
	
	# def report(self):
	# 	rpt.run()
			
	# def plot(self):
	# 	plt()
	

In [2]:
arch_path="./configs/arch/eyeriss.cfg"
workload_path="./configs/workload/test.csv"	
rpt_path="./output/rpt"	
rst_path="./output/rst"	

In [3]:
s = spmm_simulator(arch_obj=arch_path, 
                   workload_obj=workload_path, 
                   rpt_path=rpt_path, 
                   rst_path=rst_path, 
                   verbosity=True)

Load SparseMM from./configs/workload/test.csv


In [4]:
s.simulate()

In [5]:
# s.report()

In [6]:
# s.plot()