In [1]:
import os
os.getcwd()
os.chdir("drive/My Drive/STAT212/DeepZip_code/src")

In [11]:
from sklearn.metrics import mean_squared_error
from sklearn.preprocessing import MinMaxScaler
from sklearn.preprocessing import OneHotEncoder
import numpy as np
import argparse
import contextlib
import arithmeticcoding_fast
import json
from tqdm import tqdm
import struct
import tempfile
import shutil
import torch
from torch import nn
from torch.utils.data import TensorDataset, DataLoader
import torchvision.datasets as dsets
import torchvision.transforms as transforms
import torch.nn.functional as F
import numpy as np
import torch.quantization
import zipfile

# Device configuration
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device('cpu')
torch.manual_seed(1)

parser = argparse.ArgumentParser(description='Input')
parser.add_argument('-model', action='store', dest='model_weights_file',
                    help='model file')
parser.add_argument('-model_name', action='store', dest='model_name',
                    help='model file')
parser.add_argument('-batch_size', action='store', dest='batch_size', type=int,
                    help='model file')
parser.add_argument('-data', action='store', dest='sequence_npy_file',
                    help='data file')
parser.add_argument('-data_params', action='store', dest='params_file',
                    help='params file')
parser.add_argument('-output', action='store',dest='output_file_prefix',
                    help='compressed file name')

args, unknown = parser.parse_known_args()

In [12]:
def strided_app(a, L, S):  # Window len = L, Stride len/stepsize = S
    nrows = ((a.size - L) // S) + 1
    n = a.strides[0]
    return np.lib.stride_tricks.as_strided(
        a, shape=(nrows, L), strides=(S * n, n), writeable=False)
    
def predict_lstm(X, y, y_original, timesteps, bs, alphabet_size, model_name, final_step=False):       
        if not final_step:
                num_iters = int((len(X)+timesteps)/bs)
                ind = np.array(range(bs))*num_iters
                
                # open compressed files and compress first few characters using
                # uniform distribution
                f = [open(args.temp_file_prefix+'.'+str(i),'wb') for i in range(bs)]
                bitout = [arithmeticcoding_fast.BitOutputStream(f[i]) for i in range(bs)]
                enc = [arithmeticcoding_fast.ArithmeticEncoder(32, bitout[i]) for i in range(bs)]
                prob = np.ones(alphabet_size)/alphabet_size
                cumul = np.zeros(alphabet_size+1, dtype = np.uint64)
                cumul[1:] = np.cumsum(prob*10000000 + 1)        
                for i in range(bs):
                        for j in range(min(timesteps, num_iters)):
                                enc[i].write(cumul, X[ind[i],j])
                cumul = np.zeros((bs, alphabet_size+1), dtype = np.uint64)
                for j in (range(num_iters - timesteps)):
                        x=torch.Tensor(X[ind,:])
                        x = x.reshape(-1,timesteps, input_size).to(device)
                        outputs = model(x)
                        prob=F.softmax(outputs).data.cpu().numpy()
                        cumul[:,1:] = np.cumsum(prob*10000000 + 1, axis = 1)
                        for i in range(bs):
                                enc[i].write(cumul[i,:], y_original[ind[i]])
                        ind = ind + 1
                # close files
                for i in range(bs):
                        enc[i].finish()
                        bitout[i].close()
                        f[i].close()            
        else:
                f = open(args.temp_file_prefix+'.last','wb')
                bitout = arithmeticcoding_fast.BitOutputStream(f)
                enc = arithmeticcoding_fast.ArithmeticEncoder(32, bitout)
                prob = np.ones(alphabet_size)/alphabet_size
                cumul = np.zeros(alphabet_size+1, dtype = np.uint64)
                cumul[1:] = np.cumsum(prob*10000000 + 1)        

                for j in range(timesteps):
                        enc.write(cumul, X[0,j])
                for i in (range(len(X))):
                        x=torch.Tensor(X[i,:])
                        x = x.reshape(-1,timesteps, input_size).to(device)
                        outputs = model(x)
                        prob=F.softmax(outputs).data.cpu().numpy()
                        cumul[1:] = np.cumsum(prob*10000000 + 1)
                        enc.write(cumul, y_original[i][0])
                enc.finish()
                bitout.close()
                f.close()
        return


# variable length integer encoding http://www.codecodex.com/wiki/Variable-Length_Integers
def var_int_encode(byte_str_len, f):
        while True:
                this_byte = byte_str_len&127
                byte_str_len >>= 7
                if byte_str_len == 0:
                        f.write(struct.pack('B',this_byte))
                        break
                f.write(struct.pack('B',this_byte|128))
                byte_str_len -= 1

In [13]:
args.sequence_npy_file="../data/processed_files/text8.npy"
args.params_file="../data/processed_files/text8.param.json"


args.model_weights_file="../data/trained_models/text8/lstm_weights"
# args.model_weights_file="../data/trained_models/text8/lstm_pruned_weights"
# args.model_weights_file="../data/trained_models/text8/lstm_quantization_weights"
# args.model_weights_file="../data/trained_models/text8/lstm_quantization_pruned_weights"


args.model_name="LSTM"


args.output_file_prefix="../data/compressed/text8/lstm_weights.compressed"
# args.output_file_prefix="../data/compressed/text8/lstm_pruned_weights.compressed"
# args.output_file_prefix="../data/compressed/text8/lstm_quantization_weights.compressed"
# args.output_file_prefix="../data/compressed/text8/lstm_quantization_pruned_weights.compressed"


args.batch_size=1000

In [14]:
# load the data
args.temp_dir = tempfile.mkdtemp()
args.temp_file_prefix = args.temp_dir + "/compressed"
np.random.seed(0)
series = np.load(args.sequence_npy_file)
series=series[0:100000]
series = series.reshape(-1, 1)
onehot_encoder = OneHotEncoder(sparse=False)
onehot_encoded = onehot_encoder.fit(series)

batch_size = args.batch_size
timesteps = 64

with open(args.params_file, 'r') as f:
        params = json.load(f)

params['len_series'] = len(series)
params['bs'] = batch_size
params['timesteps'] = timesteps

with open(args.output_file_prefix+'.params','w') as f:
        json.dump(params, f, indent=4)

alphabet_size = len(params['id2char_dict'])

series = series.reshape(-1)
data = strided_app(series, timesteps+1, 1)

X = data[:, :-1]
Y_original = data[:, -1:]
Y = onehot_encoder.transform(Y_original)

l = int(len(series)/batch_size)*batch_size

In [16]:
# Hyper Parameters
num_epochs=10            
input_size = 1   
# hidden_size = 64
hidden_size = 1024
num_layers = 1
num_classes = alphabet_size
lr = 0.01   


# Define LSTM model
class simpleLSTM(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, num_classes):
        super(simpleLSTM, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True).to(device)
        self.fc = nn.Linear(hidden_size, num_classes).to(device)

    def forward(self, x):
        # initialize
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)
        c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)

        # forward propagate lstm
        out, (h_n, h_c) = self.lstm(x, (h0, c0))

        # output
        out =self.fc(out[:, -1, :])
        return out

model = simpleLSTM(input_size, hidden_size, num_layers, num_classes)

In [17]:
# unzip the compressed models
zip_path="../data/trained_models/text8/lstm_weights.zip"
save_path="../data/trained_models/text8"
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
    zip_ref.extractall(save_path)

# zip_path="../data/trained_models/text8/lstm_pruned_weights.zip"
# save_path="../data/trained_models/text8"
# with zipfile.ZipFile(zip_path, 'r') as zip_ref:
#     zip_ref.extractall(save_path)

# zip_path="../data/trained_models/text8/lstm_quantization_weights.zip"
# save_path="../data/trained_models/text8"
# with zipfile.ZipFile(zip_path, 'r') as zip_ref:
#     zip_ref.extractall(save_path)

# zip_path="../data/trained_models/text8/lstm_quantization_pruned_weights.zip"
# save_path="../data/trained_models/text8"
# with zipfile.ZipFile(zip_path, 'r') as zip_ref:
#     zip_ref.extractall(save_path)

In [18]:
#load the model weights
model.load_state_dict(torch.load(args.model_weights_file))

# # load the quantized model weights
# model = torch.quantization.quantize_dynamic(
#     model, {nn.LSTM, nn.Linear}, dtype=torch.qint8
# )
# model.load_state_dict(torch.load(args.model_weights_file))

<All keys matched successfully>

In [19]:
# compress the data 
predict_lstm(X, Y, Y_original, timesteps, batch_size, alphabet_size, args.model_name)

if l < len(series)-timesteps:
        predict_lstm(X[l:,:], Y[l:,:], Y_original[l:], timesteps, 1, alphabet_size, args.model_name, final_step = True)
else:
        f = open(args.temp_file_prefix+'.last','wb')
        bitout = arithmeticcoding_fast.BitOutputStream(f)
        enc = arithmeticcoding_fast.ArithmeticEncoder(32, bitout) 
        prob = np.ones(alphabet_size)/alphabet_size
        
        cumul = np.zeros(alphabet_size+1, dtype = np.uint64)
        cumul[1:] = np.cumsum(prob*10000000 + 1)        
        for j in range(l, len(series)):
                enc.write(cumul, series[j])
        enc.finish()
        bitout.close() 
        f.close()

# combine files into one file
f = open(args.output_file_prefix+'.combined','wb')
for i in range(batch_size):
        f_in = open(args.temp_file_prefix+'.'+str(i),'rb')
        byte_str = f_in.read()
        byte_str_len = len(byte_str)
        var_int_encode(byte_str_len, f)
        f.write(byte_str)
        f_in.close()
f_in = open(args.temp_file_prefix+'.last','rb')
byte_str = f_in.read()
byte_str_len = len(byte_str)
var_int_encode(byte_str_len, f)
f.write(byte_str)
f_in.close()
f.close()
shutil.rmtree(args.temp_dir)

