In [3]:
import numpy as np
import torch
import torch.nn as nn
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

In [4]:
DICTIONARY_PATH = "./dictionary.txt"
DICTIONARY_list =  np.loadtxt(DICTIONARY_PATH, dtype= str)
DICTIONARY_dict = {word: idx for idx, word in enumerate(DICTIONARY_list)}
SOS_IDX = DICTIONARY_dict["<sos>"]
EOS_IDX = DICTIONARY_dict["<eos>"]
UNK_IDX = DICTIONARY_dict["<unk>"]
print("SOS_IDX:",SOS_IDX,"EOS_IDX:",EOS_IDX,"UNK_IDX:", UNK_IDX)

SOS_IDX: 2 EOS_IDX: 1 UNK_IDX: 3


# Quantizing LSTM without attention

In [9]:
#LSTM WITHOUT ATTENTION
class BiLSTM_N_gramModel_WITHOUT_ATTENTION(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_size):
        super(BiLSTM_N_gramModel_WITHOUT_ATTENTION, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.bilstm = nn.LSTM(embedding_dim, hidden_size, batch_first=True, bidirectional=True)
        self.fc = torch.nn.Sequential(nn.Linear(hidden_size * 2, embedding_dim),
                                      nn.Linear(embedding_dim, vocab_size),)

    def forward(self, x):
        embedded = self.embedding(x)
        output, _ = self.bilstm(embedded)
        logits = self.fc(output)
        return logits
    
vocab_size = len(DICTIONARY_list)
embedding_dim = 300
hidden_size = 256


model_WITHOUT_ATTENTION = BiLSTM_N_gramModel_WITHOUT_ATTENTION(vocab_size, embedding_dim, hidden_size).to(DEVICE)

state_dict = torch.load("./unquantized/LSTM1_checkpoint.pth",map_location=torch.device('cpu'))
model_WITHOUT_ATTENTION.load_state_dict(state_dict["model_state_dict"])

<All keys matched successfully>

In [10]:
model_WITHOUT_ATTENTION = torch.quantization.quantize_dynamic(
    model_WITHOUT_ATTENTION,
    # qconfig_spec={nn.Embedding},  # Specify which submodules to quantize
    dtype=torch.qint8
)

In [17]:
torch.save({'model_state_dict':model_WITHOUT_ATTENTION.state_dict()}, 
              "./LSTM1_checkpoint.pth")

# Quantizing LSTM With attention

In [13]:
# MODEL
class BiLSTM_N_gramModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_size):
        super(BiLSTM_N_gramModel, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.bilstm = nn.LSTM(embedding_dim, hidden_size, batch_first=True, bidirectional=True)
        self.attention = nn.MultiheadAttention(embed_dim = hidden_size*2, num_heads=1, dropout=0.0,
                                               batch_first=True)
        self.fc = torch.nn.Sequential(nn.Linear(hidden_size * 2, embedding_dim),
                                      nn.Linear(embedding_dim, vocab_size),)

    def forward(self, x):
        embedded = self.embedding(x)
        output, _ = self.bilstm(embedded)
        attn_output , attn_output_weights = self.attention(query=output, key=output, value=output,
                                                           need_weights=True, average_attn_weights=True)
        logits = self.fc(output)
        return logits, attn_output_weights

# Initialize the model


model = BiLSTM_N_gramModel(vocab_size, embedding_dim, hidden_size).to(DEVICE)

state_dict = torch.load("./unquantized/LSTM_attention_checkpoint.pth",map_location=torch.device('cpu'))
model.load_state_dict(state_dict["model_state_dict"])

<All keys matched successfully>

In [14]:
model = torch.quantization.quantize_dynamic(
    model,
    # qconfig_spec={nn.Embedding},  # Specify which submodules to quantize
    dtype=torch.qint8
)

In [16]:
torch.save({'model_state_dict':model.state_dict()}, "./LSTM_attention_checkpoint.pth")