In [31]:
from thop import profile
import torch
import torch.nn.utils.prune as prune
import numpy as np

from timedenoiser.models.encdec import ShallowEncDec, EncDecDiagBiRNNSkip

In [37]:
def print_nonzeros(model):
    nonzero = total = 0
    for name, p in model.named_parameters():
        if 'mask' in name:
            continue
        tensor = p.data.cpu().numpy()
        nz_count = np.count_nonzero(tensor)
        total_params = np.prod(tensor.shape)
        nonzero += nz_count
        total += total_params
#         print(f"{name:20} | nonzeros = {nz_count:7} / {total_params:7}" +
#               f"({100 * nz_count / total_params:6.2f}%) | total_pruned =" +
#               f"{total_params - nz_count :7} | shape = {tensor.shape}")
    print(f"alive: {nonzero}, pruned : {total - nonzero}," +
          f"total: {total}, Compression rate : {total/nonzero:10.2f}x" +
          f"({100 * (total-nonzero) / total:6.2f}% pruned)")

In [38]:
macs, params

(51107280.0, 620228.0)

In [39]:
speedtorque = EncDecDiagBiRNNSkip(4, 4)
inp = torch.randn(1, 4, 100)
macs, params = profile(speedtorque, inputs=(inp, ))

[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv1d'>.
[91m[WARN] Cannot find rule for <class 'torch.nn.modules.activation.Tanh'>. Treat it as zero Macs and zero Params.[00m
[91m[WARN] Cannot find rule for <class 'timedenoiser.models.indrnn.IndRNNCell'>. Treat it as zero Macs and zero Params.[00m
[91m[WARN] Cannot find rule for <class 'timedenoiser.models.indrnn.IndRNN'>. Treat it as zero Macs and zero Params.[00m
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.ConvTranspose1d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.activation.ReLU'>.
[91m[WARN] Cannot find rule for <class 'timedenoiser.models.encdec.EncDecDiagBiRNNSkip'>. Treat it as zero Macs and zero Params.[00m


In [53]:
for i in [60, 70, 80, 90]:
    denoiser = ShallowEncDec(4, 4)
    inp = torch.randn(1, 4, 100)
    macs1, params1 = profile(denoiser, inputs=(inp, ))


    speedtorque = EncDecDiagBiRNNSkip(4, 4)
    inp = torch.randn(1, 4, 100)
    macs2, params2 = profile(speedtorque, inputs=(inp, ))
    
    print (i, params1, macs1, params2, macs2)

[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv1d'>.
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.ConvTranspose1d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.activation.ReLU'>.
[91m[WARN] Cannot find rule for <class 'timedenoiser.models.encdec.ShallowEncDec'>. Treat it as zero Macs and zero Params.[00m
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv1d'>.
[91m[WARN] Cannot find rule for <class 'torch.nn.modules.activation.Tanh'>. Treat it as zero Macs and zero Params.[00m
[91m[WARN] Cannot find rule for <class 'timedenoiser.models.indrnn.IndRNNCell'>. Treat it as zero Macs and zero Params.[00m
[91m[WARN] Cannot find rule for <class 'timedenoiser.models.indrnn.IndRNN'>. Treat it as zero Macs and zero Params.[00m
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.ConvTranspose1d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.activation.ReLU'>.
[91m[WARN] Cannot find rule for <

In [66]:
print (60, 80, 90, 96.5)
print ('denoiser')
print (311 * 1.34/4.14, 311 * 0.647/4.14, 311 * 0.298/4.14, 311 * 0.101/4.14)
print (25.35 * 1.34/4.14, 25.35 * 0.647/4.14, 25.35 * 0.298/4.14, 25.35 * 0.101/4.14)
print (0.92 * 77.37/77.05, 0.92 * 77.37/76.96, 0.92 * 77.37/76.31, 0.92 * 77.37/73.11)
print ('diagbirnn')
print (621 * 1.34/4.14, 621 * 0.647/4.14, 621 * 0.298/4.14, 621 * 0.101/4.14)
print (51.11 * 1.34/4.14, 51.11 * 0.647/4.14, 51.11 * 0.298/4.14, 51.11 * 0.101/4.14)
print (0.05 * 77.37/77.05, 0.05 * 77.37/76.96, 0.05 * 77.37/76.31, 0.05 * 77.37/73.11)
print (28.85 * 77.37/77.05, 28.85 * 77.37/76.96, 28.85 * 77.37/76.31, 28.85 * 77.37/73.11)

60 80 90 96.5
denoiser
100.66183574879229 48.60314009661836 22.38599033816425 7.587198067632851
8.205072463768117 3.961702898550725 1.8247101449275365 0.6184420289855074
0.9238208955223882 0.9249012474012476 0.9327794522343075 0.9736068937217892
diagbirnn
201.00000000000003 97.05000000000001 44.7 15.150000000000002
16.542850241545896 7.987480676328503 3.6789323671497587 1.2468864734299518
0.050207657365347186 0.050266372141372154 0.05069453544751671 0.05291341813705376
28.969818299805326 29.003696725571732 29.250746953217146 30.53104226508002
