In [None]:
#!pip3 install --user datasets

In [None]:
#!pip3 install --user transformers -U

In [63]:
import os

import torch
import torch.nn as nn

from datasets import load_dataset

from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification
from transformers import get_linear_schedule_with_warmup

In [None]:
config = {
    "BATCH_SIZE": 32,
    "NUM_LABEL": 3,
    "NUM_EPOCH": 10,
    "CATEGORY": {
        0: "contradiction",
        1: "neutral",
        2: "entailment"
    }
}
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

In [None]:
snli = load_dataset("snli")
snli = snli.filter(lambda example: example['label'] != -1)

In [None]:
tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-cased')

In [None]:
def preprocess(dataset, tokenizer):
    
    def encode(data):
        return tokenizer(data["hypothesis"], data["premise"], truncation=True, padding='max_length')
    
    dataset_pre_processed = dataset.map(encode, batched=True, batch_size=config["BATCH_SIZE"], num_proc=os.cpu_count())
    dataset_pre_processed.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label']) # The model needs tensor as inputs
    dataset_pre_processed = dataset_pre_processed.rename_column("label", "labels") # The model takes as arg labels not label
    
    return dataset_pre_processed["train"], dataset_pre_processed["validation"], dataset_pre_processed["test"] 

In [None]:
train_data, validation_data, test_data = preprocess(dataset=snli, tokenizer=tokenizer)

In [None]:
train_data_loader = torch.utils.data.DataLoader(dataset=train_data,
                                              batch_size=config["BATCH_SIZE"],
                                              shuffle=False)
validation_data_loader = torch.utils.data.DataLoader(dataset=validation_data,
                                              batch_size=config["BATCH_SIZE"],
                                              shuffle=False)
test_data_loader = torch.utils.data.DataLoader(dataset=test_data,
                                              batch_size=config["BATCH_SIZE"],
                                              shuffle=False)

In [None]:
def train(data_loader, model, optimizer, scheduler, device, verbose=False):
    model.train()
    total_number, total_loss = 0, 0.0
    for i, data in enumerate(data_loader):
        if verbose and i%100==0:
            print(f'\t[Training] Progress: {i}/{len(data_loader)}')
        inputs = {
            "input_ids": data["input_ids"].to(device),
            "labels": data["labels"].to(device),
            "attention_mask": data["attention_mask"].to(device)
        }
        outputs = model(inputs["input_ids"], inputs["attention_mask"], labels=inputs["labels"])
        loss = outputs.loss
        loss.backward()
        total_number += inputs["input_ids"].shape[0]
        total_loss += inputs["input_ids"].shape[0] * loss.item()
        
        optimizer.step()
        optimizer.zero_grad()
        
        if i == 1000:
            break
        
    return total_loss/total_number

In [None]:
def test(data_loader, model, device, verbose=False):
    model.eval()
    total_number, total_loss = 0, 0.0
    with torch.no_grad():
        for i, data in enumerate(data_loader):
            if verbose and i%100==0:
                print(f'\t[Testing] Progress: {i}/{len(data_loader)}')
            inputs = {
                "input_ids": data["input_ids"].to(device),
                "labels": data["labels"].to(device),
                "attention_mask": data["attention_mask"].to(device)
            }
            outputs = model(inputs["input_ids"], inputs["attention_mask"], labels=inputs["labels"])
            loss = outputs.loss
            total_number += inputs["input_ids"].shape[0]
            total_loss += inputs["input_ids"].shape[0] * loss.item()

    return total_loss/total_number

In [None]:
def train_test(model, train_data_loader, validation_data_loader, test_data_loader, num_epoch, optimizer, scheduler, device, verbose=False):
    for i in range(num_epoch):
        print(f'[Epoch] Epoch number: {i}')
        train_loss = train(train_data_loader, model, optimizer, 
                           scheduler, device, verbose)
        print(f'[Train loss] Train Loss: {train_loss}')
        validation_loss = test(validation_data_loader, model=model, device=device, verbose=verbose)
        print(f'[Validation loss] Validation Loss: {validation_loss}')
        scheduler.step()
    
    test_loss = test(test_data_loader, model, device, verbose=verbose)
    print(f'[Test loss] Test Loss: {test_loss}')

In [None]:
model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-cased', num_labels=config["NUM_LABEL"])
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
scheduler = get_linear_schedule_with_warmup(optimizer, 5, 2)

In [None]:
model.to(device)
for param in model.base_model.parameters():
    param.requires_grad = False

In [None]:
train_test(model, train_data_loader, validation_data_loader, test_data_loader, config["NUM_EPOCH"], optimizer, scheduler, device, verbose=True)

In [None]:
def get_prediction(model, tokenizer, hypothesis, premise, device="cpu"):
    """
        Takes as input the model, a tokenizer, an hypothesis and a premise (by default it will be used on cpu).
        It returns the probabilities for each category.
        It also displays the category.
    """
    inputs = tokenizer(hypothesis, premise, truncation=True, padding='max_length')
    model.eval()
    with torch.no_grad():
        inputs = {
                "input_ids": data["input_ids"].to(device),
                "labels": data["labels"].to(device),
                "attention_mask": data["attention_mask"].to(device)
            }
        outputs = model(inputs["input_ids"], inputs["attention_mask"], labels=inputs["labels"])
        
        logits = outputs.logits
        softmax = nn.Softmax(dim=0)
        probabilities = softmax(logits)[0].items()
        index_max = torch.argmax(probabilities, dim=0)
        prediction = probabilities[index_max]
        
        print(f'{hypothesis} and {premise} are: {config["CATEGORY"][index_max]} with a probability of {prediction}')

    return probabilities

## Example

In [None]:
model = torch.load("MODEL_PATH", map_location=device)
model.to(device)

In [None]:
hypothesis = "A soccer game with multiple males playing."
premise = "Some men are playing a sport."
prob = get_prediction(model, tokenizer, hypothesis, premise, device=device)

## VizViz

In [None]:
!pip3 install bertviz --user

In [None]:
from bertviz import head_view

In [None]:
#model = torch.load("./best_model_aug_V2.pt", map_location=device) # <- Marche pas car pas output attention
model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-cased', num_labels=config["NUM_LABEL"], output_attentions=True)
model.to(device)

In [None]:
def viz_viz(hypothesis, premise, model, tokenizer, device):
    output_tokenizer = tokenizer(hypothesis, premise, return_tensors="pt", add_special_tokens=True)
    input_ids = output_tokenizer["input_ids"][0].tolist()
    output_tokenizer = output_tokenizer.to(device)
    output_model = model(**output_tokenizer)
    try:
        attention = output_model["attentions"]
    except KeyError:
        logger.error("Model must ouput attentions")
        return
    input_token = tokenizer.convert_ids_to_tokens(input_ids)
    head_view(attention, input_token)

In [None]:
hypothesis = ""
premise = ""

In [None]:
viz_viz(hypothesis, premise, model, tokenizer, device)