In [1]:
import torch

project_name = "receiptlayoutlm"

In [2]:
from datasets import load_from_disk, load_dataset

ds_receipts = load_dataset("sibrun/receipts", use_auth_token=True)
ds_receipts['train'].features

Using custom data configuration sibrun--receipts-1fbadef9a86aa00b
Reusing dataset parquet (/Users/simon/.cache/huggingface/datasets/parquet/sibrun--receipts-1fbadef9a86aa00b/0.0.0/0b6d5799bb726b24ad7fc7be720c170d8e497f575d02d47537de9a5bac074901)


  0%|          | 0/2 [00:00<?, ?it/s]

{'image': Array3D(shape=(3, 224, 224), dtype='uint8', id=None),
 'input_ids': Sequence(feature=Value(dtype='int32', id=None), length=512, id=None),
 'bbox': Array2D(shape=(512, 4), dtype='int64', id=None),
 'labels': Sequence(feature=Value(dtype='int64', id=None), length=512, id=None),
 'attention_mask': Sequence(feature=Value(dtype='int8', id=None), length=512, id=None)}

In [3]:
label_names = ['company', 'date', 'address', 'total']
labels = ['O'] + label_names
num_labels = len(labels)
ids_to_labels = {k: v for k, v in enumerate(labels)}
labels_to_ids = {v: k for k, v in enumerate(labels)}

In [4]:
from transformers import AutoConfig

xlm_config = AutoConfig.from_pretrained("microsoft/layoutxlm-base",
                                         num_labels=num_labels,
                                         id2label=ids_to_labels,
                                         label2id=labels_to_ids)

Downloading:   0%|          | 0.00/1.00k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/904 [00:00<?, ?B/s]

In [None]:
from datasets import load_dataset, load_from_disk
from transformers import LayoutLMv2FeatureExtractor, LayoutXLMTokenizer, LayoutXLMProcessor
import torch
from transformers import LayoutLMv2ForTokenClassification

model_xlm = LayoutLMv2ForTokenClassification.from_pretrained("microsoft/layoutlmv2-base-uncased")

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ds_receipts.set_format(type="torch", device=device)
model_xlm.to(device)

In [1]:
from argparse import Namespace

# Commented parameters correspond to the small model
config = {"train_batch_size": 4,
          "valid_batch_size": 2,
          "weight_decay": 0.1,
          "learning_rate": 5e-5,
          "num_train_epochs": 2,
          "seed": 1,
          "save_checkpoint_steps": 100}

args = Namespace(**config)

In [None]:
from torch.utils.tensorboard import SummaryWriter
import logging
import wandb
import transformers
import datasets

def setup_logging(project_name):
    logger = logging.getLogger(__name__)
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO, handlers=[
        logging.FileHandler(f"log/debug.log"),
        logging.StreamHandler()])
    wandb.init(project=project_name, config=args)
    run_name = wandb.run.name
    tb_writer = SummaryWriter()
    tb_writer.add_hparams(vars(args), {'0': 0})
    logger.setLevel(logging.INFO)
    datasets.utils.logging.set_verbosity_debug()
    transformers.utils.logging.set_verbosity_info()
    return logger, tb_writer, run_name

In [None]:
def log_metrics(step, metrics, logger, tb_writer):
    logger.info(f"Step {step}: {metrics}")
    wandb.log(metrics)
    [tb_writer.add_scalar(k, v, step) for k, v in metrics.items()]


In [59]:
from torch.utils.data import DataLoader

train_dataloader = DataLoader(ds_receipts['train'], batch_size=args.train_batch_size, shuffle=True)
eval_dataloader = DataLoader(ds_receipts['test'], batch_size=args.valid_batch_size)

In [60]:
batch = next(iter(train_dataloader))
for k,v in batch.items():
  print(k, v.shape)

image torch.Size([4, 3, 224, 224])
input_ids torch.Size([4, 512])
bbox torch.Size([4, 512, 4])
labels torch.Size([4, 512])
attention_mask torch.Size([4, 512])


In [None]:
def evaluate():
    model_xlm.eval()
    losses = []
    for step, batch in enumerate(eval_dataloader):
        with torch.no_grad():
            model_outputs = model_xlm(**batch)
        losses.append(model_outputs.losse)
    total_loss = torch.mean(torch.cat(losses))
    try:
        perplexity = torch.exp(total_loss)
    except OverflowError:
            perplexity = torch.tensor(float("inf"))
    return loss.item(), perplexity.item()

In [None]:
from random import seed

seed(args.seed)

logger, tb_writer, run_name = setup_logging(project_name)
optimizer = torch.optim.AdamW(model_xlm.parameters(), lr=args.learning_rate)

global_step = 0
train_steps_total = len(train_dataloader) * args.num_train_epochs # total number of training steps

In [None]:
step = 0
model_xlm.train()
for epoch in range(args.num_train_epochs):
   print("Epoch:", epoch)
   for batch in train_dataloader:
        optimizer.zero_grad()
        outputs = model_xlm(**batch)
        loss = outputs.loss
        log_metrics(step, {'steps': step, 'loss/train': loss.item()})
        if global_step % 10 == 0:
          print(f"Loss after {step} steps: {loss.item()}")

        if step % args.save_checkpoint_steps == 0:
            logger.info('Evaluating and saving model checkpoint')
            eval_loss, perplexity = evaluate()
            log_metrics(step, {'loss/eval': eval_loss, 'perplexity': perplexity})
            model_xlm.save_pretrained("../models/receiptlayoutlm")
            model_xlm.push_to_hub("sibrun/receiptlayoutlm", commit_message=f'step {step}')
        model_xlm.train()
        loss.backward()
        optimizer.step()
        step += 1

logger.info('Evaluating and saving model after training')
eval_loss, perplexity = evaluate()
log_metrics(step, {'loss/eval': eval_loss, 'perplexity': perplexity})
model_xlm.save_pretrained("../models/receiptlayoutlm")
model_xlm.push_to_hub("sibrun/receiptlayoutlm", commit_message='final model')