In [1]:
import torch
torch.cuda.is_available()

True

In [2]:
from tqdm import tqdm
import logging
import math
import os
import random
import sys
import time
import pickle
import copy

import numpy as np
import torch

from fairseq import (
    checkpoint_utils,
    distributed_utils,
    options,
    quantization_utils,
    tasks,
    utils,
)
from fairseq.data import iterators
from fairseq.logging import meters, metrics, progress_bar
from fairseq.trainer import Trainer
from fairseq.model_parallel.megatron_trainer import MegatronTrainer
from fairseq.models.pruned_transformer import PrunedTransformerModel

In [3]:
args = pickle.load(open("argsfile.p", "rb"))

In [4]:
args.data = '/raj-learn/data/wmt16_en_de_bpe32k'

In [5]:
task = tasks.setup_task(args)

In [6]:
task.load_dataset(args.valid_subset, combine=False, epoch=1)
dataset = task.dataset(args.valid_subset)

In [60]:
# checkpoint_dir = "/home/raj/data/raj-learn/checkpoints/lr-rewind_0.75sparsity_0.2frac_30epochs/"
checkpoint_dir = "/raj-learn/checkpoints/lr-rewind_0.75sparsity_0.2frac_30epochs/"
model_paths = ["checkpoint_LTH0_epoch60.pt",
              "checkpoint_LTH1_epoch60_sparsity0.168.pt",
              "checkpoint_LTH2_epoch60_sparsity0.302.pt",
              "checkpoint_LTH3_epoch60_sparsity0.410.pt", 
              "checkpoint_LTH4_epoch60_sparsity0.496.pt", 
              "checkpoint_LTH5_epoch60_sparsity0.565.pt",
              "checkpoint_LTH6_epoch60_sparsity0.620.pt",
              "checkpoint_LTH7_epoch60_sparsity0.664.pt",
              "checkpoint_LTH8_epoch60_sparsity0.699.pt",
              ]

In [64]:
import h5py

def make_hdf5_file(vectors, output_file_path):
    '''
    Vectors: int -> np.array
    Creates hdf5 file.
    '''
    with h5py.File(output_file_path, 'w') as fout:
        for key, embeddings in vectors.items():
            fout.create_dataset(
                str(key),
                embeddings.shape, dtype='float32',
                data=embeddings)

In [66]:
%%time
import time

for path in model_paths:
    t0 = time.time()
    model_name = path.split('_')[1]
    
    args.path = checkpoint_dir + path
    models, _model_args = checkpoint_utils.load_model_ensemble(
        args.path.split(os.pathsep),
        task=task,
    )
    model = models[0]
    model.cuda()
    model.eval()
    itr = task.get_batch_iterator(
            dataset=dataset,
            max_tokens=args.max_tokens,
            max_sentences=args.max_sentences,
            max_positions=utils.resolve_max_positions(
                task.max_positions(),
                *[m.max_positions() for m in models],
            ),
            ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
            required_batch_size_multiple=args.required_batch_size_multiple,
            seed=args.seed,
            num_shards=args.distributed_world_size,
            shard_id=args.distributed_rank,
            num_workers=args.num_workers,
        ).next_epoch_itr(shuffle=False)
    
    all_attns_encenc = {}
    all_attns_encdec = {}
    all_attns_decdec = {}
    for batch in tqdm(itr):
        ids = batch["id"].cpu().numpy().tolist()
        src_lens = batch["net_input"]["src_lengths"].cpu().numpy()
        enc_outputs = model.encoder(batch["net_input"]["src_tokens"].cuda(), batch["net_input"]["src_lengths"].cuda(), 
                                    return_all_hiddens=False, return_all_attns=True)
        encenc_attns = np.array([x.detach().cpu().numpy() for x in enc_outputs.encoder_self_attns])

        out, props = model(batch["net_input"]["src_tokens"].cuda(), batch["net_input"]["src_lengths"].cuda(), \
                            batch["net_input"]["prev_output_tokens"].cuda())
        encdec_attns = [x.detach().cpu().numpy() for x in props["encdec_attns"]]
        decdec_attns = [x.detach().cpu().numpy() for x in props["decdec_attns"]]

        pad_lens = torch.sum(batch['target'] == 1, axis=1)
        tgt_lens = batch['target'].shape[1] - pad_lens

        for i, id_ in enumerate(ids):
            all_attns_encenc[id_] = np.array([attn[i, :, -src_lens[i]:, -src_lens[i]:] for attn in encenc_attns])
            all_attns_encdec[id_] = np.array([attn[:, i, :tgt_lens[i], -src_lens[i]:] for attn in encdec_attns])
            all_attns_decdec[id_] = np.array([attn[i, :, :tgt_lens[i], :tgt_lens[i]] for attn in decdec_attns])
    for (j, attntype) in enumerate([all_attns_encenc, all_attns_encdec, all_attns_decdec]):
        if j == 0:
            outfile = f'/raj-learn/data/precomputed_attns/{model_name}/encenc_attns_wmt_en_de_val.hdf5'
        elif j == 1:
            outfile = f'/raj-learn/data/precomputed_attns/{model_name}/encdec_attns_wmt_en_de_val.hdf5'
        else:
            outfile = f'/raj-learn/data/precomputed_attns/{model_name}/decdec_attns_wmt_en_de_val.hdf5'
        make_hdf5_file(attntype, outfile)
    print("Model %s took %.2fsec" % (model_name, time.time() - t0))

100%|██████████| 34/34 [00:15<00:00,  2.18it/s]


Model LTH0 took 172.75sec


100%|██████████| 34/34 [00:15<00:00,  2.27it/s]


Model LTH1 took 209.46sec


100%|██████████| 34/34 [00:14<00:00,  2.29it/s]


Model LTH2 took 202.57sec


100%|██████████| 34/34 [00:16<00:00,  2.12it/s]


Model LTH3 took 207.36sec


100%|██████████| 34/34 [00:15<00:00,  2.13it/s]


Model LTH4 took 213.00sec


100%|██████████| 34/34 [00:24<00:00,  1.38it/s]


Model LTH5 took 228.70sec


100%|██████████| 34/34 [00:30<00:00,  1.12it/s]


Model LTH6 took 229.51sec


100%|██████████| 34/34 [00:31<00:00,  1.09it/s]


Model LTH7 took 228.82sec


100%|██████████| 34/34 [00:29<00:00,  1.16it/s]


Model LTH8 took 213.13sec
CPU times: user 9min 16s, sys: 6min 35s, total: 15min 52s
Wall time: 31min 45s


In [57]:
for id_, mask in all_masks.items():
    assert np.allclose(1, np.sum(mask["encenc"][5][0,1,:]))
    assert np.allclose(1, np.sum(mask["encdec"][-1][0,1,:]))
    assert np.allclose(1, np.sum(mask["decdec"][-1][0,1,:]))