# Learning with few labels

In this notebook we try to see the effect of two things:
1) Semi-supervised learning
2) Label propagation

We work on the DBPedia dataset, which is a text classification task with 14 classes. We formulate this as a multi-label problem, meaning the last classification layer will have 14 dimensions, and the datapoint labels are one-hot encoded.

We split the available training data (560k) to 100 labeled training, 9900 labeled validation, and 550k "assumingly" unlabeled sets. We look at loss and accuracy (selecting the output with the highest logit value).

We try to compare four training scenarios:
1) A classifier which only uses labeled data;
2) A classifier using labeled data on top of a base model trained on unlabeled data using MLM training;
3) A classifier using labeled data, but also benefitting from label propagation on unlabeled data;
4) All together: A classifier using labeled data on top of a base model trained on unlabeled data using MLM training, and also benefitting from label propagation on unlabeled data.


## Importing requirements
You can also add an extra cell to install the needed requirements:
```
!pip install torch
!pip install scikit-learn
!pip install transformers
!pip install datasets
```

In [1]:
import random
import torch
import numpy as np
from collections import defaultdict
from datasets import list_datasets, load_dataset, concatenate_datasets
from transformers import AutoModelForSequenceClassification, AutoModelForMaskedLM, AutoTokenizer, AutoConfig, AutoModel
from transformers import DataCollatorForLanguageModeling, Trainer, TrainingArguments, EvalPrediction
from transformers.modeling_outputs import SequenceClassifierOutput
from tqdm import tqdm
from sklearn.semi_supervised import LabelPropagation

## Loading DBPedia dataset

In [None]:
# dataset = load_dataset('dbpedia_14', split='train')
dataset = load_dataset('nlu_evaluation_data', split='train')
print(dataset)
print(dataset.features)

## Creating a tokenizer
We use BERT base uncased.

In [2]:
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')

## Preprocessing the dataset
We process the dataset by tokenizin the text and one-hot encoding the labels

In [None]:
# dataset = dataset.map(lambda examples: tokenizer(examples['content'], truncation=True, max_length=256, padding='max_length'), batched=True)
dataset = dataset.map(lambda examples: tokenizer(examples['text'], truncation=True, max_length=256, padding='max_length'), batched=True)
dataset = dataset.map(lambda examples: {'labels': [1.0 if i == examples['label'] else 0.0 for i in range(dataset.features['label'].num_classes)]}, batched=False)
dataset.set_format(type='torch', columns=['input_ids', 'token_type_ids', 'attention_mask', 'labels'])

## Creating splits

In [None]:
# trainset, valset = torch.utils.data.random_split(dataset, [550100, 9900], generator=torch.Generator().manual_seed(42))
# trainset_labeled, trainset_unlabeled = torch.utils.data.random_split(trainset, [100, 550000], generator=torch.Generator().manual_seed(42))

trainset, valset = torch.utils.data.random_split(dataset, [20000, 5715], generator=torch.Generator().manual_seed(42))
trainset_labeled, trainset_unlabeled = torch.utils.data.random_split(trainset, [5000, 15000], generator=torch.Generator().manual_seed(42))


## Creating the MLM trained model
We use the "unlabeled" data to train a model using Masked Language Modeling (MLM). We will use this base model later in some experiments. The training take quite a few hours of GPU.

In [None]:
mlm_model = AutoModelForMaskedLM.from_pretrained('bert-base-uncased')

mlm_data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer, mlm=True, mlm_probability=0.15
)

mlm_training_args = TrainingArguments(
    output_dir="./models/nlu_evaluation_data/mlm",
    overwrite_output_dir=True,
    num_train_epochs=5,
    per_device_train_batch_size=8,
    save_strategy='epoch',
    save_total_limit=5,
    prediction_loss_only=False,
    load_best_model_at_end=False,
    logging_first_step=True,
    logging_steps=100,
)

mlm_trainer = Trainer(
    model=mlm_model,
    args=mlm_training_args,
    data_collator=mlm_data_collator,
    train_dataset=trainset_unlabeled,
)

In [None]:
mlm_trainer.train()

## Training a classifier only on labeled data

In [None]:
classifier_config = AutoConfig.from_pretrained('bert-base-uncased', num_labels=dataset.features['label'].num_classes)
classifier_model = AutoModelForSequenceClassification.from_pretrained('bert-base-uncased', config=classifier_config)

# If you want to freeze BERT weights
for param in classifier_model.bert.bert.parameters():
    param.requires_grad = False
    
def compute_metrics(p: EvalPrediction):
    accuracy = np.mean(np.argmax(p.predictions, axis=1) == np.argmax(p.label_ids, axis=1))
    return {'accuracy': accuracy}
    
classifier_training_args = TrainingArguments(
    output_dir='./models/nlu_evaluation_data/classifier',
    overwrite_output_dir=True,
    num_train_epochs=10,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=64,
    save_strategy='epoch',
    save_total_limit=5,
    prediction_loss_only=False,
    load_best_model_at_end=False,
    logging_first_step=True,
    logging_steps=100,
    evaluation_strategy='epoch'
)

classifier_trainer = Trainer(
    model=classifier_model,
    args=classifier_training_args,
    train_dataset=trainset_labeled,
    eval_dataset=valset,
    compute_metrics=compute_metrics,
)

In [None]:
classifier_trainer.train()

## Training a classifier on labeled data on top of MLM

In [None]:
classifier_mlm_config = AutoConfig.from_pretrained('bert-base-uncased', num_labels=dataset.features['label'].num_classes)
classifier_mlm_model = AutoModelForSequenceClassification.from_pretrained('./models/nlu_evaluation_data/mlm/checkpoint-9375', config=classifier_mlm_config)

def compute_metrics(p: EvalPrediction):
    accuracy = np.mean(np.argmax(p.predictions, axis=1) == np.argmax(p.label_ids, axis=1))
    return {'accuracy': accuracy}
    
classifier_mlm_training_args = TrainingArguments(
    output_dir='./models/nlu_evaluation_data/classifier_on_mlm',
    overwrite_output_dir=True,
    num_train_epochs=10,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=64,
    save_strategy='epoch',
    save_total_limit=5,
    prediction_loss_only=False,
    load_best_model_at_end=False,
    logging_first_step=True,
    logging_steps=100,
    evaluation_strategy='epoch'
)

classifier_mlm_trainer = Trainer(
    model=classifier_mlm_model,
    args=classifier_mlm_training_args,
    train_dataset=trainset_labeled,
    eval_dataset=valset,
    compute_metrics=compute_metrics,
)

In [None]:
classifier_mlm_trainer.train()

## A simple label agreement model

In [None]:
class BertForAgreement(torch.nn.Module):
    def __init__(self):
        super().__init__()
        bert_config = AutoConfig.from_pretrained('bert-base-uncased')
        self.bert = AutoModel.from_pretrained('bert-base-uncased')
        self.dropout = torch.nn.Dropout(0.1)
        self.linear = torch.nn.Linear(bert_config.hidden_size, 1)
        self.sigmoid = torch.nn.Sigmoid()
        
    def forward(self, dict_1, dict_2):
        labels = (1.0 - torch.sum(dict_1['labels'] * dict_2['labels'], dim=1)).unsqueeze(1)
        bert_outputs_1 = self.bert(input_ids=dict_1['input_ids'], attention_mask=dict_1['attention_mask'], token_type_ids=dict_1['token_type_ids'])
        bert_outputs_2 = self.bert(input_ids=dict_2['input_ids'], attention_mask=dict_2['attention_mask'], token_type_ids=dict_2['token_type_ids'])
        aggregated = (bert_outputs_1.pooler_output - bert_outputs_2.pooler_output) ** 2
        logits = self.linear(self.dropout(aggregated))
        output = self.sigmoid(logits)
        loss = torch.mean((output - labels) ** 2)
        return SequenceClassifierOutput(
            loss=loss,
            logits=logits,
        )

@dataclass
class DataCollatorForPairing:
    def __call__(self, batch):
        # First rename the keys to *_1
        for k in batch.keys():
            batch[k + '_1'] = batch.pop(k)
        # For each item create a positive and negative pair
        # 'input_ids', 'token_type_ids', 'attention_mask', 'labels'
        for index, label in enumerate(batch['labels_1']):
            indices_without_this = list(range(0, len(batch['labels_1'])))
            indices_without_this.remove(index) 
            labels_without_this = batch['labels_1'][indices_without_this]
            input_ids_without_this = batch['input_ids_1'][indices_without_this]
            token_type_ids_without_this = batch['token_type_ids_1'][indices_without_this]
            attention_mask_without_this = batch['attention_mask_1'][indices_without_this]
            # Negative pair
            is_other_label = torch.any(labels_without_this - label, dim=1)
            random_other_index = random.choice(torch.where(is_other_label)[0])
            
    
agreement_model = BertForAgreement()

## Graph Agreement Model (GAM) based training

In [None]:
# loop

# Training the agreement G model using L

# Training classification model F using L and predictions of G on U

# Extend L using the most (M = 200) confident predictions of F on U

In [3]:
def pair_up(batch, match):
    result = {
        'text_other': [],
        'match': [],
    }
    label_to_indices = defaultdict(list)
    for index, label in enumerate(batch['label']):
        label_to_indices[label].append(index)
    for label, text in zip(batch['label'], batch['text']):
        if match == 'positive':
            random_positive_index = random.choice(label_to_indices[label])
            result['match'].append(1)
            result['text_other'].append(batch['text'][random_positive_index])
        if match == 'negative':
            labels_wihtout_this = [label for label in label_to_indices.keys() if label_to_indices[label]]
            labels_wihtout_this.remove(label)
            # print(labels_wihtout_this)
            if not labels_wihtout_this:
                print(batch)
            random_negative_label = random.choice(labels_wihtout_this)
            random_negative_index = random.choice(label_to_indices[random_negative_label])
            result['match'].append(0)
            result['text_other'].append(batch['text'][random_negative_index])            
    return result
        

In [4]:
paired_dataset_positive = load_dataset('nlu_evaluation_data', split='train')
paired_dataset_positive = paired_dataset_positive.map(lambda examples: pair_up(examples, match='positive'), batched=True)

Using custom data configuration default
Reusing dataset nlu_evaluation_data (/home/beast/.cache/huggingface/datasets/nlu_evaluation_data/default/1.1.0/0416a5876d8240bd571f2bc2ad421cf6e6e88d938f8dcb5fd87b5af6033d6282)
Loading cached processed dataset at /home/beast/.cache/huggingface/datasets/nlu_evaluation_data/default/1.1.0/0416a5876d8240bd571f2bc2ad421cf6e6e88d938f8dcb5fd87b5af6033d6282/cache-ee9c96cf5cb16303.arrow


In [5]:
paired_dataset_negative = load_dataset('nlu_evaluation_data', split='train')
paired_dataset_negative = paired_dataset_negative.shuffle(seed=42)
paired_dataset_negative = paired_dataset_negative.map(lambda examples: pair_up(examples, match='negative'), batched=True)

Using custom data configuration default
Reusing dataset nlu_evaluation_data (/home/beast/.cache/huggingface/datasets/nlu_evaluation_data/default/1.1.0/0416a5876d8240bd571f2bc2ad421cf6e6e88d938f8dcb5fd87b5af6033d6282)
Loading cached shuffled indices for dataset at /home/beast/.cache/huggingface/datasets/nlu_evaluation_data/default/1.1.0/0416a5876d8240bd571f2bc2ad421cf6e6e88d938f8dcb5fd87b5af6033d6282/cache-21da13cc2f2c4679.arrow
Loading cached processed dataset at /home/beast/.cache/huggingface/datasets/nlu_evaluation_data/default/1.1.0/0416a5876d8240bd571f2bc2ad421cf6e6e88d938f8dcb5fd87b5af6033d6282/cache-6f7a5b5440e912b0.arrow


In [6]:
paired_dataset = concatenate_datasets([paired_dataset_positive, paired_dataset_negative])

In [7]:
paired_dataset

Dataset({
    features: ['label', 'match', 'scenario', 'text', 'text_other'],
    num_rows: 51430
})

In [8]:
paired_dataset = paired_dataset.map(lambda examples: tokenizer(examples['text'], examples['text_other'], truncation=True, max_length=256, padding='max_length'), batched=True)
paired_dataset = paired_dataset.map(lambda examples: {'labels': examples['match']}, batched=True)
paired_dataset.set_format(type='torch', columns=['input_ids', 'token_type_ids', 'attention_mask', 'labels'])

Loading cached processed dataset at /home/beast/.cache/huggingface/datasets/nlu_evaluation_data/default/1.1.0/0416a5876d8240bd571f2bc2ad421cf6e6e88d938f8dcb5fd87b5af6033d6282/cache-dbc19ddc7135b376.arrow
Loading cached processed dataset at /home/beast/.cache/huggingface/datasets/nlu_evaluation_data/default/1.1.0/0416a5876d8240bd571f2bc2ad421cf6e6e88d938f8dcb5fd87b5af6033d6282/cache-22c3507862b89017.arrow


In [9]:
paired_dataset

Dataset({
    features: ['attention_mask', 'input_ids', 'label', 'labels', 'match', 'scenario', 'text', 'text_other', 'token_type_ids'],
    num_rows: 51430
})

In [10]:
paired_trainset, paired_valset = torch.utils.data.random_split(paired_dataset, [40000, 11430], generator=torch.Generator().manual_seed(42))


In [None]:
model = AutoModelForSequenceClassification.from_pretrained('bert-base-uncased')

# If you want to freeze BERT weights
# for param in model.bert.parameters():
#     param.requires_grad = False
    
def compute_metrics(p: EvalPrediction):
    accuracy = np.mean(np.argmax(p.predictions, axis=1) == p.label_ids)
    return {'accuracy': accuracy}
    
training_args = TrainingArguments(
    output_dir='./models/nlu_evaluation_data/test',
    overwrite_output_dir=True,
    num_train_epochs=1,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    save_strategy='epoch',
    save_total_limit=5,
    prediction_loss_only=False,
    load_best_model_at_end=False,
    logging_first_step=True,
    logging_steps=1000,
    evaluation_strategy='steps'
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=paired_trainset,
    eval_dataset=paired_valset,
    compute_metrics=compute_metrics,
)

trainer.train()

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at

Step,Training Loss,Validation Loss
