# Initialization

In [None]:
%%capture
!pip install datasets transformers
!pip install torch torchvision torchaudio

In [None]:
import datasets

from transformers import BertTokenizer, BertModel, BertConfig

import torch
from torch import nn
from torch import optim
import torch.nn.functional as F

import random

import pprint as pprintmodule
pp = pprintmodule.PrettyPrinter(indent=2)
pprint = pp.pprint

from tqdm.notebook import tqdm

import gc


In [None]:
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

# Google Drive / Persistence

In [None]:
### Google Colab Model Persistence

from google.colab import drive
drive.mount('/content/drive')

MODEL_PERSIST_PREFIX = 'drive/MyDrive/CS6740/Final Project/saved_models/'

Mounted at /content/drive


In [None]:
from datetime import datetime
import os

def get_path(model_name):
    if 'MODEL_PERSIST_PREFIX' in globals():
        path = os.path.join(MODEL_PERSIST_PREFIX, model_name)
    else:
        path = model_name
    return path

def get_default_name(model):
    return type(model).__name__ + '-' + datetime.now().strftime('%m-%d@%H-%M')

def save_model(model, name=None):
    name = name or (get_default_name(model) + '.pt')
    path = get_path(name)

    print('Saving', name)
    torch.save(model.state_dict(), path)
    return model

def load_model(model, name):
    path = get_path(name)
    model.load_state_dict(torch.load(path, map_location=torch.device(DEVICE)))
    return model

def get_default_checkpoint_name(model, epoch):
    return 'Checkpoint-{}-e{}.pt'.format(get_default_name(model), epoch)

def save_checkpoint(model, optim, epoch):
    name = get_default_checkpoint_name(model, epoch)
    path = get_path(name)

    print('Saving', name)
    torch.save(
        {
            'epoch': epoch,
            'model': model.state_dict(),
            'optim': optim.state_dict()
        }, 
        path
    )
    return model

def load_checkpoint(model, optim, name):
    path = get_path(name)

    checkpoint = torch.load(path, map_location=torch.device(DEVICE))

    model.load_state_dict(checkpoint['model'])
    optim.load_state_dict(checkpoint['optim'])
    return checkpoint['epoch']


# The Dataset

In [None]:
%%capture
DOCRED = datasets.load_dataset('docred')
DOCRED_TRAIN = DOCRED['train_annotated']
DOCRED_VALID = DOCRED['validation']
DOCRED_TEST = DOCRED['test']

In [None]:
def get_all_relation_ids(dataset):
    ids = set()
    for example in tqdm(dataset, leave=False):
        ids |= set(example['labels']['relation_id'])
    ids = list(ids)
    ids.sort()
    return ids

def relation_ids_to_text(dataset):
    rel2txt = {}
    for example in tqdm(dataset, leave=False):
        labels = example['labels']
        for rel, txt in zip(labels['relation_id'], labels['relation_text']):
            rel2txt[rel] = txt
    return rel2txt

ALL_RELATION_IDS = get_all_relation_ids(DOCRED_TRAIN)
RELATION_ID_TO_TEXT = relation_ids_to_text(DOCRED_TRAIN)

assert set(get_all_relation_ids(DOCRED['validation'])).issubset(set(ALL_RELATION_IDS))
assert set(get_all_relation_ids(DOCRED['test'])).issubset(set(ALL_RELATION_IDS))
assert set(get_all_relation_ids(DOCRED['train_distant'])).issubset(set(ALL_RELATION_IDS))

HBox(children=(FloatProgress(value=0.0, max=3053.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=3053.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))



HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))



# Tokenizer and Model

In [None]:
ENT_BGN = ['[E0]', '[E1]', '[E2]']
ENT_END = ['[/E0]', '[/E1]', '[/E2]']
ENT_BLN = '[BLANK]'

def create_tokenizer_bert_model():
    tokenizer = BertTokenizer.from_pretrained('bert-large-uncased')
    bert_model = BertModel.from_pretrained('bert-large-uncased').to(DEVICE)

    # add the entity markers and blank

    tokenizer.add_special_tokens({
        'additional_special_tokens': ENT_BGN + ENT_END + [ENT_BLN]
    })
    bert_model.resize_token_embeddings(len(tokenizer))

    return tokenizer, bert_model

def ent_token_ids(tokenizer):
    return {
        'bgn': tokenizer.additional_special_tokens_ids[:3],
        'end': tokenizer.additional_special_tokens_ids[3:6],
        'bln': tokenizer.additional_special_tokens_ids[6]
    }

In [None]:
def format_example(
    example, 
    eid_and_bln
):
    # flatten example['sents']
    sents = []
    sentence_offsets = [0]
    for sentence in example['sents']:
        sents += sentence
        sentence_offsets.append(sentence_offsets[-1] + len(sentence))

    entities = example['vertexSet']

    fmt = lambda entity, index: (
        sentence_offsets[entity['sent_id']] + entity['pos'][0], 
        sentence_offsets[entity['sent_id']] + entity['pos'][1], 
        index
    )

    # an associated list from (start, end) positions to entity index { 0, 1, 2 }
    pos_to_ek = []
    for i, (eid, _) in enumerate(eid_and_bln):
        for entity in entities[eid]:
            pos_to_ek.append(fmt(entity, i))
    pos_to_ek.sort(key=lambda x: x[0])

    # remove overlap
    indexes = set()
    for i, (bgn, end, _) in reversed(list(enumerate(pos_to_ek))):
        if len(set(range(bgn, end)) & indexes) != 0:
            del pos_to_ek[i]
        else:
            indexes |= set(range(bgn, end))
    
    # assert no overlap
    # indexes = set()
    # for bgn, end, _ in pos_to_ek:
    #     assert len(set(range(bgn, end)) & indexes) == 0, (bgn, end, indexes)
    #     indexes |= set(range(bgn, end))

    # insert entity markers
    for bgn, end, ei in reversed(pos_to_ek):
        sents.insert(end, ENT_END[ei])
        sents.insert(bgn, ENT_BGN[ei])

        if eid_and_bln[ei][1]:
            # blank the entity
            sents[bgn + 1] = ENT_BLN
            for i in reversed(range(bgn + 2, end + 1)):
                del sents[i]

    return ' '.join(sents)

format_example(DOCRED_TRAIN[25], [(0, False), (1, False), (2, False)])


'[E0] Chelsea [/E0] was an early [E1] 1970s [/E1] band from [E2] New York City [/E2] , best known for being the band of drummer Peter Criss before he joined Kiss . They released one album , the self - titled album Chelsea in 1971 and then collapsed during the recording of their unreleased second album . In August 1971 , the band became Lips ( a trio consisting of Criss and his [E0] Chelsea [/E0] bandmates Michael Benvenga and Stan Penridge ) . By the spring of 1973 , Lips was just the duo of Criss and Penridge and eventually disbanded completely . Their sound has been compared to the Moody Blues and Procol Harum . In 1973 , Pete Shepley & Mike Brand recorded an unreleased album which included post - Chelsea Michael Benvenga , a pre - Kiss Peter Criss , and on two songs Gene Simmons as session musicians . It was titled Captain Sanity .'

In [None]:
class BertWithEntityStartPooling(nn.Module):

    def __init__(self, bert, ent_token_ids):
        super(BertWithEntityStartPooling, self).__init__()

        self._bert = bert
        self._ent_bgn_ids = ent_token_ids['bgn']
        self._h = self._bert.config.hidden_size
        self.hidden_size = 2 * self._h

    def forward(self, input):
        batched_hidden_states = self._bert(**input).last_hidden_state
        batched_result = []

        for input_ids, masks, hidden_states in zip(
            input['input_ids'], 
            input['attention_mask'],
            batched_hidden_states
        ):
            ei_to_h = [[], [], []]

            for input_id, mask, h in zip(input_ids, masks, hidden_states):
                if not mask:
                    continue

                try:
                    ei = self._ent_bgn_ids.index(input_id)
                    ei_to_h[ei].append(h)
                except:
                    pass

            for ei, hs in enumerate(ei_to_h):
                if hs: 
                    ei_to_h[ei] = torch.cat([
                        h.unsqueeze(0) 
                        for h in hs
                    ])
                else:
                    ei_to_h[ei] = None

            for ei, hs in enumerate(ei_to_h):
                if hs != None:
                    hs = torch.transpose(hs, 0, 1).unsqueeze(0)
                    ei_to_h[ei] = F.max_pool1d(hs, hs.shape[-1]).squeeze()

                    assert ei_to_h[ei].shape == (self._h,)
                else:
                    ei_to_h[ei] = torch.zeros(self._h).to(DEVICE)

            output = []
            for ei, ej in [(0, 1), (0, 2), (1, 2)]:
                output.append(torch.cat((ei_to_h[ei], ei_to_h[ej])).unsqueeze(0))

            output = torch.cat(output).unsqueeze(0)
            batched_result.append(output)

        return torch.cat(batched_result)


class FullyConnectedLayer(nn.Module):

    """ 
    A fully connected layer with an optional activation function.
    """

    def __init__(self, input, hidden, output, activation_fn=None):
        super(FullyConnectedLayer, self).__init__()

        self._linear1 = nn.Linear(input, hidden)
        self._activation_fn = activation_fn or nn.Identity()
        self._linear2 = nn.Linear(hidden, output)

    def forward(self, input):
        return self._linear2(self._activation_fn(self._linear1(input)))


class PerClassScore(nn.Module):

    def __init__(self, pcr_size):
        super(PerClassScore, self).__init__()
        self._pcr = nn.Parameter(torch.randn(pcr_size, len(ALL_RELATION_IDS)))
    
    def forward(self, input):
        result = torch.matmul(input, self._pcr)
        result = F.log_softmax(result, dim=2)
        return result


class TriBlank(nn.Module):

    def __init__(self, bert, ent_token_ids):
        super(TriBlank, self).__init__()

        bwesp = BertWithEntityStartPooling(bert, ent_token_ids)
        h = bwesp.hidden_size
        fcl = FullyConnectedLayer(h, h // 2, h // 2)
        pcr = PerClassScore(h // 2)

        self._seq = nn.Sequential(
            bwesp,
            fcl, 
            pcr
        )

    def forward(self, input):
        return self._seq(input)

In [None]:
def run_model(model, tokenizer, batch, eid_and_bln=[(0, False), (1, False), (2, True)]):
    model.eval()
    
    input = tokenizer(
        [
            format_example(example, eid_and_bln)
            for example in batch
        ],
        padding=True,
        return_tensors='pt'
    ).to(DEVICE)
    
    output = model(input)
    print(output)
    print()
    for i in range(len(batch)):
        example = batch[i]
        for ei in range(len(eid_and_bln) - 1):
            for ej in range(ei + 1, len(eid_and_bln)):
                assert ei + ej - 1 in (0, 1, 2)
                pred = torch.argmax(output[i][ei + ej - 1]).item()
                print(example['vertexSet'][eid_and_bln[ei][0]][0]['name'])
                print(RELATION_ID_TO_TEXT[ALL_RELATION_IDS[pred]])
                print(example['vertexSet'][eid_and_bln[ej][0]][0]['name'])
                print()

def namespace():
    tokenizer, bert_model = create_tokenizer_bert_model()
    model = TriBlank(bert_model, ent_token_ids(tokenizer)).to(DEVICE)
    run_model(model, tokenizer, [DOCRED_TRAIN[5], DOCRED_TRAIN[6]])

    torch.cuda.empty_cache()
    gc.collect()

# namespace()

# Re2


In [None]:
def map_to_re2_dataset(dataset):
    re2_dataset = []

    for i, example in enumerate(tqdm(dataset)):
        labels = example['labels']

        for head, rel, tail in zip(labels['head'], labels['relation_id'], labels['tail']):
            re2_dataset.append((i, head, ALL_RELATION_IDS.index(rel), tail))

    return re2_dataset

def form_re2_batch(docred_dataset, re2_dataset, tokenizer, start_index, batch_size, blank_alpha=None):
    blank_alpha = blank_alpha or 0.7
    end_index = min(start_index + batch_size, len(re2_dataset))
    
    examples = [
        format_example(
            docred_dataset[i], 
            [
                (e0, random.random() < blank_alpha), 
                (e1, random.random() < blank_alpha)
            ]
        )
        for (i, e0, _, e1) in re2_dataset[start_index:end_index]
    ]

    examples = tokenizer(examples, padding=True, return_tensors='pt', truncation=True).to(DEVICE)

    gold = torch.tensor([x[2] for x in re2_dataset[start_index:end_index]]).to(DEVICE)

    return (examples, gold)

def iter_re2_batches(docred_dataset, re2_dataset, tokenizer, batch_size, blank_alpha=None):
    for start_index in tqdm(range(0, len(re2_dataset), batch_size)):
        yield form_re2_batch(
            docred_dataset, 
            re2_dataset, 
            tokenizer, 
            start_index, 
            batch_size, 
            blank_alpha=blank_alpha
        )


In [None]:
def train_epoch_re2(
    model, 
    tokenizer, 
    optim, 
    docred_dataset, 
    re2_dataset, 
    batch_size, 
    max_grad_norm=None,
    blank_alpha=None
):
    model.train()
    
    for examples, gold in iter_re2_batches(
            docred_dataset, 
            re2_dataset, 
            tokenizer, 
            batch_size, 
            blank_alpha=blank_alpha
        ):

        optim.zero_grad()
        output = torch.transpose(model(examples), 0, 1)
        loss = F.nll_loss(output[0], gold)
        loss.backward()
        if max_grad_norm:
            nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
        optim.step()

        del output
        del loss
        torch.cuda.empty_cache()
        gc.collect()

In [None]:
def eval_contingency_table_re2(model, tokenizer, dataset, re2_dataset, blank=False):
    model.eval()

    contingency_table = [[0] * len(ALL_RELATION_IDS) for _ in range(len(ALL_RELATION_IDS))]

    with torch.no_grad():
        for i, e0, rel, e1 in tqdm(re2_dataset):
            input = tokenizer(
                [format_example(dataset[i], [(e0, blank), (e1, blank)])],
                padding=True,
                return_tensors='pt',
                truncation=True
            ).to(DEVICE)
    
            output = model(input).squeeze(0)
            pred = torch.argmax(output[0]).item()
            contingency_table[rel][pred] += 1

    return contingency_table

# Re3

In [None]:
def all_index_triples(iter):
    for i in range(len(iter)):
        for j in range(i + 1, len(iter)):
            for k in range(j + 1, len(iter)):
                yield (i, j, k)

def is_related_set(example):
    related_indexes = set()
    label = example['labels']
    for head, tail in zip(label['head'], label['tail']):
        related_indexes.add((head, tail))
        related_indexes.add((tail, head))
    return related_indexes

def extract_triples(example):
    triples = set()
    is_related = is_related_set(example)
    for i, j, k in all_index_triples(example['vertexSet']):
        if (i, j) in is_related and (j, k) in is_related:
            triples.add((i, j, k))
    return triples

def extract_triway(example):
    triples = set()
    is_related = is_related_set(example)
    for i, j, k in all_index_triples(example['vertexSet']):
        if (i, j) in is_related and (j, k) in is_related and (i, k) in is_related:
            triples.add((i, j, k))
    return triples

def print_relation(example, index_pair):
    vertex_set = example['vertexSet']
    labels = example['labels']
    i, j = index_pair
    for head, relation, tail in zip(labels['head'], labels['relation_text'], labels['tail']):
        head_rep = vertex_set[head][0]['name']
        tail_rep = vertex_set[tail][0]['name']

        if (i, j) == (head, tail):
            print('{} {} {}'.format(head_rep, relation, tail_rep))
            
        elif (j, i) == (head, tail):
            print('{} {} {}'.format(tail_rep, relation, head_rep))

def print_triple_relation(example, index_triple):
    i, j, k = index_triple
    print_relation(example, (i, j))
    print_relation(example, (j, k))
    print_relation(example, (i, k))
            
def print_extract_triway(example):
    for index_triple in extract_triway(example):
        print(index_triple)
        print_triple_relation(example, index_triple)
        print()

# extract_triway(DOCRED_TRAIN[0])

In [None]:
def extract_labeled_edges(example):
    labels = example['labels']
    labeled_edges = {}
    
    for head, relation, tail in zip(labels['head'], labels['relation_id'], labels['tail']):
        key = (head, tail)

        if key not in labeled_edges:
            labeled_edges[key] = []

        labeled_edges[key].append(relation)

    return labeled_edges

def map_to_re3_dataset(dataset):
    re3_dataset = []

    for dataset_index, example in enumerate(tqdm(dataset)):
        labeled_edges = extract_labeled_edges(example)

        for i in range(len(example['vertexSet'])):
            for j in range(len(example['vertexSet'])):
                if i == j:
                    continue
                
                for k in range(len(example['vertexSet'])):
                    if i == k or j == k:
                        continue

                    if (i, j) not in labeled_edges or (j, k) not in labeled_edges:
                        continue
                        
                    for rel_ij in labeled_edges[(i, j)]:
                        for rel_jk in labeled_edges[(j, k)]:
                            re3_dataset.append((
                                dataset_index, 
                                i, 
                                ALL_RELATION_IDS.index(rel_ij),
                                j,
                                ALL_RELATION_IDS.index(rel_jk),
                                k
                            ))

    return re3_dataset

def form_re3_batch(docred_dataset, re3_dataset, tokenizer, start_index, batch_size, blank_alpha=None):
    blank_alpha = blank_alpha or 0.7
    end_index = min(start_index + batch_size, len(re3_dataset))
    
    examples = [
        format_example(
            docred_dataset[i], 
            [
                (e0, random.random() < blank_alpha), 
                (e1, random.random() < blank_alpha),
                (e2, random.random() < blank_alpha)
            ]
        )
        for (i, e0, _, e1, _, e2) in re3_dataset[start_index:end_index]
    ]

    examples = tokenizer(examples, padding=True, return_tensors='pt', truncation=True).to(DEVICE)

    gold_01 = torch.tensor([x[2] for x in re3_dataset[start_index:end_index]]).to(DEVICE)
    gold_12 = torch.tensor([x[4] for x in re3_dataset[start_index:end_index]]).to(DEVICE)

    return (examples, gold_01, gold_12)

def iter_re3_batches(docred_dataset, re3_dataset, tokenizer, batch_size, blank_alpha=None):
    for start_index in tqdm(range(0, len(re3_dataset), batch_size)):
        yield form_re3_batch(
            docred_dataset, 
            re3_dataset, 
            tokenizer, 
            start_index, 
            batch_size, 
            blank_alpha=blank_alpha
        )


In [None]:
def train_epoch_re3(
    model, 
    tokenizer, 
    optim, 
    docred_dataset, 
    re3_dataset, 
    batch_size, 
    max_grad_norm=None,
    blank_alpha=None
):
    model.train()
    
    for examples, gold_01, gold_12 in iter_re3_batches(
            docred_dataset, 
            re3_dataset, 
            tokenizer, 
            batch_size, 
            blank_alpha=blank_alpha
        ):

        optim.zero_grad()
        output = torch.transpose(model(examples), 0, 1)
        loss = F.nll_loss(output[0], gold_01) + F.nll_loss(output[2], gold_12)
        loss.backward()
        if max_grad_norm:
            nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
        optim.step()

        del output
        del loss
        torch.cuda.empty_cache()
        gc.collect()

In [None]:
def eval_contingency_table_re3(model, tokenizer, dataset, re3_dataset, blank=False):
    model.eval()

    contingency_table = [[0] * len(ALL_RELATION_IDS) for _ in range(len(ALL_RELATION_IDS))]

    with torch.no_grad():
        for i, e0, rel_01, e1, rel_12, e2 in tqdm(re3_dataset):
            input = tokenizer(
                [format_example(dataset[i], [(e0, blank), (e1, blank), (e2, blank)])],
                padding=True,
                return_tensors='pt',
                truncation=True
            ).to(DEVICE)
    
            output = model(input).squeeze(0)
    
            pred = torch.argmax(output[0]).item()
            contingency_table[rel_01][pred] += 1
    
            pred = torch.argmax(output[2]).item()
            contingency_table[rel_12][pred] += 1

    return contingency_table


# Tri

In [None]:
def map_to_tri_dataset(dataset):
    tri_dataset = []

    for dataset_index, example in enumerate(tqdm(dataset)):
        labeled_edges = extract_labeled_edges(example)

        for i in range(len(example['vertexSet'])):
            for j in range(len(example['vertexSet'])):
                if i == j:
                    continue
                
                for k in range(len(example['vertexSet'])):
                    if i == k or j == k:
                        continue

                    if (i, j) not in labeled_edges or (i, k) not in labeled_edges or (j, k) not in labeled_edges:
                        continue
                        
                    for rel_ij in labeled_edges[(i, j)]:
                        for rel_ik in labeled_edges[(i, k)]:
                            for rel_jk in labeled_edges[(j, k)]:
                                tri_dataset.append((
                                    dataset_index, 
                                    i, 
                                    j, 
                                    k,
                                    ALL_RELATION_IDS.index(rel_ij),
                                    ALL_RELATION_IDS.index(rel_ik),
                                    ALL_RELATION_IDS.index(rel_jk),
                                ))

    return tri_dataset

def form_tri_batch(docred_dataset, tri_dataset, tokenizer, start_index, batch_size, blank_alpha=None):
    blank_alpha = blank_alpha or 0.7
    end_index = min(start_index + batch_size, len(tri_dataset))
    
    examples = [
        format_example(
            docred_dataset[i], 
            [
                (e0, random.random() < blank_alpha), 
                (e1, random.random() < blank_alpha),
                (e2, random.random() < blank_alpha)
            ]
        )
        for (i, e0, e1, e2, _, _, _) in tri_dataset[start_index:end_index]
    ]

    examples = tokenizer(examples, padding=True, return_tensors='pt', truncation=True).to(DEVICE)

    gold_01 = torch.tensor([x[4] for x in tri_dataset[start_index:end_index]]).to(DEVICE)
    gold_02 = torch.tensor([x[5] for x in tri_dataset[start_index:end_index]]).to(DEVICE)
    gold_12 = torch.tensor([x[6] for x in tri_dataset[start_index:end_index]]).to(DEVICE)

    return (examples, gold_01, gold_02, gold_12)

def iter_tri_batches(docred_dataset, tri_dataset, tokenizer, batch_size, blank_alpha=None):
    for start_index in tqdm(range(0, len(tri_dataset), batch_size)):
        yield form_tri_batch(
            docred_dataset, 
            tri_dataset, 
            tokenizer, 
            start_index, 
            batch_size, 
            blank_alpha=blank_alpha
        )


In [None]:
def train_epoch_tri(
    model, 
    tokenizer, 
    optim, 
    docred_dataset, 
    tri_dataset, 
    batch_size, 
    max_grad_norm=None,
    blank_alpha=None
):
    model.train()
    
    for examples, gold_01, gold_02, gold_12 in iter_tri_batches(
            docred_dataset, 
            tri_dataset, 
            tokenizer, 
            batch_size, 
            blank_alpha=blank_alpha
        ):

        optim.zero_grad()
        output = torch.transpose(model(examples), 0, 1)
        loss = F.nll_loss(output[0], gold_01) + F.nll_loss(output[1], gold_02) + F.nll_loss(output[2], gold_12)
        loss.backward()
        if max_grad_norm:
            nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
        optim.step()

        del output
        del loss
        torch.cuda.empty_cache()
        gc.collect()

In [None]:
def eval_contingency_table_tri(model, tokenizer, dataset, tri_dataset, blank=False):
    model.eval()

    contingency_table = [[0] * len(ALL_RELATION_IDS) for _ in range(len(ALL_RELATION_IDS))]

    with torch.no_grad():
        for i, e0, e1, e2, rel_01, rel_02, rel_12 in tqdm(tri_dataset):
            input = tokenizer(
                [format_example(dataset[i], [(e0, blank), (e1, blank), (e2, blank)])],
                padding=True,
                return_tensors='pt',
                truncation=True
            ).to(DEVICE)
    
            output = model(input).squeeze(0)
    
            pred = torch.argmax(output[0]).item()
            contingency_table[rel_01][pred] += 1
    
            pred = torch.argmax(output[1]).item()
            contingency_table[rel_02][pred] += 1
    
            pred = torch.argmax(output[2]).item()
            contingency_table[rel_12][pred] += 1

    return contingency_table


In [None]:
def eval_contingency_table_all(model, tokenizer, dataset, blank=False):
    def compute_acc(tbl):
        return sum(tbl[i][i] for i in range(len(tbl))) / sum(tbl[i][j] for i in range(len(tbl)) for j in range(len(tbl)))
        
    re2_dataset = map_to_re2_dataset(dataset)
    re2_tbl = eval_contingency_table_re2(model, tokenizer, dataset, re2_dataset, blank=blank)
    re2_acc = compute_acc(re2_tbl)

    re3_dataset = map_to_re3_dataset(dataset)
    re3_tbl = eval_contingency_table_re3(model, tokenizer, dataset, re3_dataset, blank=blank)
    re3_acc = compute_acc(re3_tbl)

    tri_dataset = map_to_tri_dataset(dataset)
    tri_tbl = eval_contingency_table_tri(model, tokenizer, dataset, tri_dataset, blank=blank)
    tri_acc = compute_acc(tri_tbl)

    print(re2_acc, re3_acc, tri_acc)

    return {
        're2': re2_tbl,
        're3': re3_tbl,
        'tri': tri_tbl
    }
