## Yong Zhu Cheng A0275768H

In [1]:
import torch
from torch import nn
from torch.nn import ModuleList
from torch.nn import functional as F
from tqdm.auto import tqdm
from torchsummary import summary
import gc

if torch.cuda.is_available():
    device = 'cuda'
else: device = 'cpu'

### Script for model size evaluation
Adapted from pytorch_modelsize by jacobkimmel. Significant modifications needed to accommodate Transformer logic.

In [2]:
import torch
import torch.nn as nn
from torch.autograd import Variable
import numpy as np
    
class SizeEstimator(object):

    def __init__(self, model, input_size=(1,1,32,32), bits=32):
        '''
        Estimates the size of PyTorch models in memory
        for a given input size
        '''
        self.model = model
        self.input_size = input_size
        self.bits = bits
        return

    def get_parameter_sizes(self):
        '''Get sizes of all parameters in `model`'''
        mods = list(self.model.modules())
        sizes = []
        
        for i in range(1,len(mods)):
            m = mods[i]
            p = list(m.parameters())
            for j in range(len(p)):
                sizes.append(np.array(p[j].size()))

        self.param_sizes = sizes
        return

    def get_output_sizes(self):
        '''Run sample input through each layer to get output sizes'''
        input_ = Variable(torch.FloatTensor(*self.input_size).to(device)).to(device)
        target_ = Variable(torch.FloatTensor(*self.input_size).to(device)).to(device)
        mods = list(self.model.named_modules())
        out_sizes = []
        for i in tqdm(range(1, len(mods))):
            name = mods[i][0]
            name_list = name.split('.')
            m = mods[i][1]

            # Special logic for specific layers
            if isinstance(m,ModuleList):
                if name_list[-1] == 'layers':
                    if 'encoder' in name_list:
                        for L in m:
                            out = L(input_)
                            out_sizes.append(np.array(out.size()))
                    elif 'decoder' in name_list:
                        for L in m:
                            memory_ = Variable(torch.FloatTensor(*self.memory_size).to(device)).to(device)
                            out = L(input_,memory_)
                            out_sizes.append(np.array(out.size()))
                elif 'mha1_e' in name_list or 'mha1_d' in name_list:
                    for L in m:
                        out = L(input_,input_,input_)
                        out_sizes.append(np.array(out.size()))
                elif 'mha2_d' in name_list and name_list[-1] == 'heads':
                    memory_ = Variable(torch.FloatTensor(*self.memory_size).to(device)).to(device)
                    for L in m:
                        out = L(input_,memory_,memory_)
                        out_sizes.append(np.array(out.size()))
            elif isinstance(m,nn.Linear) == False and ('mha1_e' in name_list or 'mha1_d' in name_list):
                out = m(input_,input_,input_)
                out_sizes.append(np.array(out.size()))
            elif (isinstance(m,nn.Linear) == False and 'mha2_d' in name_list):
                memory_ = Variable(torch.FloatTensor(*self.memory_size).to(device)).to(device)
                out = m(input_,memory_,memory_)
                out_sizes.append(np.array(out.size()))
            elif name_list[-1]=='decoder' or ('decoder' in name_list and name_list[-2]=='layers'):
                memory_ = Variable(torch.FloatTensor(*self.memory_size).to(device)).to(device)
                out = m(input_,memory_)
                out_sizes.append(np.array(out.size()))
            else:
                out = m(input_)
                out_sizes.append(np.array(out.size()))
            
            # Remembering other special layer sizes
            if name_list[-1] == 'encoder':
                self.memory_size = out.size()
            
            # The inputs for some layers should not be updated, e.g. in a module list
            if name_list[-1] != 'heads' and (name_list[-1].isdigit() is False or name_list[-2] != 'heads') and not (name_list[-1] in ['Wk','Wq','Wv']):
                input_ = out
        
        self.out_sizes = out_sizes
        return

    def calc_param_bits(self):
        '''Calculate total number of bits to store `model` parameters'''
        total_bits = 0
        total_bits = np.int64(total_bits)
        for i in range(len(self.param_sizes)):
            s = self.param_sizes[i]
            bits = np.int64(np.prod(np.array(s))*self.bits)
            total_bits += bits
        self.param_bits = total_bits
        return

    def calc_forward_backward_bits(self):
        '''Calculate bits to store forward and backward pass'''
        total_bits = 0
        total_bits = np.int64(total_bits)
        for i in range(len(self.out_sizes)):
            s = self.out_sizes[i]
            result = 1
            for n in s:
                result *= n
            bits = result*self.bits
            bits = np.int64(bits)
            total_bits += bits
        # multiply by 2 for both forward AND backward
        self.forward_backward_bits = (total_bits*2)
        return

    def calc_input_bits(self):
        '''Calculate bits to store input'''
        self.input_bits = np.prod(np.array(self.input_size))*self.bits
        return

    def estimate_size(self):
        '''Estimate model size in memory in megabytes and bits'''
        self.get_parameter_sizes()
        self.get_output_sizes()
        self.calc_param_bits()
        self.calc_forward_backward_bits()
        self.calc_input_bits()
        total = self.param_bits + self.forward_backward_bits + self.input_bits
        total_megabytes = (total/8)/(1024**2)
        return total_megabytes, total

### Build Model - Transformer
Building blocks:

In [3]:
class Attention(nn.Module):
    def __init__(self):
        super().__init__()
    
    def forward(self,Q,K,V,mask=None,dropout=None):
        '''
        input shape: B*L*H
        H: hidden layer dim, i.e. 'model_size'
        L: sequence length
        '''
        out = torch.matmul(Q,K.transpose(1,2)) # shape: B*L*L
        out = out / (Q.shape[-1]**0.5)
        out = F.softmax(out,dim=-1)
        return torch.matmul(out,V)
        
class AttentionHead(nn.Module):
    def __init__(self, model_size,qkv_size):
        super().__init__()
        self.Wq = nn.Linear(model_size,qkv_size)
        self.Wk = nn.Linear(model_size,qkv_size)
        self.Wv = nn.Linear(model_size,qkv_size)
        self.attention = Attention()
    
    def forward(self,queries,keys,values):
        return self.attention(self.Wq(queries),
                             self.Wk(keys),
                             self.Wv(values))
    
class MultiHeadAttention(nn.Module):
    def __init__(self,num_heads,model_size,qkv_size):
        super().__init__()
        self.heads = ModuleList(
            [AttentionHead(model_size,qkv_size) for _ in range(num_heads)]
        )
        self.Wo = nn.Linear(num_heads*qkv_size,model_size)
        self.qkv_size = qkv_size
        self.model_size = model_size
    
    def forward(self,query,key,value):
        out_heads = [head(query,key,value) for head in self.heads]
        out = torch.cat(out_heads, dim=-1)
        return self.Wo(out)
    
class FeedForward(nn.Module):
    def __init__(self,model_size,hidden_size=2048):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(model_size,hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size,model_size),
        )
    
    def forward(self,X):
        return self.net(X)

Construct Encoder:

In [4]:
class TransformerEncoderLayer(nn.Module):
    def __init__(self,model_size,num_heads,ff_hidden_size,dropout):
        super().__init__()
        qkv_size = max(model_size//num_heads,1)
        
        # MHA
        self.mha1_e = MultiHeadAttention(num_heads,model_size,qkv_size)
        self.dropout1 = nn.Dropout(dropout)
        self.norm1 = nn.LayerNorm(model_size) # LayerNorm vs BatchNorm
        
        # FF
        self.ff = FeedForward(model_size,ff_hidden_size)
        self.dropout2 = nn.Dropout(dropout)
        self.norm2 = nn.LayerNorm(model_size)
        
    def forward(self,source):
        # MHA (*Self-attention)
        out1 = self.mha1_e(source,source,source)
        out1 = self.dropout1(out1)
        out1 = self.norm1(out1 + source)
        
        # FF
        out2 = self.ff(out1)
        out2 = self.dropout2(out2)
        out2 = self.norm2(out2 + out1)
        return out2

class TransformerEncoder(nn.Module):
    def __init__(self,
                num_layers=6,
                model_size=512,
                num_heads=8,
                ff_hidden_size=2048,
                dropout=0.1):
        super().__init__()
        self.layers = ModuleList(
            [TransformerEncoderLayer(model_size,
                                    num_heads,
                                    ff_hidden_size,
                                    dropout) for _ in range(num_layers)]
        )
    
    def forward(self,source):
        for L in self.layers:
            source = L(source)
        return source 

Construct Decoder:

In [5]:
class TransformerDecoderLayer(nn.Module):
    def __init__(self,model_size,num_heads,ff_hidden_size,dropout):
        super().__init__()
        qkv_size = max(model_size//num_heads,1)
        
        # MHA 1
        self.mha1_d = MultiHeadAttention(num_heads,model_size,qkv_size)
        self.dropout1 = nn.Dropout(dropout)
        self.norm1 = nn.LayerNorm(model_size) 
        
        # MHA 2
        self.mha2_d = MultiHeadAttention(num_heads,model_size,qkv_size)
        self.dropout2 = nn.Dropout(dropout)
        self.norm2 = nn.LayerNorm(model_size) 
        
        # FF
        self.ff = FeedForward(model_size,ff_hidden_size)
        self.dropout3 = nn.Dropout(dropout)
        self.norm3 = nn.LayerNorm(model_size)
    
    def forward(self,target,memory):
        '''
        memory: output of encoder
        '''
        # MHA 1 (Self-attention)
        out1 = self.mha1_d(target,target,target)
        out1 = self.dropout1(out1)
        out1 = self.norm1(out1 + target)
        
        # MHA 2 (Source-target attention)
        out2 = self.mha2_d(out1,memory,memory)
        out2 = self.dropout2(out2)
        out2 = self.norm2(out2 + out1)
        
        # FF
        out3 = self.ff(out2)
        out3 = self.dropout3(out3)
        out3 = self.norm3(out3 + out2)
        return out3

class TransformerDecoder(nn.Module):
    def __init__(self,
                num_layers=6,
                model_size=512,
                num_heads=8,
                ff_hidden_size=2048,
                dropout=0.1):
        super().__init__()
        self.layers = ModuleList(
            [TransformerDecoderLayer(model_size,
                                    num_heads,
                                    ff_hidden_size,
                                    dropout) for _ in range(num_layers)]
        )
    
    def forward(self,target,memory):
        for L in self.layers:
            target = L(target,memory)
        return target

Complete Transformer:

In [6]:
class Transformer(nn.Module):
    def __init__(self,
                num_encoder_layers=6,
                num_decoder_layers=6,
                model_size=512,
                num_heads=8,
                ff_hidden_size=2048,
                dropout=0.1
                ):
        super().__init__()
        self.encoder = TransformerEncoder(
            num_layers=num_encoder_layers,
            model_size=model_size,
            num_heads=num_heads,
            ff_hidden_size=ff_hidden_size,
            dropout=dropout
        )
        self.decoder = TransformerDecoder(
            num_layers=num_decoder_layers,
            model_size=model_size,
            num_heads=num_heads,
            ff_hidden_size=ff_hidden_size,
            dropout=dropout
        )
    
    def forward(self,source,target):
        memory = self.encoder(source)
        return self.decoder(target,memory)

### Generate model summary
Some small modifications required for torchsummary code, primarily to accommodate multiple inputs to the Transformer class.

In [7]:
tf_model = Transformer().to(device)
summary(tf_model,input_size=(200,512))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Linear-1              [-1, 200, 64]          32,832
            Linear-2              [-1, 200, 64]          32,832
            Linear-3              [-1, 200, 64]          32,832
         Attention-4              [-1, 200, 64]               0
     AttentionHead-5              [-1, 200, 64]               0
            Linear-6              [-1, 200, 64]          32,832
            Linear-7              [-1, 200, 64]          32,832
            Linear-8              [-1, 200, 64]          32,832
         Attention-9              [-1, 200, 64]               0
    AttentionHead-10              [-1, 200, 64]               0
           Linear-11              [-1, 200, 64]          32,832
           Linear-12              [-1, 200, 64]          32,832
           Linear-13              [-1, 200, 64]          32,832
        Attention-14              [-1, 

### Evaluate model size

In [8]:
if 'tf_model' in globals():
    del tf_model
gc.collect()

tf_model = Transformer().to(device)
se = SizeEstimator(tf_model,input_size=(16,200,512))
se.estimate_size()

  0%|          | 0/910 [00:00<?, ?it/s]

(7152.3359375, 59998142464)

**Estimated model size**: 7152.336 MB, 59998142464 bits