In [None]:
from transformers import AutoModelForSequenceClassification, AutoTokenizer, EarlyStoppingCallback, DataCollatorWithPadding, Trainer, TrainingArguments
from datasets import Dataset
import pandas as pd
from sklearn.metrics import f1_score
import torch
from tqdm import tqdm
from torch.nn.functional import softmax
import numpy as np

tqdm.pandas()

import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader

ModuleNotFoundError: No module named 'transformers'

In [None]:
device = "mps"

In [None]:
def tokenize(examples):
    return tokenizer(
        examples["text"],
        padding="max_length",
        max_length=128,
        truncation=True,
    )

def predict(texts, model, tokenizer):
    model.eval()
    predictions = []
    probs = []
    with torch.no_grad():
        for text in tqdm(texts):
            inputs = tokenizer.encode_plus(
                text,
                add_special_tokens=True,
                max_length=128,
                return_token_type_ids=False,
                padding="max_length",
                truncation=True,
                return_attention_mask=True,
                return_tensors="pt",
            )
            inputs = {
                key: value.to(device) for key, value in inputs.items()
            }
            outputs = model(
                input_ids=inputs["input_ids"],
                attention_mask=inputs["attention_mask"],
            )
            logit = outputs.logits
            prob = (
                softmax(logit, dim=1).cpu().numpy()
            ) 
            predictions.append(np.argmax(prob))
            probs.append(prob)
    if device == "cuda":
        torch.cuda.empty_cache()
    return predictions, probs

In [None]:
tokenizer = AutoTokenizer.from_pretrained("path_to_model_or_model_id")
model = AutoModelForSequenceClassification.from_pretrained("path_to_model_or_model_id", num_labels = 2).to(device)

In [None]:
train, valid, test = pd.read_csv("path_to_train_data"), pd.read_csv("path_to_validation_data"), pd.read_csv("path_to_test_data")

train_dataset = Dataset.from_dict(
    {
        "text": list(train["text"]),
        "label": list(train["locale"]),
    }
)

valid_dataset = Dataset.from_dict(
    {
        "text": list(valid["text"]),
        "label": list(valid["locale"]),
    }
)

train_dataset = train_dataset.map(tokenize, batched=True)
valid_dataset = valid_dataset.map(tokenize, batched=True)

Map:   0%|          | 0/52746 [00:00<?, ? examples/s]

Map:   0%|          | 0/17582 [00:00<?, ? examples/s]

In [7]:
datacollator = DataCollatorWithPadding(tokenizer=tokenizer)
callback = EarlyStoppingCallback(4, 0.01)

In [None]:
def compute_metrics(p):
    preds = torch.tensor(p.predictions)
    
    probs = softmax(preds, dim=1).numpy()
    predicted_classes = np.argmax(probs, axis=1)
    
    return {
        "f1_score": f1_score(p.label_ids, predicted_classes, average="weighted")
    }

train_args = TrainingArguments(
    num_train_epochs=30,
    learning_rate=3e-5,
    eval_strategy="steps",
    logging_steps=400,
    save_steps=400,
    output_dir="path_to_output",
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    load_best_model_at_end=True
)

trainer = Trainer(
    model=model,
    args=train_args,
    train_dataset=train_dataset,
    eval_dataset=valid_dataset,
    callbacks=[callback],
    data_collator=datacollator,
    compute_metrics=compute_metrics
)

Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [9]:
trainer.train()

Step,Training Loss,Validation Loss,F1 Score
400,0.2031,0.134468,0.942235
800,0.1374,0.113566,0.951723
1200,0.1209,0.110407,0.954773
1600,0.1149,0.09961,0.956991
2000,0.0707,0.102611,0.959673
2400,0.0618,0.122075,0.959924
2800,0.0647,0.119699,0.960179
3200,0.0641,0.119592,0.956104


TrainOutput(global_step=3200, training_loss=0.10467703104019165, metrics={'train_runtime': 404.2132, 'train_samples_per_second': 391.472, 'train_steps_per_second': 12.239, 'total_flos': 3390436834900992.0, 'train_loss': 0.10467703104019165, 'epoch': 1.9405700424499697})

In [10]:
outs, probs = list(predict(list(test["text"]), model, tokenizer))

100%|██████████| 17583/17583 [00:58<00:00, 298.65it/s]


In [12]:
f1_score(outs, list(test["locale"]), average="macro")

np.float64(0.9533011502662643)

In [None]:
locs = {
    # Inner-circle vs. Outer-circle
    "en-AU": 0,
    "en-UK": 0,
    "en-IN": 1,
}
for locale in ["AU", "IN", "UK"]:
    for dom in ["reddit", "google"]:
        for split in ["train", "valid", "test"]:
            temp = pd.read_json("path_to_var_data", lines=True)

            temp["locale"] = [locs["en-"+locale]]*len(temp)
        
            preds, probs= predict(list(temp["text"]), model, tokenizer)  
    
            probs = [prob[0][locs["en-"+locale]] for prob in probs]
            outs = pd.DataFrame({
                "loc": preds,
                "probs": probs
                }
            )
            outs.to_csv("path_to_results")

            print(f'Loc:{locale} | Dom:{dom} | Split:{split} | F1:{f1_score(list(outs["loc"]), list(temp["locale"]), average="weighted")}')
        
            

100%|██████████| 1763/1763 [00:06<00:00, 287.74it/s]


Loc:AU | Dom:reddit | Split:train | F1:0.9199178753393411


100%|██████████| 241/241 [00:00<00:00, 289.85it/s]


Loc:AU | Dom:reddit | Split:valid | F1:0.9320682577018967


100%|██████████| 501/501 [00:01<00:00, 290.07it/s]


Loc:AU | Dom:reddit | Split:test | F1:0.9346246282944315


100%|██████████| 946/946 [00:03<00:00, 287.62it/s]


Loc:AU | Dom:google | Split:train | F1:0.9873329652621609


100%|██████████| 130/130 [00:00<00:00, 288.64it/s]


Loc:AU | Dom:google | Split:valid | F1:1.0


100%|██████████| 270/270 [00:00<00:00, 287.42it/s]


Loc:AU | Dom:google | Split:test | F1:0.9723087573554863


100%|██████████| 1685/1685 [00:05<00:00, 294.92it/s]


Loc:IN | Dom:reddit | Split:train | F1:0.6898577897862456


100%|██████████| 230/230 [00:00<00:00, 294.33it/s]


Loc:IN | Dom:reddit | Split:valid | F1:0.6516573396470081


100%|██████████| 479/479 [00:01<00:00, 294.18it/s]


Loc:IN | Dom:reddit | Split:test | F1:0.6647467882754161


100%|██████████| 1648/1648 [00:05<00:00, 290.71it/s]


Loc:IN | Dom:google | Split:train | F1:0.9036984689481744


100%|██████████| 225/225 [00:00<00:00, 290.96it/s]


Loc:IN | Dom:google | Split:valid | F1:0.9011494252873563


100%|██████████| 469/469 [00:01<00:00, 290.22it/s]


Loc:IN | Dom:google | Split:test | F1:0.9176345715033853


100%|██████████| 1007/1007 [00:03<00:00, 292.83it/s]


Loc:UK | Dom:reddit | Split:train | F1:0.9144587072496614


100%|██████████| 138/138 [00:00<00:00, 291.14it/s]


Loc:UK | Dom:reddit | Split:valid | F1:0.8820891441071916


100%|██████████| 287/287 [00:00<00:00, 292.08it/s]


Loc:UK | Dom:reddit | Split:test | F1:0.9120537473648652


100%|██████████| 1817/1817 [00:06<00:00, 285.85it/s]


Loc:UK | Dom:google | Split:train | F1:0.9925763194352191


100%|██████████| 248/248 [00:00<00:00, 285.07it/s]


Loc:UK | Dom:google | Split:valid | F1:0.9939556858911698


100%|██████████| 517/517 [00:01<00:00, 286.20it/s]

Loc:UK | Dom:google | Split:test | F1:0.9855167267238478





In [None]:
metric = {
    "loc":[],
    "dom": [],
    "split": [],
    "prob": []
         }

for dom in ["reddit", "google"]:
    for loc in ["au", "in", "uk"]:
        probs = []
        for split in ["train", "test", "valid"]:
            pred = pd.read_csv("path_to_results")

            metric["loc"] = metric["loc"] + [f"en-{loc.upper()}"]*len(pred)
            metric["dom"] = metric["dom"] + [dom]*len(pred)
            metric["split"] = metric["split"] + [split]*len(pred)
            metric["prob"] = metric["prob"] + list(pred.probs)

            probs.extend(list(pred.probs))
        
pd.DataFrame(metric)[["loc", "dom", "prob"]].groupby(["loc", "dom"]).mean()

Unnamed: 0_level_0,Unnamed: 1_level_0,prob
loc,dom,Unnamed: 2_level_1
en-AU,google,0.990104
en-AU,reddit,0.94491
en-IN,google,0.938977
en-IN,reddit,0.775043
en-UK,google,0.992911
en-UK,reddit,0.933627
