# 00 Torch Ember Core
> Analyzing How Model Improves

In [1]:
# default_exp core

In this tutorial, we'll use AlexNet as example, 
We can load AlexNet from ```torchvision```

By:
* Xiaochen Zhang
* Lai Wei

In [2]:
from torchvision.models.alexnet import AlexNet
import torch

In [3]:
model = AlexNet()

#### Sample data
Create a sample data, something like 2 normalized images in a batch, size 224,224

In [4]:
samp = (torch.rand(2,3,224,224)-1)*2

## Torch Ember Core

The essence of torch ember, is to place trackers within modules.

It will decorate the ```forward``` function to achieve following purpose

* What variables come in/out of the module
* The happening sequence, containing relationships between sub-modules
* The statistics we want for further analysis, eg.
    * Min, Max, Mean, Std, of input / outpout tensors
    * Min, Max, Mean, Std, of model weights at this iteration
    * Min, Max, Mean, Std, of model weights grad at this iteration

In [5]:
# export
from types import MethodType
from datetime import datetime
from torchember.helper import color,emberTracker
from functools import partial
import os

class moduleTrack(object):
    def __init__(self,module, name=None, root_module = False):
        self.module = module
        module.module_tracker = self
        
        self.base_module = True if len(list(module.modules()))==1 else False
        self.root_module = root_module
        
        self.name = name if name else module.__class__.__name__  
        #self.name = f'{name}_tracker' if name else f'{module.__class__.__name__}_tracker'
        self.id = id(module)
        self.children = []
        
    def __repr__(self):
        rt = f"<{self.name} @ {hex(self.id)}>"
        if hasattr(self,"input_dt"):
            rt+=f'\n\t[Inputs]{",".join(list(k+" "+str(list(v.shape)) for k,v in self.input_dt.items()))}'
        if hasattr(self,"output_dt"):
            rt+=f'\n\t[Outputs]{",".join(list(str(list(v.shape)) for v in self.output_dt))}'
        return rt

def get_stats(tensor):
    """
    The default statistic method, it will capture
    shape of the tensor
    mean, std, max, min of the tensor
    this will return a dictionary
    """
    def list_prod(l):
        result=1
        for i in l:
            result*=i
        return result
    return {"shape":list(tensor.shape),
            "mean":tensor.float().mean().item(), 
            "std":tensor.float().std().item(), 
            "max":tensor.float().max().item(), 
            "min":tensor.float().min().item(),
            "cnt_zero": ((tensor>-1e-10) & (tensor < 1e-10)).sum().item(),
            "zero_pct": float(((tensor>-1e-10) & (tensor < 1e-10)).sum().item())/list_prod(tensor.shape)}


    
class torchEmber(object):
    def __init__(self, model, verbose = True):
        color.green|"start analyzing model"
        self.modules = dict()
        self.verbose = verbose
        self.model = model
        
        if hasattr(model,"disarm"):
            model.disarm()
        
        self.model_name = self.model.__class__.__name__
        
        fname = f"{self.model_name}_{self.ts_str}"
        self.fname = fname
        
        self.t = emberTracker(fname)
        self.current_mt = None
        self.mt_log = []
        self.record_extra = False
        
        self.arm()
        
        self.legit_ttypes = ["in","out","weight"]
        for ttype in self.legit_ttypes: self.set_metric(ttype)(get_stats)

        if self.verbose: 
            color.green|f"[INFO][{self.ts_str}]Creating meta data"
        self.t[f"base_{fname}"]={"start":self.t.ts, 
                                 "user":os.environ["USER"]}
        self.t[f"vis_{fname}"] = {"vis_type":"standard"}
        self.t[f"structure_{fname}"] = self.mod_tree()
        
    def mark(self,**kwargs):
        self.t.mark(**kwargs)
        
    def parse_module(self,model, name, root_module = False):
        name = f"{name}({model.__class__.__name__})"
        mt = moduleTrack(model, name, root_module)
        self.modules[name]= mt
        model.forward = self.module_register(name,model)
        
        for cname,children in model.named_children():
            children_mt = self.parse_module(children,f"{name}.{cname}" )
            children_mt.parent = mt
            mt.children.append(children_mt)
        return mt
    
    def mod_tree(self):
        """
        Return the tree of module
        """
        return self.mod_tree_parse(self.model.module_tracker)
        
    def mod_tree_parse(self,mt):
        rt = {"name":mt.name, "short":mt.name.split(".")[-1]}
        if len(mt.children)>0:
            rt.update({"children":list(self.mod_tree_parse(i) for i in mt.children)})
        return rt
                
        
    @property
    def ts_str(self):
        return datetime.now().strftime("%Y%m%d_%H%M%S")
    
    @property
    def ts(self):
        return datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        
    def arm(self):
        """
        arming the tracing function to self.model
        """
        if self.verbose: 
            color.yellow|f"[ARMING][START]{self.ts}"
        self.parse_module(self.model,"model", root_module = True)
        if self.verbose: 
            color.yellow|f"[ARMING][SUCCESS]{self.ts}"
            
    def disarm(self):
        """remove the tracing function"""
        for m in self.modules.values():
            if self.verbose: 
                color.blue|f"[DISARM][{m.name}]{self.ts}"
            self.recover(m)
        color.blue|f"[DISARM][DONE]{self.ts}"
            
    def recover(self, m):
        if hasattr(m.module.forward,"former"):
            m.module.forward = m.module.forward.former
            
    def rearm(self):
        self.disarm()
        self.arm()
    
    def reg_check(self,m):
        """
        register check
        """
        if hasattr(m.forward,"armed"):
            if m.forward.armed:
                return False
        return True
    
    def set_metric(self, ttype):
        assert ttype in self.legit_ttypes, f"ttype has to be one of {str(self.ttypes)}"
        def deco(f):
            setattr(self,f"record_{ttype}_core",self.record_core(f))
            return f
        return deco
    
    def add_record(f):
        def _inner(self, f_name): return partial(f, self, f_name)
        return _inner
    
    @add_record
    def record_core(self, f_name, tensor, extra_data):
        """
        extra_data: dict
        """
        dict_= f_name(tensor)
        dict_.update(extra_data)
        self.t(dict_)
        return dict_
    
    def record_input(self,mt):
        """
        Record the input tensors of the moduleTrack
        """
        for k,tensor in mt.input_dt.items():
            extra_data= {"module":mt.name,"ts":self.t.ts,"ttype":"input","tname":k}
            if self.record_extra: self.add_extra_info(extra_data)
            self.record_in_core(tensor, extra_data)
            
    def record_output(self,mt):
        """
        Record the output tensors of the moduleTrack
        """
        for i in range(len(mt.output_dt)):
            tensor = mt.output_dt[i]
            extra_data = {"module":mt.name,"ts":self.t.ts,"ttype":"output","tname":f"output_{i}"}
            if self.record_extra:self.add_extra_info(extra_data)
            self.record_out_core(tensor,extra_data)
            
    def record_weight(self,mt):
        """
        Record the weights of the moduleTrack
        """
        if mt.base_module:
            i = 0
            for p in mt.module.parameters():
                extra_data={"module":mt.name,"ts":self.t.ts,
                                            "ttype":"weight","tname":f"weight_{i}"}
                if self.record_extra: self.add_extra_info(extra_data)
                self.record_weight_core(p.data, extra_data)
                if p.requires_grad and (p.grad!= None):
                    extra_data={"module":mt.name,"ts":self.t.ts,
                                            "ttype":"weight_grad","tname":f"grad_{i}"}
                    if self.record_extra: self.add_extra_info(extra_data)
                    self.record_weight_core(p.grad, extra_data)
                i+=1
                
    def add_extra(self, **kwargs):
        """
        Record the epoch # and batch #, in order to track the change of parameters over training process.
        After the model is armed, when users put model in training loop, have option to set it up. 
        """
        self.record_extra = True
        self.extra_info={}
        for key, value in kwargs.items():
            self.extra_info.update({f'{key}': value})
        
    def add_extra_info(self,extra_data):
        extra_data.update(self.extra_info)
    
    def after_train(self):
        """
        reset record batch after training
        """
        if self.record_extra: 
            self.record_extra=False
            self.extra_info = None
        
        
    def module_register(self,name,m):
        if self.reg_check(m) == False: return m.forward
        f = m.forward
        mt = self.modules[name]
        vs = f.__code__.co_varnames
        mt.vars = vs[1:]
        if self.verbose: 
            color.cyan | f"[BUILD FORWARD][{name}]{self.ts}"
        def new_forward(*args,**kwargs):
            mt.input_dt = dict(zip(mt.vars[:len(args)],args))
            mt.input_dt.update(kwargs)
            
            self.record_input(mt)
            self.current_mt = mt
            if mt.root_module: self.mt_log=[]
            self.mt_log.append(f"enter {mt.name}")
            
            # ------execution of the function------
            outputs = f(*args,**kwargs)
            self.record_weight(mt)
            # ------execution of the function------
            
            self.mt_log.append(f"exit {mt.name}")
            
            if type(outputs) in [list,tuple]:
                mt.output_dt = [outputs]
            else:
                mt.output_dt = [outputs,]
            self.record_output(mt)
            
            if mt.root_module:
                self.t.refresh() # start a new "latest" file
            
            return outputs
        
        setattr(new_forward,"armed",True)
        setattr(new_forward,"former",f)
        
        def disarm(this):
            """
            Remove the trackers placed by torchember
            run model.disarm()
            """
            self.disarm()
            return this
        setattr(mt.module, "disarm",MethodType(disarm,mt.module))
        return new_forward

### Tracking a model !!

Start tracking a model

In [6]:
te = torchEmber(model)

[92mstart analyzing model[0m
[93m[ARMING][START]2020-03-07 10:37:37[0m
[96m[BUILD FORWARD][model(AlexNet)]2020-03-07 10:37:37[0m
[96m[BUILD FORWARD][model(AlexNet).features(Sequential)]2020-03-07 10:37:37[0m
[96m[BUILD FORWARD][model(AlexNet).features(Sequential).0(Conv2d)]2020-03-07 10:37:37[0m
[96m[BUILD FORWARD][model(AlexNet).features(Sequential).1(ReLU)]2020-03-07 10:37:37[0m
[96m[BUILD FORWARD][model(AlexNet).features(Sequential).2(MaxPool2d)]2020-03-07 10:37:37[0m
[96m[BUILD FORWARD][model(AlexNet).features(Sequential).3(Conv2d)]2020-03-07 10:37:37[0m
[96m[BUILD FORWARD][model(AlexNet).features(Sequential).4(ReLU)]2020-03-07 10:37:37[0m
[96m[BUILD FORWARD][model(AlexNet).features(Sequential).5(MaxPool2d)]2020-03-07 10:37:37[0m
[96m[BUILD FORWARD][model(AlexNet).features(Sequential).6(Conv2d)]2020-03-07 10:37:37[0m
[96m[BUILD FORWARD][model(AlexNet).features(Sequential).7(ReLU)]2020-03-07 10:37:37[0m
[96m[BUILD FORWARD][model(AlexNet).features(Sequential)

Remove the trackers we placed

In [7]:
model = model.disarm()

[94m[DISARM][model(AlexNet)]2020-03-07 10:37:39[0m
[94m[DISARM][model(AlexNet).features(Sequential)]2020-03-07 10:37:39[0m
[94m[DISARM][model(AlexNet).features(Sequential).0(Conv2d)]2020-03-07 10:37:39[0m
[94m[DISARM][model(AlexNet).features(Sequential).1(ReLU)]2020-03-07 10:37:39[0m
[94m[DISARM][model(AlexNet).features(Sequential).2(MaxPool2d)]2020-03-07 10:37:39[0m
[94m[DISARM][model(AlexNet).features(Sequential).3(Conv2d)]2020-03-07 10:37:39[0m
[94m[DISARM][model(AlexNet).features(Sequential).4(ReLU)]2020-03-07 10:37:39[0m
[94m[DISARM][model(AlexNet).features(Sequential).5(MaxPool2d)]2020-03-07 10:37:39[0m
[94m[DISARM][model(AlexNet).features(Sequential).6(Conv2d)]2020-03-07 10:37:39[0m
[94m[DISARM][model(AlexNet).features(Sequential).7(ReLU)]2020-03-07 10:37:39[0m
[94m[DISARM][model(AlexNet).features(Sequential).8(Conv2d)]2020-03-07 10:37:39[0m
[94m[DISARM][model(AlexNet).features(Sequential).9(ReLU)]2020-03-07 10:37:39[0m
[94m[DISARM][model(AlexNet).featur

Or like this

In [8]:
te.disarm()

[94m[DISARM][model(AlexNet)]2020-03-07 10:37:40[0m
[94m[DISARM][model(AlexNet).features(Sequential)]2020-03-07 10:37:40[0m
[94m[DISARM][model(AlexNet).features(Sequential).0(Conv2d)]2020-03-07 10:37:40[0m
[94m[DISARM][model(AlexNet).features(Sequential).1(ReLU)]2020-03-07 10:37:40[0m
[94m[DISARM][model(AlexNet).features(Sequential).2(MaxPool2d)]2020-03-07 10:37:40[0m
[94m[DISARM][model(AlexNet).features(Sequential).3(Conv2d)]2020-03-07 10:37:40[0m
[94m[DISARM][model(AlexNet).features(Sequential).4(ReLU)]2020-03-07 10:37:40[0m
[94m[DISARM][model(AlexNet).features(Sequential).5(MaxPool2d)]2020-03-07 10:37:40[0m
[94m[DISARM][model(AlexNet).features(Sequential).6(Conv2d)]2020-03-07 10:37:40[0m
[94m[DISARM][model(AlexNet).features(Sequential).7(ReLU)]2020-03-07 10:37:40[0m
[94m[DISARM][model(AlexNet).features(Sequential).8(Conv2d)]2020-03-07 10:37:40[0m
[94m[DISARM][model(AlexNet).features(Sequential).9(ReLU)]2020-03-07 10:37:40[0m
[94m[DISARM][model(AlexNet).featur

Okay, refresh the tracker

In [9]:
te.rearm()

[94m[DISARM][model(AlexNet)]2020-03-07 10:37:41[0m
[94m[DISARM][model(AlexNet).features(Sequential)]2020-03-07 10:37:41[0m
[94m[DISARM][model(AlexNet).features(Sequential).0(Conv2d)]2020-03-07 10:37:41[0m
[94m[DISARM][model(AlexNet).features(Sequential).1(ReLU)]2020-03-07 10:37:41[0m
[94m[DISARM][model(AlexNet).features(Sequential).2(MaxPool2d)]2020-03-07 10:37:41[0m
[94m[DISARM][model(AlexNet).features(Sequential).3(Conv2d)]2020-03-07 10:37:41[0m
[94m[DISARM][model(AlexNet).features(Sequential).4(ReLU)]2020-03-07 10:37:41[0m
[94m[DISARM][model(AlexNet).features(Sequential).5(MaxPool2d)]2020-03-07 10:37:41[0m
[94m[DISARM][model(AlexNet).features(Sequential).6(Conv2d)]2020-03-07 10:37:41[0m
[94m[DISARM][model(AlexNet).features(Sequential).7(ReLU)]2020-03-07 10:37:41[0m
[94m[DISARM][model(AlexNet).features(Sequential).8(Conv2d)]2020-03-07 10:37:41[0m
[94m[DISARM][model(AlexNet).features(Sequential).9(ReLU)]2020-03-07 10:37:41[0m
[94m[DISARM][model(AlexNet).featur

Run forward pass for 3 iterations, nothing strange happend

In [10]:
te.mark(phase="train")
for epoch in range(2):
    te.mark(epoch=epoch)
    for batch in range(3):
        te.add_extra(n_batch=batch)
        model(samp)
te.mark(phase="valid")
for epoch in range(2):
    te.mark(epoch=epoch)
    for batch in range(2):
        te.add_extra(n_batch=batch)
        model(samp)
te.after_train()

In [11]:
!ls -l ~/.torchember/log/AlexNet_20200303_235054

total 2248
-rw-r--r--  1 salvor  staff   59804 Mar  3 23:51 init-00_phase-train_epoch-0.log
-rw-r--r--  1 salvor  staff   59832 Mar  3 23:51 init-00_phase-train_epoch-1.log
-rw-r--r--  1 salvor  staff   39868 Mar  3 23:51 init-00_phase-valid_epoch-0.log
-rw-r--r--  1 salvor  staff  942405 Mar  3 23:52 init-00_phase-valid_epoch-1.log


### Check snowballing tensor stats

In [12]:
te.t.df

Now let's start record weight grad data, once we use backward(), we'll soon have grad data kick in when next forward pass is called

### Track weight gradients

In [13]:
for i in range(3):
    model(samp).mean().backward()

As you can see here, for conv layer, 
* grad_0 is for the 1st weight grad tensor(weight), 
* grad_1 is for the 2nd(bias)

### Module tree json
This file will be stored at ```$HOME/.torchember/data/structure_<modelname>_<date>_<time>.json```

In [None]:
te.mod_tree()

In [None]:
te.mt_log

### Check latest tensor stats

In [None]:
te.t.latest_df

### Redifine what you want to record

For the default statistic function, you can keep track shape, mean, std, max,min of a tensor.

The afore-mentioned tensor could mean all of the following
* module input tensors
* module output tensors
* module weight
* gradient of module weight

If you have more interesting metrics to follow, you can redifine the statistic tracking function

#### Redifine the weight tensor/ weight grad tensor  statitic function

In [None]:
@te.set_metric("weight")
def weight_stats(tensor):
    return {"num":tensor.numel(),"row_max":list(row.max().item() for row in tensor)}

#### Redifine the input or output statitic function

In [None]:
@te.set_metric("in")
def input_stats(tensor):
    return {"num":tensor.numel(),"row_min":list(row.min().item() for row in tensor)}

@te.set_metric("out")
def output_stats(tensor):
    return {"num":tensor.numel(),"row_min":list(row.min().item() for row in tensor)}

Let's give 1 forward pass again

In [None]:
model(samp)

The latest stats changed

In [21]:
te.t.latest_df

Unnamed: 0,num,row_min,module,ts,ttype,tname,row_max
0,301056,"[-1.9999991655349731, -1.9999991655349731]",model(AlexNet),2020-03-05 23:10:22,input,x,
1,301056,"[-1.9999991655349731, -1.9999991655349731]",model(AlexNet).features(Sequential),2020-03-05 23:10:22,input,input,
2,301056,"[-1.9999991655349731, -1.9999991655349731]",model(AlexNet).features(Sequential).0(Conv2d),2020-03-05 23:10:22,input,input,
3,23232,,model(AlexNet).features(Sequential).0(Conv2d),2020-03-05 23:10:22,weight,weight_0,"[0.0520191565155983, 0.05243277549743652, 0.05..."
4,23232,,model(AlexNet).features(Sequential).0(Conv2d),2020-03-05 23:10:22,weight_grad,grad_0,"[3.522511906339787e-05, 1.3242663044366054e-05..."
...,...,...,...,...,...,...,...
75,1000,,model(AlexNet).classifier(Sequential).6(Linear),2020-03-05 23:10:23,weight,weight_1,"[-0.004558052867650986, 0.014331953600049019, ..."
76,1000,,model(AlexNet).classifier(Sequential).6(Linear),2020-03-05 23:10:23,weight_grad,grad_1,"[0.003000000026077032, 0.003000000026077032, 0..."
77,2000,"[-0.03184821456670761, -0.0329386442899704]",model(AlexNet).classifier(Sequential).6(Linear),2020-03-05 23:10:23,output,output_0,
78,2000,"[-0.03184821456670761, -0.0329386442899704]",model(AlexNet).classifier(Sequential),2020-03-05 23:10:23,output,output_0,


## Placing tracker on variables
To be experimented here

In [22]:
w = list(model.features.parameters())[0]

In [23]:
from types import BuiltinMethodType,BuiltinFunctionType

In [24]:
x1 = torch.rand(5,6)
x2 = torch.rand(5,6)
x3 = x1*6+x2

In [25]:
x2.numel()

30

In [26]:
x1.abs_()

tensor([[0.1462, 0.6524, 0.6635, 0.0931, 0.8485, 0.3402],
        [0.6705, 0.0846, 0.6348, 0.3046, 0.7542, 0.6418],
        [0.6934, 0.4078, 0.9792, 0.1871, 0.7833, 0.6145],
        [0.6606, 0.6178, 0.2674, 0.4398, 0.4242, 0.2114],
        [0.9054, 0.9068, 0.6374, 0.8210, 0.7212, 0.4652]])

In [27]:
from types import MethodType

In [28]:
import inspect

In [29]:
def TorchTensorEmber(x):
    class TensorEmber(x.__class__):
        def __init__(self,x):
            self.host_ = x
            attrs = dir(x)
            for attr in attrs:
                self.super_attr(attr)
            
        def super_attr(self,attr):
            if inspect.isbuiltin(getattr(self.host_,attr))==False: return 
            def func(self,*args,**kwargs):
                print(attr)
                return getattr(super(),attr)(*args,**kwargs)
            func.__name__ = attr
            setattr(self,attr, MethodType(func,self))
            return func
            
    return TensorEmber(x)

In [30]:
x2 = TorchTensorEmber(x2)

In [31]:
x2.add(x1)

add


tensor([[1.0196, 1.1548, 1.1521, 0.1822, 1.7265, 0.4464],
        [1.2865, 0.4544, 0.9891, 0.8650, 1.1334, 1.2300],
        [1.3343, 0.8323, 1.9395, 1.1801, 1.5499, 0.7846],
        [1.1385, 1.2144, 0.6191, 0.6455, 0.9545, 0.8413],
        [1.3088, 1.7986, 1.2820, 1.6781, 1.5974, 0.9433]])

In [32]:
x2+x1

tensor([[1.0196, 1.1548, 1.1521, 0.1822, 1.7265, 0.4464],
        [1.2865, 0.4544, 0.9891, 0.8650, 1.1334, 1.2300],
        [1.3343, 0.8323, 1.9395, 1.1801, 1.5499, 0.7846],
        [1.1385, 1.2144, 0.6191, 0.6455, 0.9545, 0.8413],
        [1.3088, 1.7986, 1.2820, 1.6781, 1.5974, 0.9433]])

## Placing tracker on optimizer
To be experimented here