In [10]:
from tango import Workspace
workspace = Workspace.from_url("local:///root/workspace/read/temp/training")
result = workspace.step_result_for_run("infotab_totto_pipeline", "eval_verification")

In [11]:
result

{'loss': 0.6931571960449219,
 'accuracy': {'accuracy': 0.5},
 'f1': {'f1': 0.0},
 'precision': {'precision': 0.0},
 'recall': {'recall': 0.0}}

In [19]:
model["failed_cases"]["document_retrieval"][0]

{'table': [{'Year': '1941',
   'Title': 'The Little Foxes',
   'Role': 'Cal',
   'Notes': ''},
  {'Year': '1953',
   'Title': 'The Joe Louis Story',
   'Role': 'Sam Langford',
   'Notes': ''},
  {'Year': '1963', 'Title': 'The Cool World', 'Role': 'Hurst', 'Notes': ''},
  {'Year': '1964', 'Title': 'Black Like Me', 'Role': 'Hodges', 'Notes': ''},
  {'Year': '1972',
   'Title': 'Dear Dead Delilah',
   'Role': 'Marshall',
   'Notes': ''},
  {'Year': '1972', 'Title': 'Corky', 'Role': 'Junkman', 'Notes': ''},
  {'Year': '1973',
   'Title': 'Badge 373',
   'Role': 'Superintendent',
   'Notes': ''},
  {'Year': '1975',
   'Title': 'Dog Day Afternoon',
   'Role': 'Howard',
   'Notes': '(final film role)'}],
 'table_webpage_url': 'http://en.wikipedia.org/wiki/John_Marriott_(actor)',
 'table_page_title': 'John Marriott (actor)',
 'table_section_title': 'Filmography',
 'table_section_text': '',
 'highlighted_cells': [[7, 0], [7, 1], [7, 2], [7, 3]],
 'example_id': -8824613940853128584,
 'overlap_su

In [3]:
import torch
import json
from tango import Step
from tango.common.dataset_dict import DatasetDict
import pandas as pd
from transformers import TapasTokenizer

class TableDataset(torch.utils.data.Dataset):
    def __init__(self, df, tokenizer):
        self.df = df
        self.tokenizer = tokenizer

    def __getitem__(self, idx):
        ex_id = idx % 2
        idx = idx // 2
        item = self.df.iloc[idx]
        table = pd.DataFrame(json.loads(item["table"]))
        cells = zip(*item["highlighted_cells"])
        cells = [list(x) for x in cells]
        sub_table = table.iloc[cells[0], cells[1]].reset_index().astype(str)

        if ex_id == 0:
            encoding = self.tokenizer(
                table=sub_table,
                queries=item["positive"],
                padding="max_length",
                truncation=True,
                max_length=512,
                return_tensors="pt",
            )
            encoding["labels"] = torch.tensor([1])
        else:
            encoding = self.tokenizer(
                table=sub_table,
                queries=item["negative"],
                padding="max_length",
                truncation=True,
                max_length=512,
                return_tensors="pt",
            )
            encoding["labels"] = torch.tensor([0])

        encoding = {key: val[-1] for key, val in encoding.items()}
        return encoding

    def __len__(self):
        return len(self.df)

In [7]:
tokenizer = TapasTokenizer.from_pretrained("google/tapas-base", max_question_length=256)
torch.manual_seed(1)
dev_df = pd.read_json("../temp/seed/sent_selection/data/dev.jsonl", lines=True)

dev_dataset = TableDataset(dev_df, tokenizer)

In [17]:
from torch.utils.data import DataLoader
import accelerate

import evaluate
name2metrics = {
    "accuracy": evaluate.load("accuracy"),
    "precision": evaluate.load("precision"),
    "recall": evaluate.load("recall"),
    "f1": evaluate.load("f1"),
}

dataloader = DataLoader(dev_dataset, batch_size=8, shuffle=False)
accelerator = accelerate.Accelerator()

print(type(model))

model, dataloader = accelerator.prepare(model, dataloader)

for batch in dataloader:
    y_hat = model(**batch)
    preds = y_hat.logits.argmax(dim=1)
    for metric in name2metrics.values():
        metric.add_batch(predictions=preds, references=batch["labels"])



<class 'transformers.models.tapas.modeling_tapas.TapasForSequenceClassification'>


In [18]:
for name, metric in name2metrics.items():
    print(name, metric.compute())

accuracy {'accuracy': 0.9653325817361894}
precision {'precision': 0.97524467472654}
recall {'recall': 0.9549041713641488}
f1 {'f1': 0.9649672457989177}
