## Author: Terrill Toe
The initial draft to the MLsystems project.

The objective:

Use a graph profiler construct a computational graph. The graph will encapsulate all the operations within an iteration of training a model. The nodes are individual operations and the edges will represent the dependencies of input and output data

Deliverables:

- [ ] Presentation containing the following 
    - [ ] Description of the intended design for the whole project
    - [ ] pseudocode of each component with explanation.
    - [ ] current progress
    - [ ] experimental results obtained up to the midpoint
- [ ] A design document describing the first phase of the project and experimental analysis that needs to be uploaded as a PDF on Canvas.
    - [ ] design of the profiler
    - [ ] pseudocode of each component with explanation.
    - [ ] experimental analysis consisting of deliverables 4(a): Computation and memory profiling statistics and static analysis and 4(b) Peak memory consumption vs mini-batch size bar graph [w/o AC]. In general the experimental analysis on the document for each experiment should include a paragraph that describes each of the experiments, a paragraph that describes observations and a graph that demonstrates the results. When presenting, every experiment can be presented in a single slide using the graph and just enough textual info to understand the setup and results.


In [2]:
# installs
# !pip3 install chardet
import numpy
from copy import deepcopy
from functools import wraps
import os
import logging
import torch.multiprocessing as mp
import torch
import torch.nn as nn
import torch.distributed as dist
import torch.fx as fx
from torch.fx.experimental.proxy_tensor import make_fx
from torch.distributed._functional_collectives import all_reduce
from torch.nn.parallel import DistributedDataParallel as DDP
from graph_prof import GraphProfiler
from benchmarks import Experiment



ModuleNotFoundError: No module named 'torchbenchmark'

In [2]:
class TTModel(nn.Module):
    """
    The TTModel for testing of graph profilers
    """
    def __init__(self, layers: int, dim: int):
        super().__init__()
        modules = []
        for _ in range(layers):
            modules.extend([nn.Linear(dim, dim), nn.ReLU()])
        self.mod = nn.Sequential(*modules)
        
    def forward(self, x):
        return self.mod(x)
    
def training_step(
        model: torch.nn.Module, optim: torch.optim.Optimizer, batch: torch.Tensor
):
    out: torch.Tensor = model(batch)
    out.sum().backward()
    optim.step()
    optim.zero_grad()
    
def run_worker(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'
    os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
    logging.getLogger().setLevel(logging.DEBUG if rank == 0 else logging.CRITICAL)
    if torch.backends.mps.is_available():
        logging.info(f"Torch MPS is available for this MacOS device")
    else:
        raise ValueError(f"Torch MPS is not available for this MacOS")
    if rank is None or world_size is None:
        dist.init_process_group(backend="nccl")
    # else:
    #     # dist.init_process_group(backend="gloo", rank=rank, world_size=world_size)
    logging.info(f"Number of visisble devices: {torch.cuda.device_count()}")
    # set the device to mps instead of cuda
    mps_device = torch.device("mps")
    # torch.cuda.set_device(rank)
    torch.manual_seed(20)
    batch_size = 100
    layers = 10
    dim = 100
    num_iters = 5
    model = TTModel(layers, dim).to(mps_device)
    batch = torch.randn(batch_size, dim).to(mps_device)
    optim = torch.optim.Adam(
        model.parameters(), lr=0.01, foreach=False, fused=False, capturable=True
    )
    
    for param in model.parameters():
        if param.requires_grad:
            param.register_hook(all_reduce)

In [3]:
# torch.cuda.current_device()
if torch.backends.mps.is_available():
    mps_device = torch.device("mps")
    x = torch.ones(1, device=mps_device)
    print (x)
else:
    print ("MPS device not found.")

tensor([1.], device='mps:0')


In [4]:
# check if the backend is set up on the mac
print(torch.backends.mps.is_available()) #the MacOS is higher than 12.3+
print(torch.backends.mps.is_built()) #MPS is activated
print(torch.backends.mps.is_macos13_or_newer())

True
True
True


In [5]:
# run the worker to check the model can train
rank = dist.get_rank()
world_size = dist.get_world_size()
run_worker(rank, world_size)

ValueError: Default process group has not been initialized, please make sure to call init_process_group.

In [7]:
# create a graph profiler of the model
layers = 10
dim = 100
num_iters = 5
model = TTModel(layers, dim).to(mps_device)
TT_nn_graph = fx.symbolic_trace(model)
print(type(TT_nn_graph))

<class 'torch.fx.graph_module.GraphModule.__new__.<locals>.GraphModuleImpl'>


In [9]:
profiler_example = GraphProfiler(module=TT_nn_graph)

Node name:  x
Node type:  placeholder
Node target:  x
Input to this node []
Users of this node:  {mod_0: None}
Node name:  mod_0
Node type:  call_module
Node target:  mod.0
Input to this node [x]
Users of this node:  {mod_1: None}
Node name:  mod_1
Node type:  call_module
Node target:  mod.1
Input to this node [mod_0]
Users of this node:  {mod_2: None}
Node name:  mod_2
Node type:  call_module
Node target:  mod.2
Input to this node [mod_1]
Users of this node:  {mod_3: None}
Node name:  mod_3
Node type:  call_module
Node target:  mod.3
Input to this node [mod_2]
Users of this node:  {mod_4: None}
Node name:  mod_4
Node type:  call_module
Node target:  mod.4
Input to this node [mod_3]
Users of this node:  {mod_5: None}
Node name:  mod_5
Node type:  call_module
Node target:  mod.5
Input to this node [mod_4]
Users of this node:  {mod_6: None}
Node name:  mod_6
Node type:  call_module
Node target:  mod.6
Input to this node [mod_5]
Users of this node:  {mod_7: None}
Node name:  mod_7
Node ty

In [10]:
# get the node definitions
profiler_example.module.graph.print_tabular()

opcode       name    target    args       kwargs
-----------  ------  --------  ---------  --------
placeholder  x       x         ()         {}
call_module  mod_0   mod.0     (x,)       {}
call_module  mod_1   mod.1     (mod_0,)   {}
call_module  mod_2   mod.2     (mod_1,)   {}
call_module  mod_3   mod.3     (mod_2,)   {}
call_module  mod_4   mod.4     (mod_3,)   {}
call_module  mod_5   mod.5     (mod_4,)   {}
call_module  mod_6   mod.6     (mod_5,)   {}
call_module  mod_7   mod.7     (mod_6,)   {}
call_module  mod_8   mod.8     (mod_7,)   {}
call_module  mod_9   mod.9     (mod_8,)   {}
call_module  mod_10  mod.10    (mod_9,)   {}
call_module  mod_11  mod.11    (mod_10,)  {}
call_module  mod_12  mod.12    (mod_11,)  {}
call_module  mod_13  mod.13    (mod_12,)  {}
call_module  mod_14  mod.14    (mod_13,)  {}
call_module  mod_15  mod.15    (mod_14,)  {}
call_module  mod_16  mod.16    (mod_15,)  {}
call_module  mod_17  mod.17    (mod_16,)  {}
call_module  mod_18  mod.18    (mod_17,)  {}


In [14]:
for _ in range(10):
    batch = torch.randn(10, dim).to(mps_device)
    profiler_example.run(batch)

In [15]:
profiler_example.runtimes_sec

{x: [0.0003540515899658203,
  0.00014090538024902344,
  1.5020370483398438e-05,
  1.2159347534179688e-05,
  1.0967254638671875e-05,
  1.3113021850585938e-05,
  1.0967254638671875e-05,
  7.152557373046875e-06,
  8.821487426757812e-06,
  9.059906005859375e-06,
  6.9141387939453125e-06],
 mod_0: [0.2624201774597168,
  0.001909017562866211,
  0.0002300739288330078,
  0.00028896331787109375,
  0.00015211105346679688,
  0.00016617774963378906,
  0.00013589859008789062,
  0.000102996826171875,
  0.00013017654418945312,
  9.799003601074219e-05,
  0.00011277198791503906],
 mod_1: [0.041297197341918945,
  0.0005528926849365234,
  9.703636169433594e-05,
  9.918212890625e-05,
  8.797645568847656e-05,
  0.00010180473327636719,
  9.179115295410156e-05,
  4.982948303222656e-05,
  6.29425048828125e-05,
  5.817413330078125e-05,
  5.626678466796875e-05],
 mod_2: [0.00026488304138183594,
  0.0001728534698486328,
  9.608268737792969e-05,
  8.606910705566406e-05,
  8.916854858398438e-05,
  8.70227813720703