## Imports

In [106]:
import numpy as np
import pandas as pd

from IPython.display import display, clear_output

## MLM

In [9]:
from transformers import AutoTokenizer, AutoModelForMaskedLM, AutoModel
from transformers.activations import GELUActivation
from transformers.modeling_outputs import MaskedLMOutput
from transformers import DataCollatorForWholeWordMask
import evaluate

# import transformers

In [27]:
import wandb
wandb.init(project="kg-lm-integration", entity="tanny411")

[34m[1mwandb[0m: Currently logged in as: [33mtanny411[0m. Use [1m`wandb login --relogin`[0m to force relogin


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016668938150000184, max=1.0…

In [10]:
import torch

torch.cuda.is_available()#, torch.cuda.device_count(), torch.cuda.current_device(), torch.cuda.get_device_name(0)

  return torch._C._cuda_getDeviceCount() > 0


False

In [11]:
bert_model_name = "distilbert-base-uncased" ##"bert-base-cased"
model = AutoModel.from_pretrained(bert_model_name)
tokenizer = AutoTokenizer.from_pretrained(bert_model_name)

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_transform.bias', 'vocab_transform.weight', 'vocab_layer_norm.weight', 'vocab_layer_norm.bias', 'vocab_projector.bias', 'vocab_projector.weight']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [12]:
from torch import nn
    
class BERTModified(nn.Module):
    def __init__(self, bert_model_name):
        super().__init__()

        self.base_model = AutoModel.from_pretrained(bert_model_name)
        self.config = self.base_model.config
        
        self.activation = GELUActivation() # for distilbert
        self.vocab_transform = nn.Linear(self.config.dim, self.config.dim)
        self.vocab_layer_norm = nn.LayerNorm(self.config.dim, eps=1e-12)
        self.vocab_projector = nn.Linear(self.config.dim, self.config.vocab_size)

        self.mlm_loss_fct = nn.CrossEntropyLoss()
        
        ## set to eval
        self.base_model.eval()
        
        ## freeze model
        for param in self.base_model.parameters():
            param.requires_grad = False

    def forward(
        self,
        input_ids = None,
        attention_mask = None,
        head_mask = None,
        inputs_embeds = None,
        labels = None,
        output_attentions = None,
        output_hidden_states = None,
        return_dict= None,):
        
        base_model_output = self.base_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        
        hidden_states = base_model_output[0]  # (bs, seq_length, dim)
        prediction_logits = self.vocab_transform(hidden_states)  # (bs, seq_length, dim)
        prediction_logits = self.activation(prediction_logits)  # (bs, seq_length, dim)
        prediction_logits = self.vocab_layer_norm(prediction_logits)  # (bs, seq_length, dim)
        prediction_logits = self.vocab_projector(prediction_logits)  # (bs, seq_length, vocab_size)

        mlm_loss = None
        if labels is not None:
            mlm_loss = self.mlm_loss_fct(prediction_logits.view(-1, prediction_logits.size(-1)), labels.view(-1))

        return MaskedLMOutput(
            loss=mlm_loss,
            logits=prediction_logits,
            hidden_states=base_model_output.hidden_states,
            attentions=base_model_output.attentions,
        )

In [13]:
BERTModified_model = BERTModified(bert_model_name)

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_transform.bias', 'vocab_transform.weight', 'vocab_layer_norm.weight', 'vocab_layer_norm.bias', 'vocab_projector.bias', 'vocab_projector.weight']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [14]:
data_collator = DataCollatorForWholeWordMask(tokenizer=tokenizer, mlm=True, mlm_probability=0.15)

In [15]:
from datasets import load_dataset
dataset = load_dataset("zhengxuanzenwu/wikitext-2-split-128", split="test")

Using custom data configuration zhengxuanzenwu--wikitext-2-split-128-f504347a654a9463
Found cached dataset parquet (/home/a2khatun/.cache/huggingface/datasets/zhengxuanzenwu___parquet/zhengxuanzenwu--wikitext-2-split-128-f504347a654a9463/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)


In [16]:
def tokenize_function(examples):
    result = tokenizer(examples["text"])
    if tokenizer.is_fast:
        result["word_ids"] = [result.word_ids(i) for i in range(len(result["input_ids"]))]
    return result


# Use batched=True to activate fast multithreading!
tokenized_datasets = dataset.map(tokenize_function, batched=True, remove_columns=["text"])

tokenized_datasets

Loading cached processed dataset at /home/a2khatun/.cache/huggingface/datasets/zhengxuanzenwu___parquet/zhengxuanzenwu--wikitext-2-split-128-f504347a654a9463/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-2d367ba46d4f5d81.arrow


Dataset({
    features: ['input_ids', 'attention_mask', 'word_ids'],
    num_rows: 8192
})

In [17]:
chunk_size = 128

def group_texts(examples):
    # Concatenate all texts
    concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
    # Compute length of concatenated texts
    total_length = len(concatenated_examples[list(examples.keys())[0]])
    # We drop the last chunk if it's smaller than chunk_size
    total_length = (total_length // chunk_size) * chunk_size
    # Split by chunks of max_len
    result = {
        k: [t[i : i + chunk_size] for i in range(0, total_length, chunk_size)]
        for k, t in concatenated_examples.items()
    }
    # Create a new labels column
    result["labels"] = result["input_ids"].copy()
    return result

lm_datasets = tokenized_datasets.map(group_texts, batched=True)
lm_datasets

Loading cached processed dataset at /home/a2khatun/.cache/huggingface/datasets/zhengxuanzenwu___parquet/zhengxuanzenwu--wikitext-2-split-128-f504347a654a9463/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-3fec34e95f2bc86d.arrow


Dataset({
    features: ['input_ids', 'attention_mask', 'word_ids', 'labels'],
    num_rows: 2253
})

In [18]:
from huggingface_hub import notebook_login

notebook_login()

Login successful
Your token has been saved to /home/a2khatun/.huggingface/token
[1m[31mAuthenticated through git-credential store but this isn't the helper defined on your machine.
You might have to re-authenticate when pushing to the Hugging Face Hub. Run the following command in your terminal in case you want to set this credential helper as the default

git config --global credential.helper store[0m


In [19]:
train_size = 1000
test_size = 100

downsampled_dataset = lm_datasets.train_test_split(train_size=train_size, test_size=test_size, seed=42)
downsampled_dataset

Loading cached split indices for dataset at /home/a2khatun/.cache/huggingface/datasets/zhengxuanzenwu___parquet/zhengxuanzenwu--wikitext-2-split-128-f504347a654a9463/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-827befdce6c51896.arrow and /home/a2khatun/.cache/huggingface/datasets/zhengxuanzenwu___parquet/zhengxuanzenwu--wikitext-2-split-128-f504347a654a9463/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-3dd2a0e3127c60f1.arrow


DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask', 'word_ids', 'labels'],
        num_rows: 1000
    })
    test: Dataset({
        features: ['input_ids', 'attention_mask', 'word_ids', 'labels'],
        num_rows: 100
    })
})

In [20]:
metric_name = evaluate.load("perplexity") # accuracy

In [32]:
from transformers import TrainingArguments

batch_size = 16

# Show the training loss with every epoch
logging_steps = len(downsampled_dataset['train']) // batch_size
model_name = "BERTModified"
output_dir = f"{model_name}-finetuned-wikitext-test"

training_args = TrainingArguments(
    output_dir=output_dir,
    overwrite_output_dir=True,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=2e-5,
    weight_decay=0.01,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    push_to_hub=True,
#     fp16=True,
    logging_steps=logging_steps,
    num_train_epochs=1,
#     load_best_model_at_end=True,
#     metric_for_best_model="loss",#metric_name,
#     greater_is_better = False,
    logging_dir='logs',
    report_to="wandb",
#     no_cuda=True,
)

PyTorch: setting up devices


metric_for_best_model (str, optional) — Use in conjunction with load_best_model_at_end to specify the metric to use to compare two different models. Must be the name of a metric returned by the evaluation with or without the prefix "eval_". Will default to "loss" if unspecified and load_best_model_at_end=True (to use the evaluation loss).

If you set this value, greater_is_better will default to True. Don’t forget to set it to False if your metric is better when lower.

In [33]:
from transformers import Trainer

trainer = Trainer(
    model=BERTModified_model,
    args=training_args,
    train_dataset=downsampled_dataset["train"],
    eval_dataset=downsampled_dataset["test"],
    data_collator=data_collator,
)

/home/a2khatun/Downloads/KG/project/BERTModified-finetuned-wikitext-test is already a clone of https://huggingface.co/Aisha/BERTModified-finetuned-wikitext-test. Make sure you pull the latest changes with `repo.git_pull()`.


In [24]:
import math

eval_results = trainer.evaluate()
print(f">>> Perplexity: {math.exp(eval_results['eval_loss']):.2f}")

The following columns in the evaluation set don't have a corresponding argument in `BERTModified.forward` and have been ignored: word_ids. If word_ids are not expected by `BERTModified.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 100
  Batch size = 16


>>> Perplexity: 37705.04


In [34]:
trainer.train()
# trainer.save_model("output/models/BERTModified")

The following columns in the training set don't have a corresponding argument in `BERTModified.forward` and have been ignored: word_ids. If word_ids are not expected by `BERTModified.forward`,  you can safely ignore this message.
***** Running training *****
  Num examples = 1000
  Num Epochs = 1
  Instantaneous batch size per device = 16
  Total train batch size (w. parallel, distributed & accumulation) = 16
  Gradient Accumulation steps = 1
  Total optimization steps = 63
Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"


Epoch,Training Loss,Validation Loss
1,9.0863,8.856506


The following columns in the evaluation set don't have a corresponding argument in `BERTModified.forward` and have been ignored: word_ids. If word_ids are not expected by `BERTModified.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 100
  Batch size = 16
Saving model checkpoint to BERTModified-finetuned-wikitext-test/checkpoint-63
Trainer.model is not a `PreTrainedModel`, only saving its state dict.


KeyError: 'eval_perplexity'

https://theaisummer.com/hugging-face-vit/

If for example we wanted to visualize the training process using the weights and biases library, we can use the WandbCallback. We can simply add another argument to the Trainer in the form of:
```
from transformers import WandbCallback
callbacks = [WandbCallback(...)]
```
One other thing: Take a look at the logging_dir='logs'. By saving the training logs, we can very easily initiate a tensorboard instance and track the training progress:

```
$ tensorboard --logdir logs/
```

An alternative is to use the TensorBoardCallback provided by the library.

In [58]:
import math

eval_results = trainer.evaluate()
print(f">>> Perplexity: {math.exp(eval_results['eval_loss']):.2f}") #21596 for 1 epoch

The following columns in the evaluation set don't have a corresponding argument in `BERTModified.forward` and have been ignored: word_ids. If word_ids are not expected by `BERTModified.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 100
  Batch size = 16


>>> Perplexity: 828.09


In [59]:
trainer.push_to_hub()

Saving model checkpoint to BERTModified-finetuned-wikitext-test
Trainer.model is not a `PreTrainedModel`, only saving its state dict.


Upload file pytorch_model.bin:   0%|          | 32.0k/345M [00:00<?, ?B/s]

remote: Scanning LFS files for validity, may be slow...        
remote: LFS file scan complete.        
To https://huggingface.co/Aisha/BERTModified-finetuned-wikitext-test
   dbda7e7..ec41d60  main -> main

Dropping the following result as it does not have all the necessary fields:
{}
To https://huggingface.co/Aisha/BERTModified-finetuned-wikitext-test
   ec41d60..721863d  main -> main



'https://huggingface.co/Aisha/BERTModified-finetuned-wikitext-test/commit/ec41d60af7b0dbf3304bf650d114b445e1f716fa'

In [47]:
# from huggingface_hub import get_full_repo_name

# repo_name = get_full_repo_name(output_dir)
# repo_name

In [48]:
# from huggingface_hub import Repository

# repo = Repository(output_dir, clone_from=repo_name)

In [60]:
from transformers import pipeline

# Initialize MLM pipeline
mlm = pipeline('fill-mask', model=BERTModified_model, tokenizer=tokenizer)

# Get mask token
mask = mlm.tokenizer.mask_token

# Get result for particular masked phrase
phrase = f'Wikipedia is a free online {mask}, created and edited by volunteers around the world'

result = mlm(phrase)

# Print result
print(result)

[{'score': 0.016104264184832573, 'token': 1010, 'token_str': ',', 'sequence': 'wikipedia is a free online,, created and edited by volunteers around the world'}, {'score': 0.005953500047326088, 'token': 1998, 'token_str': 'and', 'sequence': 'wikipedia is a free online and, created and edited by volunteers around the world'}, {'score': 0.005101792048662901, 'token': 2000, 'token_str': 'to', 'sequence': 'wikipedia is a free online to, created and edited by volunteers around the world'}, {'score': 0.004948945250362158, 'token': 1030, 'token_str': '@', 'sequence': 'wikipedia is a free online @, created and edited by volunteers around the world'}, {'score': 0.004703795537352562, 'token': 1012, 'token_str': '.', 'sequence': 'wikipedia is a free online., created and edited by volunteers around the world'}]


In [61]:
for x in result:
    print(f">>> {x['sequence']}")

>>> wikipedia is a free online,, created and edited by volunteers around the world
>>> wikipedia is a free online and, created and edited by volunteers around the world
>>> wikipedia is a free online to, created and edited by volunteers around the world
>>> wikipedia is a free online @, created and edited by volunteers around the world
>>> wikipedia is a free online., created and edited by volunteers around the world


https://huggingface.co/course/chapter7/3?fw=pt#perplexity-for-language-models

In [78]:
## compare model weights

for p1, p2 in zip(model.parameters(), BERTModified_model.parameters()):
#     print(p1.shape, p2.shape)
    if p1.data.ne(p2.data).sum() > 0:
        print(False)

In [84]:
BERTModified_model

BERTModified(
  (base_model): DistilBertModel(
    (embeddings): Embeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer): Transformer(
      (layer): ModuleList(
        (0): TransformerBlock(
          (attention): MultiHeadSelfAttention(
            (dropout): Dropout(p=0.1, inplace=False)
            (q_lin): Linear(in_features=768, out_features=768, bias=True)
            (k_lin): Linear(in_features=768, out_features=768, bias=True)
            (v_lin): Linear(in_features=768, out_features=768, bias=True)
            (out_lin): Linear(in_features=768, out_features=768, bias=True)
          )
          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (ffn): FFN(
            (dropout): Dropout(p=0.1, inplace=False)
            (lin1): Linear(in_

In [86]:
params = [x for x in BERTModified_model.parameters()]
for p in params[-6:]:
    print(p.shape)

torch.Size([768, 768])
torch.Size([768])
torch.Size([768])
torch.Size([768])
torch.Size([30522, 768])
torch.Size([30522])


## KG Model

In [3]:
import jsonlines

linked_wikitext_2 = "linked-wikitext-2/"
train = linked_wikitext_2+"train.jsonl"
valid = linked_wikitext_2+"valid.jsonl"
test = linked_wikitext_2+"test.jsonl"

## Create Embedding Map

In [4]:
import numpy as np
import pandas as pd
import gc

from IPython.display import display, clear_output

In [5]:
emb_tsv_file = "wikidata_translation_v1.tsv"

In [4]:
df_size = 0
ix = 0

for chunk in pd.read_csv(emb_tsv_file, 
                         delimiter='\t', 
                         header=None, 
                         chunksize=10000, 
                         skiprows=1, 
                         encoding='unicode-escape',
                         names=['id']+[f"embedding_{num}" for num in range(1,201)]):
    
    chunk = chunk.dropna()
    qid_df = chunk[chunk['id'].str.startswith("<http://www.wikidata.org/entity/")]
    df_size += len(qid_df)
    qid_df["id"] = qid_df["id"].apply(lambda x: x.split("/")[-1][:-1])
    qid_df.to_csv(f"qid_embedding.csv", index=False, header=False, mode="a")
#     qid_df.to_csv(f"qid_embeding/qid_embedding_{ix}.csv", index=False, header=False, mode="a")
    ix+=1
    
    clear_output()
    gc.collect()

print("Done!")

Done!


In [5]:
df_size

55032670

In [7]:
for embedding_qids in pd.read_csv('qid_embedding.csv', 
                         header=None, 
                         chunksize=10000, 
                         names=['id']+[f"embedding_{num}" for num in range(1,201)]):
    break

In [8]:
embedding_qids

Unnamed: 0,id,embedding_1,embedding_2,embedding_3,embedding_4,embedding_5,embedding_6,embedding_7,embedding_8,embedding_9,...,embedding_191,embedding_192,embedding_193,embedding_194,embedding_195,embedding_196,embedding_197,embedding_198,embedding_199,embedding_200
0,Q13442814,-0.0828,-0.1610,0.2152,-0.1867,-0.2638,0.2255,0.7521,-0.2523,-0.9443,...,0.4445,-0.2946,-0.1511,-0.4235,-0.0393,0.5657,-0.0626,-0.2530,-0.1620,0.0711
1,Q5,-0.2249,0.1400,-0.1503,-0.4524,-0.2512,0.2315,0.9779,0.6769,0.0827,...,-0.1156,0.2502,0.9211,-0.3094,-0.7469,0.4710,0.5569,-0.7302,-0.0854,0.2069
2,Q4167836,0.2119,-0.0596,-0.4983,-0.5406,-0.8534,0.1305,1.0053,0.0972,-0.6799,...,-0.0157,-0.2806,0.3098,-0.1842,-0.1467,-0.1482,0.4218,-0.5890,0.0637,0.0597
3,Q6581097,0.3376,0.2482,0.2568,-0.0287,-0.8143,-0.6013,0.4635,0.1256,-0.2780,...,0.0733,-0.0252,-0.2832,0.3008,-0.2845,0.5018,-0.3458,-0.7110,0.2288,-0.5106
4,Q16521,-0.3649,-0.7143,-0.1985,-0.8482,-0.6250,-0.1299,0.4153,-0.0469,-0.5142,...,0.7377,-0.2373,0.8955,-0.5313,0.6623,0.2243,0.6147,-0.5094,-0.2916,0.4188
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
9995,Q26842633,-0.3039,-0.3113,-0.4618,-0.3129,-0.8078,0.0136,0.4512,-0.9440,0.1844,...,0.5912,0.9583,1.2572,0.5230,-0.5270,-0.3459,0.6681,-0.0368,-0.2663,-0.0909
9996,Q15763806,-1.2473,0.6681,-0.7949,-0.7073,0.1150,-0.8819,0.5306,0.4668,-0.1859,...,0.4987,0.3669,-0.1575,0.2222,0.0021,0.1813,0.4070,-0.2016,-0.2948,0.3876
9997,Q23916,0.0757,0.0808,-0.4001,-0.3316,0.1981,0.0842,0.7219,-0.0462,0.0717,...,-0.2297,-0.3597,0.4675,-0.2015,-0.0844,0.0382,0.4272,0.1239,-0.0524,0.0774
9998,Q864217,-0.5102,-0.3203,0.1878,0.0757,-0.7398,-0.1611,0.3297,0.8216,-0.7161,...,-0.1461,-0.5963,-0.2641,0.0739,-0.0134,-0.2057,-0.7483,0.1778,0.1733,0.2463


In [16]:
embedding_qids[embedding_qids['id'] == "Q5"].iloc[0,1:].values.reshape((1,200))

array([[-0.2249, 0.14, -0.1503, -0.4524, -0.2512, 0.2315, 0.9779, 0.6769,
        0.0827, 0.174, 0.0746, -0.3145, -0.4379, -0.2596, 0.2497, 0.6317,
        0.0649, 0.8525, -0.844, -0.325, 0.1749, 0.2906, -0.7133, 0.6766,
        -0.1764, 0.4321, -0.4739, -0.8006, -0.3901, 0.7154, 0.316,
        -0.0677, 0.4136, 0.016, 0.1898, -0.9656, 0.2196, -0.7675,
        -0.8107, -0.7817, 0.5434, -0.619, -0.422, -0.2386, -0.4517,
        0.1504, -0.431, -0.4969, 0.4982, -0.274, -0.55, -0.0774, -0.5733,
        0.9172, -0.118, -0.0867, -0.3846, -0.2733, -0.6376, -0.182,
        0.756, 0.3724, 0.0529, 0.261, 0.2723, 0.1296, 0.1688, 0.644,
        -0.0943, -0.6549, -0.2459, 0.1955, 0.1156, -0.0082, 0.0364,
        0.3262, 0.7579, -0.2566, -0.439, 0.0152, 0.1234, 1.1579, -0.9209,
        0.3654, -0.0719, 0.5497, -0.1902, -0.0008, -0.1432, 0.5596,
        0.5633, 0.768, 0.138, -0.2131, -0.5384, -0.0125, -0.7738, 1.1437,
        -0.1613, -0.2311, 0.2402, -0.4568, -0.752, -0.2087, 0.4484,
        -0.0869

In [12]:
qids = pd.read_csv('qid_embedding.csv', header=None, names=['id'], usecols=[0])
qids.info(memory_usage=True)

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 55032670 entries, 0 to 55032669
Data columns (total 1 columns):
 #   Column  Dtype 
---  ------  ----- 
 0   id      object
dtypes: object(1)
memory usage: 419.9+ MB


In [13]:
qids

Unnamed: 0,id
0,Q13442814
1,Q5
2,Q4167836
3,Q6581097
4,Q16521
...,...
55032665,Q61905263
55032666,Q61908653
55032667,Q61908654
55032668,Q61909529


In [5]:
embed_range = pd.read_csv('qid_embedding.csv', 
                          header=None, 
                          nrows=10, 
                          skiprows=55032665,
                          names=['id']+[f"embedding_{num}" for num in range(1,201)])
embed_range

Unnamed: 0,id,embedding_1,embedding_2,embedding_3,embedding_4,embedding_5,embedding_6,embedding_7,embedding_8,embedding_9,...,embedding_191,embedding_192,embedding_193,embedding_194,embedding_195,embedding_196,embedding_197,embedding_198,embedding_199,embedding_200
0,Q61905263,-0.2547,0.1266,0.4141,-0.1665,0.0902,-0.2426,0.0482,0.0156,-0.0167,...,-0.2337,0.0648,0.3816,0.1224,-0.1724,0.1299,0.3947,0.0028,0.2577,0.09
1,Q61908653,-0.3424,0.0551,0.1168,-0.045,-0.2718,-0.3656,-0.1966,-0.0885,0.1675,...,0.1896,-0.0315,0.2755,-0.0479,-0.1841,0.0546,-0.0575,-0.0443,0.1206,-0.1594
2,Q61908654,-0.2041,-0.0213,0.1148,0.1016,-0.1224,-0.1033,0.1357,0.1167,-0.1219,...,0.1765,-0.1373,0.1718,-0.1315,-0.1707,-0.3487,-0.1054,0.1045,0.1903,-0.1669
3,Q61909529,-0.1326,0.0164,-0.0338,0.0486,-0.186,-0.4355,-0.1516,0.0393,0.1787,...,0.1099,-0.0978,0.2751,-0.1946,-0.2837,0.0414,-0.1365,0.0342,0.1047,-0.0946
4,Q61910646,-0.457,0.3926,0.3603,0.0041,-0.2131,-0.3305,-0.3827,0.0623,-0.0138,...,0.1182,-0.1485,-0.1201,0.1351,-0.0515,-0.2484,-0.0374,0.1714,-0.0102,-0.296


In [6]:
train, valid, test

('linked-wikitext-2/train.jsonl',
 'linked-wikitext-2/valid.jsonl',
 'linked-wikitext-2/test.jsonl')

In [11]:
## seletc Q-ids only in linked wikitext-2

import jsonlines

qid_set = set()

for dataset in [train, valid, test]:
    with jsonlines.open(dataset) as f:
        for line in f.iter():
            for annot in line['annotations']:
                qid_set.add(annot['id'])
        print(len(qid_set))

41058
44413
47932


In [16]:
qids_wktxt2 = qids[qids['id'].isin(qid_set)]

In [17]:
qids_wktxt2.to_csv("qids_wktxt2.csv", index=False)

In [26]:
qids_wktxt2 #.index.values

Unnamed: 0,id
1,Q5
4,Q16521
5,Q7432
6,Q30
7,Q1860
...,...
54913387,Q17042242
54973666,Q28208712
54991087,Q42377501
54993341,Q48816565


In [28]:
len(qids)

55032670

In [34]:
# embeds_wktxt = pd.DataFrame(columns=['id']+[f"embedding_{num}" for num in range(1,201)])

to_exclude = [i for i in range(55032670) if i not in qids_wktxt2.index.values]

embeds_wktxt = pd.read_csv('qid_embedding.csv',
                              header=None, 
                              skiprows=to_exclude,
                              names=['id']+[f"embedding_{num}" for num in range(1,201)])

len(embeds_wktxt)

46685

In [35]:
embeds_wktxt

Unnamed: 0,id,embedding_1,embedding_2,embedding_3,embedding_4,embedding_5,embedding_6,embedding_7,embedding_8,embedding_9,...,embedding_191,embedding_192,embedding_193,embedding_194,embedding_195,embedding_196,embedding_197,embedding_198,embedding_199,embedding_200
0,Q5,-0.2249,0.1400,-0.1503,-0.4524,-0.2512,0.2315,0.9779,0.6769,0.0827,...,-0.1156,0.2502,0.9211,-0.3094,-0.7469,0.4710,0.5569,-0.7302,-0.0854,0.2069
1,Q16521,-0.3649,-0.7143,-0.1985,-0.8482,-0.6250,-0.1299,0.4153,-0.0469,-0.5142,...,0.7377,-0.2373,0.8955,-0.5313,0.6623,0.2243,0.6147,-0.5094,-0.2916,0.4188
2,Q7432,0.2851,0.3203,0.1160,-0.2781,-0.0969,-0.4335,-0.6281,-0.4147,0.4710,...,0.3683,-0.6955,0.1620,0.1344,-0.1442,-0.0663,0.6458,-0.6264,0.8500,0.5677
3,Q30,1.0045,-0.1153,-0.0770,0.3894,-0.2120,-0.3979,0.3745,0.4687,-0.7981,...,0.3716,0.0019,-0.2455,0.5625,-0.3040,-0.4212,0.2615,-0.7009,-0.1871,-0.1948
4,Q1860,0.2808,-0.0076,0.3640,-0.3508,0.2861,0.1748,0.3726,0.5543,-0.6450,...,-0.2940,-1.1060,0.5231,-0.2903,0.2146,0.5142,0.0470,0.4795,0.0279,-0.9129
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
46680,Q17042242,-0.1863,-0.0469,0.0483,0.1670,-0.2604,-0.0706,0.0765,0.0135,-0.0309,...,-0.0811,0.0552,0.3624,-0.0836,-0.1808,-0.0564,-0.0138,0.1261,-0.0036,0.0071
46681,Q28208712,-0.1813,-0.0169,0.0940,0.2144,-0.1460,-0.0879,0.1475,0.0829,-0.1985,...,0.1492,-0.1119,0.2504,-0.1439,-0.1438,-0.4153,-0.1049,0.3060,0.1783,-0.1870
46682,Q42377501,-0.2616,-0.1402,0.2998,0.0265,-0.2412,-0.2058,-0.2162,-0.1520,-0.1352,...,0.1989,-0.1950,0.4726,0.0322,0.0439,-0.0669,0.0326,-0.0173,0.2337,-0.1331
46683,Q48816565,0.0007,-0.0841,-0.0816,-0.0922,-0.0922,-0.0315,0.0156,0.3363,0.2449,...,0.0873,0.2758,0.2465,0.2607,0.2390,-0.4363,0.1404,-0.1441,-0.0521,0.0599


In [36]:
embeds_wktxt.to_csv("embeds_wktxt.csv", index=False)

In [37]:
embeds_wktxt.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 46685 entries, 0 to 46684
Columns: 201 entries, id to embedding_200
dtypes: float64(200), object(1)
memory usage: 71.6+ MB


In [40]:
import jsonlines

with jsonlines.open(valid) as f:
    for line in f.iter():
        print(len(line['tokens']))

1352
10663
2474
10767
3619
1247
4262
1559
3602
1689
2038
1456
6072
1832
3524
18805
3843
1291
1535
1955
2173
5384
3375
875
2588
857
3129
2023
6652
6681
5720
3951
5427
2929
5950
1301
4328
4530
2648
2046
4824
1364
3407
4253
1325
1287
2078
1705
5747
1314
5267
2746
7504
2746
1887
5184
2060
772
1358
3096


In [62]:
embeds_wktxt

Unnamed: 0,id,embedding_1,embedding_2,embedding_3,embedding_4,embedding_5,embedding_6,embedding_7,embedding_8,embedding_9,...,embedding_191,embedding_192,embedding_193,embedding_194,embedding_195,embedding_196,embedding_197,embedding_198,embedding_199,embedding_200
0,Q5,-0.2249,0.1400,-0.1503,-0.4524,-0.2512,0.2315,0.9779,0.6769,0.0827,...,-0.1156,0.2502,0.9211,-0.3094,-0.7469,0.4710,0.5569,-0.7302,-0.0854,0.2069
1,Q16521,-0.3649,-0.7143,-0.1985,-0.8482,-0.6250,-0.1299,0.4153,-0.0469,-0.5142,...,0.7377,-0.2373,0.8955,-0.5313,0.6623,0.2243,0.6147,-0.5094,-0.2916,0.4188
2,Q7432,0.2851,0.3203,0.1160,-0.2781,-0.0969,-0.4335,-0.6281,-0.4147,0.4710,...,0.3683,-0.6955,0.1620,0.1344,-0.1442,-0.0663,0.6458,-0.6264,0.8500,0.5677
3,Q30,1.0045,-0.1153,-0.0770,0.3894,-0.2120,-0.3979,0.3745,0.4687,-0.7981,...,0.3716,0.0019,-0.2455,0.5625,-0.3040,-0.4212,0.2615,-0.7009,-0.1871,-0.1948
4,Q1860,0.2808,-0.0076,0.3640,-0.3508,0.2861,0.1748,0.3726,0.5543,-0.6450,...,-0.2940,-1.1060,0.5231,-0.2903,0.2146,0.5142,0.0470,0.4795,0.0279,-0.9129
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
46680,Q17042242,-0.1863,-0.0469,0.0483,0.1670,-0.2604,-0.0706,0.0765,0.0135,-0.0309,...,-0.0811,0.0552,0.3624,-0.0836,-0.1808,-0.0564,-0.0138,0.1261,-0.0036,0.0071
46681,Q28208712,-0.1813,-0.0169,0.0940,0.2144,-0.1460,-0.0879,0.1475,0.0829,-0.1985,...,0.1492,-0.1119,0.2504,-0.1439,-0.1438,-0.4153,-0.1049,0.3060,0.1783,-0.1870
46682,Q42377501,-0.2616,-0.1402,0.2998,0.0265,-0.2412,-0.2058,-0.2162,-0.1520,-0.1352,...,0.1989,-0.1950,0.4726,0.0322,0.0439,-0.0669,0.0326,-0.0173,0.2337,-0.1331
46683,Q48816565,0.0007,-0.0841,-0.0816,-0.0922,-0.0922,-0.0315,0.0156,0.3363,0.2449,...,0.0873,0.2758,0.2465,0.2607,0.2390,-0.4363,0.1404,-0.1441,-0.0521,0.0599


In [122]:
import jsonlines

for dataset, filename in zip([valid, test, train], ["valid", "test", "train"]):
    line_count = 0
    with jsonlines.open(dataset) as f:
        embed_list_dataset = []
        for line in f.iter():
            line_count += 1
            embed_list = np.zeros((len(line['tokens']), 200))
            
            for annot in line['annotations']:
                start, end = annot['span']
                qid = annot['id']
                df = embeds_wktxt[embeds_wktxt['id']==qid]
                if len(df)>0:
                    embed_list[start:end] = np.tile(
                            df.iloc[0,1:].values.reshape((1,200)),
                            (end-start, 1)
                        )
                else:
                    print(qid)
                    
            embed_list_dataset.append(embed_list)

#         with open(f'{filename}.npy', 'ab') as f:
#             np.save(f, embed_list)
        np.savez(f"{filename}.npz", *embed_list_dataset)
    
    print("\n", "="*50, filename, line_count, "\n")

T::11::+1964-06-21T00:00:00Z
T::11::+2004-09-02T00:00:00Z
V::0.8000::http://www.wikidata.org/entity/Q11229
T::9::+1970-00-00T00:00:00Z
T::9::+1970-00-00T00:00:00Z
T::11::+2002-11-22T00:00:00Z
T::11::+2002-11-22T00:00:00Z
T::11::+2002-11-22T00:00:00Z
T::11::+2002-11-22T00:00:00Z
T::11::+2002-11-22T00:00:00Z
T::9::+2008-00-00T00:00:00Z
T::11::+1986-10-07T00:00:00Z
Q4663150
T::11::+1986-10-07T00:00:00Z
T::9::+2015-00-00T00:00:00Z
T::9::+2016-00-00T00:00:00Z
T::11::+2007-09-09T00:00:00Z
T::9::+2008-00-00T00:00:00Z
T::11::+1873-04-04T00:00:00Z
T::9::+1874-00-00T00:00:00Z
T::11::+1873-04-04T00:00:00Z
T::9::+1758-00-00T00:00:00Z
T::9::+2005-01-01T00:00:00Z
T::11::+1991-06-23T00:00:00Z
T::11::+1991-06-23T00:00:00Z
T::11::+2009-09-09T00:00:00Z
T::11::+1869-08-05T00:00:00Z
T::11::+1940-01-09T00:00:00Z
T::11::+1839-08-27T00:00:00Z
T::11::+1840-09-02T00:00:00Z
T::11::+1804-03-04T00:00:00Z
T::11::+1808-01-26T00:00:00Z
T::11::+1810-01-01T00:00:00Z
T::9::+1832-00-00T00:00:00Z
T::9::+1829-01-01T00:00:

T::11::+0068-06-09T00:00:00Z
T::11::+0069-12-21T00:00:00Z
T::11::+0069-12-21T00:00:00Z
T::11::+0098-01-27T00:00:00Z
T::11::+0098-01-27T00:00:00Z
T::9::+0229-00-00T00:00:00Z
T::11::+0096-09-18T00:00:00Z
T::9::+1951-00-00T00:00:00Z
T::9::+1667-00-00T00:00:00Z
T::9::+1880-00-00T00:00:00Z
T::11::+1992-02-27T00:00:00Z
T::11::+1992-02-27T00:00:00Z
T::11::+2008-09-05T00:00:00Z
T::11::+2012-11-08T00:00:00Z
T::11::+2013-10-27T00:00:00Z
T::11::+2013-10-27T00:00:00Z
T::9::+1991-00-00T00:00:00Z
T::11::+2012-12-25T00:00:00Z
T::9::+2005-00-00T00:00:00Z
T::11::+2012-12-25T00:00:00Z
T::11::+1963-11-23T00:00:00Z
T::11::+1963-11-23T00:00:00Z
T::9::+1917-01-01T00:00:00Z
T::9::+2011-00-00T00:00:00Z
T::9::+2010-00-00T00:00:00Z
T::9::+1994-00-00T00:00:00Z
T::9::+1993-00-00T00:00:00Z
T::9::+2011-00-00T00:00:00Z
T::11::+2012-05-08T00:00:00Z
T::11::+2012-07-05T00:00:00Z
T::11::+2012-11-27T00:00:00Z
T::9::+2008-00-00T00:00:00Z
T::9::+2011-00-00T00:00:00Z
T::11::+2012-07-05T00:00:00Z
T::11::+2012-11-27T00:00:00Z

T::11::+1998-08-26T00:00:00Z
T::11::+1801-01-01T00:00:00Z
T::11::+1801-01-01T00:00:00Z
T::11::+1803-09-20T00:00:00Z
T::11::+1801-01-01T00:00:00Z
T::11::+1921-07-11T00:00:00Z
T::11::+1923-05-24T00:00:00Z
T::9::+1845-00-00T00:00:00Z
T::9::+2006-01-01T00:00:00Z
T::11::+1667-11-30T00:00:00Z
T::11::+1745-10-19T00:00:00Z
T::11::+1742-04-13T00:00:00Z
T::9::+1921-00-00T00:00:00Z
T::11::+1894-11-05T00:00:00Z
T::11::+1952-11-08T00:00:00Z
T::9::+1947-00-00T00:00:00Z
T::11::+1894-11-05T00:00:00Z
T::9::+1933-00-00T00:00:00Z
T::9::+1931-01-01T00:00:00Z
T::11::+1909-03-04T00:00:00Z
T::11::+1868-11-17T00:00:00Z
T::11::+2009-10-03T00:00:00Z
T::9::+2012-00-00T00:00:00Z
T::9::+1895-00-00T00:00:00Z
T::11::+1978-07-25T00:00:00Z
T::9::+2002-01-01T00:00:00Z
T::11::+1836-03-12T00:00:00Z
T::11::+1865-02-06T00:00:00Z
T::9::+1857-01-01T00:00:00Z
T::9::+1861-00-00T00:00:00Z
T::11::+1836-03-12T00:00:00Z
T::11::+1852-03-20T00:00:00Z
T::11::+1852-03-20T00:00:00Z
T::9::+1861-00-00T00:00:00Z
T::9::+1885-00-00T00:00:00

T::11::+1988-05-02T00:00:00Z
T::11::+1941-06-22T00:00:00Z
T::11::+1943-08-17T00:00:00Z
T::11::+1944-06-06T00:00:00Z
T::11::+1974-07-12T00:00:00Z
T::9::+1997-00-00T00:00:00Z
T::11::+2001-03-26T00:00:00Z
T::11::+2001-03-26T00:00:00Z
T::11::+1898-07-07T00:00:00Z
T::11::+1865-04-15T00:00:00Z
T::11::+1851-07-27T00:00:00Z
T::11::+1890-11-19T00:00:00Z
T::11::+1874-02-06T00:00:00Z
T::11::+1851-07-27T00:00:00Z
T::11::+1874-02-06T00:00:00Z
T::11::+1847-05-07T00:00:00Z
T::11::+1894-03-05T00:00:00Z
T::9::+1883-00-00T00:00:00Z
T::9::+1883-00-00T00:00:00Z
T::9::+1987-00-00T00:00:00Z
T::9::+1988-01-01T00:00:00Z
T::11::+2008-04-22T00:00:00Z
T::11::+1975-08-15T00:00:00Z
T::9::+1973-01-01T00:00:00Z
T::9::+1968-01-01T00:00:00Z
T::9::+1973-01-01T00:00:00Z
T::9::+1959-01-01T00:00:00Z
T::9::+1973-01-01T00:00:00Z
T::11::+1975-08-15T00:00:00Z
T::11::+1975-08-15T00:00:00Z
T::11::+2002-07-25T00:00:00Z
T::9::+1980-01-01T00:00:00Z
T::9::+1987-01-01T00:00:00Z
T::9::+1973-01-01T00:00:00Z
T::9::+2004-01-01T00:00:00Z

T::11::+1950-01-26T00:00:00Z
T::11::+1947-08-15T00:00:00Z
T::11::+1947-08-15T00:00:00Z
T::11::+1994-12-18T00:00:00Z
T::10::+2003-05-00T00:00:00Z
T::10::+2003-05-00T00:00:00Z
T::11::+1863-06-20T00:00:00Z
T::11::+2004-07-08T00:00:00Z
T::10::+2003-05-00T00:00:00Z
T::10::+2003-05-00T00:00:00Z
T::10::+2003-05-00T00:00:00Z
T::10::+2003-05-00T00:00:00Z
T::10::+2003-05-00T00:00:00Z
T::9::+2007-01-01T00:00:00Z
T::11::+1933-03-27T00:00:00Z
T::11::+1939-10-25T00:00:00Z
T::11::+1939-10-25T00:00:00Z
T::11::+1986-09-17T00:00:00Z
T::11::+1985-08-16T00:00:00Z
T::11::+1986-09-17T00:00:00Z
T::11::+1986-09-17T00:00:00Z
T::9::+1983-00-00T00:00:00Z
T::9::+1987-00-00T00:00:00Z
T::9::+1987-00-00T00:00:00Z
T::11::+1942-07-04T00:00:00Z
T::9::+1940-00-00T00:00:00Z
T::9::+1956-00-00T00:00:00Z
Q17305006
T::9::+1956-00-00T00:00:00Z
T::9::+1951-00-00T00:00:00Z
T::9::+1974-00-00T00:00:00Z
T::9::+1951-00-00T00:00:00Z
T::11::+2005-12-21T00:00:00Z
T::11::+2004-11-22T00:00:00Z
T::11::+2004-11-22T00:00:00Z
T::11::+1984-0

T::11::+2008-10-26T00:00:00Z
T::11::+2005-11-08T00:00:00Z
T::11::+2008-10-26T00:00:00Z
T::11::+2005-11-08T00:00:00Z
T::11::+2008-10-26T00:00:00Z
T::11::+2009-09-01T00:00:00Z
T::11::+2007-07-24T00:00:00Z
T::11::+2009-03-29T00:00:00Z
T::11::+2009-06-16T00:00:00Z
T::11::+2009-12-22T00:00:00Z
T::11::+2009-11-03T00:00:00Z
T::11::+2008-06-22T00:00:00Z
T::11::+1985-08-11T00:00:00Z
T::9::+2014-01-01T00:00:00Z
T::9::+2009-01-01T00:00:00Z
Q4764958
T::9::+1948-01-01T00:00:00Z
T::9::+1948-01-01T00:00:00Z
T::9::+1966-01-01T00:00:00Z
T::11::+1931-10-16T00:00:00Z
T::11::+2011-01-09T00:00:00Z
T::11::+1959-07-02T00:00:00Z
T::11::+1932-06-18T00:00:00Z
T::9::+0523-00-00T00:00:00Z
T::9::+1055-00-00T00:00:00Z
T::9::+1213-00-00T00:00:00Z
T::10::+1939-09-00T00:00:00Z
T::10::+1939-09-00T00:00:00Z
T::11::+1939-09-17T00:00:00Z
T::11::+1939-09-01T00:00:00Z
T::9::+1940-00-00T00:00:00Z
T::11::+1944-10-02T00:00:00Z
T::11::+1918-10-28T00:00:00Z
V::3.7000::http://www.wikidata.org/entity/Q11229
V::4.9000::http://www.w

T::9::+1949-01-01T00:00:00Z
T::11::+1951-11-19T00:00:00Z
T::9::+1960-01-01T00:00:00Z
T::9::+1961-01-01T00:00:00Z
T::11::+1955-01-22T00:00:00Z
T::9::+1958-01-01T00:00:00Z
T::9::+1958-01-01T00:00:00Z
T::9::+1960-01-01T00:00:00Z
T::9::+1961-01-01T00:00:00Z
T::9::+1960-01-01T00:00:00Z
T::9::+1961-01-01T00:00:00Z
T::9::+1961-01-01T00:00:00Z
T::9::+1966-01-01T00:00:00Z
T::9::+1977-01-01T00:00:00Z
T::9::+1977-01-01T00:00:00Z
T::9::+1977-01-01T00:00:00Z
T::11::+1976-12-27T00:00:00Z
T::9::+1984-01-01T00:00:00Z
T::9::+1984-01-01T00:00:00Z
T::9::+1987-00-00T00:00:00Z
T::11::+1995-01-01T00:00:00Z
T::9::+1992-00-00T00:00:00Z
T::9::+1993-00-00T00:00:00Z
T::11::+2009-10-07T00:00:00Z
T::11::+1895-12-21T00:00:00Z
T::11::+1979-11-12T00:00:00Z
T::11::+1923-02-09T00:00:00Z
T::11::+1929-10-22T00:00:00Z
T::11::+1979-11-12T00:00:00Z
T::11::+1895-12-21T00:00:00Z
T::11::+1915-05-09T00:00:00Z
T::11::+1918-04-01T00:00:00Z
T::11::+1923-02-09T00:00:00Z
T::11::+1944-07-02T00:00:00Z
T::11::+1979-11-12T00:00:00Z
T::1

T::11::+1845-03-18T00:00:00Z
T::11::+1863-02-27T00:00:00Z
T::11::+1861-04-11T00:00:00Z
T::9::+1650-00-00T00:00:00Z
T::11::+1913-01-23T00:00:00Z
T::11::+2009-06-27T00:00:00Z
T::9::+1955-00-00T00:00:00Z
T::9::+1946-00-00T00:00:00Z
T::9::+1946-00-00T00:00:00Z
T::9::+1936-01-01T00:00:00Z
T::10::+1931-08-01T00:00:00Z
T::11::+1941-06-22T00:00:00Z
T::9::+1941-01-01T00:00:00Z
T::9::+1941-01-01T00:00:00Z
T::9::+1861-00-00T00:00:00Z
T::9::+1075-00-00T00:00:00Z
T::9::+1070-01-01T00:00:00Z
T::11::+1072-02-07T00:00:00Z
T::9::+1089-00-00T00:00:00Z
T::11::+1086-07-14T00:00:00Z
T::9::+1075-00-00T00:00:00Z
T::11::+1119-03-10T00:00:00Z
T::11::+1911-09-29T00:00:00Z
T::11::+1911-09-29T00:00:00Z
T::11::+1939-09-01T00:00:00Z
T::11::+1931-03-03T00:00:00Z
T::11::+1849-11-13T00:00:00Z
T::11::+1849-11-13T00:00:00Z
T::11::+2015-09-25T00:00:00Z
T::11::+1889-11-08T00:00:00Z
T::9::+1910-00-00T00:00:00Z
T::11::+1951-04-22T00:00:00Z
T::11::+1951-10-19T00:00:00Z
T::11::+1393-07-17T00:00:00Z
T::11::+2014-02-11T00:00:00

#### Dataset stats (lines)
- valid 60
- test 60
- train 600

In [113]:
line

{'title': 'Meridian, Mississippi',
 'tokens': ['@@START@@',
  'Meridian',
  'is',
  'the',
  'sixth',
  'largest',
  'city',
  'in',
  'the',
  'state',
  'of',
  'Mississippi',
  ',',
  'United',
  'States',
  '.',
  '@@END@@',
  'It',
  'is',
  'the',
  'county',
  'seat',
  'of',
  'Lauderdale',
  'County',
  'and',
  'the',
  'principal',
  'city',
  'of',
  'the',
  'Meridian',
  ',',
  'Mississippi',
  'Micropolitan',
  'Statistical',
  'Area',
  '.',
  '@@END@@',
  'Along',
  'major',
  'highways',
  ',',
  'the',
  'city',
  'is',
  '93',
  'mi',
  '(',
  '150',
  'km',
  ')',
  'east',
  'of',
  'Jackson',
  ',',
  'Mississippi',
  ';',
  '@@END@@',
  '154',
  'mi',
  '(',
  '248',
  'km',
  ')',
  'southwest',
  'of',
  'Birmingham',
  ',',
  'Alabama',
  ';',
  '@@END@@',
  '202',
  'mi',
  '(',
  '325',
  'km',
  ')',
  'northeast',
  'of',
  'New',
  'Orleans',
  ',',
  'Louisiana',
  ';',
  'and',
  '231',
  'mi',
  '(',
  '372',
  'km',
  ')',
  'southeast',
  'of',
  'M

In [112]:
annot

{'source': 'KG',
 'id': 'T::11::+1964-06-21T00:00:00Z',
 'relation': ['P570'],
 'parent_id': ['Q2733240'],
 'span': [669, 670]}

In [109]:
import time

for ix in range(len(embed_list)):
    print(embed_list[ix])
    time.sleep(1)
    clear_output()

[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0.]


KeyboardInterrupt: 

In [128]:
npfile = np.load("valid.npz")
npfile

<numpy.lib.npyio.NpzFile at 0x7fdb7d41ff70>

In [129]:
npfile.files

['arr_0',
 'arr_1',
 'arr_2',
 'arr_3',
 'arr_4',
 'arr_5',
 'arr_6',
 'arr_7',
 'arr_8',
 'arr_9',
 'arr_10',
 'arr_11',
 'arr_12',
 'arr_13',
 'arr_14',
 'arr_15',
 'arr_16',
 'arr_17',
 'arr_18',
 'arr_19',
 'arr_20',
 'arr_21',
 'arr_22',
 'arr_23',
 'arr_24',
 'arr_25',
 'arr_26',
 'arr_27',
 'arr_28',
 'arr_29',
 'arr_30',
 'arr_31',
 'arr_32',
 'arr_33',
 'arr_34',
 'arr_35',
 'arr_36',
 'arr_37',
 'arr_38',
 'arr_39',
 'arr_40',
 'arr_41',
 'arr_42',
 'arr_43',
 'arr_44',
 'arr_45',
 'arr_46',
 'arr_47',
 'arr_48',
 'arr_49',
 'arr_50',
 'arr_51',
 'arr_52',
 'arr_53',
 'arr_54',
 'arr_55',
 'arr_56',
 'arr_57',
 'arr_58',
 'arr_59']

In [127]:
npfile['arr_0']

array([[ 0.    ,  0.    ,  0.    , ...,  0.    ,  0.    ,  0.    ],
       [-0.5412,  0.4502,  0.2182, ..., -0.099 , -0.1066,  0.134 ],
       [-0.5412,  0.4502,  0.2182, ..., -0.099 , -0.1066,  0.134 ],
       ...,
       [ 0.    ,  0.    ,  0.    , ...,  0.    ,  0.    ,  0.    ],
       [ 0.    ,  0.    ,  0.    , ...,  0.    ,  0.    ,  0.    ],
       [ 0.    ,  0.    ,  0.    , ...,  0.    ,  0.    ,  0.    ]])

In [98]:
line['annotations']

[{'source': 'WIKI',
  'id': 'Q822935',
  'relation': ['@@NEW@@'],
  'parent_id': ['Q822935'],
  'span': [1, 3]},
 {'source': 'WIKI',
  'id': 'Q1788582',
  'relation': ['@@NEW@@'],
  'parent_id': ['Q1788582'],
  'span': [5, 7]},
 {'source': 'WIKI',
  'id': 'Q35657',
  'relation': ['@@NEW@@'],
  'parent_id': ['Q35657'],
  'span': [13, 15]},
 {'source': 'WIKI',
  'id': 'Q1408',
  'relation': ['R:P31', 'P131'],
  'parent_id': ['Q35657', 'Q822935'],
  'span': [16, 18]},
 {'source': 'COREF',
  'id': 'Q822935',
  'relation': ['@@REFLEXIVE@@', 'R:P131'],
  'parent_id': ['Q822935', 'Q1408'],
  'span': [20, 21]},
 {'source': 'WIKI',
  'id': 'Q811383',
  'relation': ['R:P131'],
  'parent_id': ['Q1408'],
  'span': [32, 35]},
 {'source': 'WIKI',
  'id': 'Q811493',
  'relation': ['@@NEW@@'],
  'parent_id': ['Q811493'],
  'span': [41, 44]},
 {'source': 'WIKI',
  'id': 'Q1073974',
  'relation': ['@@NEW@@'],
  'parent_id': ['Q1073974'],
  'span': [45, 47]},
 {'source': 'WIKI',
  'id': 'Q497795',
  'rel

In [99]:
' '.join(line['tokens'])

'@@START@@ Route 50 is a state highway in the southern part of the U.S. state of New Jersey . @@END@@ It runs 26.02 mi ( 41.88 km ) from an intersection with U.S. Route 9 ( US 9 ) and the Garden State Parkway in Upper Township , Cape May County to an intersection with US 30 and County Route 563 ( CR 563 ) in Egg Harbor City , Atlantic County . @@END@@ The route , which is mostly a two - lane undivided road , passes through mostly rural areas of Atlantic and Cape May counties as well as the communities of Tuckahoe , Corbin City , Estell Manor , and Mays Landing . @@END@@ NJ 50 intersects several roads , including Route 49 in Tuckahoe , US 40 in Mays Landing , and US 322 and the Atlantic City Expressway in Hamilton Township . @@END@@ The portion of current Route 50 between Seaville and Petersburg received funding in 1910 to become a spur of the Ocean Highway . @@END@@ In 1917 , what is now Route 50 was designated a part of pre-1927 Route 14 , a route that was to run from Cape May to Egg 

In [51]:
# with open('test.npy', 'wb') as f:
#     np.save(f, np.array([1, 2]))
#     np.save(f, np.array([1, 3]))
# with open('test.npy', 'rb') as f:
#     a = np.load(f)
#     b = np.load(f)
# a, b # [1 2] [1 3]

(array([1, 2]), array([1, 3]))

In [61]:
# with open('test2.npy', 'ab') as f:
#     np.save(f, np.array([6, 2]))

# with open('test2.npy', 'ab') as f:
#     np.save(f, np.array([1, 3]))
    
# with open('test2.npy', 'ab') as f:
#     np.save(f, np.array([2, 3]))
    
# with open('test2.npy', 'rb') as f:
#     x = np.load(f)
#     y = np.load(f)
#     z = np.load(f)
# x, y, z # [1 2] [1 3]

(array([1, 2]), array([1, 3]), array([2, 3]))

## Create PyTorch Dataset

In [130]:
## seletc Q-ids only in linked wikitext-2

import jsonlines

with jsonlines.open(valid) as f:
    for line in f.iter():
        print(line)
        break

{'title': 'New Jersey Route 50', 'tokens': ['@@START@@', 'Route', '50', 'is', 'a', 'state', 'highway', 'in', 'the', 'southern', 'part', 'of', 'the', 'U.S.', 'state', 'of', 'New', 'Jersey', '.', '@@END@@', 'It', 'runs', '26.02', 'mi', '(', '41.88', 'km', ')', 'from', 'an', 'intersection', 'with', 'U.S.', 'Route', '9', '(', 'US', '9', ')', 'and', 'the', 'Garden', 'State', 'Parkway', 'in', 'Upper', 'Township', ',', 'Cape', 'May', 'County', 'to', 'an', 'intersection', 'with', 'US', '30', 'and', 'County', 'Route', '563', '(', 'CR', '563', ')', 'in', 'Egg', 'Harbor', 'City', ',', 'Atlantic', 'County', '.', '@@END@@', 'The', 'route', ',', 'which', 'is', 'mostly', 'a', 'two', '-', 'lane', 'undivided', 'road', ',', 'passes', 'through', 'mostly', 'rural', 'areas', 'of', 'Atlantic', 'and', 'Cape', 'May', 'counties', 'as', 'well', 'as', 'the', 'communities', 'of', 'Tuckahoe', ',', 'Corbin', 'City', ',', 'Estell', 'Manor', ',', 'and', 'Mays', 'Landing', '.', '@@END@@', 'NJ', '50', 'intersects', 'se

In [131]:
line.keys()

dict_keys(['title', 'tokens', 'annotations'])

In [133]:
from datasets import load_dataset

data_files = {"train": train, "valid": valid, "test": test}
wikitest2_dataset = load_dataset("json", data_files=data_files)
wikitest2_dataset

Using custom data configuration default-c5571b3a8bc0c3d4


Downloading and preparing dataset json/default to /home/a2khatun/.cache/huggingface/datasets/json/default-c5571b3a8bc0c3d4/0.0.0/e6070c77f18f01a5ad4551a8b7edfba20b8438b7cad4d94e6ad9378022ce4aab...


Downloading data files:   0%|          | 0/3 [00:00<?, ?it/s]

Extracting data files:   0%|          | 0/3 [00:00<?, ?it/s]

0 tables [00:00, ? tables/s]

0 tables [00:00, ? tables/s]

0 tables [00:00, ? tables/s]

Dataset json downloaded and prepared to /home/a2khatun/.cache/huggingface/datasets/json/default-c5571b3a8bc0c3d4/0.0.0/e6070c77f18f01a5ad4551a8b7edfba20b8438b7cad4d94e6ad9378022ce4aab. Subsequent calls will reuse this data.


  0%|          | 0/3 [00:00<?, ?it/s]

DatasetDict({
    train: Dataset({
        features: ['title', 'tokens', 'annotations'],
        num_rows: 600
    })
    valid: Dataset({
        features: ['title', 'tokens', 'annotations'],
        num_rows: 60
    })
    test: Dataset({
        features: ['title', 'tokens', 'annotations'],
        num_rows: 60
    })
})