In [22]:
from tango import Workspace
workspace = Workspace.from_url("local:///root/workspace/read/temp/training")
result = workspace.step_result_for_run("seed_totto_pipeline", "evaluation")

In [23]:
result["result"]

{'document_retrieval': {'accuracy': 0.8819376230239868,
  'recall': 0.8819376230239868,
  'precision': 0.27042147517204285,
  'mrr': 0.507101833820343},
 'sentence_selection': {'accuracy': 0.8165030479431152,
  'recall': 0.8165030479431152,
  'precision': 0.656883180141449,
  'mrr': 0.7559483051300049},
 'table_verification': {'accuracy': 0.8594416975975037,
  'f1': 0.8585124015808105,
  'precision': 0.8836339116096497,
  'recall': 0.8347797989845276}}

In [13]:
result["failed_cases"]["document_retrieval"][0]

[{'sentence': 'As a band, Jet was active for over two decades total.',
  'table': [{'title': ['Jet'],
    'Origin': ['Melbourne, Victoria, Australia'],
    'Genres': ['Garage rock', 'hard rock', 'alternative rock', 'indie rock'],
    'Years active': ['2001-2012', '2016-present'],
    'Labels': ['Atlantic', 'Elektra', 'EMI Music Group', 'Rubber'],
    'Associated acts': ['The Bamboos',
     'The CA$inos',
     'The Wrights',
     'TISM',
     'DAMNDOGS']}],
  'label': False,
  'title': 'Jet',
  'linearized_table': '<page_title> Jet </page_title> <table> <cell> Jet <col_header> title </col_header> </cell> <cell> Melbourne, Victoria, Australia <col_header> Origin </col_header> </cell> <cell> Garage rock hard rock alternative rock indie rock <col_header> Genres </col_header> </cell> <cell> 2001-2012 2016-present <col_header> Years active </col_header> </cell> <cell> Atlantic Elektra EMI Music Group Rubber <col_header> Labels </col_header> </cell> <cell> The Bamboos The CA$inos The Wrights 

In [3]:
import torch
import json
from tango import Step
from tango.common.dataset_dict import DatasetDict
import pandas as pd
from transformers import TapasTokenizer

class TableDataset(torch.utils.data.Dataset):
    def __init__(self, df, tokenizer):
        self.df = df
        self.tokenizer = tokenizer

    def __getitem__(self, idx):
        ex_id = idx % 2
        idx = idx // 2
        item = self.df.iloc[idx]
        table = pd.DataFrame(json.loads(item["table"]))
        cells = zip(*item["highlighted_cells"])
        cells = [list(x) for x in cells]
        sub_table = table.iloc[cells[0], cells[1]].reset_index().astype(str)

        if ex_id == 0:
            encoding = self.tokenizer(
                table=sub_table,
                queries=item["positive"],
                padding="max_length",
                truncation=True,
                max_length=512,
                return_tensors="pt",
            )
            encoding["labels"] = torch.tensor([1])
        else:
            encoding = self.tokenizer(
                table=sub_table,
                queries=item["negative"],
                padding="max_length",
                truncation=True,
                max_length=512,
                return_tensors="pt",
            )
            encoding["labels"] = torch.tensor([0])

        encoding = {key: val[-1] for key, val in encoding.items()}
        return encoding

    def __len__(self):
        return len(self.df)

In [7]:
tokenizer = TapasTokenizer.from_pretrained("google/tapas-base", max_question_length=256)
torch.manual_seed(1)
dev_df = pd.read_json("../temp/seed/sent_selection/data/dev.jsonl", lines=True)

dev_dataset = TableDataset(dev_df, tokenizer)

In [17]:
from torch.utils.data import DataLoader
import accelerate

import evaluate
name2metrics = {
    "accuracy": evaluate.load("accuracy"),
    "precision": evaluate.load("precision"),
    "recall": evaluate.load("recall"),
    "f1": evaluate.load("f1"),
}

dataloader = DataLoader(dev_dataset, batch_size=8, shuffle=False)
accelerator = accelerate.Accelerator()

print(type(model))

model, dataloader = accelerator.prepare(model, dataloader)

for batch in dataloader:
    y_hat = model(**batch)
    preds = y_hat.logits.argmax(dim=1)
    for metric in name2metrics.values():
        metric.add_batch(predictions=preds, references=batch["labels"])



<class 'transformers.models.tapas.modeling_tapas.TapasForSequenceClassification'>


In [18]:
for name, metric in name2metrics.items():
    print(name, metric.compute())

accuracy {'accuracy': 0.9653325817361894}
precision {'precision': 0.97524467472654}
recall {'recall': 0.9549041713641488}
f1 {'f1': 0.9649672457989177}
