# Environment preparation

## Libraries

In [1]:
!pip install transformers
!pip install jsonlines

Collecting jsonlines
  Downloading jsonlines-4.0.0-py3-none-any.whl.metadata (1.6 kB)
Downloading jsonlines-4.0.0-py3-none-any.whl (8.7 kB)
Installing collected packages: jsonlines
Successfully installed jsonlines-4.0.0


In [2]:
import os
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import random
import numpy as np
import sklearn.metrics as metrics
import json
import csv
from torch.utils import data
from transformers import AutoModel, AutoTokenizer, AdamW, get_linear_schedule_with_warmup
from sklearn.feature_extraction.text import TfidfVectorizer
from collections import Counter
from nltk.corpus import stopwords
import spacy
import jsonlines
import re
import time
import sklearn
from tqdm import tqdm
from scipy.special import softmax


In [3]:
!python -m spacy download en_core_web_lg

Collecting en-core-web-lg==3.7.1
  Downloading https://github.com/explosion/spacy-models/releases/download/en_core_web_lg-3.7.1/en_core_web_lg-3.7.1-py3-none-any.whl (587.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m587.7/587.7 MB[0m [31m1.3 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: en-core-web-lg
Successfully installed en-core-web-lg-3.7.1
[38;5;2m✔ Download and installation successful[0m
You can now load the package via spacy.load('en_core_web_lg')
[38;5;3m⚠ Restart to reload dependencies[0m
If you are in a Jupyter or Colab notebook, you may need to restart Python in
order to load all the package's dependencies. You can do this by selecting the
'Restart kernel' or 'Restart runtime' option.


In [4]:
import nltk
nltk.download('stopwords')

[nltk_data] Downloading package stopwords to /root/nltk_data...
[nltk_data]   Unzipping corpora/stopwords.zip.


True

## Sample data

In [5]:
url = f"https://raw.githubusercontent.com/megagonlabs/ditto/master/data/er_magellan/Structured/Beer/train.txt"
!wget --no-cache --backups=1 {url}
url = f"https://raw.githubusercontent.com/megagonlabs/ditto/master/data/er_magellan/Structured/Beer/valid.txt"
!wget --no-cache --backups=1 {url}
url = f"https://raw.githubusercontent.com/megagonlabs/ditto/master/data/er_magellan/Structured/Beer/test.txt"
!wget --no-cache --backups=1 {url}

--2024-09-07 08:14:53--  https://raw.githubusercontent.com/megagonlabs/ditto/master/data/er_magellan/Structured/Beer/train.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 84402 (82K) [text/plain]
Failed to rename train.txt to train.txt.1: (2) No such file or directory
Saving to: ‘train.txt’


2024-09-07 08:14:54 (52.8 MB/s) - ‘train.txt’ saved [84402/84402]

--2024-09-07 08:14:54--  https://raw.githubusercontent.com/megagonlabs/ditto/master/data/er_magellan/Structured/Beer/valid.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Leng

# Data preparation

## Augmenter

In [6]:
class Augmenter:
    """Data augmentation operator.

    Support both span and attribute level augmentation operators.
    """
    def __init__(self):
        pass

    def augment(self, tokens, labels, op='del'):
        """ Performs data augmentation on a sequence of tokens

        The supported ops:
           ['del', 'drop_col',
            'append_col', 'drop_token',
            'drop_len',
            'drop_sym',
            'drop_same',
            'swap',
            'ins',
            'all']

        Args:
            tokens (list of strings): the input tokens
            labels (list of strings): the labels of the tokens
            op (str, optional): a string encoding of the operator to be applied

        Returns:
            list of strings: the augmented tokens
            list of strings: the augmented labels
        """
        if 'del' in op:
            # insert padding to keep the length consistent
            # span_len = random.randint(1, 3)
            span_len = random.randint(1, 2)
            pos1, pos2 = self.sample_span(tokens, labels, span_len=span_len)
            if pos1 < 0:
                return tokens, labels
            new_tokens = tokens[:pos1] + tokens[pos2+1:]
            new_labels = tokens[:pos1] + labels[pos2+1:]
        elif 'swap' in op:
            span_len = random.randint(2, 4)
            pos1, pos2 = self.sample_span(tokens, labels, span_len=span_len)
            if pos1 < 0:
                return tokens, labels
            sub_arr = tokens[pos1:pos2+1]
            random.shuffle(sub_arr)
            new_tokens = tokens[:pos1] + sub_arr + tokens[pos2+1:]
            new_labels = tokens[:pos1] + ['O'] * (pos2 - pos1 + 1) + labels[pos2+1:]
        elif 'drop_len' in op:
            # drop tokens below a certain length
            all_lens = [len(token) for token, label in \
                    zip(tokens, labels) if label == 'O']
            if len(all_lens) == 0:
                return tokens, labels
            target_lens = random.choices(all_lens, k=1)
            new_tokens = []
            new_labels = []

            for token, label in zip(tokens, labels):
                if label != 'O' or len(token) not in target_lens:
                    new_tokens.append(token)
                    new_labels.append(label)
            return new_tokens, new_labels
        elif 'drop_sym' in op:
            def drop_sym(token):
                return ''.join([ch if ch.isalnum() else ' ' for ch in token])
            dropped_tokens = [drop_sym(token) for token in tokens]
            new_tokens = []
            new_labels = []
            for token, d_token, label in zip(tokens, dropped_tokens, labels):
                if random.randint(0, 4) != 0 or label != 'O':
                    new_tokens.append(token)
                    new_labels.append(label)
                else:
                    if d_token != '':
                        new_tokens.append(d_token)
                        new_labels.append(label)
        elif 'drop_same' in op:
            left_token = set([])
            right_token = set([])
            left = True
            for token, label in zip(tokens, labels):
                if label == 'O':
                    token = token.lower()
                    if left:
                        left_token.add(token)
                    else:
                        right_token.add(token)
                if token == '[SEP]':
                    left = False

            same = left_token & right_token
            targets = random.choices(list(same), k=1)
            new_tokens, new_labels = [], []
            for token, label in zip(tokens, labels):
                if token.lower() not in targets or label != 'O':
                    new_tokens.append(token)
                    new_labels.append(label)
            return new_tokens, new_labels
        elif 'drop_token' in op:
            new_tokens, new_labels = [], []
            for token, label in zip(tokens, labels):
                if label != 'O' or random.randint(0, 4) != 0:
                    new_tokens.append(token)
                    new_labels.append(label)
            return new_tokens, new_labels
        elif 'ins' in op:
            pos = self.sample_position(tokens, labels)
            symbol = random.choice('-*.,#&')
            new_tokens = tokens[:pos] + [symbol] + tokens[pos:]
            new_labels = labels[:pos] + ['O'] + labels[pos:]
            return new_tokens, new_labels
        elif 'append_col' in op:
            col_starts = [i for i in range(len(tokens)) if tokens[i] == 'COL']
            col_ends = [0] * len(col_starts)
            col_lens = [0] * len(col_starts)
            for i, pos in enumerate(col_starts):
                if i == len(col_starts) - 1:
                    col_lens[i] = len(tokens) - pos
                    col_ends[i] = len(tokens) - 1
                else:
                    col_lens[i] = col_starts[i + 1] - pos
                    col_ends[i] = col_starts[i + 1] - 1

                if tokens[col_ends[i]] == '[SEP]':
                    col_ends[i] -= 1
                    col_lens[i] -= 1
                    break
            candidates = [i for i, le in enumerate(col_lens) if le > 0]
            if len(candidates) >= 2:
                idx1, idx2 = random.sample(candidates,k=2)
                start1, end1 = col_starts[idx1], col_ends[idx1]
                sub_tokens = tokens[start1:end1+1]
                sub_labels = labels[start1:end1+1]
                val_pos = 0
                for i, token in enumerate(sub_tokens):
                    if token == 'VAL':
                        val_pos = i + 1
                        break
                sub_tokens = sub_tokens[val_pos:]
                sub_labels = sub_labels[val_pos:]

                end2 = col_ends[idx2]
                new_tokens = []
                new_labels = []
                for i in range(len(tokens)):
                    if start1 <= i <= end1:
                        continue
                    new_tokens.append(tokens[i])
                    new_labels.append(labels[i])
                    if i == end2:
                        new_tokens += sub_tokens
                        new_labels += sub_labels
                return new_tokens, new_labels
            else:
                new_tokens, new_labels = tokens, labels
        elif 'drop_col' in op:
            col_starts = [i for i in range(len(tokens)) if tokens[i] == 'COL']
            col_ends = [0] * len(col_starts)
            col_lens = [0] * len(col_starts)
            for i, pos in enumerate(col_starts):
                if i == len(col_starts) - 1:
                    col_lens[i] = len(tokens) - pos
                    col_ends[i] = len(tokens) - 1
                else:
                    col_lens[i] = col_starts[i + 1] - pos
                    col_ends[i] = col_starts[i + 1] - 1

                if tokens[col_ends[i]] == '[SEP]':
                    col_ends[i] -= 1
                    col_lens[i] -= 1
            candidates = [i for i, le in enumerate(col_lens) if le <= 8]
            if len(candidates) > 0:
                idx = random.choice(candidates)
                start, end = col_starts[idx], col_ends[idx]
                new_tokens = tokens[:start] + tokens[end+1:]
                new_labels = labels[:start] + labels[end+1:]
            else:
                new_tokens, new_labels = tokens, labels
        else:
            new_tokens, new_labels = tokens, labels

        return new_tokens, new_labels


    def augment_sent(self, text, op='all'):
        """ Performs data augmentation on a classification example.

        Similar to augment(tokens, labels) but works for sentences
        or sentence-pairs.

        Args:
            text (str): the input sentence
            op (str, optional): a string encoding of the operator to be applied

        Returns:
            str: the augmented sentence
        """
        # 50% of chance of flipping
        if ' [SEP] ' in text and random.randint(0, 1) == 0:
            left, right = text.split(' [SEP] ')
            text = right + ' [SEP] ' + left

        # tokenize the sentence
        current = ''
        tokens = text.split(' ')

        # avoid the special tokens
        labels = []
        for token in tokens:
            if token in ['COL', 'VAL']:
                labels.append('HD')
            elif token in ['[CLS]', '[SEP]']:
                labels.append('<SEP>')
            else:
                labels.append('O')

        if op == 'all':
            # RandAugment: https://arxiv.org/pdf/1909.13719.pdf
            N = 3
            ops = ['del', 'swap', 'drop_col', 'append_col']
            for op in random.choices(ops, k=N):
                tokens, labels = self.augment(tokens, labels, op=op)
        else:
            tokens, labels = self.augment(tokens, labels, op=op)
        results = ' '.join(tokens)
        return results

    def sample_span(self, tokens, labels, span_len=3):
        candidates = []
        for idx, token in enumerate(tokens):
            if idx + span_len - 1 < len(labels) and ''.join(labels[idx:idx+span_len]) == 'O'*span_len:
                candidates.append((idx, idx+span_len-1))
        if len(candidates) <= 0:
            return -1, -1
        return random.choice(candidates)

    def sample_position(self, tokens, labels, tfidf=False):
        candidates = []
        for idx, token in enumerate(tokens):
            if labels[idx] == 'O':
                candidates.append(idx)
        if len(candidates) <= 0:
            return -1
        return random.choice(candidates)

## Dataset

In [7]:
class DittoDataset(data.Dataset):
    """EM dataset"""

    def __init__(self,
                 path,
                 max_len=256,
                 size=None,
                 lm='roberta-base',
                 da=None):
        self.tokenizer = AutoTokenizer.from_pretrained(lm)
        self.pairs = []
        self.labels = []
        self.max_len = max_len
        self.size = size

        if isinstance(path, list):
            lines = path
        else:
            lines = open(path)

        for line in lines:
            s1, s2, label = line.strip().split('\t')
            self.pairs.append((s1, s2))
            self.labels.append(int(label))

        self.pairs = self.pairs[:size]
        self.labels = self.labels[:size]
        self.da = da
        if da is not None:
            self.augmenter = Augmenter()
        else:
            self.augmenter = None


    def __len__(self):
        """Return the size of the dataset."""
        return len(self.pairs)

    def __getitem__(self, idx):
        """Return a tokenized item of the dataset.

        Args:
            idx (int): the index of the item

        Returns:
            List of int: token ID's of the two entities
            List of int: token ID's of the two entities augmented (if da is set)
            int: the label of the pair (0: unmatch, 1: match)
        """
        left = self.pairs[idx][0]
        right = self.pairs[idx][1]

        # left + right
        x = self.tokenizer.encode(text=left,
                                  text_pair=right,
                                  max_length=self.max_len,
                                  truncation=True)

        # augment if da is set
        if self.da is not None:
            combined = self.augmenter.augment_sent(left + ' [SEP] ' + right, self.da)
            left, right = combined.split(' [SEP] ')
            x_aug = self.tokenizer.encode(text=left,
                                      text_pair=right,
                                      max_length=self.max_len,
                                      truncation=True)
            return x, x_aug, self.labels[idx]
        else:
            return x, self.labels[idx]


    @staticmethod
    def pad(batch):
        """Merge a list of dataset items into a train/test batch
        Args:
            batch (list of tuple): a list of dataset items

        Returns:
            LongTensor: x1 of shape (batch_size, seq_len)
            LongTensor: x2 of shape (batch_size, seq_len).
                        Elements of x1 and x2 are padded to the same length
            LongTensor: a batch of labels, (batch_size,)
        """
        if len(batch[0]) == 3:
            x1, x2, y = zip(*batch)

            maxlen = max([len(x) for x in x1+x2])
            x1 = [xi + [0]*(maxlen - len(xi)) for xi in x1]
            x2 = [xi + [0]*(maxlen - len(xi)) for xi in x2]
            return torch.LongTensor(x1), \
                   torch.LongTensor(x2), \
                   torch.LongTensor(y)
        else:
            x12, y = zip(*batch)
            maxlen = max([len(x) for x in x12])
            x12 = [xi + [0]*(maxlen - len(xi)) for xi in x12]
            return torch.LongTensor(x12), \
                   torch.LongTensor(y)

## Summarizer

In [8]:
stopwords = set(stopwords.words('english'))

class Summarizer:
    """To summarize a data entry pair into length up to the max sequence length.

    Args:
        task_config (Dictionary): the task configuration
        lm (string): the language model (bert, albert, or distilbert)

    Attributes:
        config (Dictionary): the task configuration
        tokenizer (Tokenizer): a tokenizer from the huggingface library
    """
    def __init__(self, task_config, lm):
        self.config = task_config
        self.tokenizer = AutoTokenizer.from_pretrained(lm)
        self.len_cache = {}

        # build the tfidf index
        self.build_index()

    def build_index(self):
        """Build the idf index.

        Store the index and vocabulary in self.idf and self.vocab.
        """
        fns = [self.config['trainset'],
               self.config['validset'],
               self.config['testset']]
        content = []
        for fn in fns:
            with open(fn) as fin:
                for line in fin:
                    LL = line.split('\t')
                    if len(LL) > 2:
                        for entry in LL:
                            content.append(entry)

        vectorizer = TfidfVectorizer().fit(content)
        self.vocab = vectorizer.vocabulary_
        self.idf = vectorizer.idf_

    def get_len(self, word):
        """Return the sentence_piece length of a token.
        """
        if word in self.len_cache:
            return self.len_cache[word]
        length = len(self.tokenizer.tokenize(word))
        self.len_cache[word] = length
        return length

    def transform(self, row, max_len=128):
        """Summarize one single example.

        Only retain tokens of the highest tf-idf

        Args:
            row (str): a matching example of two data entries and a binary label, separated by tab
            max_len (int, optional): the maximum sequence length to be summarized to

        Returns:
            str: the summarized example
        """
        sentA, sentB, label = row.strip().split('\t')
        res = ''
        cnt = Counter()
        for sent in [sentA, sentB]:
            tokens = sent.split(' ')
            for token in tokens:
                if token not in ['COL', 'VAL'] and \
                   token not in stopwords:
                    if token in self.vocab:
                        cnt[token] += self.idf[self.vocab[token]]

        for sent in [sentA, sentB]:
            token_cnt = Counter(sent.split(' '))
            total_len = token_cnt['COL'] + token_cnt['VAL']

            subset = Counter()
            for token in set(token_cnt.keys()):
                subset[token] = cnt[token]
            subset = subset.most_common(max_len)

            topk_tokens_copy = set([])
            for word, _ in subset:
                bert_len = self.get_len(word)
                if total_len + bert_len > max_len:
                    break
                total_len += bert_len
                topk_tokens_copy.add(word)

            num_tokens = 0
            for token in sent.split(' '):
                if token in ['COL', 'VAL']:
                    res += token + ' '
                elif token in topk_tokens_copy:
                    res += token + ' '
                    topk_tokens_copy.remove(token)

            res += '\t'

        res += label + '\n'
        return res

    def transform_file(self, input_fn, max_len=256, overwrite=False):
        """Summarize all lines of a tsv file.

        Run the summarizer. If the output already exists, just return the file name.

        Args:
            input_fn (str): the input file name
            max_len (int, optional): the max sequence len
            overwrite (bool, optional): if true, then overwrite any cached output

        Returns:
            str: the output file name
        """
        out_fn = input_fn + '.su'
        if not os.path.exists(out_fn) or \
           os.stat(out_fn).st_size == 0 or overwrite:
            with open(out_fn, 'w') as fout:
                for line in open(input_fn):
                    fout.write(self.transform(line, max_len=max_len))
        return out_fn

## DK injector

In [9]:
class DKInjector:
    """Inject domain knowledge to the data entry pairs.

    Attributes:
        config: the task configuration
        name: the injector name
    """
    def __init__(self, config, name):
        self.config = config
        self.name = name
        self.initialize()

    def initialize(self):
        """Initialize spacy"""
        self.nlp = spacy.load('en_core_web_lg')

    def transform_file(self, input_fn, overwrite=False):
        """Transform all lines of a tsv file.

        Run the knowledge injector. If the output already exists, just return the file name.

        Args:
            input_fn (str): the input file name
            overwrite (bool, optional): if true, then overwrite any cached output

        Returns:
            str: the output file name
        """
        out_fn = input_fn + '.dk'
        if not os.path.exists(out_fn) or \
            os.stat(out_fn).st_size == 0 or overwrite:

            with open(out_fn, 'w') as fout:
                for line in open(input_fn):
                    LL = line.split('\t')
                    if len(LL) == 3:
                        entry0 = self.transform(LL[0])
                        entry1 = self.transform(LL[1])
                        fout.write(entry0 + '\t' + entry1 + '\t' + LL[2])
        return out_fn

    def transform(self, entry):
        """Transform a data entry.

        Use NER to regconize the product-related named entities and
        mark them in the sequence. Normalize the numbers into the same format.

        Args:
            entry (str): the serialized data entry

        Returns:
            str: the transformed entry
        """
        res = ''
        doc = self.nlp(entry, disable=['tagger', 'parser'])
        ents = doc.ents
        start_indices = {}
        end_indices = {}

        for ent in ents:
            start, end, label = ent.start, ent.end, ent.label_
            if label in ['PERSON', 'ORG', 'LOC', 'PRODUCT', 'DATE', 'QUANTITY', 'TIME']:
                start_indices[start] = label
                end_indices[end] = label

        for idx, token in enumerate(doc):
            if idx in start_indices:
                res += start_indices[idx] + ' '

            # normalizing the numbers
            if token.like_num:
                try:
                    val = float(token.text)
                    if val == round(val):
                        res += '%d ' % (int(val))
                    else:
                        res += '%.2f ' % (val)
                except:
                    res += token.text + ' '
            elif len(token.text) >= 7 and \
                 any([ch.isdigit() for ch in token.text]):
                res += 'ID ' + token.text + ' '
            else:
                res += token.text + ' '
        return res.strip()

# Model

In [10]:
class DittoModel(nn.Module):

    def __init__(self, device='cuda', lm='roberta-base', alpha_aug=0.8):
        super().__init__()

        self.bert = AutoModel.from_pretrained(lm)
        self.device = device
        self.alpha_aug = alpha_aug

        # linear layer
        hidden_size = self.bert.config.hidden_size
        self.fc = torch.nn.Linear(hidden_size, 2)


    def forward(self, x1, x2=None):
        """Encode the left, right, and the concatenation of left+right.

        Args:
            x1 (LongTensor): a batch of ID's
            x2 (LongTensor, optional): a batch of ID's (augmented)

        Returns:
            Tensor: binary prediction
        """
        x1 = x1.to(self.device) # (batch_size, seq_len)

        if x2 is not None:
            # MixDA
            x2 = x2.to(self.device) # (batch_size, seq_len)
            enc = self.bert(torch.cat((x1, x2)))[0][:, 0, :]
            batch_size = len(x1)
            enc1 = enc[:batch_size] # (batch_size, emb_size)
            enc2 = enc[batch_size:] # (batch_size, emb_size)

            aug_lam = np.random.beta(self.alpha_aug, self.alpha_aug)
            enc = enc1 * aug_lam + enc2 * (1.0 - aug_lam)
        else:
            enc = self.bert(x1)[0][:, 0, :]

        return self.fc(enc)

# Train utils

In [11]:
def evaluate(model, iterator, threshold=None):
    """Evaluate a model on a validation/test dataset

    Args:
        model (DMModel): the EM model
        iterator (Iterator): the valid/test dataset iterator
        threshold (float, optional): the threshold on the 0-class

    Returns:
        float: the F1 score
        float (optional): if threshold is not provided, the threshold
            value that gives the optimal F1
    """
    all_p = []
    all_y = []
    all_probs = []
    with torch.no_grad():
        for batch in iterator:
            x, y = batch
            logits = model(x)
            probs = logits.softmax(dim=1)[:, 1]
            all_probs += probs.cpu().numpy().tolist()
            all_y += y.cpu().numpy().tolist()

    if threshold is not None:
        pred = [1 if p > threshold else 0 for p in all_probs]
        f1 = metrics.f1_score(all_y, pred)
        return f1
    else:
        best_th = 0.5
        f1 = 0.0

        for th in np.arange(0.0, 1.0, 0.05):
            pred = [1 if p > th else 0 for p in all_probs]
            new_f1 = metrics.f1_score(all_y, pred)
            if new_f1 > f1:
                f1 = new_f1
                best_th = th

        return f1, best_th


def train_step(train_iter, model, optimizer, scheduler, hp):
    """Perform a single training step

    Args:
        train_iter (Iterator): the train data loader
        model (DMModel): the model
        optimizer (Optimizer): the optimizer (Adam or AdamW)
        scheduler (LRScheduler): learning rate scheduler
        hp (Namespace): other hyper-parameters (e.g., fp16)

    Returns:
        None
    """
    criterion = nn.CrossEntropyLoss()

    for i, batch in enumerate(train_iter):
        optimizer.zero_grad()

        if len(batch) == 2:
            x, y = batch
            prediction = model(x)
        else:
            x1, x2, y = batch
            prediction = model(x1, x2)

        loss = criterion(prediction, y.to(model.device))

        loss.backward()
        optimizer.step()
        scheduler.step()
        if i % 10 == 0: # monitoring
            print(f"step: {i}, loss: {loss.item()}")
        del loss


def train(trainset, validset, testset, hp):
    """Train and evaluate the model

    Args:
        trainset (DittoDataset): the training set
        validset (DittoDataset): the validation set
        testset (DittoDataset): the test set
        hp (Namespace): Hyper-parameters (e.g., batch_size,
                        learning rate)

    Returns:
        None
    """
    padder = trainset.pad
    # create the DataLoaders
    train_iter = data.DataLoader(dataset=trainset,
                                 batch_size=args['batch_size'],
                                 shuffle=True,
                                 num_workers=0,
                                 collate_fn=padder)
    valid_iter = data.DataLoader(dataset=validset,
                                 batch_size=args['batch_size'],
                                 shuffle=False,
                                 num_workers=0,
                                 collate_fn=padder)
    test_iter = data.DataLoader(dataset=testset,
                                 batch_size=args['batch_size'],
                                 shuffle=False,
                                 num_workers=0,
                                 collate_fn=padder)

    # initialize model, optimizer, and LR scheduler
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model = DittoModel(device=device,
                       lm=args['lm'],
                       alpha_aug=args['alpha_aug'])
    model = model.cuda()
    optimizer = AdamW(model.parameters(), lr=args['lr'])

    num_steps = (len(trainset) // args['batch_size']) * args['n_epochs']
    scheduler = get_linear_schedule_with_warmup(optimizer,
                                                num_warmup_steps=0,
                                                num_training_steps=num_steps)

    best_dev_f1 = best_test_f1 = 0.0
    for epoch in range(1, args['n_epochs']+1):
        # train
        model.train()
        train_step(train_iter, model, optimizer, scheduler, hp)

        # eval
        model.eval()
        dev_f1, th = evaluate(model, valid_iter)
        test_f1 = evaluate(model, test_iter, threshold=th)

        if dev_f1 > best_dev_f1:
            best_dev_f1 = dev_f1
            best_test_f1 = test_f1
            if args['save_model']:
                # save the checkpoints for each component
                ckpt_path = os.path.join('model.pt')
                ckpt = {'model': model.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'scheduler': scheduler.state_dict(),
                        'epoch': epoch}
                torch.save(ckpt, ckpt_path)

        print(f"epoch {epoch}: f1={dev_f1}")

# Train

In [12]:
args = {
    "task": "Structured/Beer",
    "run_id": 123,
    "batch_size": 64,
    "max_len": 256,
    "lr": 3e-5,
    "n_epochs": 20,
    "finetuning": True,
    "save_model": True,
    "logdir": "checkpoints/",
    "lm": "roberta-base",
    "da": 'drop_col',
    "alpha_aug": 0.8,
    "dk": None,
    "summarize": False,
    "size": None
}

# set seeds
seed = args['run_id']
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)

trainset = 'train.txt'
validset = 'valid.txt'
testset = 'test.txt'
config = {
    'trainset': trainset,
    'validset': validset,
    'testset': testset,
}

# summarize the sequences up to the max sequence length
if args['summarize']:
    summarizer = Summarizer(config, lm=args['lm'])
    trainset = summarizer.transform_file(trainset, max_len=args['max_len'])
    validset = summarizer.transform_file(validset, max_len=args['max_len'])
    testset = summarizer.transform_file(testset, max_len=args['max_len'])

if args['dk'] is not None:
    injector = DKInjector(config, args['dk'])

    trainset = injector.transform_file(trainset)
    validset = injector.transform_file(validset)
    testset = injector.transform_file(testset)

# load train/dev/test sets
train_dataset = DittoDataset(trainset,
                             lm=args['lm'],
                             max_len=args['max_len'],
                             size=args['size'],
                             da=args['da'])
valid_dataset = DittoDataset(validset, lm=args['lm'])
test_dataset = DittoDataset(testset, lm=args['lm'])

# train and evaluate the model
train(train_dataset,
      valid_dataset,
      test_dataset,
      args)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/25.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/481 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]



model.safetensors:   0%|          | 0.00/499M [00:00<?, ?B/s]

Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


step: 0, loss: 0.6497491002082825
epoch 1: f1=0.2666666666666667
step: 0, loss: 0.39931249618530273
epoch 2: f1=0.27184466019417475
step: 0, loss: 0.4042988717556
epoch 3: f1=0.2692307692307693
step: 0, loss: 0.43629908561706543
epoch 4: f1=0.3773584905660377
step: 0, loss: 0.3302553594112396
epoch 5: f1=0.40625
step: 0, loss: 0.2828435003757477
epoch 6: f1=0.4444444444444445
step: 0, loss: 0.40554261207580566
epoch 7: f1=0.45454545454545453
step: 0, loss: 0.42321130633354187
epoch 8: f1=0.5777777777777778
step: 0, loss: 0.41000962257385254
epoch 9: f1=0.689655172413793
step: 0, loss: 0.2845877408981323
epoch 10: f1=0.7878787878787878
step: 0, loss: 0.203620046377182
epoch 11: f1=0.7741935483870968
step: 0, loss: 0.32986846566200256
epoch 12: f1=0.896551724137931
step: 0, loss: 0.12400054931640625
epoch 13: f1=0.896551724137931
step: 0, loss: 0.24103820323944092
epoch 14: f1=0.9032258064516129
step: 0, loss: 0.047060687094926834
epoch 15: f1=0.9032258064516129
step: 0, loss: 0.12831860

# Inference utils

In [13]:
def to_str(ent1, ent2, summarizer=None, max_len=256, dk_injector=None):
    """Serialize a pair of data entries

    Args:
        ent1 (Dictionary): the 1st data entry
        ent2 (Dictionary): the 2nd data entry
        summarizer (Summarizer, optional): the summarization module
        max_len (int, optional): the max sequence length
        dk_injector (DKInjector, optional): the domain-knowledge injector

    Returns:
        string: the serialized version
    """
    content = ''
    for ent in [ent1, ent2]:
        if isinstance(ent, str):
            content += ent
        else:
            for attr in ent.keys():
                content += 'COL %s VAL %s ' % (attr, ent[attr])
        content += '\t'

    content += '0'

    if summarizer is not None:
        content = summarizer.transform(content, max_len=max_len)

    new_ent1, new_ent2, _ = content.split('\t')
    if dk_injector is not None:
        new_ent1 = dk_injector.transform(new_ent1)
        new_ent2 = dk_injector.transform(new_ent2)

    return new_ent1 + '\t' + new_ent2 + '\t0'

def classify(sentence_pairs, model,
             lm='distilbert',
             max_len=256,
             threshold=None):
    """Apply the MRPC model.

    Args:
        sentence_pairs (list of str): the sequence pairs
        model (MultiTaskNet): the model in pytorch
        max_len (int, optional): the max sequence length
        threshold (float, optional): the threshold of the 0's class

    Returns:
        list of float: the scores of the pairs
    """
    inputs = sentence_pairs
    # print('max_len =', max_len)
    dataset = DittoDataset(inputs,
                           max_len=max_len,
                           lm=lm)
    # print(dataset[0])
    iterator = data.DataLoader(dataset=dataset,
                               batch_size=len(dataset),
                               shuffle=False,
                               num_workers=0,
                               collate_fn=DittoDataset.pad)

    # prediction
    all_probs = []
    all_logits = []
    with torch.no_grad():
        # print('Classification')
        for i, batch in enumerate(iterator):
            x, _ = batch
            logits = model(x)
            probs = logits.softmax(dim=1)[:, 1]
            all_probs += probs.cpu().numpy().tolist()
            all_logits += logits.cpu().numpy().tolist()

    if threshold is None:
        threshold = 0.5

    pred = [1 if p > threshold else 0 for p in all_probs]
    return pred, all_logits

def predict(input_path, output_path, config,
            model,
            batch_size=1024,
            summarizer=None,
            lm='distilbert',
            max_len=256,
            dk_injector=None,
            threshold=None):
    """Run the model over the input file containing the candidate entry pairs

    Args:
        input_path (str): the input file path
        output_path (str): the output file path
        config (Dictionary): task configuration
        model (DittoModel): the model for prediction
        batch_size (int): the batch size
        summarizer (Summarizer, optional): the summarization module
        max_len (int, optional): the max sequence length
        dk_injector (DKInjector, optional): the domain-knowledge injector
        threshold (float, optional): the threshold of the 0's class

    Returns:
        None
    """
    pairs = []

    def process_batch(rows, pairs, writer):
        predictions, logits = classify(pairs, model, lm=lm,
                                       max_len=max_len,
                                       threshold=threshold)
        # try:
        #     predictions, logits = classify(pairs, model, lm=lm,
        #                                    max_len=max_len,
        #                                    threshold=threshold)
        # except:
        #     # ignore the whole batch
        #     return
        scores = softmax(logits, axis=1)
        for row, pred, score in zip(rows, predictions, scores):
            output = {'left': row[0], 'right': row[1],
                'match': pred,
                'match_confidence': score[int(pred)]}
            writer.write(output)

    # input_path can also be train/valid/test.txt
    # convert to jsonlines
    if '.txt' in input_path:
        with jsonlines.open(input_path + '.jsonl', mode='w') as writer:
            for line in open(input_path):
                writer.write(line.split('\t')[:2])
        input_path += '.jsonl'

    # batch processing
    start_time = time.time()
    with jsonlines.open(input_path) as reader,\
         jsonlines.open(output_path, mode='w') as writer:
        pairs = []
        rows = []
        for idx, row in tqdm(enumerate(reader)):
            pairs.append(to_str(row[0], row[1], summarizer, max_len, dk_injector))
            rows.append(row)
            if len(pairs) == batch_size:
                process_batch(rows, pairs, writer)
                pairs.clear()
                rows.clear()

        if len(pairs) > 0:
            process_batch(rows, pairs, writer)


def tune_threshold(config, model, hp):
    """Tune the prediction threshold for a given model on a validation set"""
    validset = config['validset']

    # summarize the sequences up to the max sequence length
    summarizer = injector = None
    if hp["summarize"]:
        summarizer = Summarizer(config, lm=hp["lm"])
        validset = summarizer.transform_file(validset, max_len=hp["max_len"], overwrite=True)

    if hp["dk"] is not None:
        injector = DKInjector(config, hp["dk"])

        validset = injector.transform_file(validset)

    # load dev sets
    valid_dataset = DittoDataset(validset,
                                 max_len=hp["max_len"],
                                 lm=hp["lm"])

    valid_iter = data.DataLoader(dataset=valid_dataset,
                                 batch_size=64,
                                 shuffle=False,
                                 num_workers=0,
                                 collate_fn=DittoDataset.pad)

    f1, th = evaluate(model, valid_iter, threshold=None)

    return th



def load_model(lm, checkpoint='model.pt'):
    """Load a model for a specific task.

    Args:
        path (str): the path of the checkpoint directory
        lm (str): the language model
        use_gpu (boolean): whether to use gpu
        fp16 (boolean, optional): whether to use fp16

    Returns:
        the model
    """
    # load model
    if not os.path.exists(checkpoint):
        raise ValueError(f"Model not found at: {checkpoint}")

    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    model = DittoModel(device=device, lm=lm)

    saved_state = torch.load(
        checkpoint, weights_only=True,
        map_location=lambda storage, loc: storage)
    model.load_state_dict(saved_state['model'])
    model = model.to(device)

    return model

# Inference

In [14]:
args = {
    "task": "Structured/Beer",
    "input_path": "test.txt",
    "output_path": "output_small.jsonl",
    "run_id": 123,
    "batch_size": 64,
    "max_len": 256,
    "lm": "roberta-base",
    "dk": None,
    "summarize": False
}

# set seeds
seed = args['run_id']
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)

trainset = 'train.txt'
validset = 'valid.txt'
testset = 'test.txt'
config = {
    'trainset': trainset,
    'validset': validset,
    'testset': testset,
}

model = load_model(args["lm"])

summarizer = dk_injector = None
if args["summarize"]:
    summarizer = Summarizer(config, args["lm"])

if args["dk"] is not None:
    dk_injector = DKInjector(config, args["dk"])

# tune threshold
threshold = tune_threshold(config, model, args)

# run prediction
print(f"Threshold: {threshold}")
predict(
    args["input_path"], args["output_path"], config, model,
    summarizer=summarizer, max_len=args["max_len"],
    lm=args["lm"], dk_injector=dk_injector,
    threshold=threshold
)

predicts = []
with jsonlines.open(args["output_path"], mode="r") as reader:
    for line in reader:
        predicts.append(int(line['match']))

labels = []
with open(args["input_path"]) as fin:
    for line in fin:
        labels.append(int(line.split('\t')[-1]))

f1 = sklearn.metrics.f1_score(labels, predicts)
print("Test f1 =", f1)

Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Threshold: 0.30000000000000004


91it [00:00, 59516.87it/s]


Test f1 = 0.7368421052631579
