# Fire Crack
> Analyzing How Model Improves

In [41]:
from torchvision.models.alexnet import AlexNet
import torch

In [42]:
model = AlexNet()

In [43]:
samp = torch.rand(2,3,224,224)

In [44]:
model(samp).shape

torch.Size([2, 1000])

In [45]:
model(samp).mean().backward()

In [46]:
from uuid import uuid4

In [47]:
from types import MethodType

module_tracks = dict()

class moduleTrack(object):
    def __init__(self,module):
        self.module = module
        self.name = module.__class__.__name__
        self.id = id(module)
        module_tracks[self.id] = self
        
    def __repr__(self):
        rt = f"<{self.name} @ {hex(self.id)}>"
        if hasattr(self,"input_dt"):
            rt+=f'\n\t[Inputs]{",".join(list(k+" "+str(list(v.shape)) for k,v in self.input_dt.items()))}'
        if hasattr(self,"output_dt"):
            rt+=f'\n\t[Outputs]{",".join(list(str(list(v.shape)) for v in self.output_dt))}'
        return rt
    
data_edges = dict()
    
class dataEdge(object):
    def __init__(self, uuid):
        self.uuid = uuid
        data_edges[uuid] = self
        self.ups = []
        self.downs = []
        
    def __repr__(self):
        return "\nEdge:"\
        +f"\n\t - Inputs: {','.join(list(str(u) for u in self.ups))}" \
        +f"\n\t - Outputs: {','.join(list(str(u) for u in self.downs))}"
        
def get_stats(tensor):
    return {"mean":tensor.mean().item(), "std":tensor.std().item(), "max":tensor.max().item(), "min":tensor.min().item()}

class fireCrack(object):
    def __init__(self, model):
        self.modules = dict()
        self.uids = dict()
        self.model = model
        self.arm()
        
    def arm(self):
        """
        arming the tracing function to self.model
        """
        for m in self.model.modules():
            m.forward = self.module_register(m)
            
    def disarm(self):
        """remove the tracing function"""
        for m in self.model.modules():
            self.recover(m)
            
    def rearm(self):
        self.disarm()
        self.arm()
    
    def reg_check(self,m):
        if id(m) in self.modules:
            return False
        if hasattr(m.forward,"armed"):
            if m.forward.armed:
                return False
        return True
    
    def recover(self, m):
        if hasattr(m,"former"):
            m.forward = m.former
            
    def check_out_tid(self,mt):
        if hasattr(mt,"out_tid")==False:
            mt.out_tid = list()
            for op in mt.output_dt:
                de = dataEdge(uuid4())
                dataEdge.shape = list(op.shape)
                de.ups.append(mt)
                mt.out_tid.append(de)
                op.tid = de.uuid
                
    def check_in_tid(self,mt):
        if hasattr(mt,"in_tid") == False:
            mt.in_tid = dict()
            for k, ip in mt.input_dt.items():
                if hasattr(ip,"tid"):
                    de = data_edges[ip.tid]
                    mt.in_tid[k] = de
                    de.downs.append(mt)
        
    def module_register(self,m):
        if self.reg_check(m) == False: return m.forward
        f = m.forward
        mt = moduleTrack(m)
        self.modules[id(m)] = mt
        vs = f.__code__.co_varnames
        mt.vars = vs[1:]
        
        def wraper(*args,**kwargs):
            mt.input_dt = dict(zip(mt.vars[:len(args)],args))
            mt.input_dt.update(kwargs)
            
            self.check_in_tid(mt)
            
            # ------execution of the function------
            outputs = f(*args,**kwargs)
            # ------execution of the function------
            if type(outputs) in [list,tuple]:
                mt.output_dt = [outputs]
            else:
                mt.output_dt = [outputs,]
                
            self.check_out_tid(mt)
               
            print(mt)
            return outputs
        
        setattr(wraper,"armed",True)
        setattr(wraper,"former",f)
        return wraper

In [48]:
fc = fireCrack(model)
fc.rearm()

In [49]:
model(samp)

<Conv2d @ 0x16be5b8d0>
	[Inputs]input [2, 3, 224, 224]
	[Outputs][2, 64, 55, 55]
<ReLU @ 0x16bcfc490>
	[Inputs]input [2, 64, 55, 55]
	[Outputs][2, 64, 55, 55]
<MaxPool2d @ 0x16bcfc990>
	[Inputs]input [2, 64, 55, 55]
	[Outputs][2, 64, 27, 27]
<Conv2d @ 0x16bcfcb50>
	[Inputs]input [2, 64, 27, 27]
	[Outputs][2, 192, 27, 27]
<ReLU @ 0x16ba88790>
	[Inputs]input [2, 192, 27, 27]
	[Outputs][2, 192, 27, 27]
<MaxPool2d @ 0x16bcfc450>
	[Inputs]input [2, 192, 27, 27]
	[Outputs][2, 192, 13, 13]
<Conv2d @ 0x16bcfc290>
	[Inputs]input [2, 192, 13, 13]
	[Outputs][2, 384, 13, 13]
<ReLU @ 0x16bcfca10>
	[Inputs]input [2, 384, 13, 13]
	[Outputs][2, 384, 13, 13]
<Conv2d @ 0x16ba55e10>
	[Inputs]input [2, 384, 13, 13]
	[Outputs][2, 256, 13, 13]
<ReLU @ 0x16be5b110>
	[Inputs]input [2, 256, 13, 13]
	[Outputs][2, 256, 13, 13]
<Conv2d @ 0x16be5b790>
	[Inputs]input [2, 256, 13, 13]
	[Outputs][2, 256, 13, 13]
<ReLU @ 0x16be5b610>
	[Inputs]input [2, 256, 13, 13]
	[Outputs][2, 256, 13, 13]
<MaxPool2d @ 0x16be5b890>


tensor([[ 0.0185, -0.0033, -0.0087,  ...,  0.0088, -0.0076,  0.0108],
        [ 0.0164, -0.0021, -0.0101,  ...,  0.0016, -0.0085,  0.0085]],
       grad_fn=<AddmmBackward>)

In [50]:
data_edges

{UUID('2828a738-6290-454b-b8c4-cb4e6e5f10e0'): 
 Edge:
 	 - Inputs: <Conv2d @ 0x16be5b8d0>
 	[Inputs]input [2, 3, 224, 224]
 	[Outputs][2, 64, 55, 55]
 	 - Outputs: <ReLU @ 0x16bcfc490>
 	[Inputs]input [2, 64, 55, 55]
 	[Outputs][2, 64, 55, 55], UUID('d4185188-fc37-44cf-9beb-edc14b0992e5'): 
 Edge:
 	 - Inputs: <ReLU @ 0x16bcfc490>
 	[Inputs]input [2, 64, 55, 55]
 	[Outputs][2, 64, 55, 55]
 	 - Outputs: <MaxPool2d @ 0x16bcfc990>
 	[Inputs]input [2, 64, 55, 55]
 	[Outputs][2, 64, 27, 27], UUID('3b65cfbb-4797-4534-8fcc-2c5f5954e710'): 
 Edge:
 	 - Inputs: <MaxPool2d @ 0x16bcfc990>
 	[Inputs]input [2, 64, 55, 55]
 	[Outputs][2, 64, 27, 27]
 	 - Outputs: <Conv2d @ 0x16bcfcb50>
 	[Inputs]input [2, 64, 27, 27]
 	[Outputs][2, 192, 27, 27], UUID('480c612b-3015-408f-a275-eba44caa9452'): 
 Edge:
 	 - Inputs: <Conv2d @ 0x16bcfcb50>
 	[Inputs]input [2, 64, 27, 27]
 	[Outputs][2, 192, 27, 27]
 	 - Outputs: <ReLU @ 0x16ba88790>
 	[Inputs]input [2, 192, 27, 27]
 	[Outputs][2, 192, 27, 27], UUID('efc9

In [9]:
x1 = torch.rand(5,6)
x2 = torch.rand(5,6)
x3 = x1*6+x2

In [10]:
sz = x3.size()

In [11]:
x3.std()

tensor(1.5808)

In [35]:
import numpy as np
list(sz)

[5, 6]