In [None]:
import gc
import os
from dotenv import load_dotenv
from random import choices
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import torch
from torch.utils.data import DataLoader
from torchinfo import summary
from tqdm.notebook import tqdm_notebook
from transformers import BertTokenizer, BertForSequenceClassification
from transformers import get_linear_schedule_with_warmup
from datasets import Dataset, concatenate_datasets
import evaluate

from sklearn.model_selection import StratifiedKFold
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import precision_score, recall_score, f1_score, confusion_matrix
from sklearn.utils import resample
from sklearn.utils.class_weight import compute_class_weight

tqdm_notebook.pandas()

In [None]:
%load_ext autotime
load_dotenv()

HF_TOKEN = os.getenv("HF_TOKEN")

In [None]:
device = "mps" # change to "cuda" if needed

In [None]:
tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-cased', token=HF_TOKEN)
pretrained = BertForSequenceClassification.from_pretrained("bert-base-multilingual-cased", token=HF_TOKEN).to(device)

In [None]:
summary(pretrained)

In [None]:
# for param in pretrained.parameters():
#     param.requires_grad = False

for layer in pretrained.bert.encoder.layer[:6]:
    for param in layer.parameters():
        param.requires_grad = False

In [None]:
summary(pretrained)

In [None]:
data_df = pd.read_parquet("dataset/train_df.parquet")
data_df = data_df[data_df["language"].isin(["English", "Korean"])]
data_df["sentence"] = data_df["title"] + " || " + data_df["content"]
data_df["stratified_col"] = data_df["impact_length_idx"].astype(str) + data_df["language"] # for train/test split
data_df["resample_col"] = data_df["stratified_col"] # for resampling
esg_dataset = Dataset.from_pandas(data_df, preserve_index=True)

In [None]:
esg_dataset = esg_dataset.class_encode_column("impact_length_idx") # encode label
train_valid = esg_dataset.class_encode_column("stratified_col").train_test_split(test_size=0.25, 
                                                                                 stratify_by_column="stratified_col") # encode for splits
train_dataset = train_valid["train"]
valid_dataset = train_valid["test"]

In [None]:
train_idx = sorted(train_dataset['__index_level_0__'])
class_counts = data_df.loc[train_idx]['resample_col'].value_counts()
majority_class = class_counts.idxmax()
majority_count = class_counts.max()

def upsample(dataset, majority_class, majority_count):
    # upsample every class to the same size of majority class
    
    upsampled_datasets = []

    for class_label in set(dataset['resample_col']):
        class_dataset = dataset.filter(lambda x: x["resample_col"] == class_label)
        
        if majority_class == class_label:   
            upsampled_datasets.append(class_dataset)
        else:
            upsampled = class_dataset.select(choices(range(len(class_dataset)), k=majority_count))
            upsampled_datasets.append(upsampled)
            
    return concatenate_datasets(upsampled_datasets)

resampled_train = upsample(train_dataset, majority_class,majority_count)

In [None]:
def collate_fn(data):
    sentences = [i["sentence"] for i in data]
    labels = torch.tensor([i["impact_length_idx"] for i in data]).to(device)

    data = tokenizer.batch_encode_plus(sentences,
                                       truncation=True,
                                       padding='max_length',
                                       max_length=400,
                                       return_tensors='pt',
                                       return_length=True)

    input_ids = data['input_ids'].to(device)
    attention_mask = data['attention_mask'].to(device)
    token_type_ids = data['token_type_ids'].to(device)

    return input_ids, attention_mask, token_type_ids, labels

train_loader = torch.utils.data.DataLoader(dataset=resampled_train,
                                    batch_size=16,
                                    collate_fn=collate_fn,
                                    shuffle=True,
                                    drop_last=True)

loader_valid = torch.utils.data.DataLoader(dataset=valid_dataset,
                                            batch_size=16,
                                            collate_fn=collate_fn,
                                            shuffle=True,
                                            drop_last=True)

In [None]:
class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = torch.nn.Linear(768, 3).to(device)
        print(summary(self.fc))

    def forward(self, input_ids, attention_mask, token_type_ids):
        out = pretrained(input_ids=input_ids,
                    attention_mask=attention_mask,
                    token_type_ids=token_type_ids,
                    output_hidden_states=True)

        out = self.fc(out["hidden_states"][-1][:, 0])

        out = out.softmax(dim=1)

        return out

model = Model()


In [None]:
#5e^4, 0.01
criterion = torch.nn.CrossEntropyLoss()

num_epochs = 3

optimizer = torch.optim.AdamW(model.parameters(), 
                              lr=2e-5, 
                              weight_decay=0.01)

# Total number of training steps
num_training_steps = num_epochs * len(train_loader)

# Scheduler including warm-up
scheduler = get_linear_schedule_with_warmup(optimizer, 
                                            num_warmup_steps=num_training_steps * 0.1, 
                                            num_training_steps=num_training_steps)

model.train()
for epoch in range(num_epochs):
    print(f"Starting epoch {epoch+1}/{num_epochs}")
    for i, (input_ids, attention_mask, token_type_ids, labels) in enumerate(train_loader):
        optimizer.zero_grad()
        
        out = model(input_ids=input_ids,
                    attention_mask=attention_mask,
                    token_type_ids=token_type_ids)
        
        gc.collect()
        torch.mps.empty_cache()
        
        loss = criterion(out, labels)
        loss.backward()
        
        gc.collect()
        torch.mps.empty_cache()
        
        optimizer.step()
        scheduler.step()
        

        if i % 25 == 0:
            out = out.argmax(dim=1)
            accuracy = (out == labels).sum().item() / len(labels)

            print(f"Epoch {epoch+1}/{num_epochs}, Step {i}, Loss: {loss.item()}, Accuracy: {accuracy}")
        
        gc.collect()
        torch.mps.empty_cache()
        
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        out = model(input_ids=input_ids,
                    attention_mask=attention_mask,
                    token_type_ids=token_type_ids)

    out = out.argmax(dim=1)
    correct += (out == labels).sum().item()
    total += len(labels)
    accuracy = correct / total
    
    print("epoch valid accuracy:", accuracy)
    
    gc.collect()
    torch.mps.empty_cache()
        
    model.train()


In [None]:
def test():
    model.eval()
    correct = 0
    total = 0
    all_labels = []
    all_preds = []

    for i, (input_ids, attention_mask, token_type_ids, labels) in enumerate(loader_valid):

        with torch.no_grad():
            out = model(input_ids=input_ids,
                        attention_mask=attention_mask,
                        token_type_ids=token_type_ids)

        out = out.argmax(dim=1)
        correct += (out == labels).sum().item()
        total += len(labels)

        all_labels.extend(labels.cpu().numpy())
        all_preds.extend(out.cpu().numpy())

    accuracy = correct / total
    precision = precision_score(all_labels, all_preds, average='macro')
    recall = recall_score(all_labels, all_preds, average='macro')
    f1 = f1_score(all_labels, all_preds, average='macro')

    print(f'Accuracy: {accuracy}')
    print(f'Precision: {precision}')
    print(f'Recall: {recall}')
    print(f'F1 Score: {f1}')
    
    cm = confusion_matrix(all_labels, all_preds)

    plt.figure(figsize=(3,3))
    sns.heatmap(cm, annot=True, fmt="0", linewidths=.5,
                square = True, cmap = "Blues")
    plt.ylabel("Actual label")
    plt.xlabel("Predicted label")
    plt.xticks(ticks=np.arange(3) + 0.5, labels=[0, 1, 2], rotation=45, ha='right')
    plt.yticks(ticks=np.arange(3) + 0.5, labels=[0, 1, 2], rotation=0)

    all_sample_title = "Confusion Matrix"
    plt.title(all_sample_title, size = 15)

test()
