# Multi-label text classification using BERT

In [None]:
!nvidia-smi

In [None]:
# !pip install transformers

## Imports

In [21]:
import os
from typing import List
import json
import glob
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import shutil
import sys
import logging 

logging.basicConfig(
     level=logging.INFO, 
     format= '[%(asctime)s|%(levelname)s|%(module)s.py:%(lineno)s] %(message)s',
     datefmt='%H:%M:%S'
 )
import tqdm.notebook as tq
from tqdm import tqdm
# Create new `pandas` methods which use `tqdm` progress
# (can use tqdm_gui, optional kwargs, etc.)
tqdm.pandas()
from collections import defaultdict

from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.metrics import confusion_matrix, classification_report
import torch
import torch.nn as nn
from transformers import BertTokenizer, BertModel, AdamW

from defi_textmine_2025.data import load_test_raw_data
from defi_textmine_2025.data import TARGET_COL, INTERIM_DIR, MODELS_DIR, submission_path

BASE_CHECKPOINT = "bert-base-uncased" #"bert-base-multilingual-uncased"
TASK_NAME = "multilabel_tagged_text"

entity_classes = {'TERRORIST_OR_CRIMINAL', 'LASTNAME', 'LENGTH', 'NATURAL_CAUSES_DEATH', 'COLOR', 'STRIKE', 'DRUG_OPERATION', 'HEIGHT', 'INTERGOVERNMENTAL_ORGANISATION', 'TRAFFICKING', 'NON_MILITARY_GOVERNMENT_ORGANISATION', 'TIME_MIN', 'DEMONSTRATION', 'TIME_EXACT', 'FIRE', 'QUANTITY_MIN', 'MATERIEL', 'GATHERING', 'PLACE', 'CRIMINAL_ARREST', 'CBRN_EVENT', 'ECONOMICAL_CRISIS', 'ACCIDENT', 'LONGITUDE', 'BOMBING', 'MATERIAL_REFERENCE', 'WIDTH', 'FIRSTNAME', 'MILITARY_ORGANISATION', 'CIVILIAN', 'QUANTITY_MAX', 'CATEGORY', 'POLITICAL_VIOLENCE', 'EPIDEMIC', 'TIME_MAX', 'TIME_FUZZY', 'NATURAL_EVENT', 'SUICIDE', 'CIVIL_WAR_OUTBREAK', 'POLLUTION', 'ILLEGAL_CIVIL_DEMONSTRATION', 'NATIONALITY', 'GROUP_OF_INDIVIDUALS', 'QUANTITY_FUZZY', 'RIOT', 'WEIGHT', 'THEFT', 'MILITARY', 'NON_GOVERNMENTAL_ORGANISATION', 'LATITUDE', 'COUP_D_ETAT', 'ELECTION', 'HOOLIGANISM_TROUBLEMAKING', 'QUANTITY_EXACT', 'AGITATING_TROUBLE_MAKING'}
categories_to_check = ['END_DATE', 'GENDER_MALE', 'WEIGHS', 'DIED_IN', 'HAS_FAMILY_RELATIONSHIP', 'IS_DEAD_ON', 'IS_IN_CONTACT_WITH', 'HAS_CATEGORY', 'HAS_CONTROL_OVER', 'IS_BORN_IN', 'IS_OF_SIZE', 'HAS_LATITUDE', 'IS_PART_OF', 'IS_OF_NATIONALITY', 'IS_COOPERATING_WITH', 'DEATHS_NUMBER', 'HAS_FOR_HEIGHT', 'INITIATED', 'WAS_DISSOLVED_IN', 'HAS_COLOR', 'CREATED', 'IS_LOCATED_IN', 'WAS_CREATED_IN', 'IS_AT_ODDS_WITH', 'HAS_CONSEQUENCE', 'HAS_FOR_LENGTH', 'INJURED_NUMBER', 'START_DATE', 'STARTED_IN', 'GENDER_FEMALE', 'HAS_LONGITUDE', 'RESIDES_IN', 'HAS_FOR_WIDTH', 'IS_BORN_ON', 'HAS_QUANTITY', 'OPERATES_IN', 'IS_REGISTERED_AS']

mlb = MultiLabelBinarizer()
mlb.fit([categories_to_check])
logging.info(f"{mlb.classes_=}")

generated_data_dir_path = os.path.join(INTERIM_DIR, "multilabel_tagged_text_dataset")
assert os.path.exists(generated_data_dir_path)

preprocessed_data_dir = os.path.join(INTERIM_DIR, "one_hot_multilabel_tagged_text_dataset")
train_preprocessed_data_dir_path = os.path.join(preprocessed_data_dir,"train")
! mkdir -p {train_preprocessed_data_dir_path}
val_preprocessed_data_dir_path = os.path.join(preprocessed_data_dir,"val")
! mkdir -p {val_preprocessed_data_dir_path}

model_dir_path = os.path.join(MODELS_DIR, f"finetuned-{BASE_CHECKPOINT}")
! mkdir -p {model_dir_path}
model_dict_state_path = os.path.join(model_dir_path,"MLTC_model_state.bin")

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
device

[07:48:03|INFO|3617865939.py:41] mlb.classes_=array(['CREATED', 'DEATHS_NUMBER', 'DIED_IN', 'END_DATE', 'GENDER_FEMALE',
       'GENDER_MALE', 'HAS_CATEGORY', 'HAS_COLOR', 'HAS_CONSEQUENCE',
       'HAS_CONTROL_OVER', 'HAS_FAMILY_RELATIONSHIP', 'HAS_FOR_HEIGHT',
       'HAS_FOR_LENGTH', 'HAS_FOR_WIDTH', 'HAS_LATITUDE', 'HAS_LONGITUDE',
       'HAS_QUANTITY', 'INITIATED', 'INJURED_NUMBER', 'IS_AT_ODDS_WITH',
       'IS_BORN_IN', 'IS_BORN_ON', 'IS_COOPERATING_WITH', 'IS_DEAD_ON',
       'IS_IN_CONTACT_WITH', 'IS_LOCATED_IN', 'IS_OF_NATIONALITY',
       'IS_OF_SIZE', 'IS_PART_OF', 'IS_REGISTERED_AS', 'OPERATES_IN',
       'RESIDES_IN', 'STARTED_IN', 'START_DATE', 'WAS_CREATED_IN',
       'WAS_DISSOLVED_IN', 'WEIGHS'], dtype=object)


device(type='cuda')

In [6]:
def load_csv(dir_or_file_path: str, index_col=None, sep=',') -> pd.DataFrame:
    if os.path.isdir(dir_or_file_path):
        all_files = glob.glob(os.path.join(dir_or_file_path , "*.csv"))  
    else:
        assert dir_or_file_path.endswith(".csv")
        all_files = [dir_or_file_path]
    assert len(all_files) > 0
    return pd.concat([pd.read_csv(filename, index_col=index_col, header=0, sep=sep) for filename in all_files], axis=0, ignore_index=True)

def process_data(data: pd.DataFrame) -> pd.DataFrame:
    return pd.concat([data, pd.DataFrame(mlb.transform(data[TARGET_COL]), columns=mlb.classes_, index=data.index)], axis=1) # .drop([TARGET_COL], axis=1)


def format_relations_str_to_list(labels_as_str: str) -> List[str]:
    return json.loads(
        labels_as_str.replace("{", "[").replace("}", "]").replace("'", '"')
    )


def process_csv_to_csv(in_dir_or_file_path: str, out_dir_path: str) -> None:    
    if os.path.isdir(in_dir_or_file_path):
        all_files = glob.glob(os.path.join(in_dir_or_file_path , "*.csv"))  
    else:
        assert in_dir_or_file_path.endswith(".csv")
        all_files = [in_dir_or_file_path]
    for filename in (pb:=tqdm(all_files)):
        pb.set_description(filename)
        preprocessed_data_filename = os.path.join(out_dir_path, os.path.basename(filename))
        process_data(load_csv(filename).assign(**{TARGET_COL: lambda df: df[TARGET_COL].apply(format_relations_str_to_list)})).to_csv(preprocessed_data_filename, sep="\t")

## Preprocess and save data

- load generated data
- convert to dataframe
- convert categories into one-hot labels
- save into a tsv file

In [None]:
# process_csv_to_csv("data/defi-text-mine-2025/interim/multilabel_tagged_text_dataset/val/15.csv", val_preprocessed_data_dir_path)
process_csv_to_csv(os.path.join(generated_data_dir_path, "val"), val_preprocessed_data_dir_path)
process_csv_to_csv(os.path.join(generated_data_dir_path, "train"), train_preprocessed_data_dir_path)

## Load preprocessed data

In [9]:
df_valid = load_csv(val_preprocessed_data_dir_path, index_col=0, sep='\t').sample(1500)
df_train = load_csv(train_preprocessed_data_dir_path, index_col=0, sep='\t').sample(7000)
logging.info(f"Train: {df_train.shape}, Valid: {df_valid.shape}")
df_valid.head()

[07:45:42|INFO|2820765577.py:3] Train: (7000, 42), Valid: (1500, 42)


Unnamed: 0,text_index,e1,e2,text,relations,CREATED,DEATHS_NUMBER,DIED_IN,END_DATE,GENDER_FEMALE,...,IS_OF_SIZE,IS_PART_OF,IS_REGISTERED_AS,OPERATES_IN,RESIDES_IN,STARTED_IN,START_DATE,WAS_CREATED_IN,WAS_DISSOLVED_IN,WEIGHS
391,41656,9,6,"Le 20 septembre 2017 en Colombie, la brigade d...",['IS_PART_OF'],0,0,0,0,0,...,0,1,0,0,0,0,0,0,0,0
1759,11823,1,2,"Le 12 janvier 2022, une vingtaine d’<e1><GROUP...","['IS_LOCATED_IN', 'HAS_CONTROL_OVER']",0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
3458,3608,1,4,"En Australie, des manifestants vêtus de tee-sh...",['HAS_CONSEQUENCE'],0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
4682,3692,16,5,"Le 6 décembre 2020, une attaque armée a eu lie...",['IS_IN_CONTACT_WITH'],0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
5290,164,17,11,Trente personnes ont été carbonisées dans un v...,['OPERATES_IN'],0,0,0,0,0,...,0,0,0,1,0,0,0,0,0,0


## Create the tokenized datasets for model input

In [12]:
# Hyperparameters
MAX_LEN = 256  # TODO: increase
tokenizer = BertTokenizer.from_pretrained(BASE_CHECKPOINT)
# task_special_tokens = ["<e1>", "</e1>", "<e2>", "</e2>"] + [
#     f"<{entity_class}>" for entity_class in entity_classes
# ]
# # add special tokens to the tokenizer
# num_added_tokens = tokenizer.add_tokens(task_special_tokens, special_tokens=True)
# num_added_tokens

In [None]:
# Test the tokenizer
test_text = "La <e2><NON_MILITARY_GOVERNMENT_ORGANISATION>police</e2> tchèque a <e2><NON_MILITARY_GOVERNMENT_ORGANISATION>mis la main</e2> sur le couple responsable d'un trafic d'œuvres d'art. Il s'agit de <e1><TERRORIST_OR_CRIMINAL>Patel</e1> et Mirna Maroski. Une <e2><NON_MILITARY_GOVERNMENT_ORGANISATION>perquisition</e2> à leur domicile a permis de retrouver une centaine de tableaux d'artistes européens. Il y avait également des pots en céramique et en porcelaine d'origine chinoise, ainsi que plusieurs faux documents de voyage. Les époux Maroski ont été conduits au poste de <e2><NON_MILITARY_GOVERNMENT_ORGANISATION>police</e2> dans un véhicule blindé. Mirna Maroski s'est évanouie une fois arrivée au poste. Elle a été amenée en ambulance au CHU de Motol où elle a été soignée. Monsieur Sergueï Alekseï, le directeur de l'hôpital, a demandé à ses collaborateurs d'être vigilants et de ne pas se laisser corrompre par la criminelle."
# generate encodings
encodings = tokenizer.encode_plus(test_text, 
                                  add_special_tokens = True,
                                  max_length = MAX_LEN,
                                  truncation = True,
                                  padding = "max_length", 
                                  return_attention_mask = True, 
                                  return_tensors = "pt")
# we get a dictionary with three keys (see: https://huggingface.co/transformers/glossary.html) 
encodings

In [None]:
tokenizer.batch_decode(encodings['input_ids'])

In [13]:
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, df, tokenizer, max_len, target_list):
        self.tokenizer = tokenizer
        self.df = df
        # self.e1 = list(df['e1'])
        # self.e1 = list(df['e1'])
        # self.text_indexes = list(df['text_index'])
        self.title = list(df['text'])
        self.targets = self.df[target_list].values
        self.max_len = max_len

    def __len__(self):
        return len(self.title)

    def __getitem__(self, index):
        text = str(self.title[index])
        text = " ".join(text.split())
        inputs = self.tokenizer.encode_plus(
            text,
            None,
            add_special_tokens=True,
            max_length=self.max_len,
            padding='max_length',
            return_token_type_ids=True,
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt'
        )
        return {
            'input_ids': inputs['input_ids'].flatten(),
            'attention_mask': inputs['attention_mask'].flatten(),
            'token_type_ids': inputs["token_type_ids"].flatten(),
            'targets': torch.FloatTensor(self.targets[index]),
            'title': text,
            # 'text_index': self.text_index[index],
            # 'e1': self.e1[index],
            # 'e2': self.e2[index],
        }

In [None]:
df_valid[mlb.classes_].sum().sort_values(ascending=False)

In [None]:
df_train[mlb.classes_].sum().sort_values(ascending=False)

In [15]:
most_common_categories = df_train[mlb.classes_].sum().sort_values(ascending=False).index[:7]
logging.info(most_common_categories)
# target_list = mlb.classes_.tolist()
target_list = most_common_categories
logging.info(target_list)

[07:46:46|INFO|3645693013.py:2] Index(['IS_LOCATED_IN', 'HAS_CONTROL_OVER', 'IS_IN_CONTACT_WITH',
       'OPERATES_IN', 'STARTED_IN', 'IS_AT_ODDS_WITH', 'IS_PART_OF'],
      dtype='object')
[07:46:46|INFO|3645693013.py:4] Index(['IS_LOCATED_IN', 'HAS_CONTROL_OVER', 'IS_IN_CONTACT_WITH',
       'OPERATES_IN', 'STARTED_IN', 'IS_AT_ODDS_WITH', 'IS_PART_OF'],
      dtype='object')


In [None]:
df_train[most_common_categories].isnull().all(axis=1)

In [16]:
train_dataset = CustomDataset(pd.concat([df_train[df_train.columns.difference(mlb.classes_)], df_train[most_common_categories]], axis=1), tokenizer, MAX_LEN, target_list)
valid_dataset = CustomDataset(pd.concat([df_valid[df_train.columns.difference(mlb.classes_)], df_valid[most_common_categories]], axis=1), tokenizer, MAX_LEN, target_list)

In [17]:
# testing the dataset
next(iter(train_dataset))

{'input_ids': tensor([  101,  4372,  4066,  4630,  1040,  1521, 16655, 10301,  2139,  3008,
          1040,  1521,  3449, 23047,  1037,  1048,  1521, 12431,  2139,  7842,
          1026,  1041,  2475,  1028,  1026,  6831,  1028,  6039,  2063,  1026,
          1013,  1041,  2475,  1028,  8740, 11779,  1010, 21380,  1026,  1041,
          2487,  1028,  1026,  6831,  1028,  8254, 13775, 13477,  1026,  1013,
          1041,  2487,  1028,  1037,  3802,  2063,  6778,  2063,  1040,  1521,
         16655, 12943,  8303,  3258,  6887,  7274,  7413,  3393,  2321, 14736,
          2268,  1012,  4895, 26574,  3417,  4630, 16655,  6187,  3995,  9307,
          1037,  2173,  4895,  2522, 10421,  4887,  7505,  2365,  2522,  2226,
          3802,  1026,  1041,  2487,  1028,  1026,  6831,  1028, 11320,  2072,
          1026,  1013,  1041,  2487,  1028,  1037, 26927,  2015,  2365, 17266,
          1012,  6335,  1055,  1521,  9765,  4372, 11263,  2072, 16655,  1042,
         10054, 19804,  2229, 20704, 21

## Create data loaders

In [23]:
TRAIN_BATCH_SIZE = 32
VALID_BATCH_SIZE = 32

# Data loaders
train_data_loader = torch.utils.data.DataLoader(train_dataset, 
    batch_size=TRAIN_BATCH_SIZE,
    shuffle=True,
    num_workers=0
)

val_data_loader = torch.utils.data.DataLoader(valid_dataset, 
    batch_size=VALID_BATCH_SIZE,
    shuffle=False,
    num_workers=0
)

## Prepare the model to trained

In [19]:
class BERTClass(torch.nn.Module):
    def __init__(self):
        super(BERTClass, self).__init__()
        self.bert_model = BertModel.from_pretrained(BASE_CHECKPOINT, return_dict=True)
        self.dropout = torch.nn.Dropout(0.3)
        self.linear = torch.nn.Linear(768, len(target_list))

    def forward(self, input_ids, attn_mask, token_type_ids):
        output = self.bert_model(
            input_ids, 
            attention_mask=attn_mask, 
            token_type_ids=token_type_ids
        )
        output_dropout = self.dropout(output.pooler_output)
        output = self.linear(output_dropout)
        return output

model = BERTClass()

# # Freezing BERT layers: (tested, weaker convergence)
# for param in model.bert_model.parameters():
#     param.requires_grad = False

model.to(device)



BERTClass(
  (bert_model): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_a

In [27]:
# BCEWithLogitsLoss combines a Sigmoid layer and the BCELoss in one single class. 
# This version is more numerically stable than using a plain Sigmoid followed 
# by a BCELoss as, by combining the operations into one layer, 
# we take advantage of the log-sum-exp trick for numerical stability.
def loss_fn(outputs, targets):
    return torch.nn.BCEWithLogitsLoss()(outputs, targets)

In [28]:
# define the optimizer
optimizer = AdamW(model.parameters(), lr = 1e-5)         



## Function to tain the model

In [None]:
# Training of the model for one epoch
def train_model(training_loader, model, optimizer):

    losses = []
    correct_predictions = 0
    num_samples = 0
    # set model to training mode (activate dropout, batch norm)
    model.train()
    # initialize the progress bar
    loop = tq.tqdm(enumerate(training_loader), total=len(training_loader), 
                      leave=True, colour='steelblue')
    for batch_idx, data in loop:
        ids = data['input_ids'].to(device, dtype = torch.long)
        mask = data['attention_mask'].to(device, dtype = torch.long)
        token_type_ids = data['token_type_ids'].to(device, dtype = torch.long)
        targets = data['targets'].to(device, dtype = torch.float)

        # forward
        outputs = model(ids, mask, token_type_ids) # (batch,predict)=(32,37)
        loss = loss_fn(outputs, targets)
        losses.append(loss.item())
        # training accuracy, apply sigmoid, round (apply thresh 0.5)
        outputs = torch.sigmoid(outputs).cpu().detach().numpy().round()
        targets = targets.cpu().detach().numpy()
        correct_predictions += np.sum(outputs==targets)
        num_samples += targets.size   # total number of elements in the 2D array

        # backward
        optimizer.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        # grad descent step
        optimizer.step()

        # Update progress bar
        #loop.set_description(f"")
        #loop.set_postfix(batch_loss=loss)

    # returning: trained model, model accuracy, mean loss
    return model, float(correct_predictions)/num_samples, np.mean(losses)

## Function to evaluate the model

In [25]:
def eval_model(validation_loader, model, optimizer):
    losses = []
    correct_predictions = 0
    num_samples = 0
    # set model to eval mode (turn off dropout, fix batch norm)
    model.eval()

    with torch.no_grad():
        for batch_idx, data in enumerate(validation_loader, 0):
            ids = data['input_ids'].to(device, dtype = torch.long)
            mask = data['attention_mask'].to(device, dtype = torch.long)
            token_type_ids = data['token_type_ids'].to(device, dtype = torch.long)
            targets = data['targets'].to(device, dtype = torch.float)
            outputs = model(ids, mask, token_type_ids)

            loss = loss_fn(outputs, targets)
            losses.append(loss.item())

            # validation accuracy
            # add sigmoid, for the training sigmoid is in BCEWithLogitsLoss
            outputs = torch.sigmoid(outputs).cpu().detach().numpy().round()
            targets = targets.cpu().detach().numpy()
            correct_predictions += np.sum(outputs==targets)
            num_samples += targets.size   # total number of elements in the 2D array

    return float(correct_predictions)/num_samples, np.mean(losses)

## Model Training

In [None]:
EPOCHS = 10
LEARNING_RATE = 1e-05
THRESHOLD = 0.5 # threshold for the sigmoid

history = defaultdict(list)
best_accuracy = 0
assert not os.path.exists(model_dict_state_path), "The trained model is already serialized at {model_dict_state_path}"

for epoch in range(1, EPOCHS+1):
    print(f'Epoch {epoch}/{EPOCHS}')
    model, train_acc, train_loss = train_model(train_data_loader, model, optimizer)
    val_acc, val_loss = eval_model(val_data_loader, model, optimizer)

    print(f'train_loss={train_loss:.4f}, val_loss={val_loss:.4f} train_acc={train_acc:.4f}, val_acc={val_acc:.4f}')

    history['train_acc'].append(train_acc)
    history['train_loss'].append(train_loss)
    history['val_acc'].append(val_acc)
    history['val_loss'].append(val_loss)
    # save the best model
    if val_acc > best_accuracy:
        torch.save(model.state_dict(), model_dict_state_path)
        best_accuracy = val_acc

In [None]:
plt.rcParams["figure.figsize"] = (10,7)
plt.plot(history['train_acc'], label='train accuracy')
plt.plot(history['val_acc'], label='validation accuracy')
plt.plot(history['train_loss'], label='train loss')
plt.plot(history['val_loss'], label='validation loss')
plt.title('Training history')
plt.ylabel('Accuracy / loss')
plt.xlabel('Epoch')
plt.legend()
plt.ylim([0, 1])
plt.grid()

## Evaluation of the model

In [22]:
# Loading pretrained model (best model)
model = BERTClass()
model.load_state_dict(torch.load(model_dict_state_path))
model = model.to(device)



In [30]:
# Evaluate the model using the test data
val_acc, val_loss = eval_model(val_data_loader, model, optimizer)

In [31]:
# The accuracy looks OK, similar to the validation accuracy
# The model generalizes well !
val_acc

0.9515238095238095

## Prepare submission

In [32]:
df_test = load_csv(os.path.join(generated_data_dir_path, "test")) #.drop(TARGET_COL, axis=1)
df_test

Unnamed: 0,text_index,e1,e2,text,relations
0,51344,1,0,Un <e2><FIRE>incendie</e2> a eu lieu hier mati...,
1,51344,0,1,Un <e1><FIRE>incendie</e1> a eu lieu hier mati...,
2,51344,2,0,Un <e2><FIRE>incendie</e2> a eu lieu hier mati...,
3,51344,0,2,Un <e1><FIRE>incendie</e1> a eu lieu hier mati...,
4,51344,2,1,Un incendie a eu lieu hier matin au <e2><PLACE...,
...,...,...,...,...,...
174575,4998,19,22,Un braquage de banque a eu lieu à New York hie...,
174576,4998,22,20,Un braquage de banque a eu lieu à New York hie...,
174577,4998,20,22,Un braquage de banque a eu lieu à New York hie...,
174578,4998,22,21,Un braquage de banque a eu lieu à New York hie...,


In [None]:
# df_test.head().drop(TARGET_COL, axis=1).assign(**{cat: [0]*df_test.head().shape[0] for cat in target_list})

In [33]:
test_dataset = CustomDataset(df_test.drop(TARGET_COL, axis=1).assign(**{cat: [0]*df_test.shape[0] for cat in target_list}), tokenizer, MAX_LEN, target_list)

In [34]:
TEST_BATCH_SIZE = 512

test_data_loader = torch.utils.data.DataLoader(test_dataset, 
    batch_size=TEST_BATCH_SIZE,
    shuffle=False,
    num_workers=0
)

In [35]:
def get_predictions(model, data_loader):
    """
    Outputs:
      predictions - 
    """
    model = model.eval()
    
    titles = []
    predictions = []
    prediction_probs = []
    target_values = []

    with torch.no_grad():
      for data in tqdm(data_loader):
        title = data["title"]
        ids = data["input_ids"].to(device, dtype = torch.long)
        mask = data["attention_mask"].to(device, dtype = torch.long)
        token_type_ids = data['token_type_ids'].to(device, dtype = torch.long)
        targets = data["targets"].to(device, dtype = torch.float)
        
        outputs = model(ids, mask, token_type_ids)
        # add sigmoid, for the training sigmoid is in BCEWithLogitsLoss
        outputs = torch.sigmoid(outputs).detach().cpu()
        # thresholding at 0.5
        preds = outputs.round()
        targets = targets.detach().cpu()

        titles.extend(title)
        predictions.extend(preds)
        prediction_probs.extend(outputs)
        target_values.extend(targets)
    
    predictions = torch.stack(predictions)
    prediction_probs = torch.stack(prediction_probs)
    target_values = torch.stack(target_values)
    
    return titles, predictions, prediction_probs, target_values


In [36]:
titles, predictions, prediction_probs, target_values = get_predictions(model, test_data_loader)

100%|██████████| 341/341 [18:51<00:00,  3.32s/it]


In [42]:
df_train.columns

Index(['text_index', 'e1', 'e2', 'text', 'relations', 'CREATED',
       'DEATHS_NUMBER', 'DIED_IN', 'END_DATE', 'GENDER_FEMALE', 'GENDER_MALE',
       'HAS_CATEGORY', 'HAS_COLOR', 'HAS_CONSEQUENCE', 'HAS_CONTROL_OVER',
       'HAS_FAMILY_RELATIONSHIP', 'HAS_FOR_HEIGHT', 'HAS_FOR_LENGTH',
       'HAS_FOR_WIDTH', 'HAS_LATITUDE', 'HAS_LONGITUDE', 'HAS_QUANTITY',
       'INITIATED', 'INJURED_NUMBER', 'IS_AT_ODDS_WITH', 'IS_BORN_IN',
       'IS_BORN_ON', 'IS_COOPERATING_WITH', 'IS_DEAD_ON', 'IS_IN_CONTACT_WITH',
       'IS_LOCATED_IN', 'IS_OF_NATIONALITY', 'IS_OF_SIZE', 'IS_PART_OF',
       'IS_REGISTERED_AS', 'OPERATES_IN', 'RESIDES_IN', 'STARTED_IN',
       'START_DATE', 'WAS_CREATED_IN', 'WAS_DISSOLVED_IN', 'WEIGHS'],
      dtype='object')

In [43]:
mlb.classes_

array(['CREATED', 'DEATHS_NUMBER', 'DIED_IN', 'END_DATE', 'GENDER_FEMALE',
       'GENDER_MALE', 'HAS_CATEGORY', 'HAS_COLOR', 'HAS_CONSEQUENCE',
       'HAS_CONTROL_OVER', 'HAS_FAMILY_RELATIONSHIP', 'HAS_FOR_HEIGHT',
       'HAS_FOR_LENGTH', 'HAS_FOR_WIDTH', 'HAS_LATITUDE', 'HAS_LONGITUDE',
       'HAS_QUANTITY', 'INITIATED', 'INJURED_NUMBER', 'IS_AT_ODDS_WITH',
       'IS_BORN_IN', 'IS_BORN_ON', 'IS_COOPERATING_WITH', 'IS_DEAD_ON',
       'IS_IN_CONTACT_WITH', 'IS_LOCATED_IN', 'IS_OF_NATIONALITY',
       'IS_OF_SIZE', 'IS_PART_OF', 'IS_REGISTERED_AS', 'OPERATES_IN',
       'RESIDES_IN', 'STARTED_IN', 'START_DATE', 'WAS_CREATED_IN',
       'WAS_DISSOLVED_IN', 'WEIGHS'], dtype=object)

In [52]:
pd.DataFrame(0, index=df_test.index, columns=list(set(categories_to_check).difference(most_common_categories)))

KeyError: "['HAS_CONTROL_OVER', 'IS_AT_ODDS_WITH', 'IS_IN_CONTACT_WITH', 'IS_LOCATED_IN', 'IS_PART_OF', 'OPERATES_IN', 'STARTED_IN'] not in index"

In [55]:
pd.concat(
        [
            # df_test, 
            pd.DataFrame(predictions.numpy(), columns=most_common_categories, index=df_test.index),
            pd.DataFrame(0, index=df_test.index, columns=list(set(categories_to_check).difference(most_common_categories)))
        ],
        axis=1
)#[mlb.classes_]

Unnamed: 0,IS_LOCATED_IN,HAS_CONTROL_OVER,IS_IN_CONTACT_WITH,OPERATES_IN,STARTED_IN,IS_AT_ODDS_WITH,IS_PART_OF,GENDER_MALE,HAS_FOR_HEIGHT,IS_REGISTERED_AS,...,IS_BORN_ON,WAS_CREATED_IN,HAS_FOR_LENGTH,IS_BORN_IN,HAS_LONGITUDE,END_DATE,WEIGHS,HAS_CATEGORY,IS_OF_SIZE,GENDER_FEMALE
0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
1,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
2,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
3,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
4,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
174575,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
174576,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
174577,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
174578,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0,0,0,...,0,0,0,0,0,0,0,0,0,0


In [72]:

ml_labeled_test_df = pd.concat(
    [
        df_test.drop(TARGET_COL, axis=1),
        pd.Series(
            mlb.inverse_transform(
                pd.concat(
                    [
                        # df_test, 
                        pd.DataFrame(predictions.numpy(), columns=most_common_categories, index=df_test.index),
                        pd.DataFrame(0, index=df_test.index, columns=list(set(categories_to_check).difference(most_common_categories)))
                    ],
                    axis=1
                )[mlb.classes_].values
            ),
            name=TARGET_COL,
            index=df_test.index
        )
    ],
    axis=1
)
ml_labeled_test_df

Unnamed: 0,text_index,e1,e2,text,relations
0,51344,1,0,Un <e2><FIRE>incendie</e2> a eu lieu hier mati...,"(IS_LOCATED_IN, STARTED_IN)"
1,51344,0,1,Un <e1><FIRE>incendie</e1> a eu lieu hier mati...,"(IS_LOCATED_IN, STARTED_IN)"
2,51344,2,0,Un <e2><FIRE>incendie</e2> a eu lieu hier mati...,()
3,51344,0,2,Un <e1><FIRE>incendie</e1> a eu lieu hier mati...,()
4,51344,2,1,Un incendie a eu lieu hier matin au <e2><PLACE...,()
...,...,...,...,...,...
174575,4998,19,22,Un braquage de banque a eu lieu à New York hie...,()
174576,4998,22,20,Un braquage de banque a eu lieu à New York hie...,()
174577,4998,20,22,Un braquage de banque a eu lieu à New York hie...,()
174578,4998,22,21,Un braquage de banque a eu lieu à New York hie...,()


In [112]:
text_idx_to_relations = {
    text_index: [l[0] for l in group_df.drop(["text_index", "text"], axis=1)[group_df.relations.str.len()>0].apply(lambda row: [[row.iloc[0], r, row.iloc[1]] for r in row.iloc[-1]] if len(row.iloc[-1]) > 0 else [], axis=1).values.tolist()]
 for text_index, group_df in tqdm(ml_labeled_test_df.groupby("text_index"))
}

100%|██████████| 400/400 [00:01<00:00, 329.80it/s]


In [113]:
text_idx_to_relations[13]

[[2, 'IS_IN_CONTACT_WITH', 1],
 [1, 'IS_IN_CONTACT_WITH', 2],
 [3, 'IS_IN_CONTACT_WITH', 1],
 [1, 'IS_IN_CONTACT_WITH', 3],
 [3, 'IS_IN_CONTACT_WITH', 2],
 [4, 'IS_IN_CONTACT_WITH', 1],
 [1, 'IS_IN_CONTACT_WITH', 4],
 [4, 'IS_IN_CONTACT_WITH', 2],
 [2, 'IS_IN_CONTACT_WITH', 4],
 [4, 'IS_AT_ODDS_WITH', 3],
 [3, 'IS_AT_ODDS_WITH', 4],
 [5, 'IS_IN_CONTACT_WITH', 1],
 [1, 'IS_IN_CONTACT_WITH', 5],
 [5, 'IS_IN_CONTACT_WITH', 2],
 [5, 'IS_AT_ODDS_WITH', 3],
 [3, 'IS_AT_ODDS_WITH', 5],
 [4, 'IS_PART_OF', 5],
 [6, 'IS_PART_OF', 1],
 [1, 'IS_PART_OF', 6],
 [6, 'IS_IN_CONTACT_WITH', 2],
 [2, 'IS_PART_OF', 6],
 [6, 'IS_PART_OF', 3],
 [3, 'IS_PART_OF', 6],
 [6, 'IS_IN_CONTACT_WITH', 4],
 [4, 'IS_PART_OF', 6],
 [7, 'IS_LOCATED_IN', 0],
 [0, 'IS_LOCATED_IN', 7],
 [7, 'HAS_CONTROL_OVER', 1],
 [1, 'HAS_CONTROL_OVER', 7],
 [7, 'IS_LOCATED_IN', 2],
 [2, 'IS_LOCATED_IN', 7],
 [7, 'HAS_CONTROL_OVER', 3],
 [3, 'IS_LOCATED_IN', 7],
 [7, 'IS_LOCATED_IN', 4],
 [4, 'IS_LOCATED_IN', 7],
 [7, 'OPERATES_IN', 5],


In [116]:
test_index = load_test_raw_data().index
test_index

Index([ 1204,  4909,  2353,  1210, 41948, 41092, 41094, 51395,   194, 41515,
       ...
       41765,  4961,  4969, 51257, 51446, 51452, 51491, 51492, 51495, 51742],
      dtype='int64', name='id', length=400)

In [125]:
submission_df = pd.DataFrame({"id": list(text_idx_to_relations.keys()), TARGET_COL: list(text_idx_to_relations.values())}).set_index("id").loc[load_test_raw_data().index]
submission_df = submission_df.assign(relations= submission_df.relations.map(lambda x: str(x).replace("'", '"')))
submission_df

Unnamed: 0_level_0,relations
id,Unnamed: 1_level_1
1204,"[[2, ""IS_IN_CONTACT_WITH"", 1], [1, ""IS_IN_CONT..."
4909,"[[1, ""IS_PART_OF"", 0], [0, ""IS_PART_OF"", 1], [..."
2353,"[[2, ""IS_IN_CONTACT_WITH"", 1], [3, ""IS_IN_CONT..."
1210,"[[1, ""IS_LOCATED_IN"", 0], [0, ""IS_LOCATED_IN"",..."
41948,"[[3, ""IS_IN_CONTACT_WITH"", 2], [0, ""HAS_CONTRO..."
...,...
51452,"[[3, ""IS_LOCATED_IN"", 0], [0, ""IS_LOCATED_IN"",..."
51491,"[[2, ""IS_PART_OF"", 1], [1, ""IS_PART_OF"", 2], [..."
51492,"[[2, ""HAS_CONTROL_OVER"", 1], [1, ""HAS_CONTROL_..."
51495,"[[3, ""IS_IN_CONTACT_WITH"", 2], [2, ""IS_IN_CONT..."


In [126]:
submission_df.to_csv(submission_path)

In [38]:
print(f"titles:{len(titles)} \npredictions:{predictions.shape} \nprediction_probs:{prediction_probs.shape} \ntarget_values:{target_values.shape}")

titles:174580 
predictions:torch.Size([174580, 7]) 
prediction_probs:torch.Size([174580, 7]) 
target_values:torch.Size([174580, 7])


In [None]:
# Generate Classification Metrics
#
# note that the total support is greater than the number of samples
# some samples have multiple lables

print(classification_report(target_values, predictions, target_names=target_list))

In [None]:
# import seaborn as sns
# def show_confusion_matrix(confusion_matrix):
#     hmap = sns.heatmap(confusion_matrix, annot=True, fmt="d", cmap="Blues")
#     hmap.yaxis.set_ticklabels(hmap.yaxis.get_ticklabels(), rotation=0, ha='right')
#     hmap.xaxis.set_ticklabels(hmap.xaxis.get_ticklabels(), rotation=30, ha='right')
#     plt.ylabel('True category')
#     plt.xlabel('Predicted category');

In [None]:
# cm = confusion_matrix(target_values, predictions)
# df_cm = pd.DataFrame(cm, index=target_list, columns=target_list)
# show_confusion_matrix(df_cm)