In [1]:
import torch
from transformers import BertTokenizer, BertModel

In [18]:
import joblib

mlb_file = "multi_label_binarizer.pkl"
mlb = joblib.load(mlb_file)

In [21]:
from transformers import BertForSequenceClassification, BertTokenizer
import torch


model_path = 'bert_classifier_last/'
model = BertForSequenceClassification.from_pretrained(model_path)
tokenizer = BertTokenizer.from_pretrained(model_path)
model.eval()

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(105879, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1

In [35]:
import pandas as pd
import os

lang = 'PT'

def read_semeval_data(directory_path=''):
    file_info = []
    for filename in os.listdir(directory_path):
        file_path = os.path.join(directory_path, filename)
        if os.path.isfile(file_path):
            with open(file_path, "r", encoding="utf-8") as file:
                file_info.append((filename, file.read()))

    df = pd.DataFrame(file_info, columns=['filename', 'text'])
    return df

df = read_semeval_data(directory_path=f'test_set/{lang}/subtask-2-documents/')
df

Unnamed: 0,filename,text
0,PT_CC_TEST_436.txt,É “Anti-Greta” e defende que as alterações cli...
1,PT_CC_TEST_437.txt,Alterações Climáticas. Há quem diga que a culp...
2,PT_CC_TEST_438.txt,Google não está a cumprir promessa de bloquear...
3,PT_CC_TEST_441.txt,"Uma PAC que não nos serve\n\nNa sexta-feira, 2..."
4,PT_CC_TEST_442.txt,Em defesa do activismo ambiental e climático\n...
...,...,...
95,PT_URW_TEST_507.txt,“Bomba suja”: teoria da conspiração ou provoca...
96,PT_URW_TEST_508.txt,Lavrov diz que ou a Ucrânia aceita as proposta...
97,PT_URW_TEST_509.txt,Deputado russo quer nacionalizar fábricas de e...
98,PT_URW_TEST_510.txt,"""Negociado ou à força"", Putin mantém objetivo ..."


In [36]:
from datasets import Dataset

def tokenize_documents(documents, max_length=512, tokenizer=None):
    """
    Tokenizes documents into single sequences of max length `max_length`.
    Adds [CLS] and [SEP] tokens automatically and pads/truncates to `max_length`.

    Args:
       documents (list): List of text documents to tokenize.
       max_length (int): Maximum length for tokenized sequences.
       tokenizer (BertTokenizer): Pre-trained tokenizer instance.

    Returns:
       dict: Dictionary with `input_ids`, `attention_mask`
    """
    input_ids_list = []
    attention_mask_list = []

    for doc in documents:
        tokens = tokenizer(
            doc.lower(),
            max_length=max_length,
            truncation=True,
            padding="max_length",
            return_tensors="pt",
        )

        input_ids_list.append(tokens["input_ids"].squeeze(0).tolist())
        attention_mask_list.append(tokens["attention_mask"].squeeze(0).tolist())

    return {
        "input_ids": input_ids_list,
        "attention_mask": attention_mask_list
    }

test_texts = df['text'].tolist()
test_data = tokenize_documents(test_texts, max_length=512, tokenizer=tokenizer)
test_dataset = Dataset.from_dict(test_data)
test_dataset

Dataset({
    features: ['input_ids', 'attention_mask'],
    num_rows: 100
})

In [37]:
from torch.utils.data import DataLoader
from sklearn.metrics import classification_report, accuracy_score, precision_recall_fscore_support, hamming_loss

def collate_fn(batch):
    input_ids = torch.tensor([item['input_ids'] for item in batch])
    attention_mask = torch.tensor([item['attention_mask'] for item in batch])
    return {"input_ids": input_ids, "attention_mask": attention_mask}

test_loader = DataLoader(test_dataset, batch_size=16, collate_fn=collate_fn)

all_logits = []

for batch in test_loader:
    inputs = {
        "input_ids": batch["input_ids"],
        "attention_mask": batch["attention_mask"]
    }

    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits
        all_logits.append(logits)

all_logits = torch.cat(all_logits, dim=0)

In [38]:
predictions = (torch.sigmoid(all_logits) > 0.3).int().numpy()
result_df_save = df[['filename']]

result_df_save['narrative'] = [
    ';'.join(list(set(':'.join(j.split(':')[:2]) for j in i))) if i else 'Other' 
    for i in mlb.inverse_transform(predictions)
]
result_df_save['sub_narrative'] = [
    ';'.join(i if i else ['Other']) 
    for i in mlb.inverse_transform(predictions)
]

In [39]:
result_df_save

Unnamed: 0,filename,narrative,sub_narrative
0,PT_CC_TEST_436.txt,CC: Criticism of institutions and authorities,CC: Criticism of institutions and authorities:...
1,PT_CC_TEST_437.txt,Other,Other
2,PT_CC_TEST_438.txt,Other,Other
3,PT_CC_TEST_441.txt,CC: Criticism of institutions and authorities,CC: Criticism of institutions and authorities:...
4,PT_CC_TEST_442.txt,CC: Criticism of climate movement;CC: Criticis...,CC: Criticism of climate movement: Climate mov...
...,...,...,...
95,PT_URW_TEST_507.txt,URW: Russia is the Victim;URW: Discrediting Uk...,URW: Discrediting Ukraine: Other;URW: Discredi...
96,PT_URW_TEST_508.txt,URW: Praise of Russia;URW: Russia is the Victi...,URW: Blaming the war on others rather than the...
97,PT_URW_TEST_509.txt,Other,Other
98,PT_URW_TEST_510.txt,URW: Praise of Russia;URW: Russia is the Victi...,URW: Discrediting Ukraine: Discrediting Ukrain...


In [40]:
result_df_save[['filename', 'narrative', 'sub_narrative']]\
    .to_csv(f'test_submissions/test_bert_{lang}.txt', index=False, sep='\t', header=None)