# Articles classification model

## Libraries and parameters

In [None]:
!pip install transformers

In [None]:
import os
import json
import time
import random
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn as nn
import transformers
from transformers import (
    AutoModel, 
    BertTokenizer, 
    BertForSequenceClassification,
    AdamW,
    LongformerTokenizerFast,
    LongformerModel, 
    LongformerConfig
)
from sklearn.metrics import  (
    f1_score, 
    accuracy_score, 
    multilabel_confusion_matrix, 
    confusion_matrix
)
from sklearn.model_selection import train_test_split, StratifiedKFold
from tqdm.auto import tqdm

os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [None]:
VER = 'vxl1'
DATA_PATH = '/home/jovyan/__LABELING/data'
MDLS_PATH = f'/home/jovyan/__LABELING/models_{VER}'
CONFIG = {
    'data_file': 'project-6-at-2024-04-01-08-12-9704f322.json',
    'device': 'cuda' if torch.cuda.is_available() else 'cpu',
    'bbone': 'kazzand/ru-longformer-tiny-16384', # `kazzand/ru-longformer-base-4096` `kazzand/ru-longformer-tiny-16384`
    'dropout': .3,
    'targets_type': 'general', # targets can be 'all', explicit', 'general' 
    'targets': [
        'target_0_explicit', 
        'target_3_explicit', 
        'target_3_general', 
        'target_4_explicit',
        'target_4_general',
        'target_7_explicit',
        'target_7_general', 
        'target_11_explicit', 
        'target_11_general',
        'target_12_explicit', 
        'target_12_general'
    ], 
    'targets_description': {
        'target_0': 'ЦУР отсутствуют', 
        'target_3': 'ЦУР 3 - Хорошее здоровье и благополучие',
        'target_4': 'ЦУР 4 - Качественное образование',
        'target_7': 'ЦУР 7 - Недорогостоящая и чистая энергия', 
        'target_11': 'ЦУР 11 - Устойчивые города',
        'target_12': 'ЦУР 12 - Ответственное потребление и производство', 
    },
    'pos_weight': False,
    'folds': 5,
    'max_seq_len': 10240, # 4096 10240
    'batch_size': 6,
    'num_workers': 4,
    'acc': False,
    'epochs': 50,
    'lr': 2e-5, # default `2e-5`
    'patience': 5,
    'seed': 23
}
if not os.path.exists(MDLS_PATH):
    os.mkdir(MDLS_PATH)
with open(f'{MDLS_PATH}/config.json', 'w') as file:
    json.dump(CONFIG, file)

def seed_all(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.benchmark = True
        torch.backends.cudnn.deterministic = False

seed_all(CONFIG['seed'])

## Data

In [None]:
with open(f'{DATA_PATH}/{CONFIG["data_file"]}') as file:
    data = json.load(file)

print('total records:', len(data))

In [None]:
def last_ann(anns):
    return sorted(anns, key=lambda d: d['updated_at'])[-1]

ds = []
for d in data:
    dd = {}
    dd['idart'] = d['id']
    dd['article'] = ' '.join([
        d['data']['title'].replace('\n', ' '),
        d['data']['anno'].replace('\n', ' '),
        d['data']['text'].replace('\n', ' ')
    ])
    dd.update(
        {
            x['from_name'] + '_' + x['value']['choices'][0] : 1
            for x in last_ann(d['annotations'])['result']
        }
    )
    ds.append(dd)
    
df = pd.DataFrame(ds)
df.fillna(0, inplace=True)
seq_len = [len(str(i).split()) for i in df['article']]
print('max sequence lenght:', max(seq_len))
display(df.head())

In [None]:
plt.style.use('ggplot')
plt.figure(figsize=(8, 4))
df['length'] = df['article'].apply(lambda x: len(x.split()))
sns.histplot(df[df['length'] < 20000]['length'], bins=30)
plt.title('frequence of documents of a given length', fontsize=10)
plt.xlabel('length', fontsize=10)

In [None]:
def chk_balance(df, target_cols):
    print('total:', len(df))
    for col in target_cols:
        if 'target' in col: print(
            col, '\t',
            df[col].sum(), '\t',
            '{:.1%}'.format(df[col].sum() / len(df))
        )

In [None]:
target_cols = sorted(list(set([
    c.replace('_explicit', '').replace('_general', '') 
    for c in CONFIG['targets']
])))
if CONFIG['targets_type'] == 'explicit':
    for col in target_cols:
        df[col] = df[col + '_explicit']
elif CONFIG['targets_type'] == 'general':
    for col in target_cols:
        name1 = col + '_explicit'
        name2 = col + '_general'
        if col == 'target_0':
            df[col] = df[name1]
        else:
            df[col] = df.apply(lambda x: max(x[name1], x[name2]), axis=1)
elif CONFIG['targets_type'] == 'all':
    target_cols = [col for col in df.columns if 'target' in col]
else:
    ValueError('`targets_type` parameter error')
if CONFIG['targets_type'] != 'all':
    df.loc[
        df[[c for c in target_cols if 'target_0' not in c]].sum(axis=1) == 0, 
        'target_0'
    ] = 1
chk_balance(df, target_cols)

In [None]:
CONFIG['target_cols'] = target_cols
with open(f'{MDLS_PATH}/config.json', 'w') as file:
    json.dump(CONFIG, file)

In [None]:
skf = StratifiedKFold(CONFIG['folds'], shuffle=True, random_state=CONFIG['seed'])
df['fold'] = -1
for i, (train_idxs, val_idxs) in enumerate(skf.split(df, df['target_0_explicit'])):
    df.loc[val_idxs, 'fold'] = i
for fold_num in range(CONFIG['folds']): 
    train_idxs = np.where((df['fold'] != fold_num))[0]
    val_idxs = np.where((df['fold'] == fold_num))[0]
    df_train = df.loc[train_idxs]
    df_val = df.loc[val_idxs]
    print('FOLD', fold_num)
    chk_balance(df_train, target_cols)
    chk_balance(df_val, target_cols)
    print('-' * 30)

## Training

In [None]:
class LongformerArticlesLabeling(torch.nn.Module):
    def __init__(self, model_name, target_cols, dropout=None):
        super(LongformerArticlesLabeling, self).__init__()
        self.longformer = AutoModel.from_pretrained(model_name)
        if dropout:
            self.dropout = dropout
            self.l2 = torch.nn.Dropout(dropout)
        if 'tiny' in model_name:
            self.fc = torch.nn.Linear(312, len(target_cols))
        else:
            self.fc = torch.nn.Linear(768, len(target_cols))
        
    def forward(self, input_ids=None, attention_mask=None, 
                global_attention_mask=None, 
                token_type_ids=None, position_ids=None, 
                inputs_embeds=None):
        if global_attention_mask is None:
            global_attention_mask = torch.zeros_like(input_ids)
            global_attention_mask[:, 0] = 1
        _, features = self.longformer(
            input_ids=input_ids,
            attention_mask = attention_mask,
            global_attention_mask = global_attention_mask,
            token_type_ids = token_type_ids,
            position_ids = position_ids,
            return_dict=False
        )
        if self.dropout: 
            x = self.l2(features)
            output = self.fc(x)
        output = self.fc(features)
        return output

In [None]:
class ArticlesDataset(torch.utils.data.Dataset):
    def __init__(self, df, col, target_cols, tokenizer, max_len):
        self.df = df
        self.max_len = max_len
        self.text = df[col]
        self.tokenizer = tokenizer
        self.targets = df[target_cols].values
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, index):
        text = self.text[index]
        inputs = self.tokenizer.encode_plus(
            text,
            add_special_tokens=True,
            max_length=self.max_len,
            padding='max_length',
            return_attention_mask=True,
            truncation=True,
            return_tensors='pt'
        )
        ids = inputs['input_ids']
        attention_mask = inputs['attention_mask']
        global_attention_mask = [
            [1 if token_id == self.tokenizer.cls_token_id else 0 for token_id in input_ids]
            for input_ids in inputs['input_ids']
        ]
        return {
            'ids': ids[0], 
            'attention_mask': attention_mask[0],
            'global_attention_mask': torch.tensor(global_attention_mask, dtype=torch.long)[0],
            'targets': torch.tensor(self.targets[index], dtype=torch.float)
        }

In [None]:
class ArticlesTrainer:
    def __init__(self, model, device, 
                 optimizer, scheduler, 
                 criterion, pos_weight,
                 acc_flag=True):
        self.model = model
        self.device = device
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.criterion = criterion
        self.pos_weight = pos_weight
        self.acc_flag = acc_flag
        if acc_flag:
            self.best_val_acc = 0
        else:
            self.best_val_loss = np.inf
        self.val_losses = []
        self.train_losses = []
        self.val_acc = []
        self.lastmodel = None
        
    def fit(self, epochs, train_loader, val_loader, 
            max_patience, save_mode='best', save_name='model'):     
        n_patience = 0
        for n_epoch in range(1, epochs + 1):
            self.info_message('EPOCH: {}', n_epoch)
            train_loss, train_time = self.train_epoch(train_loader)
            val_loss, val_acc, val_f1_mic, val_f1_mac, val_time = self.val_epoch(val_loader)
            self.train_losses.append(train_loss)
            self.val_losses.append(val_loss)
            self.val_acc.append(val_acc)
            self.info_message(
                'epoch train: {} | loss: {:.2f} | time: {:.0f} sec',
                n_epoch, train_loss, train_time
            )
            self.info_message(
                'epoch val: {} | loss: {:.2f} | ' +
                'acc: {:.2f} | f1 micro: {:.2f} | f1 macro: {:.2f} | ' +
                'time: {:.0f} sec',
                n_epoch, val_loss, val_acc, val_f1_mic, val_f1_mac, val_time
            )
            if self.acc_flag:
                if self.best_val_acc < val_acc: 
                    self.save_model(n_epoch, save_mode, save_name, val_loss, val_acc)
                    self.info_message(
                        'val accuracy improved {:.2f} -> {:.2f} | saved model to "{}"', 
                        self.best_val_acc, val_acc, self.lastmodel
                    )
                    self.best_val_acc = val_acc
                    n_patience = 0
                else:
                    n_patience += 1
            else:
                if self.best_val_loss > val_loss: 
                    self.save_model(n_epoch, save_mode, save_name, val_loss, val_acc)
                    self.info_message(
                        'val loss improved {:.2f} -> {:.2f} | saved model to "{}"', 
                        self.best_val_loss, val_loss, self.lastmodel
                    )
                    self.best_val_loss = val_loss
                    n_patience = 0
                else:
                    n_patience += 1
            if n_patience >= max_patience:
                self.info_message(
                    '\nno improvement for last {} epochs', 
                    n_patience
                )
                break
        history = {
            'train losses': self.train_losses, 
            'val losses': self.val_losses, 
            'val accuracy': self.val_acc
        }
        return history
            
    def train_epoch(self, train_loader):
        self.model.train()
        t = time.time()
        sum_loss = 0
        for step, batch in enumerate(train_loader, 1):
            ids = batch['ids'].to(self.device, dtype=torch.long)
            mask = batch['attention_mask'].to(self.device, dtype=torch.long)
            gamask = batch['global_attention_mask'].to(self.device, dtype=torch.long)
            targets = batch['targets'].to(self.device, dtype=torch.float)
            if len(self.pos_weight):
                pos_weight = torch.from_numpy(self.pos_weight).to(
                    self.device, 
                    dtype=torch.float
                )
            else:
                pos_weight = None
            self.optimizer.zero_grad()
            outputs = self.model(
                input_ids=ids,
                attention_mask=mask,
                global_attention_mask=gamask
            )
            loss = self.criterion(outputs, targets, pos_weight=pos_weight)
            loss.backward()
            self.optimizer.step()
            if self.scheduler: 
                self.scheduler.step()
            sum_loss += loss.detach().item()   
            self.info_message(
                'train step {}/{} | train loss: {:.4f}           ',
                step, len(train_loader), sum_loss / step, end='\r'
            )
        return sum_loss / len(train_loader), int(time.time() - t)
    
    def val_epoch(self, val_loader):
        self.model.eval()
        t = time.time()
        sum_loss = 0
        y_all = []
        outputs_all = []
        for step, batch in enumerate(val_loader, 1):
            with torch.no_grad():
                ids = batch['ids'].to(self.device, dtype=torch.long)
                mask = batch['attention_mask'].to(self.device, dtype=torch.long)
                targets = batch['targets'].to(self.device, dtype=torch.float)
                outputs = self.model(ids, mask)
                if len(self.pos_weight):
                    pos_weight = torch.from_numpy(self.pos_weight).to(
                        self.device, 
                        dtype=torch.float
                    )
                else:
                    pos_weight = None
                loss = self.criterion(outputs, targets, pos_weight=pos_weight)
                sum_loss += loss.detach().item()
                y_all.extend(targets.cpu().detach().numpy().tolist())
                outputs_all.extend(torch.sigmoid(outputs).cpu().detach().numpy().tolist())
            self.info_message(
                'val step {}/{} | val loss: {:.4f}               ', 
                step, len(val_loader), sum_loss / step, end='\r'
            )
        outputs_all = np.array(outputs_all) > .5
        acc = accuracy_score(y_all, outputs_all)
        mcm = multilabel_confusion_matrix(outputs_all, y_all)
        self.cm_print(mcm, start='\n')
        f1_mic = f1_score(y_all, outputs_all, average='micro')
        f1_mac = f1_score(y_all, outputs_all, average='macro')
        return sum_loss / len(val_loader), acc, f1_mic, f1_mac, int(time.time() - t)
    
    def save_model(self, n_epoch, save_mode, save_name, loss, acc):
        if save_mode == 'best':
            self.lastmodel = f'{MDLS_PATH}/{save_name}.pth'
        else:
            self.lastmodel = f'{MDLS_PATH}/{save_name}-e{n_epoch}-loss{loss:.3f}-acc{acc:.3f}.pth'
        dict_save = {
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'n_epoch': n_epoch,
        }
        if self.acc_flag:
            dict_save['best_val_acc'] = self.best_val_acc
        else:
            dict_save['best_val_loss'] = self.best_val_loss
        torch.save(dict_save, self.lastmodel)
    
    def display_plots(self):
        fig, axes = plt.subplots(figsize=(16, 4), nrows=1, ncols=2)
        axes[0].set_title(f'training and validation losses')
        axes[0].plot(self.val_losses, label='val')
        axes[0].plot(self.train_losses, label='train')
        axes[0].set_xlabel('iterations')
        axes[0].set_ylabel('loss')
        axes[0].legend()
        axes[1].set_title(f'validation accuracy')
        axes[1].plot(self.val_acc, label='val')
        axes[1].set_xlabel('iterations')
        axes[1].set_ylabel('accuracy')
        axes[1].legend()
        plt.show()
        plt.close()
    
    @staticmethod
    def info_message(message, *args, end='\n'):
        print(message.format(*args), end=end)
        
    @staticmethod
    def cm_print(mcm, start='\n'):
        print(start)
        for cm in mcm:
            print(cm)
            print('-' * 50)

In [None]:
def df_pos_weight(df, target_cols, pos_weight):
    num_pos_samples = df[target_cols].sum()
    num_neg_samples = len(df) - num_pos_samples
    weights = pos_weight * np.array(num_neg_samples / num_pos_samples)
    print('weights used:', weights)
    return weights

In [None]:
def train_art_model(df_train, df_val, target_cols,
                    device, model_name, max_seq_len,
                    epochs, save_name, patience, 
                    batch_size, num_workers):
    print('=' * 20, f'MODEL TRAIN - {save_name}', '=' * 20)
    print('train:', df_train.shape, '| val:', df_val.shape)
    tokenizer = LongformerTokenizerFast.from_pretrained(model_name)
    train_dataset = ArticlesDataset(
        df=df_train, 
        col='article',
        target_cols=target_cols,
        tokenizer=tokenizer, 
        max_len=max_seq_len
    )
    val_dataset = ArticlesDataset(
        df=df_val, 
        col='article',
        target_cols=target_cols,
        tokenizer=tokenizer,
        max_len=max_seq_len
    )
    train_loader = torch.utils.data.DataLoader(
        train_dataset, 
        batch_size=batch_size,
        num_workers=num_workers, 
        shuffle=True,
        pin_memory=True
    )
    val_loader = torch.utils.data.DataLoader(
        val_dataset, 
        batch_size=batch_size, 
        num_workers=num_workers, 
        shuffle=False, 
        pin_memory=True
    )
    model = LongformerArticlesLabeling(
        model_name=model_name, 
        target_cols=target_cols,
        dropout=CONFIG['dropout']
    )
    model.to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=CONFIG['lr'], weight_decay=1e-6)
    scheduler = None
    criterion = nn.functional.binary_cross_entropy_with_logits
    if CONFIG['pos_weight']:
        pos_weight = df_pos_weight(df_train, target_cols, CONFIG['pos_weight'])
    else:
        pos_weight = []
    trainer = ArticlesTrainer(
        model, 
        device, 
        optimizer, 
        scheduler,
        criterion,
        pos_weight,
        acc_flag=CONFIG['acc']
    )
    history = trainer.fit(
        epochs, 
        train_loader, 
        val_loader, 
        save_mode='best', 
        save_name=save_name,
        max_patience=patience
    )
    trainer.display_plots()
    with open(f'{MDLS_PATH}/history_{save_name}.json', 'w') as file:
        json.dump(history, file)
    return trainer.lastmodel

In [None]:
model_files = []
for fold_num in range(CONFIG['folds']): 
    train_idxs = np.where((df['fold'] != fold_num))[0]
    val_idxs = np.where((df['fold'] == fold_num))[0]
    df_train = df.loc[train_idxs]
    df_val = df.loc[val_idxs]
    df_train.reset_index(drop=True, inplace=True)
    df_val.reset_index(drop=True, inplace=True)
    model_files.append(train_art_model(
        df_train, 
        df_val, 
        target_cols,
        device=CONFIG['device'], 
        model_name=CONFIG['bbone'],
        max_seq_len=CONFIG['max_seq_len'],
        epochs=CONFIG['epochs'], 
        save_name=f'model_{fold_num}',
        patience=CONFIG['patience'], 
        batch_size=CONFIG['batch_size'],
        num_workers=CONFIG['num_workers']
    ))
print(model_files)
with open(f'{MDLS_PATH}/model_files.json', 'w') as file:
    json.dump(model_files, file)

## Inference

In [None]:
with open(f'{MDLS_PATH}/model_files.json', 'r') as file:
    model_files = json.load(file)
print("models files loaded:", model_files)

In [None]:
def infer(model_file, df, target_cols, 
          model_name, max_seq_len,
          device, batch_size, num_workers):
    print('PREDICT:', model_file, df.shape)
    tokenizer = LongformerTokenizerFast.from_pretrained(model_name)
    pred_dataset = ArticlesDataset(
        df=df, 
        col='article',
        target_cols=target_cols,
        tokenizer=tokenizer, 
        max_len=max_seq_len
    )
    pred_loader = torch.utils.data.DataLoader(
        pred_dataset, 
        batch_size=batch_size,
        num_workers=num_workers, 
        shuffle=False,
        pin_memory=True
    )
    model = LongformerArticlesLabeling(
        model_name=model_name, 
        target_cols=target_cols,
        dropout=CONFIG['dropout']
    )
    model.to(device)
    checkpoint = torch.load(model_file)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()
    outputs_all = []
    for step, batch in enumerate(pred_loader, 1):
        with torch.no_grad():
            ids = batch['ids'].to(device, dtype=torch.long)
            mask = batch['attention_mask'].to(device, dtype=torch.long)
            outputs = model(ids, mask)
            outputs_all.extend(torch.sigmoid(outputs).cpu().detach().numpy().tolist())
    df_pred = pd.DataFrame(outputs_all) 
    df_pred.columns = [c + '_pred' for c in target_cols]
    return df_pred

In [None]:
df_pred = pd.DataFrame()
for model_file in tqdm(model_files):
    df_pred_tmp = infer(
        model_file=model_file, 
        df=df_val, 
        target_cols=target_cols,
        model_name=CONFIG['bbone'],
        max_seq_len=CONFIG['max_seq_len'],
        device=CONFIG['device'], 
        batch_size=CONFIG['batch_size'],
        num_workers=CONFIG['num_workers']
    )
    if len(df_pred):
        df_pred += df_pred_tmp
    else:
        df_pred = df_pred_tmp
df_pred /= len(model_files)
display(df_pred.head())

In [None]:
y_unique = target_cols
mcm = multilabel_confusion_matrix(
    (df_pred > .5).values, 
    df_val[target_cols].values
)
for c, cm in zip(target_cols, mcm):
    print(
        c, '\t',
        df_val[c].sum(), '\t',
            '{:.1%}'.format(df[c].sum() / len(df)),
    )
    print(cm)
    print('=' * 50)

In [None]:
text = df_val.loc[1, 'article']
print(text[:500], len(text))

In [None]:
d = {'article': [text]}
d.update(dict(zip(target_cols, [0] * len(target_cols))))
df_txt = pd.DataFrame(d)
df_txt

In [None]:
%%time
df_pred = pd.DataFrame()
for model_file in tqdm(model_files):
    df_pred_tmp = infer(
        model_file=model_file, 
        df=df_txt, 
        target_cols=target_cols,
        model_name=CONFIG['bbone'],
        max_seq_len=CONFIG['max_seq_len'],
        device=CONFIG['device'], 
        batch_size=CONFIG['batch_size'],
        num_workers=CONFIG['num_workers']
    )
    if len(df_pred):
        df_pred += df_pred_tmp
    else:
        df_pred = df_pred_tmp
df_pred /= len(model_files)
df_pred