In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import transformers
from transformers import AutoModelForTokenClassification, AutoTokenizer, AutoModel
from datasets import load_dataset
import torch
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
import logging
from torch.utils.data import DataLoader
from torcheval.metrics.functional import multiclass_f1_score, multiclass_confusion_matrix
from copy import deepcopy
from IPython.display import clear_output
from huggingface_hub import notebook_login

import os 
while 'notebooks' in os.getcwd():
    os.chdir("..")
    
from src.preprocessing.make_dataset import ImageLayoutDataset
from src.model.trainer import BertTrainer, LayoutLMTrainer
import warnings
warnings.filterwarnings("ignore")

In [3]:
notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [4]:
!mkdir logs

mkdir: cannot create directory ‘logs’: File exists


In [5]:
device = "cuda" if torch.cuda.is_available() else 'cpu'
model_name = "few_shot_learning"
dataset_name = "sroie"

In [6]:
logging.basicConfig(filename='logs/few_shot_learning_cord.log', encoding='utf-8', level= logging.INFO)

## Importing source and support dataset

In [7]:
source_dataset = load_dataset(
    "darentang/sroie",
    cache_dir = "/Data/pedro.silva/"
)

support_dataset = load_dataset(
    "nielsr/funsd",
    cache_dir = "/Data/pedro.silva/"
)

In [8]:
model = AutoModel.from_pretrained(
    "microsoft/layoutlm-base-uncased",
    cache_dir = "/Data/pedro.silva/"
).to(device)

tokenizer = AutoTokenizer.from_pretrained(
    "microsoft/layoutlm-base-uncased",
    cache_dir = "/Data/pedro.silva/"
)

Some weights of LayoutLMModel were not initialized from the model checkpoint at microsoft/layoutlm-base-uncased and are newly initialized: ['layoutlm.embeddings.word_embeddings.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [9]:
source_dataset['train'].features['ner_tags'].feature.names

['O',
 'B-COMPANY',
 'I-COMPANY',
 'B-DATE',
 'I-DATE',
 'B-ADDRESS',
 'I-ADDRESS',
 'B-TOTAL',
 'I-TOTAL']

In [10]:
valid_labels ={
    0: 0,
    1: 1,
    2: 1,
    3: 2,
    4: 2,
    5: 3,
    6: 3,
    7: 4,
    8: 4
}

source_df = ImageLayoutDataset(
    source_dataset['train'],
    tokenizer,
    valid_labels_keymap= valid_labels
)

  0%|          | 0/626 [00:00<?, ?it/s]

100%|██████████| 626/626 [00:03<00:00, 188.64it/s]


In [11]:
input = source_df[0]

In [12]:
out = model(
    input_ids=input['input_ids'].reshape(1,-1), 
    bbox= input['bbox'].reshape([1, 512, 4]),
    attention_mask=input['attention_mask'].reshape(1,-1), 
    token_type_ids=input['token_type_ids'].reshape(1,-1),
)

In [13]:
out.last_hidden_state.shape

torch.Size([1, 512, 768])

In [14]:
class F_mean(torch.nn.Module):
    def __init__(self, 
                 input_dim:int, 
                 output_dim: int, 
                 device : str = 'cuda',
                 *args, 
                 **kwargs
    ) -> None:
        super().__init__(*args, **kwargs)

        self.input_dim = input_dim
        self.output_dim = output_dim

        self.layer1 = torch.nn.Linear(input_dim, 512).to(device)
        self.layer2 = torch.nn.Linear(512, output_dim).to(device)
        self.elu = torch.nn.ELU().to(device)
        self.relu = torch.nn.ReLU().to(device)

    def forward(self, x: torch.Tensor):
        x = self.layer1(x)
        x = self.relu(x)
        x = self.layer2(x)
        return self.elu(x).squeeze()

class F_cov(torch.nn.Module):
    def __init__(self, 
                 input_dim:int, 
                 output_dim: int, 
                 device : str = 'cuda',
                 epsilon : float = 1e-14,
                 *args, 
                 **kwargs
    ) -> None:
        super().__init__(*args, **kwargs)

        self.input_dim = input_dim
        self.output_dim = output_dim
        self.epsilon = epsilon

        self.layer1 = torch.nn.Linear(input_dim, 512).to(device)
        self.layer2 = torch.nn.Linear(512, output_dim).to(device)
        self.elu = torch.nn.ELU().to(device)
        self.relu = torch.nn.ReLU().to(device)

    def forward(self, x: torch.Tensor):
        x = self.layer1(x)
        x = self.relu(x)
        x = self.layer2(x)
        x = self.elu(x)
        x += 1 + self.epsilon

        # return torch.diag(x)
        return torch.diag(x.squeeze())

In [15]:
input_dim = 768
output_dim = 64

f_mean = F_mean(
    input_dim,
    output_dim
)

f_cov = F_cov(
    input_dim,
    output_dim
)

In [16]:
out.last_hidden_state[:,0, :].shape

torch.Size([1, 768])

In [17]:
params = list(f_mean.parameters()) + list(f_cov.parameters()) + list(model.parameters())
optimizer = torch.optim.Adam(
    params,
    lr = 1e-4
)

## Train on source set

In [18]:
def KL_div(
    mu_1 : torch.Tensor,
    mu_2 : torch.Tensor,
    cov_1: torch.Tensor,
    cov_2: torch.Tensor
):
    l = cov_1.shape[0]
    inv_1 = cov_1.inverse()
    # inv_2 = cov_2.inverse()
    kl = 1/2 * (mu_1 - mu_2).T @ inv_1 @ (mu_1 - mu_2)
    kl+= 1/2 * torch.trace(inv_1 @ cov_2)
    kl+= -l/2
    kl+= torch.log(
        torch.det(cov_1)/torch.det(cov_2)
    )

    return kl

In [19]:
source_df[0]['labels'][:,3]

tensor([-100], device='cuda:0')

In [20]:
total_samples = len(source_df)
minibatch_size = 5
total_batches = total_samples // minibatch_size

# Initialize a list to store minibatches
minibatches = []

# Loop over the total number of batches
for _ in range(total_batches):
    # Generate random indices for the minibatch
    random_indices = np.random.choice(total_samples, minibatch_size, replace=False)
    # Append the minibatch to the list
    minibatches.append(random_indices)

In [21]:
len(minibatches)

125

In [22]:
checkpoint_f_mean = "/Data/pedro.silva/f_mean_sroie.pt"
checkpoint_f_cov = "/Data/pedro.silva/f_cov_sroie.pt"
checkpoint_llm = "peulsilva/container-source-sroie-checkpoint"

In [23]:
for batch_idx in range(len(minibatches[0:10])):
    loss = 0
    total_points = 0
    for document_idx in tqdm(minibatches[batch_idx]):
        means = {}
        covs = {}
        X_p = 0
        s_X = 0
        n_p = 0

        size = source_df[document_idx]['labels'].shape[1]
        input = source_df[document_idx]

        out = model(
            input_ids=input['input_ids'].reshape(1,-1), 
            bbox= input['bbox'].reshape([1, 512, 4]),
            attention_mask=input['attention_mask'].reshape(1,-1), 
            token_type_ids=input['token_type_ids'].reshape(1,-1),
        )

        for i in range(size):
            if source_df[document_idx]['labels'][:,i] == -100:
                continue
            mu_i = f_mean(out.last_hidden_state[:,i,:])
            cov_i = f_cov(out.last_hidden_state[:,i, :])

            means[i] = mu_i
            covs[i] = cov_i

        del out

        for i in range(size):
            if source_df[document_idx]['labels'][:,i] == -100:
                continue


            total_points+= 1
            for j in range(i+1, size):
                if source_df[document_idx]['labels'][:,j] == -100:
                    continue

                # mu_i = f_mean(out.last_hidden_state[:,i,:])
                # cov_i = f_cov(out.last_hidden_state[:,i, :])
                # mu_j = f_mean(out.last_hidden_state[:,j,:])
                # cov_j = f_cov(out.last_hidden_state[:,j, :])
                mu_i = means[i]
                cov_i = covs[i]
                mu_j = means[j]
                cov_j = covs[j]
                d_ij = 1/2 * (KL_div(mu_i, mu_j, cov_i, cov_j) + KL_div(mu_j, mu_i, cov_j, cov_i))

                if source_df[document_idx]['labels'][:,i] == source_df[document_idx]['labels'][:,j]:
                    X_p += torch.exp(-d_ij)
                    n_p += 1

                s_X += torch.exp(-d_ij)

        del means
        del covs

        loss -= torch.log(X_p / n_p / s_X)
        

    loss = loss/total_points
    clear_output()
    print(f"batch {batch_idx}")
    print(f"loss: {loss.item()}")

    loss.backward()

    optimizer.step()

    optimizer.zero_grad()

model.push_to_hub(checkpoint_llm)
torch.save(f_mean.state_dict(), checkpoint_f_mean)
torch.save(f_cov.state_dict(), checkpoint_f_cov)


batch 9
loss: 0.001353245577774942


model.safetensors:   0%|          | 0.00/451M [00:00<?, ?B/s]

In [24]:
loss

tensor(0.0014, device='cuda:0', grad_fn=<DivBackward0>)

## Few shot learning on support domain

In [25]:
support_dataset['train'].features['ner_tags'].feature.names

['O',
 'B-HEADER',
 'I-HEADER',
 'B-QUESTION',
 'I-QUESTION',
 'B-ANSWER',
 'I-ANSWER']

In [26]:
n_shots = 2
valid_labels_funsd = {
    0: 0,
    1:1,
    2:1,
    3:2,
    4:2,
    5:3,
    6:3
}
support_df = ImageLayoutDataset(
    support_dataset['train'],
    tokenizer,
    valid_labels_keymap=valid_labels_funsd
)

  0%|          | 0/149 [00:00<?, ?it/s]

100%|██████████| 149/149 [00:01<00:00, 112.24it/s]


In [27]:
input_dim = 768
output_dim = 64

f_mean = F_mean(
    input_dim,
    output_dim
)

f_cov = F_cov(
    input_dim,
    output_dim
)

f_mean.load_state_dict(torch.load(checkpoint_f_mean))
f_cov.load_state_dict(torch.load(checkpoint_f_cov))

model = AutoModel.from_pretrained(
    checkpoint_llm,
    cache_dir = '/Data/pedro.silva/'
).to(device)

model.safetensors:   0%|          | 0.00/451M [00:00<?, ?B/s]

In [28]:
params = list(f_mean.parameters()) + list(f_cov.parameters()) + list(model.parameters())
optimizer_support = torch.optim.Adam(
    params,
    lr = 1e-4
)

In [29]:
loss_prev = torch.inf
loss_ft = 1e100
loss_ft < loss_prev

True

In [30]:
loss_prev = torch.inf
loss_ft = 1e100
epoch = 0

while loss_ft < loss_prev:
    epoch+=1

    loss_prev = loss_ft
    loss_ft = 0
    total_points = 0
    for document_idx in tqdm(range(len(support_df[0:n_shots]))):
        means = {}
        covs = {}
        X_p = 0
        s_X = 0
        n_p = 0

        size = support_df[document_idx]['labels'].shape[1]
        input = support_df[document_idx]

        out = model(
            input_ids=input['input_ids'].reshape(1,-1), 
            bbox= input['bbox'].reshape([1, 512, 4]),
            attention_mask=input['attention_mask'].reshape(1,-1), 
            token_type_ids=input['token_type_ids'].reshape(1,-1),
        )

        for i in range(size):
            if support_df[document_idx]['labels'][:,i] == -100:
                continue
            mu_i = f_mean(out.last_hidden_state[:,i,:])
            cov_i = f_cov(out.last_hidden_state[:,i, :])

            means[i] = mu_i
            covs[i] = cov_i

        del out

        for i in range(size):
            if support_df[document_idx]['labels'][:,i] == -100:
                continue

            total_points+= 1
            for j in range(i+1, size):
                if support_df[document_idx]['labels'][:,j] == -100:
                    continue

                # mu_i = f_mean(out.last_hidden_state[:,i,:])
                # cov_i = f_cov(out.last_hidden_state[:,i, :])
                # mu_j = f_mean(out.last_hidden_state[:,j,:])
                # cov_j = f_cov(out.last_hidden_state[:,j, :])
                mu_i = means[i]
                cov_i = covs[i]
                mu_j = means[j]
                cov_j = covs[j]
                d_ij = 1/2 * (KL_div(mu_i, mu_j, cov_i, cov_j) + KL_div(mu_j, mu_i, cov_j, cov_i))

                if support_df[document_idx]['labels'][:,i] == support_df[document_idx]['labels'][:,j]:
                    X_p += torch.exp(-d_ij)
                    n_p += 1

                s_X += torch.exp(-d_ij)

        del means
        del covs

        loss_ft -= torch.log(X_p / n_p / s_X)
        

    loss_ft = loss_ft/total_points
    clear_output()
    print(f"epoch {epoch}")
    print(f"loss: {loss_ft.item()}")

    loss_ft.backward()

    optimizer_support.step()

    optimizer_support.zero_grad()

model.push_to_hub(f"peulsilva/container-source-sroie-funsd")
torch.save(f_mean.state_dict(), "/Data/pedro.silva/f_mean_sroie-2.pt")
torch.save(f_cov.state_dict(), "/Data/pedro.silva/f_cov_sroie-2.pt")


epoch 2
loss: 0.0005063646822236478


 50%|█████     | 1/2 [00:17<00:17, 17.20s/it]

## Evaluation

In [None]:
with torch.no_grad():
    for document_idx in tqdm(range(len(support_df[0:n_shots]))):
        X_p = 0
        s_X = 0
        n_p = 0

        size = support_df[document_idx]['labels'].shape[1]

        out = model(
            input_ids=input['input_ids'].reshape(1,-1), 
            bbox= input['bbox'].reshape([1, 512, 4]),
            attention_mask=input['attention_mask'].reshape(1,-1), 
            token_type_ids=input['token_type_ids'].reshape(1,-1),
        )

        for i in range(size):
            if support_df[document_idx]['labels'][:,i] == -100:
                continue

            h_i = out.last_hidden_state[:,i,:]


for document_idx in tqdm((range(len(support_df[n_shots:])))):
    