# Analysis3 - extracting FiD encoder embedding

## CHECKING PARSER

In [None]:
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import set_seed
import heapq
import pathlib
import shutil
from FiD.src.model import FiDT5
from src.model import FiDEncoderForSequenceClassification

from pprint import pprint
from tqdm.auto import tqdm
from src.data import BinaryCustomDatasetShuffle

import json
import math
import os
import logging
import sys
import evaluate
from util import utils

import transformers
import torch
import numpy as np
import random
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torch.utils.data import DataLoader, SequentialSampler
from transformers import (
    AutoModelForSequenceClassification,
    AutoModel,
    AutoConfig,
    AutoTokenizer,
    DataCollatorWithPadding,
    EvalPrediction,
    HfArgumentParser,
    set_seed,
    get_scheduler,
)
from util.arguments import ModelArguments, DataTrainingArguments, CustomTrainingArguments

In [None]:
os.environ['CUDA_VISIBLE_DEVICES']='1'
# os.environ['CUDA_VISIBLE_DEVICES']='1'

In [None]:
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, CustomTrainingArguments))

model_args, data_args, train_args = parser.parse_args_into_dataclasses([])

In [None]:
vars(model_args)

In [None]:
vars(data_args)

In [None]:
train_dict = vars(train_args)

In [None]:
train_dict

## modeling

In [None]:
from pprint import pprint
import numpy as np
import torch
from torch import nn
import transformers
from transformers import AutoConfig, AutoTokenizer
from transformers import T5PreTrainedModel
import copy
from transformers.modeling_outputs import SequenceClassifierOutput
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

In [None]:
from FiD.src.model import FiDT5

In [None]:
model_path = '/data/philhoon-relevance/FiD/pretrained_models/nq_reader_large'
num_labels = 2

In [None]:
model_class = FiDT5

In [None]:
config = AutoConfig.from_pretrained(model_path, num_labels=num_labels)

In [None]:
pprint(config)

In [None]:
model = model_class.from_pretrained(model_path)

In [None]:
# model.encoder => FiDT5.EncoderWrapper
# model.encoder.encoder => FiDT5.EncoderWrapper.encoder = T5 encoder Architecture w FiDT5 parameters
model_encoder = model.encoder.encoder

In [None]:
type(model_encoder)

In [None]:
tokenizer = AutoTokenizer.from_pretrained('t5-base', return_dict=False)

In [None]:
# tokenizer = transformers.T5Tokenizer.from_pretrained('t5-base', return_dict=False)

In [None]:
example = {'id': 12, 
           'question': "when was the public service commission original version of the upsc set up", 
           'ctx': {"id": "17105334", 
                   "title": "Bihar Public Service Commission", 
                   "text": "3 of the Regulations, 1960 the Commission was constituted with a Chairman and 10 (ten) other members. The strength of members was reduced to 6 (six) after bifurcation of the State of Bihar and the State of Jharkhand vide notification no. 7/PSC-1013/95 (Part-3) Per 8262 dated 9 October 2002 of the Personnel & Administrative Reforms Department, Bihar. Article 320 and 321 of the Constitution of India prescribes the mandate of the State Public Service Commissions, which are: a)Recruitment by conduct of Competitive Examinations/ through interviews to the services of the State Government. b)Advising the State Government on the suitability of"}
          }


In [None]:
padded_output = tokenizer(example['question'], padding=True)

In [None]:
padded_output

In [None]:
tokenizer.convert_ids_to_tokens(padded_output['input_ids'])

In [None]:
test_ids = np.expand_dims(np.array(padded_output['input_ids']), 0)

In [None]:
test_ids = torch.from_numpy(test_ids)
print(test_ids)

In [None]:
test_attention = np.expand_dims(np.array(padded_output['attention_mask']), 0)

In [None]:
test_attention = torch.from_numpy(test_attention)
print(test_attention)

In [None]:
result = model_encoder.forward(input_ids = test_ids, attention_mask = test_attention)

In [None]:
vars(result)

In [None]:
result['last_hidden_state'].shape

## Get Input 

In [None]:
import FiD.src.data

In [None]:
eval_data = '/data/philhoon-relevance/FiD/open_domain_data/NQ/dev.json'

In [None]:
eval_examples = FiD.src.data.load_data(
        eval_data,
    )

In [None]:
len(eval_examples)

In [None]:
eval_examples[0].keys()

In [None]:
eval_examples[0]['question']

In [None]:
eval_examples[0]['answers']

In [None]:
# eval_examples[0]['id']

In [None]:
import FiD.src.data

In [None]:
n_context = 10

In [None]:
eval_dataset = FiD.src.data.Dataset(
    eval_examples,  
    n_context = n_context
)

In [None]:
len(eval_dataset[0]['passages'])

In [None]:
eval_dataset[0].keys()

In [None]:
# eval_dataset[0]

In [None]:
text_maxlength = 200

In [None]:
collator_function = FiD.src.data.Collator(text_maxlength, tokenizer, n_context)

In [None]:
eval_sampler = SequentialSampler(eval_dataset) 

In [None]:
per_gpu_batch_size = 2

In [None]:
eval_dataloader = DataLoader(
    eval_dataset, 
    sampler=eval_sampler, 
    batch_size=per_gpu_batch_size,
    num_workers=8,
    collate_fn=collator_function
)


In [None]:
iter_ = iter(eval_dataloader)

In [None]:
ins = next(iter_)

In [None]:
# (index, target_ids, target_mask, passage_ids, passage_masks)

In [None]:
## Index
ins[0]

In [None]:
# target_ids
ins[1].shape

In [None]:
# target_mask
ins[2].shape

In [None]:
# passage_ids = context_ids = input_ids
ins[3].shape

In [None]:
# passage_masks  = context_masks = attention_masks
ins[4].shape

In [None]:
# result = model_encoder.forward(input_ids = ins[3], attention_mask = ins[4])

In [None]:
# ins.keys()

In [None]:
# ins[2].size(0)

In [None]:
input_ids, attention_mask = ins[3], ins[4]

In [None]:
input_ids = input_ids.view(input_ids.size(0), -1)
attention_mask = attention_mask.view(attention_mask.size(0), -1)

In [None]:
print(input_ids.shape, attention_mask.shape)

In [None]:
bsz, total_length = input_ids.shape

In [None]:
print(bsz, total_length)

In [None]:
n_context

In [None]:
n_passages = n_context
print(n_passages)
print(total_length)

In [None]:
# n_passages = 100
passage_length = total_length // n_passages
print(passage_length)

In [None]:
input_ids = input_ids.view(bsz*n_passages, passage_length)
print(input_ids.shape)

In [None]:
attention_mask = attention_mask.view(bsz*n_passages, passage_length)
print(attention_mask.shape)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else 'cpu')

In [None]:
model_encoder.to(device = device)

In [None]:
# t = torch.randn(2,2).cuda()
print(input_ids.is_cuda)  # returns True
# t = t.cpu()
# t.is_cuda  # returns False

In [None]:
device

In [None]:
input_ids = input_ids.to(device)
print(input_ids.is_cuda)
attention_mask = attention_mask.to(device)
print(input_ids.is_cuda)

In [None]:
# bsz * # psgs X 
print(input_ids.shape)
print(attention_mask.shape)

## Chekcing inputs are 500 * 200 shape and in a right order

In [None]:
# eval_dataset[0]

In [None]:
# tokenizer.convert_ids_to_tokens(input_ids[0][:])

In [None]:
# tokenizer.convert_ids_to_tokens(input_ids[1][:])

In [None]:
outputs = model_encoder(input_ids, attention_mask)

In [None]:
vars(outputs)

In [None]:
print(outputs[0].shape)
print(bsz, n_passages, passage_length)

In [None]:
output_by_batch = outputs[0]

In [None]:
# outputs = self.encoder(input_ids, attention_mask, **kwargs)
# outputs = (outputs[0].view(bsz, n_passages*passage_length, -1), ) + outputs[1:]

In [None]:
# outputs[0].shape

In [None]:
bsz

In [None]:
output_by_batch.shape

In [None]:
output_by_batch = output_by_batch.view(bsz, n_passages, passage_length, -1)

In [None]:
output_by_batch.shape

In [None]:
output_by = output_by_batch.detach().cpu()

In [None]:
print(output_by_batch.is_cuda)
print(output_by.is_cuda)

In [None]:
for i in range(2):
    print(output_by[i,].shape)

In [None]:
print(output_by.shape)