In [1]:
import torch
import nlp

from transformers import T5ForConditionalGeneration, T5Tokenizer

from tqdm.auto import tqdm

from sklearn import metrics

In [2]:
# Load the pretrained model
model = T5ForConditionalGeneration.from_pretrained('./models_csqa_3epochs')
tokenizer = T5Tokenizer.from_pretrained('./models_csqa_3epochs')

In [3]:
# Load the validation dataset
valid_dataset = torch.load('./data/commonsense_qa/valid_data.pt')
dataloader = torch.utils.data.DataLoader(valid_dataset, batch_size = 32)

In [4]:
# Generate predictions
predictions = []
targets = []
for batch in tqdm(dataloader):
    prediction = model.generate(input_ids = batch['input_ids'], 
                          attention_mask = batch['attention_mask'],
                          max_length = 16,
                          early_stopping = True)
    prediction = [tokenizer.decode(ids) for ids in prediction]
    target = [tokenizer.decode(ids) for ids in batch['target_ids']]
    
    predictions.extend(prediction)
    targets.extend(target)

HBox(children=(FloatProgress(value=0.0, max=39.0), HTML(value='')))

  return function(data_struct)





In [5]:
metrics.accuracy_score(targets, predictions)

0.6183456183456183

In [18]:
incorrect_idxs = [i for i, prediction in enumerate(predictions) if prediction != targets[i]]
for incorrect_idx in incorrect_idxs:
    print(tokenizer.decode(valid_dataset[incorrect_idx]['input_ids']))
    print("Target Answer: {} Predicted Answer: {}".format(tokenizer.decode(valid_dataset[incorrect_idx]['target_ids']), predictions[incorrect_idx]))

question: A revolving door is convenient for two direction travel, but it also serves as a security measure at a what? options: A: bank B: library C: department store D: mall E: new york
Target Answer: A Predicted Answer: D
question: James was looking for a good place to buy farmland. Where might he look? options: A: midwest B: countryside C: estate D: farming areas E: illinois
Target Answer: A Predicted Answer: C
question: What island country is ferret popular? options: A: own home B: north carolina C: great britain D: hutch E: outdoors
Target Answer: C Predicted Answer: B
question: What do animals do when an enemy is approaching? options: A: feel pleasure B: procreate C: pass water D: listen to each other E: sing
Target Answer: D Predicted Answer: C
question: What do people typically do while playing guitar? options: A: cry B: hear sounds C: singing D: arthritis E: making music
Target Answer: C Predicted Answer: E
question: What would vinyl be an odd thing to replace? options: A: pan

question: What might someone believe in if they are cleaning clothes? options: A: feminism B: sanitation C: ruined D: wrinkles E: buttons to fall off
Target Answer: B Predicted Answer: A
question: Where would you find a basement that can be accessed with an elevator? options: A: eat cake B: closet C: church D: office building E: own house
Target Answer: D Predicted Answer: E
question: What could you get an unsmooth pit from? options: A: backyard B: rock C: mine D: cherry E: peach
Target Answer: E Predicted Answer: B
question: The man tried to reply to the woman, but he had difficulty keeping track of conversations that he didn't do what to? options: A: initiate B: ignore C: question D: answer E: ask
Target Answer: A Predicted Answer: B
question: How can someone be let into a brownstone? options: A: brooklyn B: ring C: subdivision D: bricks E: new york city
Target Answer: B Predicted Answer: D
question: Where would you keep an ottoman near your front door? options: A: living room B: par

Target Answer: E Predicted Answer: A
question: Where are horses judged on appearance? options: A: race track B: fair C: raised by humans D: in a field E: countryside
Target Answer: B Predicted Answer: A
question: Why do people read non fiction? options: A: having fun B: it's more relatable C: learn new things D: becoming absorbed E: falling asleep
Target Answer: C Predicted Answer: B
question: The man flew his airplane over the city and saw pollution visibly in the sky, what was polluted? options: A: forest B: street C: air D: caused by humans E: car show
Target Answer: C Predicted Answer: B
question: If not in a stream but in a market where will you find fish? options: A: stream B: aquarium C: refrigerator D: boat ride E: market
Target Answer: C Predicted Answer: E
question: During a shark filled tornado where should you not be? options: A: marine museum B: pool hall C: noodle house D: bad movie E: outside
Target Answer: E Predicted Answer: D
question: Where would you put a glass afte

question: Going public about a common problem can gain what for a celebrity? options: A: wide acceptance B: a degree C: pain D: getting high E: press coverage
Target Answer: A Predicted Answer: E
question: The electricity went out and everyone was shrouded in darkness. They all remained in their seats, because it would have been dangerous to try to find there way out. Where mihgt they have been? options: A: opera B: concert C: basement D: bedroom E: grand canyon
Target Answer: A Predicted Answer: B
question: If it is Chrismas time what came most recently before? options: A: halloween B: summer C: easter D: kwaanza E: give gift
Target Answer: A Predicted Answer: C
question: The criminal insisted he must do the crime to the bank teller, but she tried to convince him there were other ways in life and this was what? options: A: willing B: optional C: should not D: have to E: unnecessary
Target Answer: E Predicted Answer: B
question: what do you fill with ink to write? options: A: squid B: 

Target Answer: E Predicted Answer: C
question: There's one obvious reason to eat vegetables, they're plain what you? options: A: lose weight B: good for C: bland D: chewing E: fibre
Target Answer: B Predicted Answer: C
question: What is the sun ultimately responsible for? options: A: earth warming B: sun tan C: light D: life on earth E: heat
Target Answer: D Predicted Answer: E
question: They were searching for rocks, so they missed the birds overhead as they stared at the what? options: A: ground B: drawer C: surface of earth D: pizza E: waterfall
Target Answer: A Predicted Answer: E
question: When you wipe you feet on the door mat and walk through the door where do you enter? options: A: a chair B: school C: living room D: doorway E: bathroom
Target Answer: C Predicted Answer: E
question: What can you use to store a book while traveling? options: A: library of congress B: pocket C: backpack D: suitcase E: synagogue
Target Answer: D Predicted Answer: C
question: Where would you find g

question: If I wanted to eat something that is made from plants and needs to be washed, what would it be? options: A: roots B: millions of cells C: see work D: leaves to gather light E: flowers on
Target Answer: A Predicted Answer: E
question: The homeowner frowned at the price of gas, what did he have to do later? options: A: own home B: mail property tax payments C: board windows D: cut grass E: receive mail
Target Answer: D Predicted Answer: B
question: What has a shelf that does not allow you to see what is inside of it? options: A: chest of drawers B: stove C: hold alcohol D: bookcase E: grocery store
Target Answer: A Predicted Answer: D
question: The boat passenger was explaining his fear of blowfish, but the captain figured he meant piranhas since they were on a river in the what? options: A: cuba B: styx C: atlantic ocean D: france E: jungle
Target Answer: E Predicted Answer: C
question: Where could you find only a few office? options: A: skyscraper B: new york C: school buildi

Target Answer: A Predicted Answer: D
question: What could people do that involves talking? options: A: confession B: state park C: sing D: carnival E: opera
Target Answer: A Predicted Answer: C
question: If you're a child answering questions and an adult is asking them that adult is doing what? options: A: discussion B: explaning C: teaching D: confusion E: correct
Target Answer: C Predicted Answer: B
question: He has lactose intolerant, but was eating dinner made of cheese, what followed for him? options: A: digestive B: feel better C: sleepiness D: indigestion E: illness
Target Answer: D Predicted Answer: B
question: When you get an F, you fail. If you get A's you are? options: A: passed B: completing C: passed D: passing E: succeeding
Target Answer: D Predicted Answer: E
question: What is the main purpose of having a bath? options: A: cleanness B: wetness C: exfoliation D: use water E: hygiene
Target Answer: A Predicted Answer: E
question: The ball was hit over a boundary and struck

Target Answer: A Predicted Answer: E
question: What mall store sells jeans for a decent price? options: A: clothing store B: bedroom C: thrift store D: apartment E: gap
Target Answer: E Predicted Answer: C
question: Where can a bath towel be borrowed? options: A: cupboard B: at hotel C: swimming pool D: clothes line E: backpack
Target Answer: B Predicted Answer: C
question: John rode on the plain until it reached the ocean and couldn't go any farther. What might he have bee on? options: A: mountain B: fancy C: sandplain D: cliff E: gorge
Target Answer: D Predicted Answer: E
question: Where would you use a folding chair but not store one? options: A: beach B: city hall C: closet D: garage E: school
Target Answer: A Predicted Answer: D
question: What does impeachment mean for the president? options: A: vote B: election C: trouble D: board room E: corporation
Target Answer: C Predicted Answer: B
question: Where is hard to read note likely to be? options: A: fridge B: sheet music C: desk D

Target Answer: E Predicted Answer: D
question: If you really wanted a grape, where would you go to get it? options: A: winery B: fruit stand C: field D: kitchen E: food
Target Answer: B Predicted Answer: A


In [19]:
# Load the pretrained model
model = T5ForConditionalGeneration.from_pretrained('./models/social_i_qa_commonsense_qa')
tokenizer = T5Tokenizer.from_pretrained('./models/social_i_qa_commonsense_qa')

In [20]:
# Generate predictions
predictions = []
targets = []
for batch in tqdm(dataloader):
    prediction = model.generate(input_ids = batch['input_ids'], 
                          attention_mask = batch['attention_mask'],
                          max_length = 16,
                          early_stopping = True)
    prediction = [tokenizer.decode(ids) for ids in prediction]
    target = [tokenizer.decode(ids) for ids in batch['target_ids']]
    
    predictions.extend(prediction)
    targets.extend(target)

HBox(children=(FloatProgress(value=0.0, max=39.0), HTML(value='')))




In [21]:
metrics.accuracy_score(targets, predictions)

0.6224406224406225