In [2]:
pip install torchviz

Note: you may need to restart the kernel to use updated packages.


In [3]:
import torch
import torch.nn as nn
from torchviz import make_dot
CHARSET = " abcdefghijklmnopqrstuvwxyz,."

class YourModel(nn.Module):
    def __init__(self, H):
        super(YourModel, self).__init__()
        self.prepare = nn.Sequential(
            nn.Linear(80, H),
            nn.BatchNorm1d(H),
            nn.ReLU(),
            nn.Linear(H, H),
            nn.BatchNorm1d(H),
            nn.ReLU()
        )
        self.encoder = nn.GRU(H, H, batch_first=False)
        self.decode = nn.Sequential(
            nn.Linear(H, H//2),
            nn.BatchNorm1d(H//2),
            nn.ReLU(),
            nn.Linear(H//2, len(CHARSET))
        )

    def forward(self, x):
        x = self.prepare(x)
        x, _ = self.encoder(x)
        x = x[-1]
        x = self.decode(x)
        return x

In [8]:
H = 256
model = YourModel(H)
dummy_input = torch.randn(1, 256, 80)
output = model(dummy_input)

dot = make_dot(output, params=dict(model.named_parameters()))
dot.format = 'png'
dot.render("model", format="png")

'model.png'

In [9]:
encoder = YourModel(H).encoder
dummy_input = torch.randn(1, 1, H)
encoder_output, _ = encoder(dummy_input)

dot = make_dot(encoder_output, params=dict(encoder.named_parameters()))
dot.format = 'png'
dot.render("encoder", format="png")

'encoder.png'

In [10]:
H = 256
decoder = YourModel(H).decode

dummy_input = torch.randn(1, H//2, 256)

decoder_output = decoder(dummy_input)

dot = make_dot(decoder_output, params=dict(decoder.named_parameters()))
dot.format = 'png'
dot.render("decoder", format="png")

'decoder.png'