In [12]:
from collections import defaultdict
import torch
import re
import os
import h5py
import numpy as np
from fairseq.models.transformer import TransformerModel
from fairseq import tasks
from fairseq.utils import resolve_max_positions
from torch_scatter import scatter_mean

In [13]:
def get_word_ids(indices, batch, tgt_dict):
    word_ids = []
    for index, sent in zip(indices, batch):
        index = index.item()
        i = -1
        ids = []
        for t in sent:
            token = tgt_dict[t]
            if "▁" in token or "</s>" == token:
                i += 1
            elif index == 491 and "<unk>" == token:
                # a dirty fix as this unk token of validation set happens to start the sentence
                i += 1
            ids.append(i)
        word_ids.append(ids)
    return word_ids

def write_hidden(fout, indices, word_ids, inner_states):
    stack_hidden = torch.stack(inner_states, dim=2).transpose(0, 1)
    word_hidden = scatter_mean(stack_hidden, word_ids, dim=1)
    for index, h, ids in zip(indices, word_hidden, word_ids):
        out_h = h[:ids[-1]]
        fout[str(index.item())] = out_h.transpose(0, 1).float().cpu().numpy()

def forward(model, b, tgt_dict, out):
    src_tokens = b['net_input']['src_tokens'].cuda()
    src_lengths = b['net_input']['src_lengths'].cuda()
    target = b['target']
    prev_output_tokens = b['net_input']['prev_output_tokens'].cuda()
    ids = b['id']
    encoder_out = model.encoder.forward(src_tokens, src_lengths)
    decoder_out = model.decoder(prev_output_tokens, encoder_out=encoder_out, features_only=True,
                                return_all_hiddens=True)
    tgt_word_ids = torch.tensor(get_word_ids(ids, target, tgt_dict), dtype=torch.long, device=prev_output_tokens.device)
    write_hidden(out, ids, tgt_word_ids, decoder_out[1]["inner_states"])

def extract_rand_feats(mode, rand_init):
    data = "rand"
    output_path = "output/Data_{}-Mode_{}".format(data, mode)
    data_path = 'data-{}/{}'.format(data, mode)
    out_dir = "feats/{}_{}".format(data, mode + "rand" if rand_init else mode)
    max_tokens = 65536
    os.makedirs(out_dir, exist_ok=True)
    linearizer = TransformerModel.from_pretrained(
            output_path,
            checkpoint_file='checkpoint_avg.pt',
            data_name_or_path=data_path,
            constraints="unordered")
    if rand_init:
        linearizer.models[0] = linearizer.models[0].build_model(linearizer.args, linearizer.task)

    linearizer.half().cuda()

    for split in ["train", "valid", "test"]:
        print(split)
        linearizer.task.load_dataset(split)
        biter = linearizer.task.get_batch_iterator(
                dataset=linearizer.task.dataset(split),
                max_tokens=max_tokens,
                max_positions=resolve_max_positions(
                        linearizer.task.max_positions(),
                        linearizer.models[0].max_positions(),
                        max_tokens)
        ).next_epoch_itr(shuffle=False)
        path = os.path.join(out_dir, "{}.hdf5".format(split))
        if os.path.exists(path):
            os.remove(path)
        with torch.no_grad(), h5py.File(path, "w") as out:
            for b in biter:
                forward(linearizer.models[0], b, linearizer.task.target_dictionary, out)




In [14]:
for mode in ["base", "pos", "udep"]:
    extract_rand_feats(mode, False)
    torch.cuda.empty_cache()
extract_rand_feats("base", True)
torch.cuda.empty_cache()


train
valid
test
train
valid
test
train
valid
test
train
valid
test


In [15]:
def get_word_ids(batch, tgt_dict, bpe_decoder):
    word_ids = []
    word_pat = re.compile(r"(Ġ|</s>)")
    for sent in batch:
        i = -1
        ids = []
        for t in sent:
            token = tgt_dict[t]
            if token.isdecimal() and int(token) in bpe_decoder:
                token = bpe_decoder[int(token)]
            if word_pat.match(token):
                i += 1
            ids.append(i)
        word_ids.append(ids)
    return word_ids

def write_hidden(fout, indices, word_ids, inner_states):
    stack_hidden = torch.stack(inner_states, dim=2).transpose(0, 1)
    word_hidden = scatter_mean(stack_hidden, word_ids, dim=1)
    for index, h, ids in zip(indices, word_hidden, word_ids):
        out_h = h[:ids[-1]]
        fout[str(index.item())] = out_h.transpose(0, 1).float().cpu().numpy()

def forward(model, b, tgt_dict, bpe_decoder, out):
    src_tokens = b['net_input']['src_tokens'].cuda()
    src_lengths = b['net_input']['src_lengths'].cuda()
    target = b['target']
    prev_output_tokens = b['net_input']['prev_output_tokens'].cuda()
    ids = b['id']
    encoder_out = model.encoder.forward(src_tokens, src_lengths)
    decoder_out = model.decoder(prev_output_tokens, encoder_out=encoder_out, features_only=True, return_all_hiddens=True)
    inner_states = decoder_out[1]["inner_states"]
    tgt_word_ids = torch.tensor(get_word_ids(target, tgt_dict, bpe_decoder), dtype=torch.long, device=prev_output_tokens.device)
    write_hidden(out, ids, tgt_word_ids, inner_states)

def extract_bart_feats(mode, rand_init):
    data = "bart"
    output_path = "output/Data_{}-Mode_{}".format(data, mode)
    data_path = 'data-{}/{}'.format(data, mode)
    out_dir = "feats/{}_{}".format(data, mode + "rand" if rand_init else mode)
    max_tokens = 16384
    os.makedirs(out_dir, exist_ok=True)

    linearizer = TransformerModel.from_pretrained(
        output_path,
        checkpoint_file="checkpoint_avg.pt",
        data_name_or_path=data_path,
        constraints="unordered",
        bpe="gpt2",
        gpt2_encoder_json="bart/encoder.json",
        gpt2_vocab_bpe="bart/vocab.bpe")

    if rand_init:
        linearizer.models[0].from_pretrained("bart/bart.base/", bpe="gpt2",
            gpt2_encoder_json="bart/encoder.json",
            gpt2_vocab_bpe="bart/vocab.bpe")
    linearizer.half().cuda()
    os.makedirs(out_dir, exist_ok=True)
    for split in ["train", "valid", "test"]:
        linearizer.task.load_dataset(split)
        biter = linearizer.task.get_batch_iterator(
                    dataset=linearizer.task.dataset(split),
                    max_tokens=max_tokens,
                    max_positions=resolve_max_positions(
                        linearizer.task.max_positions(),
                        linearizer.models[0].max_positions(),
                        max_tokens)
                    ).next_epoch_itr(shuffle=False)
        path = os.path.join(out_dir, "{}.hdf5".format(split))
        if os.path.exists(path):
            os.remove(path)
        with torch.no_grad(), h5py.File(path, "w") as out:
            for b in biter:
                forward(linearizer.models[0], b, linearizer.task.target_dictionary,
                        linearizer.bpe.bpe.decoder, out)

In [16]:

for mode in ["base", "pos", "udep"]:
    extract_bart_feats(mode, False)
    torch.cuda.empty_cache()
extract_bart_feats("base", True)
torch.cuda.empty_cache()
