<a target="_blank" href="https://colab.research.google.com/github/shayongithub/vietnamese-mtl-model-for-sa-nli-tasks/blob/main/notebooks/MTL%20Model%20Inference.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


### **Install packages**

In [2]:
!pip install accelerate==0.21.0 transformers==4.31.0 datasets==2.14.0 evaluate==0.4.0 loguru seqeval torchinfo xformers

Collecting accelerate==0.21.0
  Downloading accelerate-0.21.0-py3-none-any.whl (244 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m244.2/244.2 kB[0m [31m4.4 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting transformers==4.31.0
  Downloading transformers-4.31.0-py3-none-any.whl (7.4 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.4/7.4 MB[0m [31m18.6 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting datasets==2.14.0
  Downloading datasets-2.14.0-py3-none-any.whl (492 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m492.2/492.2 kB[0m [31m21.1 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting evaluate==0.4.0
  Downloading evaluate-0.4.0-py3-none-any.whl (81 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m81.4/81.4 kB[0m [31m10.2 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting loguru
  Downloading loguru-0.7.0-py3-none-any.whl (59 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m60.0/60.0 kB[0

### **Add support classes**

In [3]:
import torch.nn as nn
import torch
from typing import List
from transformers import AutoModel, AutoModelForSequenceClassification, RobertaForSequenceClassification
from transformers import PreTrainedModel
from transformers import PretrainedConfig, RobertaConfig


### Load packages
import torch
import torch.nn as nn
import transformers
from datasets import load_dataset, load_from_disk
from datasets import Dataset, DatasetDict
from datasets import ClassLabel, Value
import pandas as pd
import re
import numpy as np
import matplotlib.pyplot as plt
from loguru import logger
import logging
import json
logging.basicConfig(level=logging.NOTSET)
from dataclasses import dataclass, field


@dataclass
class Task:
    task_id: int
    name: str
    task_type: str
    num_labels: int

task_sa = Task(
    task_id=0,
    name='uit-nlp/vietnamese_students_feedback',
    task_type="seq_classification",
    num_labels=2,
)

task_nli = Task(
    task_id=1,
    name='vinli',
    task_type="seq_classification",
    num_labels=3,
)

tasks = [task_sa, task_nli]
sa_task_id = torch.tensor([0], dtype=torch.int32)
zsl_task_id = torch.tensor([1], dtype=torch.int32)


class SequenceClassificationHead(nn.Module):
    def __init__(self, hidden_size, num_labels, dropout_p=0.1):
        super().__init__()
        # super().__init__(config=PretrainedConfig())
        self.num_labels = num_labels
        self.dropout = nn.Dropout(dropout_p)
        self.classifier = nn.Linear(hidden_size, num_labels)

        self._init_weights()

    def forward(self, sequence_output, pooled_output, labels=None, **kwargs):
        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)

        loss = None
        if labels is not None:
            if labels.dim() != 1:
                # Remove padding
                labels = labels[:, 0]

            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(
                logits.view(-1, self.num_labels), labels.long().view(-1)
            )

        return logits, loss

    def _init_weights(self):
        self.classifier.weight.data.normal_(mean=0.0, std=0.02)
        if self.classifier.bias is not None:
            self.classifier.bias.data.zero_()


class MultiTaskModel(nn.Module):
    def __init__(self, encoder_name_or_path, tasks: List):
        super().__init__()
        # self.tasks = tasks

        self.encoder = AutoModel.from_pretrained(encoder_name_or_path, return_dict=False)

        self.output_heads = nn.ModuleDict()
        for task in tasks:
            decoder = self._create_output_head(self.encoder.config.hidden_size, task)
            # ModuleDict requires keys to be strings
            self.output_heads[str(task.task_id)] = decoder

    @staticmethod
    def _create_output_head(encoder_hidden_size: int, task):
        if task.task_type == "seq_classification":
            return SequenceClassificationHead(encoder_hidden_size, task.num_labels)
        else:
            raise NotImplementedError()

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        labels=None,
        task_ids=None,
        **kwargs,
    ):
        # print("input_ids: ", input_ids)
        # print("attention_mask: ", attention_mask)
        # print("token_type_ids: ", token_type_ids)
        # print("labels: ", labels)
        # print("task_ids: ", task_ids)

        outputs = self.encoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
        )

        sequence_output, pooled_output = outputs[:2]

        unique_task_ids_list = torch.unique(task_ids).tolist()

        loss_list = []
        logits = None
        for unique_task_id in unique_task_ids_list:

            task_id_filter = task_ids == unique_task_id

            logits, task_loss = self.output_heads[str(unique_task_id)].forward(
                sequence_output[task_id_filter],
                pooled_output[task_id_filter],
                labels=None if labels is None else labels[task_id_filter],
                attention_mask=attention_mask[task_id_filter],
            )

            if labels is not None:
                loss_list.append(task_loss)

        # logits are only used for eval. and in case of eval the batch is not multi task
        # For training only the loss is used
        outputs = (logits, outputs[2:])

        if loss_list:
            loss = torch.stack(loss_list)
            outputs = (loss.mean(),) + outputs

        return outputs

### **Load trained MTL model**

In [4]:
from transformers import AutoTokenizer
from torchinfo import summary

# Load trained model by state_dict
load_mtl_model = MultiTaskModel('vinai/phobert-base-v2', tasks)
tokenizer = AutoTokenizer.from_pretrained('vinai/phobert-base-v2')
load_mtl_model.load_state_dict(torch.load('/content/drive/MyDrive/Shay/models/multitask_model/phobert-v2-mtl-sequence-classification-model-no-config-10-epochs-512-merged-ds/pytorch_model.bin', map_location=torch.device('cpu')))
load_mtl_model.eval()

Downloading (…)lve/main/config.json:   0%|          | 0.00/678 [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/540M [00:00<?, ?B/s]

Some weights of RobertaModel were not initialized from the model checkpoint at vinai/phobert-base-v2 and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/895k [00:00<?, ?B/s]

Downloading (…)solve/main/bpe.codes:   0%|          | 0.00/1.14M [00:00<?, ?B/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


MultiTaskModel(
  (encoder): RobertaModel(
    (embeddings): RobertaEmbeddings(
      (word_embeddings): Embedding(64001, 768, padding_idx=1)
      (position_embeddings): Embedding(258, 768, padding_idx=1)
      (token_type_embeddings): Embedding(1, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): RobertaEncoder(
      (layer): ModuleList(
        (0-11): 12 x RobertaLayer(
          (attention): RobertaAttention(
            (self): RobertaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): RobertaSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): Laye

### **Evaluation on test set**

#### **Load testset**

In [5]:
uit_sa_ds = load_from_disk("/content/drive/MyDrive/Shay/MTL_Datasets/merged_uit_sa_ds")
vinli_ds = load_from_disk("/content/drive/MyDrive/Shay/MTL_Datasets/merged_vi_nli_ds")

uit_sa_ds_test = uit_sa_ds['validation']
vinli_ds_test = vinli_ds['validation']

In [6]:
uit_sa_ds_test

Dataset({
    features: ['sentence1', 'labels'],
    num_rows: 2999
})

In [7]:
vinli_ds_test[100:200]

{'labels': [2,
  1,
  0,
  0,
  0,
  0,
  1,
  1,
  0,
  1,
  0,
  2,
  1,
  1,
  2,
  1,
  0,
  0,
  1,
  0,
  2,
  1,
  1,
  2,
  2,
  0,
  2,
  0,
  2,
  1,
  0,
  1,
  2,
  0,
  0,
  0,
  1,
  2,
  2,
  1,
  0,
  2,
  1,
  2,
  0,
  2,
  2,
  2,
  2,
  0,
  0,
  1,
  2,
  1,
  0,
  2,
  0,
  2,
  2,
  1,
  0,
  1,
  2,
  1,
  2,
  2,
  2,
  0,
  0,
  1,
  1,
  0,
  1,
  2,
  0,
  0,
  2,
  2,
  2,
  1,
  0,
  2,
  2,
  0,
  0,
  0,
  0,
  0,
  1,
  2,
  0,
  1,
  2,
  0,
  2,
  0,
  1,
  0,
  1,
  1],
 'sentence1': ['giám đốc điều hành apple lisa jackson cho biết khó khăn của việc sử dụng năng lượng sạch như điện gió và ánh sáng mặt trời là không ổn định',
  'mở cửa cách đây 6 năm quán cà phê an luôn đông khách vì không gian hoài niệm với những đồ vật cổ xưa',
  'å trong tiếng bắc âu cổ có nghĩa là dòng sông nhỏ và có ít nhất bảy ngôi làng ở na uy có tên gọi này',
  'honda cho biết đã chia sẻ mọi thông tin có được với nhtsa và sẽ tiếp tục hợp tác trong cuộc điều tra mới nhất',
  'n

In [8]:
uit_sa_ds_test[100:200]

{'sentence1': ['ngoài những kiến thức trên lớp thầy còn hướng dẫn và giảng dạy nhiều kiến thức mới hỗ trợ tụi em rất nhiều trong quá trình hoàn thành đồ án',
  'giảng viên dạy rất là nhiệt tình',
  'quá nhiệt tình',
  'thầy lý thuyết wzjwz223 cũng thường xuyên đi trễ',
  'nên cho sinh viên slide để học',
  'cách giảng của thầy đa số trong lớp không ai hiểu dù thấy có giảng lại',
  'kiến thức chuyên môn sâu',
  'thầy dạy rất hay và kỹ',
  'thầy dạy hay dễ hiểu cho làm nhiều bài tập mở rộng để hiểu bài',
  'những kiến thức thầy cung cấp rất đa dạng và cần thiết cảm ơn thầy rất nhiều',
  'không cần phải mở lớp thực hành với môn học này',
  'sự nhiệt tình tận tâm',
  'thêm nhiều bài tập thiết thực hơn',
  'rất tận tình',
  'phương pháp học áp dụng thực tiễn',
  'nhiệt tình hăng hái',
  'giảng viên nhiệt tình giảng dạy hiệu quả bằng tiếng anh',
  'cần có nhiều giảng viên như cô dạy hơn',
  'ít cho bài tập',
  'có nhiều thầy cô đến lớp không đúng giờ',
  'đôi lúc sinh viên theo không kịp',
 

#### **Post-process NLI predictions**

In [9]:
def postprocess_nli(model_outputs,
                    label2id = {'entailment': 0, 'neutral': 1, 'contradiction': 2},
                    multi_label=False):

    candidate_labels = [outputs["candidate_label"] for outputs in model_outputs]
    sequences = [outputs["sequence"] for outputs in model_outputs]

    logits = np.concatenate([output["logits"].numpy() for output in model_outputs])
    N = logits.shape[0]
    n = len(candidate_labels)

    num_sequences = N // n
    reshaped_outputs = logits.reshape((num_sequences, n, -1))

    entailment_id = label2id['entailment']
    contradiction_id = label2id['contradiction']

    if multi_label or len(candidate_labels) == 1:
        # softmax over the entailment vs. contradiction dim for each label independently
        entail_contr_logits = reshaped_outputs[..., [contradiction_id, entailment_id]]
        scores = np.exp(entail_contr_logits) / np.exp(entail_contr_logits).sum(-1, keepdims=True)
        scores = scores[..., 1]
    else:
        # softmax the "entailment" logits over all candidate labels
        entail_logits = reshaped_outputs[..., entailment_id]
        scores = np.exp(entail_logits) / np.exp(entail_logits).sum(-1, keepdims=True)

    top_inds = list(reversed(scores[0].argsort()))

    return {
        "sequence": sequences[0],
        "labels": [candidate_labels[i] for i in top_inds],
        "scores": scores[0, top_inds].tolist(),
    }

#### **Run eval**

In [10]:
from transformers import pipeline, AutoModelForSequenceClassification
from transformers import TrainingArguments, Trainer, DataCollatorWithPadding, set_seed
import numpy as np
from datasets import load_metric
import evaluate

In [11]:
import os

def get_subfolders_with_prefix(folder_path, prefix):
    subfolders = []
    for root, dirs, files in os.walk(folder_path):
        for dir in dirs:
            if dir.startswith(prefix):
                subfolders.append(os.path.join(root, dir))
    return subfolders

def sort_folders(folders):
    # Split each folder name by '-' and convert the second part to an integer
    # Use this integer for sorting
    sorted_folders = sorted(folders, key=lambda folder: int(folder.split('-')[1]))
    return sorted_folders

In [11]:
checkpoints = sort_folders(get_subfolders_with_prefix('/content/drive/MyDrive/Shay/models/remote_trained_mtl/zsl_videberta_xsmall_runs_16082023', 'checkpoint'))

checkpoints

['/content/drive/MyDrive/Shay/models/remote_trained_mtl/zsl_videberta_xsmall_runs_16082023/checkpoint-10272',
 '/content/drive/MyDrive/Shay/models/remote_trained_mtl/zsl_videberta_xsmall_runs_16082023/checkpoint-20544',
 '/content/drive/MyDrive/Shay/models/remote_trained_mtl/zsl_videberta_xsmall_runs_16082023/checkpoint-30816',
 '/content/drive/MyDrive/Shay/models/remote_trained_mtl/zsl_videberta_xsmall_runs_16082023/checkpoint-41088',
 '/content/drive/MyDrive/Shay/models/remote_trained_mtl/zsl_videberta_xsmall_runs_16082023/checkpoint-51360',
 '/content/drive/MyDrive/Shay/models/remote_trained_mtl/zsl_videberta_xsmall_runs_16082023/checkpoint-61632',
 '/content/drive/MyDrive/Shay/models/remote_trained_mtl/zsl_videberta_xsmall_runs_16082023/checkpoint-71904',
 '/content/drive/MyDrive/Shay/models/remote_trained_mtl/zsl_videberta_xsmall_runs_16082023/checkpoint-82176',
 '/content/drive/MyDrive/Shay/models/remote_trained_mtl/zsl_videberta_xsmall_runs_16082023/checkpoint-92448',
 '/content

In [12]:
subfolders = get_subfolders_with_prefix('/content/drive/MyDrive/Shay/models/remote_trained_mtl/zsl_phobertv2_runs_16082023', 'checkpoint')
len(subfolders)

50

##### **ZSL**

In [None]:
accuracy_metric_zsl = evaluate.load('accuracy')
f1_metric_zsl = evaluate.load('f1', average='macro')
precision_metric_zsl = evaluate.load('precision', average='macro')
recall_metric_zsl = evaluate.load('recall', average='macro')


def compute_metrics_zsl(eval_preds):
    logits, labels = eval_preds
    predictions = np.argmax(logits, axis=-1)

    accuracy = accuracy_metric_zsl.compute(predictions=predictions, references=labels)[
        'accuracy'
    ]
    precision = precision_metric_zsl.compute(
        predictions=predictions, references=labels, average='macro',
    )['precision']
    recall = recall_metric_zsl.compute(
        predictions=predictions, references=labels, average='macro',
    )['recall']
    f1 = f1_metric_zsl.compute(predictions=predictions, references=labels, average='macro')[
        'f1'
    ]

    return {'accuracy': accuracy, 'precision': precision, 'recall': recall, 'f1': f1}

Downloading builder script:   0%|          | 0.00/4.20k [00:00<?, ?B/s]

Downloading builder script:   0%|          | 0.00/6.77k [00:00<?, ?B/s]

Downloading builder script:   0%|          | 0.00/7.55k [00:00<?, ?B/s]

Downloading builder script:   0%|          | 0.00/7.36k [00:00<?, ?B/s]

In [None]:
# Tokenize the data
def pre_process_and_tokenize(batch):
    return tokenizer(
        batch['sentence1'], batch['sentence2'], truncation=True, padding=True,
    )

tokenized_vinli_train_dataset = vinli_ds['train'].map(
    pre_process_and_tokenize, batched=True)

tokenized_vinli_test_dataset = vinli_ds_test.map(
    pre_process_and_tokenize, batched=True,
)

###### **ZSL: ViDeberta xsmall**

In [None]:
videberta_xsmall_tokenizer = AutoTokenizer.from_pretrained('Fsoft-AIC/videberta-xsmall', max_length=512)
videberta_xsmall_data_collator = DataCollatorWithPadding(videberta_xsmall_tokenizer)

training_arg_zsl_videberta_xsmall = TrainingArguments("ZSL_Videberta_xsmall",
                                                      per_device_eval_batch_size=16,)

In [None]:
sorted_checkpoints = sort_folders(get_subfolders_with_prefix('/content/drive/MyDrive/Shay/models/remote_trained_mtl/zsl_videberta_xsmall_runs_16082023', 'checkpoint'))

for index, checkpoint in enumerate(sorted_checkpoints):
    zsl_videberta_xsmall_model = AutoModelForSequenceClassification.from_pretrained(checkpoint,
                                      num_labels=3)

    trainer_zsl_videberta_xsmall = Trainer(
        model=zsl_videberta_xsmall_model,
        args=training_arg_zsl_videberta_xsmall,
        train_dataset=tokenized_vinli_train_dataset,
        eval_dataset=tokenized_vinli_test_dataset,
        compute_metrics=compute_metrics_zsl,
        tokenizer=videberta_xsmall_tokenizer,
        data_collator=videberta_xsmall_data_collator,
    )

    metrics = trainer_zsl_videberta_xsmall.evaluate(eval_dataset=tokenized_vinli_test_dataset)

    metrics["eval_samples"] = len(tokenized_vinli_test_dataset)
    logger.info(f"[{index+1}] - ZSL task ViDeBERTA-xsmall at {os.path.basename(checkpoint)} with F1-score = {metrics['eval_f1']} and Accuracy = {metrics['eval_accuracy']}")

    # trainer_zsl_videberta_xsmall.log_metrics(f"eval for ZSL task with ViDeBERTA-xsmall at {os.path.basename(checkpoint)}", metrics)


You're using a DebertaV2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


  _warn_prf(average, modifier, msg_start, len(result))
[32m2023-08-20 15:07:50.868[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 3>[0m:[36m20[0m - [1m[1] - ZSL task ViDeBERTA-xsmall at checkpoint-10272 with F1-score = 0.16650278597181253 and Accuracy = 0.3328964613368283[0m


  _warn_prf(average, modifier, msg_start, len(result))
[32m2023-08-20 15:08:10.677[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 3>[0m:[36m20[0m - [1m[2] - ZSL task ViDeBERTA-xsmall at checkpoint-20544 with F1-score = 0.16650278597181253 and Accuracy = 0.3328964613368283[0m


  _warn_prf(average, modifier, msg_start, len(result))
[32m2023-08-20 15:08:26.941[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 3>[0m:[36m20[0m - [1m[3] - ZSL task ViDeBERTA-xsmall at checkpoint-30816 with F1-score = 0.16452780519907867 and Accuracy = 0.32765399737876805[0m


  _warn_prf(average, modifier, msg_start, len(result))
[32m2023-08-20 15:08:45.536[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 3>[0m:[36m20[0m - [1m[4] - ZSL task ViDeBERTA-xsmall at checkpoint-41088 with F1-score = 0.16650278597181253 and Accuracy = 0.3328964613368283[0m


  _warn_prf(average, modifier, msg_start, len(result))
[32m2023-08-20 15:09:03.653[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 3>[0m:[36m20[0m - [1m[5] - ZSL task ViDeBERTA-xsmall at checkpoint-51360 with F1-score = 0.16452780519907867 and Accuracy = 0.32765399737876805[0m


  _warn_prf(average, modifier, msg_start, len(result))
[32m2023-08-20 15:09:23.694[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 3>[0m:[36m20[0m - [1m[6] - ZSL task ViDeBERTA-xsmall at checkpoint-61632 with F1-score = 0.16650278597181253 and Accuracy = 0.3328964613368283[0m


  _warn_prf(average, modifier, msg_start, len(result))
[32m2023-08-20 15:09:42.525[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 3>[0m:[36m20[0m - [1m[7] - ZSL task ViDeBERTA-xsmall at checkpoint-71904 with F1-score = 0.16452780519907867 and Accuracy = 0.32765399737876805[0m


  _warn_prf(average, modifier, msg_start, len(result))
[32m2023-08-20 15:09:58.861[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 3>[0m:[36m20[0m - [1m[8] - ZSL task ViDeBERTA-xsmall at checkpoint-82176 with F1-score = 0.16894977168949774 and Accuracy = 0.3394495412844037[0m


  _warn_prf(average, modifier, msg_start, len(result))
[32m2023-08-20 15:10:17.928[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 3>[0m:[36m20[0m - [1m[9] - ZSL task ViDeBERTA-xsmall at checkpoint-92448 with F1-score = 0.16452780519907867 and Accuracy = 0.32765399737876805[0m


  _warn_prf(average, modifier, msg_start, len(result))
[32m2023-08-20 15:10:37.059[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 3>[0m:[36m20[0m - [1m[10] - ZSL task ViDeBERTA-xsmall at checkpoint-102720 with F1-score = 0.16894977168949774 and Accuracy = 0.3394495412844037[0m


  _warn_prf(average, modifier, msg_start, len(result))
[32m2023-08-20 15:10:55.708[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 3>[0m:[36m20[0m - [1m[11] - ZSL task ViDeBERTA-xsmall at checkpoint-112992 with F1-score = 0.16650278597181253 and Accuracy = 0.3328964613368283[0m


  _warn_prf(average, modifier, msg_start, len(result))
[32m2023-08-20 15:11:11.062[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 3>[0m:[36m20[0m - [1m[12] - ZSL task ViDeBERTA-xsmall at checkpoint-123264 with F1-score = 0.16452780519907867 and Accuracy = 0.32765399737876805[0m


  _warn_prf(average, modifier, msg_start, len(result))
[32m2023-08-20 15:11:27.014[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 3>[0m:[36m20[0m - [1m[13] - ZSL task ViDeBERTA-xsmall at checkpoint-133536 with F1-score = 0.16650278597181253 and Accuracy = 0.3328964613368283[0m


  _warn_prf(average, modifier, msg_start, len(result))
[32m2023-08-20 15:11:43.623[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 3>[0m:[36m20[0m - [1m[14] - ZSL task ViDeBERTA-xsmall at checkpoint-143808 with F1-score = 0.16894977168949774 and Accuracy = 0.3394495412844037[0m


  _warn_prf(average, modifier, msg_start, len(result))
[32m2023-08-20 15:11:59.581[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 3>[0m:[36m20[0m - [1m[15] - ZSL task ViDeBERTA-xsmall at checkpoint-154080 with F1-score = 0.16452780519907867 and Accuracy = 0.32765399737876805[0m


  _warn_prf(average, modifier, msg_start, len(result))
[32m2023-08-20 15:12:17.489[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 3>[0m:[36m20[0m - [1m[16] - ZSL task ViDeBERTA-xsmall at checkpoint-164352 with F1-score = 0.16650278597181253 and Accuracy = 0.3328964613368283[0m


  _warn_prf(average, modifier, msg_start, len(result))
[32m2023-08-20 15:12:36.689[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 3>[0m:[36m20[0m - [1m[17] - ZSL task ViDeBERTA-xsmall at checkpoint-174624 with F1-score = 0.16894977168949774 and Accuracy = 0.3394495412844037[0m


  _warn_prf(average, modifier, msg_start, len(result))
[32m2023-08-20 15:12:54.130[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 3>[0m:[36m20[0m - [1m[18] - ZSL task ViDeBERTA-xsmall at checkpoint-184896 with F1-score = 0.16452780519907867 and Accuracy = 0.32765399737876805[0m


  _warn_prf(average, modifier, msg_start, len(result))
[32m2023-08-20 15:13:09.635[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 3>[0m:[36m20[0m - [1m[19] - ZSL task ViDeBERTA-xsmall at checkpoint-195168 with F1-score = 0.16452780519907867 and Accuracy = 0.32765399737876805[0m


  _warn_prf(average, modifier, msg_start, len(result))
[32m2023-08-20 15:13:26.661[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 3>[0m:[36m20[0m - [1m[20] - ZSL task ViDeBERTA-xsmall at checkpoint-205440 with F1-score = 0.16650278597181253 and Accuracy = 0.3328964613368283[0m


  _warn_prf(average, modifier, msg_start, len(result))
[32m2023-08-20 15:13:42.759[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 3>[0m:[36m20[0m - [1m[21] - ZSL task ViDeBERTA-xsmall at checkpoint-215712 with F1-score = 0.16650278597181253 and Accuracy = 0.3328964613368283[0m


  _warn_prf(average, modifier, msg_start, len(result))
[32m2023-08-20 15:13:58.644[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 3>[0m:[36m20[0m - [1m[22] - ZSL task ViDeBERTA-xsmall at checkpoint-225984 with F1-score = 0.16452780519907867 and Accuracy = 0.32765399737876805[0m


  _warn_prf(average, modifier, msg_start, len(result))
[32m2023-08-20 15:14:14.064[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 3>[0m:[36m20[0m - [1m[23] - ZSL task ViDeBERTA-xsmall at checkpoint-236256 with F1-score = 0.16650278597181253 and Accuracy = 0.3328964613368283[0m


  _warn_prf(average, modifier, msg_start, len(result))
[32m2023-08-20 15:14:29.761[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 3>[0m:[36m20[0m - [1m[24] - ZSL task ViDeBERTA-xsmall at checkpoint-246528 with F1-score = 0.16650278597181253 and Accuracy = 0.3328964613368283[0m


  _warn_prf(average, modifier, msg_start, len(result))
[32m2023-08-20 15:14:45.468[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 3>[0m:[36m20[0m - [1m[25] - ZSL task ViDeBERTA-xsmall at checkpoint-256800 with F1-score = 0.16894977168949774 and Accuracy = 0.3394495412844037[0m


  _warn_prf(average, modifier, msg_start, len(result))
[32m2023-08-20 15:15:01.224[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 3>[0m:[36m20[0m - [1m[26] - ZSL task ViDeBERTA-xsmall at checkpoint-267072 with F1-score = 0.16452780519907867 and Accuracy = 0.32765399737876805[0m


  _warn_prf(average, modifier, msg_start, len(result))
[32m2023-08-20 15:15:16.706[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 3>[0m:[36m20[0m - [1m[27] - ZSL task ViDeBERTA-xsmall at checkpoint-277344 with F1-score = 0.16650278597181253 and Accuracy = 0.3328964613368283[0m


  _warn_prf(average, modifier, msg_start, len(result))
[32m2023-08-20 15:15:36.624[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 3>[0m:[36m20[0m - [1m[28] - ZSL task ViDeBERTA-xsmall at checkpoint-287616 with F1-score = 0.16650278597181253 and Accuracy = 0.3328964613368283[0m


  _warn_prf(average, modifier, msg_start, len(result))
[32m2023-08-20 15:15:52.950[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 3>[0m:[36m20[0m - [1m[29] - ZSL task ViDeBERTA-xsmall at checkpoint-297888 with F1-score = 0.16650278597181253 and Accuracy = 0.3328964613368283[0m


  _warn_prf(average, modifier, msg_start, len(result))
[32m2023-08-20 15:16:11.791[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 3>[0m:[36m20[0m - [1m[30] - ZSL task ViDeBERTA-xsmall at checkpoint-308160 with F1-score = 0.16650278597181253 and Accuracy = 0.3328964613368283[0m


  _warn_prf(average, modifier, msg_start, len(result))
[32m2023-08-20 15:16:27.857[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 3>[0m:[36m20[0m - [1m[31] - ZSL task ViDeBERTA-xsmall at checkpoint-318432 with F1-score = 0.16894977168949774 and Accuracy = 0.3394495412844037[0m


  _warn_prf(average, modifier, msg_start, len(result))
[32m2023-08-20 15:16:43.207[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 3>[0m:[36m20[0m - [1m[32] - ZSL task ViDeBERTA-xsmall at checkpoint-328704 with F1-score = 0.16650278597181253 and Accuracy = 0.3328964613368283[0m


  _warn_prf(average, modifier, msg_start, len(result))
[32m2023-08-20 15:16:59.238[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 3>[0m:[36m20[0m - [1m[33] - ZSL task ViDeBERTA-xsmall at checkpoint-338976 with F1-score = 0.16894977168949774 and Accuracy = 0.3394495412844037[0m


  _warn_prf(average, modifier, msg_start, len(result))
[32m2023-08-20 15:17:17.416[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 3>[0m:[36m20[0m - [1m[34] - ZSL task ViDeBERTA-xsmall at checkpoint-349248 with F1-score = 0.16452780519907867 and Accuracy = 0.32765399737876805[0m


  _warn_prf(average, modifier, msg_start, len(result))
[32m2023-08-20 15:17:33.493[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 3>[0m:[36m20[0m - [1m[35] - ZSL task ViDeBERTA-xsmall at checkpoint-359520 with F1-score = 0.16650278597181253 and Accuracy = 0.3328964613368283[0m


  _warn_prf(average, modifier, msg_start, len(result))
[32m2023-08-20 15:17:50.933[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 3>[0m:[36m20[0m - [1m[36] - ZSL task ViDeBERTA-xsmall at checkpoint-369792 with F1-score = 0.16650278597181253 and Accuracy = 0.3328964613368283[0m


  _warn_prf(average, modifier, msg_start, len(result))
[32m2023-08-20 15:18:06.456[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 3>[0m:[36m20[0m - [1m[37] - ZSL task ViDeBERTA-xsmall at checkpoint-380064 with F1-score = 0.16452780519907867 and Accuracy = 0.32765399737876805[0m


  _warn_prf(average, modifier, msg_start, len(result))
[32m2023-08-20 15:18:22.256[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 3>[0m:[36m20[0m - [1m[38] - ZSL task ViDeBERTA-xsmall at checkpoint-390336 with F1-score = 0.16894977168949774 and Accuracy = 0.3394495412844037[0m


  _warn_prf(average, modifier, msg_start, len(result))
[32m2023-08-20 15:18:38.383[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 3>[0m:[36m20[0m - [1m[39] - ZSL task ViDeBERTA-xsmall at checkpoint-400608 with F1-score = 0.16650278597181253 and Accuracy = 0.3328964613368283[0m


  _warn_prf(average, modifier, msg_start, len(result))
[32m2023-08-20 15:18:53.811[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 3>[0m:[36m20[0m - [1m[40] - ZSL task ViDeBERTA-xsmall at checkpoint-410880 with F1-score = 0.16650278597181253 and Accuracy = 0.3328964613368283[0m


  _warn_prf(average, modifier, msg_start, len(result))
[32m2023-08-20 15:19:09.131[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 3>[0m:[36m20[0m - [1m[41] - ZSL task ViDeBERTA-xsmall at checkpoint-421152 with F1-score = 0.16650278597181253 and Accuracy = 0.3328964613368283[0m


  _warn_prf(average, modifier, msg_start, len(result))
[32m2023-08-20 15:19:24.860[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 3>[0m:[36m20[0m - [1m[42] - ZSL task ViDeBERTA-xsmall at checkpoint-431424 with F1-score = 0.16650278597181253 and Accuracy = 0.3328964613368283[0m


  _warn_prf(average, modifier, msg_start, len(result))
[32m2023-08-20 15:19:40.295[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 3>[0m:[36m20[0m - [1m[43] - ZSL task ViDeBERTA-xsmall at checkpoint-441696 with F1-score = 0.16894977168949774 and Accuracy = 0.3394495412844037[0m


  _warn_prf(average, modifier, msg_start, len(result))
[32m2023-08-20 15:19:56.126[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 3>[0m:[36m20[0m - [1m[44] - ZSL task ViDeBERTA-xsmall at checkpoint-451968 with F1-score = 0.16650278597181253 and Accuracy = 0.3328964613368283[0m


  _warn_prf(average, modifier, msg_start, len(result))
[32m2023-08-20 15:20:16.255[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 3>[0m:[36m20[0m - [1m[45] - ZSL task ViDeBERTA-xsmall at checkpoint-462240 with F1-score = 0.16650278597181253 and Accuracy = 0.3328964613368283[0m


  _warn_prf(average, modifier, msg_start, len(result))
[32m2023-08-20 15:20:32.485[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 3>[0m:[36m20[0m - [1m[46] - ZSL task ViDeBERTA-xsmall at checkpoint-472512 with F1-score = 0.16894977168949774 and Accuracy = 0.3394495412844037[0m


  _warn_prf(average, modifier, msg_start, len(result))
[32m2023-08-20 15:20:47.940[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 3>[0m:[36m20[0m - [1m[47] - ZSL task ViDeBERTA-xsmall at checkpoint-482784 with F1-score = 0.16650278597181253 and Accuracy = 0.3328964613368283[0m


  _warn_prf(average, modifier, msg_start, len(result))
[32m2023-08-20 15:21:04.286[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 3>[0m:[36m20[0m - [1m[48] - ZSL task ViDeBERTA-xsmall at checkpoint-493056 with F1-score = 0.16650278597181253 and Accuracy = 0.3328964613368283[0m


  _warn_prf(average, modifier, msg_start, len(result))
[32m2023-08-20 15:21:20.010[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 3>[0m:[36m20[0m - [1m[49] - ZSL task ViDeBERTA-xsmall at checkpoint-503328 with F1-score = 0.16650278597181253 and Accuracy = 0.3328964613368283[0m


  _warn_prf(average, modifier, msg_start, len(result))
[32m2023-08-20 15:21:35.512[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 3>[0m:[36m20[0m - [1m[50] - ZSL task ViDeBERTA-xsmall at checkpoint-513600 with F1-score = 0.16650278597181253 and Accuracy = 0.3328964613368283[0m


###### **ZSL: ViDeberta base**

In [None]:
videberta_base_tokenizer = AutoTokenizer.from_pretrained('Fsoft-AIC/videberta-base', max_length=512)
videberta_base_data_collator = DataCollatorWithPadding(videberta_base_tokenizer)

training_arg_zsl_videberta_base = TrainingArguments("ZSL_Videberta_base",
                                                      per_device_eval_batch_size=16,)

videberta_base_zsl_checkpoints = sort_folders(get_subfolders_with_prefix('/content/drive/MyDrive/Shay/models/remote_trained_mtl/zsl_videberta_base_runs_18082023', 'checkpoint'))

for index, checkpoint in enumerate(videberta_base_zsl_checkpoints):
    zsl_videberta_base_model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=3)

    trainer_zsl_videberta_base = Trainer(
        model=zsl_videberta_base_model,
        args=training_arg_zsl_videberta_base,
        train_dataset=tokenized_vinli_train_dataset,
        eval_dataset=tokenized_vinli_test_dataset,
        compute_metrics=compute_metrics_zsl,
        tokenizer=videberta_base_tokenizer,
        data_collator=videberta_base_data_collator,
    )

    metrics = trainer_zsl_videberta_base.evaluate(eval_dataset=tokenized_vinli_test_dataset)

    metrics["eval_samples"] = len(tokenized_vinli_test_dataset)
    logger.info(f"[{index+1}] - ZSL task ViDeBERTA-base at {os.path.basename(checkpoint)} with F1-score = {metrics['eval_f1']} and Accuracy = {metrics['eval_accuracy']}")

    # trainer_zsl_videberta_xsmall.log_metrics(f"eval for ZSL task with ViDeBERTA-xsmall at {os.path.basename(checkpoint)}", metrics)

You're using a DebertaV2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


[32m2023-08-20 15:22:30.836[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 15>[0m:[36m31[0m - [1m[1] - ZSL task ViDeBERTA-base at checkpoint-2568 with F1-score = 0.2481994910262839 and Accuracy = 0.3333333333333333[0m


[32m2023-08-20 15:23:19.125[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 15>[0m:[36m31[0m - [1m[2] - ZSL task ViDeBERTA-base at checkpoint-5136 with F1-score = 0.16926596801184216 and Accuracy = 0.33420707732634336[0m


[32m2023-08-20 15:24:10.206[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 15>[0m:[36m31[0m - [1m[3] - ZSL task ViDeBERTA-base at checkpoint-7704 with F1-score = 0.16673381623876674 and Accuracy = 0.327217125382263[0m


[32m2023-08-20 15:25:03.200[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 15>[0m:[36m31[0m - [1m[4] - ZSL task ViDeBERTA-base at checkpoint-10272 with F1-score = 0.17596100605858556 and Accuracy = 0.3328964613368283[0m


[32m2023-08-20 15:25:51.868[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 15>[0m:[36m31[0m - [1m[5] - ZSL task ViDeBERTA-base at checkpoint-12840 with F1-score = 0.25451692881062477 and Accuracy = 0.33377020532983837[0m


[32m2023-08-20 15:26:42.339[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 15>[0m:[36m31[0m - [1m[6] - ZSL task ViDeBERTA-base at checkpoint-15408 with F1-score = 0.18599045979169973 and Accuracy = 0.3320227173438183[0m


[32m2023-08-20 15:27:30.736[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 15>[0m:[36m31[0m - [1m[7] - ZSL task ViDeBERTA-base at checkpoint-17976 with F1-score = 0.3094146613169067 and Accuracy = 0.34294451725644387[0m


[32m2023-08-20 15:28:23.730[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 15>[0m:[36m31[0m - [1m[8] - ZSL task ViDeBERTA-base at checkpoint-20544 with F1-score = 0.27103093796772887 and Accuracy = 0.32809086937527304[0m


[32m2023-08-20 15:29:16.601[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 15>[0m:[36m31[0m - [1m[9] - ZSL task ViDeBERTA-base at checkpoint-23112 with F1-score = 0.26511517331406986 and Accuracy = 0.3315858453473132[0m


[32m2023-08-20 15:30:10.785[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 15>[0m:[36m31[0m - [1m[10] - ZSL task ViDeBERTA-base at checkpoint-25680 with F1-score = 0.21602966400520615 and Accuracy = 0.3285277413717781[0m


[32m2023-08-20 15:30:57.816[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 15>[0m:[36m31[0m - [1m[11] - ZSL task ViDeBERTA-base at checkpoint-28248 with F1-score = 0.2805704000628015 and Accuracy = 0.32809086937527304[0m


[32m2023-08-20 15:31:46.056[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 15>[0m:[36m31[0m - [1m[12] - ZSL task ViDeBERTA-base at checkpoint-30816 with F1-score = 0.2899271596063872 and Accuracy = 0.3350808213193534[0m


[32m2023-08-20 15:32:40.678[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 15>[0m:[36m31[0m - [1m[13] - ZSL task ViDeBERTA-base at checkpoint-33384 with F1-score = 0.1979302087661221 and Accuracy = 0.3328964613368283[0m


[32m2023-08-20 15:33:31.038[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 15>[0m:[36m31[0m - [1m[14] - ZSL task ViDeBERTA-base at checkpoint-35952 with F1-score = 0.3294244062614822 and Accuracy = 0.3350808213193534[0m


[32m2023-08-20 15:34:18.899[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 15>[0m:[36m31[0m - [1m[15] - ZSL task ViDeBERTA-base at checkpoint-38520 with F1-score = 0.31533011347239276 and Accuracy = 0.3315858453473132[0m


[32m2023-08-20 15:35:09.914[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 15>[0m:[36m31[0m - [1m[16] - ZSL task ViDeBERTA-base at checkpoint-41088 with F1-score = 0.2690769104305546 and Accuracy = 0.32940148536478814[0m


[32m2023-08-20 15:36:05.160[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 15>[0m:[36m31[0m - [1m[17] - ZSL task ViDeBERTA-base at checkpoint-43656 with F1-score = 0.22385379819622545 and Accuracy = 0.3377020532983836[0m


[32m2023-08-20 15:36:54.842[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 15>[0m:[36m31[0m - [1m[18] - ZSL task ViDeBERTA-base at checkpoint-46224 with F1-score = 0.21054615174179458 and Accuracy = 0.33245958934032327[0m


[32m2023-08-20 15:37:48.176[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 15>[0m:[36m31[0m - [1m[19] - ZSL task ViDeBERTA-base at checkpoint-48792 with F1-score = 0.19321817632417582 and Accuracy = 0.34076015727391873[0m


[32m2023-08-20 15:38:43.161[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 15>[0m:[36m31[0m - [1m[20] - ZSL task ViDeBERTA-base at checkpoint-51360 with F1-score = 0.18504433471761408 and Accuracy = 0.34032328527741373[0m


[32m2023-08-20 15:39:36.154[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 15>[0m:[36m31[0m - [1m[21] - ZSL task ViDeBERTA-base at checkpoint-53928 with F1-score = 0.17005997403995568 and Accuracy = 0.32940148536478814[0m


[32m2023-08-20 15:40:24.180[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 15>[0m:[36m31[0m - [1m[22] - ZSL task ViDeBERTA-base at checkpoint-56496 with F1-score = 0.16900387021055133 and Accuracy = 0.32809086937527304[0m


[32m2023-08-20 15:41:13.543[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 15>[0m:[36m31[0m - [1m[23] - ZSL task ViDeBERTA-base at checkpoint-59064 with F1-score = 0.16635968304262594 and Accuracy = 0.32809086937527304[0m


[32m2023-08-20 15:42:03.227[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 15>[0m:[36m31[0m - [1m[24] - ZSL task ViDeBERTA-base at checkpoint-61632 with F1-score = 0.21476538912641682 and Accuracy = 0.3411970292704238[0m


[32m2023-08-20 15:42:52.865[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 15>[0m:[36m31[0m - [1m[25] - ZSL task ViDeBERTA-base at checkpoint-64200 with F1-score = 0.17128536522137114 and Accuracy = 0.3285277413717781[0m


[32m2023-08-20 15:43:45.738[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 15>[0m:[36m31[0m - [1m[26] - ZSL task ViDeBERTA-base at checkpoint-66768 with F1-score = 0.16827548631642322 and Accuracy = 0.3289646133682831[0m


[32m2023-08-20 15:44:33.965[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 15>[0m:[36m31[0m - [1m[27] - ZSL task ViDeBERTA-base at checkpoint-69336 with F1-score = 0.16622000633752015 and Accuracy = 0.327217125382263[0m


[32m2023-08-20 15:45:20.166[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 15>[0m:[36m31[0m - [1m[28] - ZSL task ViDeBERTA-base at checkpoint-71904 with F1-score = 0.1738341454569047 and Accuracy = 0.33114897335080823[0m


[32m2023-08-20 15:46:12.438[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 15>[0m:[36m31[0m - [1m[29] - ZSL task ViDeBERTA-base at checkpoint-74472 with F1-score = 0.1743415616963582 and Accuracy = 0.32983835736129313[0m


[32m2023-08-20 15:47:03.934[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 15>[0m:[36m31[0m - [1m[30] - ZSL task ViDeBERTA-base at checkpoint-77040 with F1-score = 0.2144415486001164 and Accuracy = 0.33726518130187855[0m


[32m2023-08-20 15:47:56.326[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 15>[0m:[36m31[0m - [1m[31] - ZSL task ViDeBERTA-base at checkpoint-79608 with F1-score = 0.2366250076537566 and Accuracy = 0.33245958934032327[0m


[32m2023-08-20 15:48:48.894[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 15>[0m:[36m31[0m - [1m[32] - ZSL task ViDeBERTA-base at checkpoint-82176 with F1-score = 0.22538660100899563 and Accuracy = 0.33682830930537355[0m


[32m2023-08-20 15:49:38.921[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 15>[0m:[36m31[0m - [1m[33] - ZSL task ViDeBERTA-base at checkpoint-84744 with F1-score = 0.21210351688477355 and Accuracy = 0.3315858453473132[0m


[32m2023-08-20 15:50:31.305[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 15>[0m:[36m31[0m - [1m[34] - ZSL task ViDeBERTA-base at checkpoint-87312 with F1-score = 0.2029423226724365 and Accuracy = 0.33551769331585846[0m


[32m2023-08-20 15:51:24.113[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 15>[0m:[36m31[0m - [1m[35] - ZSL task ViDeBERTA-base at checkpoint-89880 with F1-score = 0.21776105427310233 and Accuracy = 0.3315858453473132[0m


[32m2023-08-20 15:52:14.911[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 15>[0m:[36m31[0m - [1m[36] - ZSL task ViDeBERTA-base at checkpoint-92448 with F1-score = 0.23190558312509527 and Accuracy = 0.3350808213193534[0m


[32m2023-08-20 15:53:13.823[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 15>[0m:[36m31[0m - [1m[37] - ZSL task ViDeBERTA-base at checkpoint-95016 with F1-score = 0.24432089045471486 and Accuracy = 0.3350808213193534[0m


[32m2023-08-20 15:54:15.581[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 15>[0m:[36m31[0m - [1m[38] - ZSL task ViDeBERTA-base at checkpoint-97584 with F1-score = 0.26570712787238704 and Accuracy = 0.32809086937527304[0m


[32m2023-08-20 15:55:05.317[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 15>[0m:[36m31[0m - [1m[39] - ZSL task ViDeBERTA-base at checkpoint-100152 with F1-score = 0.2147456558027633 and Accuracy = 0.33114897335080823[0m


[32m2023-08-20 15:55:57.859[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 15>[0m:[36m31[0m - [1m[40] - ZSL task ViDeBERTA-base at checkpoint-102720 with F1-score = 0.2713176709402567 and Accuracy = 0.33245958934032327[0m


[32m2023-08-20 15:56:49.506[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 15>[0m:[36m31[0m - [1m[41] - ZSL task ViDeBERTA-base at checkpoint-105288 with F1-score = 0.23880234316121007 and Accuracy = 0.32328527741371776[0m


[32m2023-08-20 15:57:41.609[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 15>[0m:[36m31[0m - [1m[42] - ZSL task ViDeBERTA-base at checkpoint-107856 with F1-score = 0.2490982790198842 and Accuracy = 0.32634338138925295[0m


[32m2023-08-20 15:58:30.791[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 15>[0m:[36m31[0m - [1m[43] - ZSL task ViDeBERTA-base at checkpoint-110424 with F1-score = 0.28016568957952764 and Accuracy = 0.33682830930537355[0m


[32m2023-08-20 15:59:19.017[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 15>[0m:[36m31[0m - [1m[44] - ZSL task ViDeBERTA-base at checkpoint-112992 with F1-score = 0.27225568727230004 and Accuracy = 0.3346439493228484[0m


[32m2023-08-20 16:00:06.777[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 15>[0m:[36m31[0m - [1m[45] - ZSL task ViDeBERTA-base at checkpoint-115560 with F1-score = 0.24893823425237757 and Accuracy = 0.327217125382263[0m


[32m2023-08-20 16:00:55.677[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 15>[0m:[36m31[0m - [1m[46] - ZSL task ViDeBERTA-base at checkpoint-118128 with F1-score = 0.266286435428499 and Accuracy = 0.3346439493228484[0m


[32m2023-08-20 16:01:46.826[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 15>[0m:[36m31[0m - [1m[47] - ZSL task ViDeBERTA-base at checkpoint-120696 with F1-score = 0.2666635051981742 and Accuracy = 0.32940148536478814[0m


[32m2023-08-20 16:02:34.137[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 15>[0m:[36m31[0m - [1m[48] - ZSL task ViDeBERTA-base at checkpoint-123264 with F1-score = 0.261233533737862 and Accuracy = 0.33114897335080823[0m


[32m2023-08-20 16:03:23.709[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 15>[0m:[36m31[0m - [1m[49] - ZSL task ViDeBERTA-base at checkpoint-125832 with F1-score = 0.2604115894834856 and Accuracy = 0.3307121013543032[0m


[32m2023-08-20 16:04:08.623[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 15>[0m:[36m31[0m - [1m[50] - ZSL task ViDeBERTA-base at checkpoint-128400 with F1-score = 0.261220173258466 and Accuracy = 0.3307121013543032[0m


###### **ZSL: PhoBERT base v2**

In [None]:
from transformers import pipeline, AutoModelForSequenceClassification
from transformers import TrainingArguments, Trainer, DataCollatorWithPadding, set_seed
import numpy as np
from datasets import load_metric


phobertv2_tokenizer = AutoTokenizer.from_pretrained('vinai/phobert-base-v2', max_length=512)
phobertv2_data_collator = DataCollatorWithPadding(phobertv2_tokenizer)

training_arg_zsl_phobertv2 = TrainingArguments("ZSL_PhoBERT_v2",
                                                per_device_eval_batch_size=16,)

phobertv2_zsl_checkpoints = sort_folders(get_subfolders_with_prefix('/content/drive/MyDrive/Shay/models/remote_trained_mtl/zsl_phobertv2_runs_16082023', 'checkpoint'))

for index, checkpoint in enumerate(phobertv2_zsl_checkpoints):
    zsl_phobertv2_base_model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=3)

    trainer_zsl_phobertv2 = Trainer(
        model=zsl_phobertv2_base_model,
        args=training_arg_zsl_phobertv2,
        train_dataset=tokenized_vinli_train_dataset,
        eval_dataset=tokenized_vinli_test_dataset,
        compute_metrics=compute_metrics_zsl,
        tokenizer=phobertv2_tokenizer,
        data_collator=phobertv2_data_collator,
    )

    metrics = trainer_zsl_phobertv2.evaluate(eval_dataset=tokenized_vinli_test_dataset)

    metrics["eval_samples"] = len(tokenized_vinli_test_dataset)
    logger.info(f"[{index+1}] - ZSL task PhoBERTv2 at {os.path.basename(checkpoint)} with F1-score = {metrics['eval_f1']} and Accuracy = {metrics['eval_accuracy']}")

    # trainer_zsl_videberta_xsmall.log_metrics(f"eval for ZSL task with ViDeBERTA-xsmall at {os.path.basename(checkpoint)}", metrics)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


  _warn_prf(average, modifier, msg_start, len(result))
[32m2023-08-20 16:04:44.214[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 15>[0m:[36m31[0m - [1m[1] - ZSL task PhoBERTv2 at checkpoint-10272 with F1-score = 0.16650278597181253 and Accuracy = 0.3328964613368283[0m


  _warn_prf(average, modifier, msg_start, len(result))
[32m2023-08-20 16:05:24.763[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 15>[0m:[36m31[0m - [1m[2] - ZSL task PhoBERTv2 at checkpoint-20544 with F1-score = 0.16650278597181253 and Accuracy = 0.3328964613368283[0m


  _warn_prf(average, modifier, msg_start, len(result))
[32m2023-08-20 16:06:04.661[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 15>[0m:[36m31[0m - [1m[3] - ZSL task PhoBERTv2 at checkpoint-30816 with F1-score = 0.16452780519907867 and Accuracy = 0.32765399737876805[0m


  _warn_prf(average, modifier, msg_start, len(result))
[32m2023-08-20 16:06:45.316[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 15>[0m:[36m31[0m - [1m[4] - ZSL task PhoBERTv2 at checkpoint-41088 with F1-score = 0.16650278597181253 and Accuracy = 0.3328964613368283[0m


  _warn_prf(average, modifier, msg_start, len(result))
[32m2023-08-20 16:07:23.577[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 15>[0m:[36m31[0m - [1m[5] - ZSL task PhoBERTv2 at checkpoint-51360 with F1-score = 0.16650278597181253 and Accuracy = 0.3328964613368283[0m


  _warn_prf(average, modifier, msg_start, len(result))
[32m2023-08-20 16:08:06.094[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 15>[0m:[36m31[0m - [1m[6] - ZSL task PhoBERTv2 at checkpoint-61632 with F1-score = 0.16650278597181253 and Accuracy = 0.3328964613368283[0m


  _warn_prf(average, modifier, msg_start, len(result))
[32m2023-08-20 16:08:45.627[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 15>[0m:[36m31[0m - [1m[7] - ZSL task PhoBERTv2 at checkpoint-71904 with F1-score = 0.16452780519907867 and Accuracy = 0.32765399737876805[0m


  _warn_prf(average, modifier, msg_start, len(result))
[32m2023-08-20 16:09:25.147[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 15>[0m:[36m31[0m - [1m[8] - ZSL task PhoBERTv2 at checkpoint-82176 with F1-score = 0.16894977168949774 and Accuracy = 0.3394495412844037[0m


  _warn_prf(average, modifier, msg_start, len(result))
[32m2023-08-20 16:10:05.273[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 15>[0m:[36m31[0m - [1m[9] - ZSL task PhoBERTv2 at checkpoint-92448 with F1-score = 0.16452780519907867 and Accuracy = 0.32765399737876805[0m


  _warn_prf(average, modifier, msg_start, len(result))
[32m2023-08-20 16:10:45.662[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 15>[0m:[36m31[0m - [1m[10] - ZSL task PhoBERTv2 at checkpoint-102720 with F1-score = 0.16894977168949774 and Accuracy = 0.3394495412844037[0m


  _warn_prf(average, modifier, msg_start, len(result))
[32m2023-08-20 16:11:26.632[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 15>[0m:[36m31[0m - [1m[11] - ZSL task PhoBERTv2 at checkpoint-112992 with F1-score = 0.16650278597181253 and Accuracy = 0.3328964613368283[0m


  _warn_prf(average, modifier, msg_start, len(result))
[32m2023-08-20 16:12:06.142[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 15>[0m:[36m31[0m - [1m[12] - ZSL task PhoBERTv2 at checkpoint-123264 with F1-score = 0.16452780519907867 and Accuracy = 0.32765399737876805[0m


  _warn_prf(average, modifier, msg_start, len(result))
[32m2023-08-20 16:12:45.958[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 15>[0m:[36m31[0m - [1m[13] - ZSL task PhoBERTv2 at checkpoint-133536 with F1-score = 0.16650278597181253 and Accuracy = 0.3328964613368283[0m


  _warn_prf(average, modifier, msg_start, len(result))
[32m2023-08-20 16:13:34.758[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 15>[0m:[36m31[0m - [1m[14] - ZSL task PhoBERTv2 at checkpoint-143808 with F1-score = 0.16894977168949774 and Accuracy = 0.3394495412844037[0m


  _warn_prf(average, modifier, msg_start, len(result))
[32m2023-08-20 16:14:16.336[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 15>[0m:[36m31[0m - [1m[15] - ZSL task PhoBERTv2 at checkpoint-154080 with F1-score = 0.16452780519907867 and Accuracy = 0.32765399737876805[0m


  _warn_prf(average, modifier, msg_start, len(result))
[32m2023-08-20 16:15:03.311[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 15>[0m:[36m31[0m - [1m[16] - ZSL task PhoBERTv2 at checkpoint-164352 with F1-score = 0.16650278597181253 and Accuracy = 0.3328964613368283[0m


  _warn_prf(average, modifier, msg_start, len(result))
[32m2023-08-20 16:15:43.377[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 15>[0m:[36m31[0m - [1m[17] - ZSL task PhoBERTv2 at checkpoint-174624 with F1-score = 0.16894977168949774 and Accuracy = 0.3394495412844037[0m


  _warn_prf(average, modifier, msg_start, len(result))
[32m2023-08-20 16:16:22.038[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 15>[0m:[36m31[0m - [1m[18] - ZSL task PhoBERTv2 at checkpoint-184896 with F1-score = 0.16452780519907867 and Accuracy = 0.32765399737876805[0m


  _warn_prf(average, modifier, msg_start, len(result))
[32m2023-08-20 16:17:02.907[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 15>[0m:[36m31[0m - [1m[19] - ZSL task PhoBERTv2 at checkpoint-195168 with F1-score = 0.16452780519907867 and Accuracy = 0.32765399737876805[0m


  _warn_prf(average, modifier, msg_start, len(result))
[32m2023-08-20 16:17:41.860[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 15>[0m:[36m31[0m - [1m[20] - ZSL task PhoBERTv2 at checkpoint-205440 with F1-score = 0.16650278597181253 and Accuracy = 0.3328964613368283[0m


  _warn_prf(average, modifier, msg_start, len(result))
[32m2023-08-20 16:18:23.778[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 15>[0m:[36m31[0m - [1m[21] - ZSL task PhoBERTv2 at checkpoint-215712 with F1-score = 0.16650278597181253 and Accuracy = 0.3328964613368283[0m


  _warn_prf(average, modifier, msg_start, len(result))
[32m2023-08-20 16:19:04.796[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 15>[0m:[36m31[0m - [1m[22] - ZSL task PhoBERTv2 at checkpoint-225984 with F1-score = 0.16452780519907867 and Accuracy = 0.32765399737876805[0m


  _warn_prf(average, modifier, msg_start, len(result))
[32m2023-08-20 16:19:43.826[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 15>[0m:[36m31[0m - [1m[23] - ZSL task PhoBERTv2 at checkpoint-236256 with F1-score = 0.16650278597181253 and Accuracy = 0.3328964613368283[0m


  _warn_prf(average, modifier, msg_start, len(result))
[32m2023-08-20 16:20:26.584[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 15>[0m:[36m31[0m - [1m[24] - ZSL task PhoBERTv2 at checkpoint-246528 with F1-score = 0.16650278597181253 and Accuracy = 0.3328964613368283[0m


  _warn_prf(average, modifier, msg_start, len(result))
[32m2023-08-20 16:21:10.429[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 15>[0m:[36m31[0m - [1m[25] - ZSL task PhoBERTv2 at checkpoint-256800 with F1-score = 0.16894977168949774 and Accuracy = 0.3394495412844037[0m


  _warn_prf(average, modifier, msg_start, len(result))
[32m2023-08-20 16:21:47.122[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 15>[0m:[36m31[0m - [1m[26] - ZSL task PhoBERTv2 at checkpoint-267072 with F1-score = 0.16452780519907867 and Accuracy = 0.32765399737876805[0m


  _warn_prf(average, modifier, msg_start, len(result))
[32m2023-08-20 16:22:31.822[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 15>[0m:[36m31[0m - [1m[27] - ZSL task PhoBERTv2 at checkpoint-277344 with F1-score = 0.16650278597181253 and Accuracy = 0.3328964613368283[0m


  _warn_prf(average, modifier, msg_start, len(result))
[32m2023-08-20 16:23:15.243[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 15>[0m:[36m31[0m - [1m[28] - ZSL task PhoBERTv2 at checkpoint-287616 with F1-score = 0.16650278597181253 and Accuracy = 0.3328964613368283[0m


  _warn_prf(average, modifier, msg_start, len(result))
[32m2023-08-20 16:24:02.373[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 15>[0m:[36m31[0m - [1m[29] - ZSL task PhoBERTv2 at checkpoint-297888 with F1-score = 0.16650278597181253 and Accuracy = 0.3328964613368283[0m


  _warn_prf(average, modifier, msg_start, len(result))
[32m2023-08-20 16:24:43.298[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 15>[0m:[36m31[0m - [1m[30] - ZSL task PhoBERTv2 at checkpoint-308160 with F1-score = 0.16650278597181253 and Accuracy = 0.3328964613368283[0m


  _warn_prf(average, modifier, msg_start, len(result))
[32m2023-08-20 16:25:22.159[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 15>[0m:[36m31[0m - [1m[31] - ZSL task PhoBERTv2 at checkpoint-318432 with F1-score = 0.16650278597181253 and Accuracy = 0.3328964613368283[0m


  _warn_prf(average, modifier, msg_start, len(result))
[32m2023-08-20 16:26:03.078[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 15>[0m:[36m31[0m - [1m[32] - ZSL task PhoBERTv2 at checkpoint-328704 with F1-score = 0.16650278597181253 and Accuracy = 0.3328964613368283[0m


  _warn_prf(average, modifier, msg_start, len(result))
[32m2023-08-20 16:26:42.961[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 15>[0m:[36m31[0m - [1m[33] - ZSL task PhoBERTv2 at checkpoint-338976 with F1-score = 0.16894977168949774 and Accuracy = 0.3394495412844037[0m


  _warn_prf(average, modifier, msg_start, len(result))
[32m2023-08-20 16:27:32.398[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 15>[0m:[36m31[0m - [1m[34] - ZSL task PhoBERTv2 at checkpoint-349248 with F1-score = 0.16452780519907867 and Accuracy = 0.32765399737876805[0m


  _warn_prf(average, modifier, msg_start, len(result))
[32m2023-08-20 16:28:14.941[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 15>[0m:[36m31[0m - [1m[35] - ZSL task PhoBERTv2 at checkpoint-359520 with F1-score = 0.16650278597181253 and Accuracy = 0.3328964613368283[0m


  _warn_prf(average, modifier, msg_start, len(result))
[32m2023-08-20 16:28:59.951[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 15>[0m:[36m31[0m - [1m[36] - ZSL task PhoBERTv2 at checkpoint-369792 with F1-score = 0.16650278597181253 and Accuracy = 0.3328964613368283[0m


  _warn_prf(average, modifier, msg_start, len(result))
[32m2023-08-20 16:29:40.697[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 15>[0m:[36m31[0m - [1m[37] - ZSL task PhoBERTv2 at checkpoint-380064 with F1-score = 0.16650278597181253 and Accuracy = 0.3328964613368283[0m


  _warn_prf(average, modifier, msg_start, len(result))
[32m2023-08-20 16:30:17.852[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 15>[0m:[36m31[0m - [1m[38] - ZSL task PhoBERTv2 at checkpoint-390336 with F1-score = 0.16650278597181253 and Accuracy = 0.3328964613368283[0m


  _warn_prf(average, modifier, msg_start, len(result))
[32m2023-08-20 16:30:59.601[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 15>[0m:[36m31[0m - [1m[39] - ZSL task PhoBERTv2 at checkpoint-400608 with F1-score = 0.16650278597181253 and Accuracy = 0.3328964613368283[0m


  _warn_prf(average, modifier, msg_start, len(result))
[32m2023-08-20 16:31:38.077[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 15>[0m:[36m31[0m - [1m[40] - ZSL task PhoBERTv2 at checkpoint-410880 with F1-score = 0.16650278597181253 and Accuracy = 0.3328964613368283[0m


  _warn_prf(average, modifier, msg_start, len(result))
[32m2023-08-20 16:32:18.018[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 15>[0m:[36m31[0m - [1m[41] - ZSL task PhoBERTv2 at checkpoint-421152 with F1-score = 0.16650278597181253 and Accuracy = 0.3328964613368283[0m


  _warn_prf(average, modifier, msg_start, len(result))
[32m2023-08-20 16:33:02.382[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 15>[0m:[36m31[0m - [1m[42] - ZSL task PhoBERTv2 at checkpoint-431424 with F1-score = 0.16650278597181253 and Accuracy = 0.3328964613368283[0m


  _warn_prf(average, modifier, msg_start, len(result))
[32m2023-08-20 16:33:43.372[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 15>[0m:[36m31[0m - [1m[43] - ZSL task PhoBERTv2 at checkpoint-441696 with F1-score = 0.16452780519907867 and Accuracy = 0.32765399737876805[0m


  _warn_prf(average, modifier, msg_start, len(result))
[32m2023-08-20 16:34:28.677[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 15>[0m:[36m31[0m - [1m[44] - ZSL task PhoBERTv2 at checkpoint-451968 with F1-score = 0.16650278597181253 and Accuracy = 0.3328964613368283[0m


  _warn_prf(average, modifier, msg_start, len(result))
[32m2023-08-20 16:35:08.645[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 15>[0m:[36m31[0m - [1m[45] - ZSL task PhoBERTv2 at checkpoint-462240 with F1-score = 0.16650278597181253 and Accuracy = 0.3328964613368283[0m


  _warn_prf(average, modifier, msg_start, len(result))
[32m2023-08-20 16:35:45.439[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 15>[0m:[36m31[0m - [1m[46] - ZSL task PhoBERTv2 at checkpoint-472512 with F1-score = 0.16650278597181253 and Accuracy = 0.3328964613368283[0m


  _warn_prf(average, modifier, msg_start, len(result))
[32m2023-08-20 16:36:25.378[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 15>[0m:[36m31[0m - [1m[47] - ZSL task PhoBERTv2 at checkpoint-482784 with F1-score = 0.16650278597181253 and Accuracy = 0.3328964613368283[0m


  _warn_prf(average, modifier, msg_start, len(result))
[32m2023-08-20 16:37:10.378[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 15>[0m:[36m31[0m - [1m[48] - ZSL task PhoBERTv2 at checkpoint-493056 with F1-score = 0.16650278597181253 and Accuracy = 0.3328964613368283[0m


  _warn_prf(average, modifier, msg_start, len(result))
[32m2023-08-20 16:37:57.031[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 15>[0m:[36m31[0m - [1m[49] - ZSL task PhoBERTv2 at checkpoint-503328 with F1-score = 0.16650278597181253 and Accuracy = 0.3328964613368283[0m


  _warn_prf(average, modifier, msg_start, len(result))
[32m2023-08-20 16:38:36.378[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 15>[0m:[36m31[0m - [1m[50] - ZSL task PhoBERTv2 at checkpoint-513600 with F1-score = 0.16650278597181253 and Accuracy = 0.3328964613368283[0m


##### **Sentiment Analysis**

In [12]:
import evaluate

accuracy_metric = evaluate.load('accuracy')
f1_metric = evaluate.load('f1')
precision_metric = evaluate.load('precision')
recall_metric = evaluate.load('recall')


def compute_metrics_sa(eval_preds):
    logits, labels = eval_preds
    predictions = np.argmax(logits, axis=-1)

    accuracy = accuracy_metric.compute(predictions=predictions, references=labels)[
        'accuracy'
    ]
    precision = precision_metric.compute(
        predictions=predictions, references=labels,
    )['precision']
    recall = recall_metric.compute(
        predictions=predictions, references=labels,
    )['recall']
    f1 = f1_metric.compute(predictions=predictions, references=labels)[
        'f1'
    ]

    return {'accuracy': accuracy, 'precision': precision, 'recall': recall, 'f1': f1}

Downloading builder script:   0%|          | 0.00/4.20k [00:00<?, ?B/s]

Downloading builder script:   0%|          | 0.00/6.77k [00:00<?, ?B/s]

Downloading builder script:   0%|          | 0.00/7.55k [00:00<?, ?B/s]

Downloading builder script:   0%|          | 0.00/7.36k [00:00<?, ?B/s]

In [13]:
raw_datasets = load_from_disk('/content/drive/MyDrive/Shay/training_script/merged_uit_sa_ds')
train_dataset = raw_datasets['train']
validation_dataset = raw_datasets['validation']
# Tokenize the data
def preprocess_function(row):
    return tokenizer(row['sentence1'], truncation=True, padding=True)

sa_tokenized_train_dataset = train_dataset.map(
        preprocess_function, batched=True)
sa_tokenized_validation_dataset = validation_dataset.map(
        preprocess_function, batched=True)

Map:   0%|          | 0/12478 [00:00<?, ? examples/s]

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


Map:   0%|          | 0/2999 [00:00<?, ? examples/s]

###### **Sentiment Analysis: ViDeberta xsmall**

In [16]:
videberta_xsmall_tokenizer = AutoTokenizer.from_pretrained('Fsoft-AIC/videberta-xsmall')
videberta_xsmall_data_collator = DataCollatorWithPadding(videberta_xsmall_tokenizer)

training_arg_sa_videberta_xsmall = TrainingArguments("SA_Videberta_xsmall",
                                                      per_device_eval_batch_size=32,)

sa_videberta_xsmall_model = AutoModelForSequenceClassification.from_pretrained("/content/drive/MyDrive/Shay/training_script/output/models/sa_videberta_xsmall_model_10_epochs_21082023", num_labels=2)

trainer_sa_videberta_xsmall = Trainer(
    model=sa_videberta_xsmall_model,
    args=training_arg_sa_videberta_xsmall,
    train_dataset=sa_tokenized_train_dataset,
    eval_dataset=sa_tokenized_validation_dataset,
    compute_metrics=compute_metrics_sa,
    tokenizer=videberta_xsmall_tokenizer,
    data_collator=videberta_xsmall_data_collator,
)

metrics = trainer_sa_videberta_xsmall.evaluate(eval_dataset=sa_tokenized_validation_dataset)

metrics["eval_samples"] = len(sa_tokenized_validation_dataset)
trainer_sa_videberta_xsmall.log_metrics('eval for sa', metrics)

You're using a DebertaV2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


***** eval for sa metrics *****
  eval_accuracy           =     0.4915
  eval_f1                 =     0.3753
  eval_loss               =     2.0162
  eval_precision          =     0.5382
  eval_recall             =     0.2881
  eval_runtime            = 0:00:07.04
  eval_samples            =       2999
  eval_samples_per_second =    425.573
  eval_steps_per_second   =     13.339


In [25]:
videberta_xsmall_tokenizer = AutoTokenizer.from_pretrained('Fsoft-AIC/videberta-xsmall')
videberta_xsmall_data_collator = DataCollatorWithPadding(videberta_xsmall_tokenizer)

training_arg_sa_videberta_xsmall = TrainingArguments("SA_Videberta_xsmall",
                                                      per_device_eval_batch_size=32,)

videberta_xsmall_sa_checkpoints = sort_folders(get_subfolders_with_prefix('/content/drive/MyDrive/Shay/training_script/output/runs/sa_videberta_xsmall_runs_10_epochs_21082023/', 'checkpoint'))

for index, checkpoint in enumerate(videberta_xsmall_sa_checkpoints):
    sa_videberta_xsmall_model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=2)

    trainer_sa_videberta_xsmall = Trainer(
        model=sa_videberta_xsmall_model,
        args=training_arg_sa_videberta_xsmall,
        train_dataset=sa_tokenized_train_dataset,
        eval_dataset=sa_tokenized_validation_dataset,
        compute_metrics=compute_metrics_sa,
        tokenizer=videberta_xsmall_tokenizer,
        data_collator=videberta_xsmall_data_collator,
    )

    metrics = trainer_sa_videberta_xsmall.evaluate(eval_dataset=sa_tokenized_validation_dataset)

    metrics["eval_samples"] = len(sa_tokenized_validation_dataset)
    logger.info(f"[{index+1}] - SA task ViDeBERTA-xsmall at {os.path.basename(checkpoint)} with F1-score = {metrics['eval_f1']} and Accuracy = {metrics['eval_accuracy']}")

You're using a DebertaV2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


[32m2023-08-21 09:24:29.596[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 9>[0m:[36m25[0m - [1m[1] - SA task ViDeBERTA-xsmall at checkpoint-195 with F1-score = 0.426677380474086 and Accuracy = 0.5241747249083027[0m


[32m2023-08-21 09:24:39.430[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 9>[0m:[36m25[0m - [1m[2] - SA task ViDeBERTA-xsmall at checkpoint-390 with F1-score = 0.08508158508158507 and Accuracy = 0.4764921640546849[0m


[32m2023-08-21 09:24:49.007[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 9>[0m:[36m25[0m - [1m[3] - SA task ViDeBERTA-xsmall at checkpoint-585 with F1-score = 0.29145728643216084 and Accuracy = 0.48282760920306766[0m


[32m2023-08-21 09:24:59.495[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 9>[0m:[36m25[0m - [1m[4] - SA task ViDeBERTA-xsmall at checkpoint-780 with F1-score = 0.2105263157894737 and Accuracy = 0.4798266088696232[0m


[32m2023-08-21 09:25:09.443[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 9>[0m:[36m25[0m - [1m[5] - SA task ViDeBERTA-xsmall at checkpoint-975 with F1-score = 0.33333333333333337 and Accuracy = 0.4838279426475492[0m


[32m2023-08-21 09:25:19.562[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 9>[0m:[36m25[0m - [1m[6] - SA task ViDeBERTA-xsmall at checkpoint-1170 with F1-score = 0.41068580542264754 and Accuracy = 0.5071690563521174[0m


[32m2023-08-21 09:25:29.153[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 9>[0m:[36m25[0m - [1m[7] - SA task ViDeBERTA-xsmall at checkpoint-1365 with F1-score = 0.3767634854771785 and Accuracy = 0.4991663887962654[0m


[32m2023-08-21 09:25:38.487[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 9>[0m:[36m25[0m - [1m[8] - SA task ViDeBERTA-xsmall at checkpoint-1560 with F1-score = 0.3966142684401451 and Accuracy = 0.5008336112037346[0m


[32m2023-08-21 09:25:48.612[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 9>[0m:[36m25[0m - [1m[9] - SA task ViDeBERTA-xsmall at checkpoint-1755 with F1-score = 0.3203539823008849 and Accuracy = 0.48782927642547513[0m


[32m2023-08-21 09:25:58.170[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 9>[0m:[36m25[0m - [1m[10] - SA task ViDeBERTA-xsmall at checkpoint-1950 with F1-score = 0.37525604260548956 and Accuracy = 0.4914971657219073[0m


###### **Sentiment Analysis: ViDeberta base**

In [None]:
from transformers import TrainingArguments, Trainer
import numpy as np
from datasets import load_metric

training_arg_videberta_base = TrainingArguments("SA_Videberta_base"),

trainer_videberta_base = Trainer(
    model=load_mtl_model,
    args=training_args,
    train_dataset=uit_sa_ds['train'],
    eval_dataset=uit_sa_ds_test,
    compute_metrics=compute_metrics,
)

###### **Sentiment Analysis: PhoBERT base v2**

In [None]:
from transformers import TrainingArguments, Trainer
import numpy as np
from datasets import load_metric

training_arg_phobertv2 = TrainingArguments("SA_Phobertv2"),

trainer_phobertv2 = Trainer(
    model=load_mtl_model,
    args=training_args,
    train_dataset=uit_sa_ds['train'],
    eval_dataset=uit_sa_ds_test,
    compute_metrics=compute_metrics,
)

##### **MTL Model**

In [None]:
from transformers import EvalPrediction
import numpy as np
from datasets import load_metric
import evaluate


seqeval_metric = evaluate.load("seqeval")

accuracy_metric = evaluate.load("accuracy")
f1_metric = evaluate.load("f1")
precision_metric = evaluate.load("precision")
recall_metric = evaluate.load("recall")


def compute_metrics(p: EvalPrediction):

    preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions

    if preds.shape[1] == 2:
        average = 'binary'
    elif preds.shape[1] == 3:
        average = 'macro'
    else:
        raise NotImplementedError()

    if preds.ndim == 2:
        logits, labels = p

        predictions = np.argmax(preds, axis=1)

        accuracy = accuracy_metric.compute(predictions=predictions, references=labels)["accuracy"]
        precision = precision_metric.compute(predictions=predictions, references=labels, average=average)["precision"]
        recall = recall_metric.compute(predictions=predictions, references=labels, average=average)["recall"]
        f1 = f1_metric.compute(predictions=predictions, references=labels, average=average)["f1"]

        return {"accuracy": accuracy, "precision": precision, "recall": recall, "f1": f1}
    else:
        raise NotImplementedError()

Downloading builder script:   0%|          | 0.00/6.34k [00:00<?, ?B/s]

Downloading builder script:   0%|          | 0.00/4.20k [00:00<?, ?B/s]

Downloading builder script:   0%|          | 0.00/6.77k [00:00<?, ?B/s]

Downloading builder script:   0%|          | 0.00/7.55k [00:00<?, ?B/s]

Downloading builder script:   0%|          | 0.00/7.36k [00:00<?, ?B/s]

### **Inference**

In [None]:
from typing import List

def zsl_classifier(model: nn.Module,
                   premise: str,
                   candidate_labels: List,
                   hypothesis_template: str = 'Đây là một câu nói có nội dung về {}',
                   multi_label: bool = False):

    sequence_pairs = []
    sequence_pairs.extend([[premise, hypothesis_template.format(label)] for label in candidate_labels])

    model_outputs = []

    for sequence_pair, candidate_label in zip(sequence_pairs, candidate_labels):
        inputs_nli = tokenizer(sequence_pair[0],
                               sequence_pair[1],
                                padding='max_length',
                                max_length=512,
                                truncation=True,
                                return_tensors='pt')

        with torch.no_grad():
            outputs=load_mtl_model(**inputs_nli, task_ids=zsl_task_id)[0]
            model_outputs.append({'candidate_label': candidate_label,
                                  'sequence': premise,
                                  'logits': outputs})

    result = postprocess_nli(model_outputs, multi_label=multi_label)

    return result

result = zsl_classifier(load_mtl_model,
                        premise='nhiệt tình giảng dạy gần với sinh viên',
                        candidate_labels = ['nấu ăn', 'nhảy múa', 'giáo dục'],
                        multi_label=True)

result

{'sequence': 'nhiệt tình giảng dạy gần với sinh viên',
 'labels': ['giáo dục', 'nhảy múa', 'nấu ăn'],
 'scores': [0.7753003239631653, 0.6339335441589355, 0.4622208774089813]}

In [None]:
# Ex: "nhiệt tình giảng dạy gần gũi với sinh viên"
# Ex: "Tôi không thích đi du lịch lắm"

premise = input('Enter sentence: ')
logger.info(f'Sentence: {premise}')
candidate_labels = ['nấu ăn', 'nhảy múa', 'giáo dục']
hypothesises = [f'Đây là một câu nói có nội dung về {topic}' for topic in candidate_labels]



# Inputs for sentiment analysis
inputs_sa = tokenizer(premise,
                      padding='max_length',
                      max_length=512,
                      truncation=True,
                      return_tensors='pt')

with torch.no_grad():
    outputs_sa=load_mtl_model(**inputs_sa, task_ids=sa_task_id)[0]

probs_sa = torch.nn.functional.softmax(outputs_sa,dim=-1)
print(probs_sa)
pred_label_sa = torch.argmax(probs_sa,dim=-1).item()
print(pred_label_sa)
if pred_label_sa == 0:
    label_sa = 'negative'
else:
    label_sa = 'positive'

logger.success(f'Sentiment Analysis: {label_sa} - Probability: {float(probs_sa[0][pred_label_sa]):.3%}')

# model_outputs
# [{'candidate_label': 'du lịch', 'sequence': 'Tôi không thích đi d...ế giới lắm', 'logits': tensor([[-2.0524, -0... 2.3854]])},
# {'candidate_label': 'nấu ăn', 'sequence': 'Tôi không thích đi d...ế giới lắm', 'logits': tensor([[ 0.0509, -0... 0.2263]])},
# {'candidate_label': 'nhảy múa', 'sequence': 'Tôi không thích đi d...ế giới lắm', 'logits': tensor([[ 0.0869, -0...-0.0252]])}]

model_outputs = []

for hypothesis, candidate_label in zip(hypothesises, candidate_labels):

    # Inputs for nli
    inputs_nli = tokenizer(premise, hypothesis,
                           padding='max_length',
                           max_length=512,
                           truncation=True,
                           return_tensors='pt')

    with torch.no_grad():
        outputs=load_mtl_model(**inputs_nli, task_ids=zsl_task_id)[0]
        model_outputs.append({'candidate_label': candidate_label,
                              'sequence': premise,
                              'logits': outputs})

processed = postprocess_nli(model_outputs)

logger.success(f'Topic Identification -- Labels: {processed["labels"]} --> Scores: {processed["scores"]}')

Enter sentence: Tôi không thích đi làm đâu


[32m2023-08-16 15:51:18.293[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 5>[0m:[36m5[0m - [1mSentence: Tôi không thích đi làm đâu[0m
[32m2023-08-16 15:51:20.282[0m | [32m[1mSUCCESS [0m | [36m__main__[0m:[36m<cell line: 30>[0m:[36m30[0m - [32m[1mSentiment Analysis: negative - Probability: 99.974%[0m


tensor([[9.9974e-01, 2.6322e-04]])
0


[32m2023-08-16 15:51:27.908[0m | [32m[1mSUCCESS [0m | [36m__main__[0m:[36m<cell line: 56>[0m:[36m56[0m - [32m[1mTopic Identification -- Labels: ['nhảy múa', 'giáo dục', 'nấu ăn'] --> Scores: [0.5495253205299377, 0.2613559067249298, 0.18911874294281006][0m
