In [None]:
import os
import warnings
warnings.simplefilter('ignore', RuntimeWarning)
import sys
import logging

IS_DEBUG = False
IS_KAGGLE = True

if IS_KAGGLE:
    package_paths = [
        '/kaggle/input/git-mykaggle/mykaggle/'
    ]
    for pth in package_paths:
        sys.path.append(pth)

from typing import Any, Dict, Tuple, List, Optional
import gc
import copy
from pathlib import Path
from enum import Enum
import random
import math
import yaml
import pickle
import numpy as np
import pandas as pd
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset
from torch.cuda.amp import autocast
from transformers import (
    AutoTokenizer, PreTrainedTokenizerFast, PreTrainedModel, AutoConfig, AutoModel,
    BertModel, RobertaModel, ElectraModel, DebertaModel, AlbertModel
)
from argparse import ArgumentParser, Namespace


def fix_seed(seed: int = 1019) -> None:
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)  # type: ignore
    torch.backends.cudnn.deterministic = True  # type: ignore


SETTINGS = yaml.safe_load('''
name: '095_ensemble13_cv_kaggle93'
competition: commonlitreadabilityprize
mode: ensemble
seed: 1019
device: cuda
is_full_training: true
training:
    trainer: class
    use_amp: true
    num_gpus: 1
    train_file: train_mod.csv
    test_file: test.csv
    additional_data:
    num_folds: 5
    folds: stratified_5fold3.pkl
    learning_rate: 0.00002
    learning_rate_output: 0.0001
    num_epochs: 6
    batch_size: 4
    test_batch_size: 4
    num_accumulations: 4
    num_workers: 4
    scheduler: LinearDecayWithWarmUp
    scheduler_epoch: 10
    batch_scheduler: true
    max_length: 256
    warmup_epochs: 0.6
    logger_verbose_step: 10
    optimizer: AdamW
    weight_decay: 0.0
    optimizer_debias: true
    use_layerwise_optim_params: true
    use_large_output_lr: true
    optim_layerwise_type: 2
    loss: histgram
    loss_reduction: 'none'
    loss_rank_margin: 0.5
    loss_huber_delta: 0.5
    loss_use_rdrop: true
    loss_rdrop_alpha: 2.0
    use_rank_and_rmse_loss: true
    loss_rmse_weight: 1.0
    loss_external_weight: 1.0
    val_check_interval: 10 # 0 or None means per epoch
    ckpt_callback_verbose: true
    use_standard_error: false
    use_numerical_features: false
    num_numerical_features: 0
    mlm_probability: 0.15
model:
    model_name: microsoft/deberta-large
    model_type: custom_head_class
    pretrained: true
    num_classes: 1
    encoder_attn_dropout_rate: 0.1
    encoder_ffn_dropout_rate: 0.1
    layer_norm_eps: 0.0000001
    dropout_rate: 0.3
    output_activation: False
    custom_head_types: ['cls', 'avg', 'max', 'attn'] # ['cls', 'attn', 'avg', 'max', 'conv']
    custom_head_ensemble: avg
    output_head_features: true
    num_use_layers_as_output: 4
    num_reinit_layers: 5
    num_reinit_aux_layers: 0
    second_output_dim: 256
    head_hidden_dim: 1024
    head_intermediate_dim: 512
    num_output_heads: 2
    use_middle_layers: [-1]
    mlm_use_pooler: true
    ckpt_from: ckpt/929_de_pt_922m05/model_{fold}.pt
    head_stack:
    aux_num_hidden_layers: 4
    aux_hidden_size: 256
headstack:
    model_name:
ensemble:
    num_permutations: 12
    fold_only:
    ensemble_type: Nelder-Mead # Nelder-Mead
    lb_weights: [
        0.464,
        0.462,
        0.461,
        0.457,
        0.456,
        # 0.454,
        0.465,
        # 0.454,
        0.456,
        0.455,
        0.455,
        0.458,
        # 0.456,
        # 0.454,
        0.461,
        0.463,
        0.464,
        # 0.459
    ]
    lb_constant_weight: 1.5
    models:
        - 680_de_ft535_head
        - 713_de_ft535_head
        - 801_de_ft795_attn
        - 839_de_ft819_heads
        - 840_de_ft819_headsw
        # - 875_de_ft843_heads
        - 881_ro_ft853_heads
        # - 884_de_ft843_heads
        - 885_de_ft843_headsw
        - 906_de_ft843_heads # 10
        - 919_de_ft891_heads
        - 920_de_aux_ft890
        # - 931_de_chaux_ft890
        # - 938_gauss_ft843_ch
        - 941_alxxl_ft924
        - 960_alxxl_ft927_heads
        - 969_el_ft945_heads
        # - 081_de_ft961_classmodel
''')


def get_logger(name: str, level: int = logging.INFO) -> logging.Logger:
    '''
    ライブラリ側で使用する logger を取得します。
    :param name: logger の名前空間
    '''
    logger = logging.getLogger(name)
    handler = logging.StreamHandler(sys.stdout)
    handler.setLevel(level)
    logger.addHandler(handler)
    logger.setLevel(level)
    return logger


LOGGER = get_logger(__name__)

# misc
if not IS_KAGGLE:
    import dotenv
    from mykaggle.util.ml_logger import MLLogger, assert_env
    torch.multiprocessing.set_sharing_strategy('file_system')
    dotenv.load_dotenv()
    assert_env()
    SETTINGS['model']['pretrained'] = False
else:
    MLLogger = Any  # type: ignore

# preparing constant path

if IS_KAGGLE:
    DATADIR = Path('/kaggle/input/') / SETTINGS["competition"]
    CKPTDIR = Path('/kaggle/input/ckpt-mykaggle/') / SETTINGS['name']
    OUTPUTDIR = Path('/kaggle/working')
else:
    DATADIR = Path('./data/')
    CKPTDIR = Path('./ckpt/') / SETTINGS['name']
    OUTPUTDIR = CKPTDIR

    if not CKPTDIR.exists():
        CKPTDIR.mkdir()

ROOT_CKPTDIR = CKPTDIR.parent
TRAINDIR = DATADIR / 'train'
TESTDIR = DATADIR / 'test'

# load data

if IS_KAGGLE:
    df_train = pd.read_csv(DATADIR / 'train.csv')
    df_test = pd.read_csv(DATADIR / 'test.csv')
    df_sub = pd.read_csv(DATADIR / 'sample_submission.csv')
else:
    if SETTINGS['training']['train_file'].endswith('ftr'):
        df_train = pd.read_feather(DATADIR / SETTINGS['training']['train_file'])
    else:
        df_train = pd.read_csv(DATADIR / SETTINGS['training']['train_file'])
        if not IS_KAGGLE and SETTINGS['training']['loss'] == 'histgram' and SETTINGS['mode'] == 'training':
            hist_target = pickle.load(open(DATADIR / f'hist_target_c{SETTINGS["model"]["num_classes"]}.pkl', 'rb'))
            df_train[[f'hist_target_{i}' for i in range(SETTINGS['model']['num_classes'])]] = hist_target
    df_test = pd.read_csv(DATADIR / SETTINGS['training']['test_file'])
    df_sub = pd.read_csv(DATADIR / SETTINGS['training']['test_file'])
    if SETTINGS['training']['additional_data']:
        df_additional = pd.read_csv(DATADIR / SETTINGS['training']['additional_data'])


if IS_DEBUG:
    df_train = df_train.iloc[:100]

# necessary parameters
SETTINGS['ckptdir'] = str(CKPTDIR)
SETTINGS['training']['ckptdir'] = str(CKPTDIR)
SETTINGS['training']['num_classes'] = SETTINGS['model']['num_classes']
SETTINGS['training']['num_batches'] = math.ceil((
    math.ceil(len(df_train) * (SETTINGS['training']['num_folds'] - 1) / SETTINGS['training']['num_folds'])
) / (SETTINGS['training']['batch_size'] * SETTINGS['training']['num_accumulations']))
SETTINGS['training']['num_total_steps'] = SETTINGS['training']['num_batches'] * SETTINGS['training']['scheduler_epoch']
if SETTINGS['mode'] == 'stacking':
    SETTINGS['stacking']['training']['num_total_steps'] = SETTINGS['training']['num_total_steps']
    SETTINGS['stacking']['training']['num_batches'] = SETTINGS['training']['num_batches']

LOGGER.info(f'Loaded data, train shape:{df_train.shape}, test shape:{df_test.shape}')


def parse() -> Namespace:
    parser = ArgumentParser(description='Process some integers.')
    parser.add_argument(
        '--gpus', type=str, help='index of gpus. if multiple, use comma to list.'
    )
    args = parser.parse_args()
    return args


class Mode(Enum):
    TRAIN = 'TRAIN'
    VALID = 'VALID'
    TEST = 'TEST'


class MyDataset(Dataset):
    '''Get raw dataframe and json dir, prepare model inputs and feed by tf.data.Dataset
    '''
    def __init__(
        self,
        settings: Dict,
        df: pd.DataFrame,
        datadir: Optional[Path],
        tokenizer: PreTrainedTokenizerFast,
        mode: Mode,
        fold: int,
    ) -> None:
        super().__init__()
        self.settings = settings
        self.df = df.reset_index(drop=True).copy()
        self.tokenizer = tokenizer
        self.mode = mode

        def clean_text(text: str) -> str:
            return text.replace('\n', ' ').strip()

        self.df['excerpt'] = self.df['excerpt'].apply(clean_text)

        inputs = self.tokenizer(
            self.df['excerpt'].values.tolist(),
            padding='max_length',
            truncation=True,
            max_length=self.settings['max_length'],
        )
        self.inputs = {k: np.array(v).astype(np.int32) for k, v in inputs.items()}
        if 'target' in self.df.columns:
            self.labels = self.df['target'].values
        else:
            self.labels = None
        if 'weight' not in self.df.columns:
            self.df['weight'] = 1.0
            if 'data_type' in self.df.columns:
                self.df.loc[~df['data_type'].isna(), 'weight'] = self.settings.get('loss_external_weight', 1.0)
        self.keys = list(self.inputs.keys())
        self.epoch = 0

    def __len__(self) -> int:
        return len(self.df)

    def __getitem__(self, index: int):
        data = self.df.iloc[index]
        weight = data['weight'] if 'weight' in data else 1.0
        inp = {}
        for key in self.keys:
            inp[key] = self.inputs[key][index]
        if self.settings.get('use_standard_error', False) and self.mode == Mode.TRAIN:
            std = data['standard_error']
            inp['std'] = std

        if self.labels is not None:
            target = self.labels[index]
            if 'dist_target' in data and self.mode == Mode.TRAIN and not np.isnan(data['dist_target']):
                target = data['dist_target'] or target
            return inp, (torch.tensor(target, dtype=torch.double), weight)
        return inp


def get_dataloader(
    settings: Dict[str, Any],
    dataset: Dataset,
    mode: Mode,
    fold: int,
    *args, **kwargs
) -> DataLoader:
    batch_size = settings['batch_size'] if mode == Mode.TRAIN else settings['test_batch_size']
    dataloader: DataLoader = DataLoader(
        dataset,
        batch_size=batch_size,
        pin_memory=False,
        drop_last=False,
        shuffle=mode == Mode.TRAIN,
        num_workers=settings['num_workers'],
    )
    return dataloader


def get_transformers_model(
    settings: Dict[str, Any],
    model_name: str,
    pretrained: bool = True,
    ckptdir: Optional[Path] = None,
) -> PreTrainedModel:
    model_path = model_name if pretrained else str(ckptdir)
    config = AutoConfig.from_pretrained(model_path)
    config.attention_probs_dropout_prob = settings.get('encoder_attn_dropout_rate', 0.1)
    config.hidden_dropout_prob = settings.get('encoder_ffn_dropout_rate', 0.1)
    config.layer_norm_eps = settings.get('layer_norm_eps', 1e-5)
    if pretrained:
        model = AutoModel.from_pretrained(model_name, config=config)
        return model

    if 'albert' in model_name:
        model = AlbertModel(config=config)
    elif 'roberta' in model_name:
        model = RobertaModel(config=config)
    elif 'deberta' in model_name:
        model = DebertaModel(config=config)
    elif 'bert' in model_name:
        model = BertModel(config=config)
    elif 'electra' in model_name:
        model = ElectraModel(config=config)
    else:
        model = BertModel(config=config)
    return model


class ModelCustomHeadEnsemble(nn.Module):
    def __init__(
        self,
        settings: Dict[str, Any],
        model: PreTrainedModel
    ) -> None:
        super().__init__()
        self.settings = settings
        self.model = model
        self.num_reinit_layers = settings['model'].get('num_reinit_layers', 0)
        self.head_types = settings['model']['custom_head_types']
        self.num_use_layers = self.settings['model']['num_use_layers_as_output']
        output_layers = {}

        if 'attn' in self.head_types:
            self.hidden_dim = self.settings['model']['head_hidden_dim']
            self.intermediate_dim = self.settings['model'].get('head_intermediate_dim', self.hidden_dim)
            self.attn_head = AttentionHead(self.hidden_dim, self.intermediate_dim)
        if 'conv' in self.head_types:
            hidden_dim = self.settings['model'].get('conv_head_hidden_dim', 256)
            kernel_size = self.settings['model'].get('conv_head_kernel_size', 2)
            self.conv1 = nn.Conv1d(self.model.config.hidden_size, hidden_dim, kernel_size=kernel_size, padding=1)
            self.conv2 = nn.Conv1d(hidden_dim, 1, kernel_size=kernel_size, padding=1)
        if 'layers_sum' in self.head_types:
            self.layer_weight = nn.Parameter(torch.tensor([1] * self.num_use_layers, dtype=torch.float))

        for head in self.head_types:
            if 'concat' in head:
                output_layers[head] = nn.Linear(self.model.config.hidden_size * self.num_use_layers, 1)
            elif head == 'conv':
                continue
            else:
                output_layers[head] = nn.Linear(self.model.config.hidden_size, 1)
        self.output_layers = nn.ModuleDict(output_layers)

        self.dropout = nn.Dropout(settings['model']['dropout_rate'])
        self.ensemble_type = settings['model']['custom_head_ensemble']
        if self.ensemble_type == 'weight':
            self.ensemble_weight = nn.Linear(len(self.head_types), 1, bias=False)
        self.output_head_features = settings['model'].get('output_head_features', False)

        self.initialize()

    def forward(self, inputs):
        outputs = self.model(**inputs, output_hidden_states=True)
        head_features = []
        features = []
        if 'cls' in self.head_types:
            cls_state = outputs.last_hidden_state[:, 0, :]
            feature = self.output_layers['cls'](self.dropout(cls_state))
            head_features.append(cls_state)
            features.append(feature)
        if 'avg' in self.head_types:
            avg_pool = torch.mean(outputs.last_hidden_state, 1)
            feature = self.output_layers['avg'](self.dropout(avg_pool))
            head_features.append(avg_pool)
            features.append(feature)
        if 'max' in self.head_types:
            max_pool = torch.max(outputs.last_hidden_state, 1)[0]
            feature = self.output_layers['max'](self.dropout(max_pool))
            head_features.append(max_pool)
            features.append(feature)
        if 'attn' in self.head_types:
            attn_state = self.attn_head(outputs.last_hidden_state)
            feature = self.output_layers['attn'](self.dropout(attn_state))
            head_features.append(attn_state)
            features.append(feature)
        if 'conv' in self.head_types:
            conv_state = self.conv1(outputs.last_hidden_state.permute(0, 2, 1))
            conv_state = F.relu(self.conv2(conv_state))
            feature, _ = torch.max(conv_state, -1)
            head_features.append(conv_state)
            features.append(feature)
        if 'layers_concat' in self.head_types:
            hidden_states = outputs.hidden_states[-self.num_use_layers:]
            cat_feature = torch.cat([state[:, 0, :] for state in hidden_states], -1)
            feature = self.output_layers['layers_concat'](self.dropout(cat_feature))
            head_features.append(cat_feature)
            features.append(feature)
        if 'layers_avg' in self.head_types:
            hidden_states = torch.stack(outputs.hidden_states[-self.num_use_layers:], -1)[:, 0, :, :]
            avg_feature = torch.mean(hidden_states, -1)
            feature = self.output_layers['layers_avg'](self.dropout(avg_feature))
            head_features.append(avg_feature)
            features.append(feature)
        if 'layers_sum' in self.head_types:
            hidden_states = torch.stack(outputs.hidden_states[-self.num_use_layers:], -1)[:, 0, :, :]
            weight = self.layer_weight[None, None, :] / self.layer_weight.sum()
            weighted_sum_feature = torch.sum(hidden_states * weight, -1)
            feature = self.output_layers['layers_sum'](self.dropout(weighted_sum_feature))
            head_features.append(weighted_sum_feature)
            features.append(feature)
        if 'layers_attn' in self.head_types:
            hidden_states = torch.stack(outputs.hidden_states[-self.num_use_layers:], -1)[:, 0, :, :]
            attn_state = self.layer_attn(hidden_states)
            feature = self.output_layers['layers_attn'](self.dropout(attn_state))
            head_features.append(attn_state)
            features.append(feature)

        outputs = torch.cat(features, -1)
        if len(self.head_types) > 1:
            if self.ensemble_type == 'avg':
                outputs = torch.mean(outputs, -1)
            elif self.ensemble_type == 'weight':
                if self.settings['training']['trainer'] == 'multi':
                    outputs = outputs.detach()
                weight = self.ensemble_weight.weight / torch.sum(self.ensemble_weight.weight)
                outputs = torch.sum(weight * outputs, -1)
        outputs = outputs.reshape(inputs['input_ids'].shape[0])
        if self.output_head_features:
            features = [f.reshape(inputs['input_ids'].shape[0]) for f in features]
            return outputs, features, head_features
        return outputs

    def initialize(self):
        if self.ensemble_type == 'weight':
            torch.nn.init.constant_(self.ensemble_weight.weight, 1.0)
        self.output_layers.apply(self._init_weight)
        for i in range(self.num_reinit_layers):
            self.model.encoder.layer[-(1 + i)].apply(self._init_weight)

    def _init_weight(self, module):
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean=0.0, std=self.model.config.initializer_range)
            if module.bias is not None:
                module.bias.data.zero_()


class ModelCustomHeadClassification(nn.Module):
    def __init__(
        self,
        settings: Dict[str, Any],
        model: PreTrainedModel
    ) -> None:
        super().__init__()
        self.settings = settings
        self.model = model
        self.num_classes = settings['model']['num_classes']
        self.num_reinit_layers = settings['model'].get('num_reinit_layers', 0)
        self.head_types = settings['model']['custom_head_types']
        self.num_use_layers = self.settings['model']['num_use_layers_as_output']
        output_layers = {}

        if 'attn' in self.head_types:
            self.hidden_dim = self.settings['model']['head_hidden_dim']
            self.intermediate_dim = self.settings['model'].get('head_intermediate_dim', self.hidden_dim)
            self.attn_head = AttentionHead(self.hidden_dim, self.intermediate_dim)
        if 'layers_sum' in self.head_types:
            self.layer_weight = nn.Parameter(torch.tensor([1] * self.num_use_layers, dtype=torch.float))

        for head in self.head_types:
            if 'concat' in head:
                output_layers[head] = nn.Linear(self.model.config.hidden_size * self.num_use_layers, self.num_classes)
            else:
                output_layers[head] = nn.Linear(self.model.config.hidden_size, self.num_classes)
        self.output_layers = nn.ModuleDict(output_layers)

        self.dropout = nn.Dropout(settings['model']['dropout_rate'])
        self.ensemble_type = settings['model']['custom_head_ensemble']
        if self.ensemble_type == 'weight':
            self.ensemble_weight = nn.Linear(len(self.head_types), 1, bias=False)
        self.output_head_features = settings['model'].get('output_head_features', False)

        self.initialize()

    def forward(self, inputs):
        outputs = self.model(**inputs, output_hidden_states=True)
        head_features = []
        features = []
        if 'cls' in self.head_types:
            cls_state = outputs.last_hidden_state[:, 0, :]
            feature = self.output_layers['cls'](self.dropout(cls_state))
            head_features.append(cls_state)
            features.append(feature)
        if 'avg' in self.head_types:
            avg_pool = torch.mean(outputs.last_hidden_state, 1)
            feature = self.output_layers['avg'](self.dropout(avg_pool))
            head_features.append(avg_pool)
            features.append(feature)
        if 'max' in self.head_types:
            max_pool = torch.max(outputs.last_hidden_state, 1)[0]
            feature = self.output_layers['max'](self.dropout(max_pool))
            head_features.append(max_pool)
            features.append(feature)
        if 'attn' in self.head_types:
            attn_state = self.attn_head(outputs.last_hidden_state)
            feature = self.output_layers['attn'](self.dropout(attn_state))
            head_features.append(attn_state)
            features.append(feature)
        if 'layers_concat' in self.head_types:
            hidden_states = outputs.hidden_states[-self.num_use_layers:]
            cat_feature = torch.cat([state[:, 0, :] for state in hidden_states], -1)
            feature = self.output_layers['layers_concat'](self.dropout(cat_feature))
            head_features.append(cat_feature)
            features.append(feature)
        if 'layers_sum' in self.head_types:
            hidden_states = torch.stack(outputs.hidden_states[-self.num_use_layers:], -1)[:, 0, :, :]
            weight = self.layer_weight[None, None, :] / self.layer_weight.sum()
            weighted_sum_feature = torch.sum(hidden_states * weight, -1)
            feature = self.output_layers['layers_sum'](self.dropout(weighted_sum_feature))
            head_features.append(weighted_sum_feature)
            features.append(feature)

        features = [F.softmax(f, -1) for f in features]
        outputs = torch.stack(features, -1)
        if len(self.head_types) > 1:
            if self.ensemble_type == 'avg':
                outputs = torch.mean(outputs, -1)
            elif self.ensemble_type == 'weight':
                if self.settings['training']['trainer'] == 'multi':
                    outputs = outputs.detach()
                weight = self.ensemble_weight.weight / torch.sum(self.ensemble_weight.weight)
                outputs = torch.sum(weight * outputs, -1)
        outputs = outputs.reshape((inputs['input_ids'].shape[0], -1))
        if self.output_head_features:
            return outputs, features, head_features
        return outputs

    def initialize(self):
        if self.ensemble_type == 'weight':
            torch.nn.init.constant_(self.ensemble_weight.weight, 1.0)
        self.output_layers.apply(self._init_weight)
        for i in range(self.num_reinit_layers):
            self.model.encoder.layer[-(1 + i)].apply(self._init_weight)

    def _init_weight(self, module):
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean=0.0, std=self.model.config.initializer_range)
            if module.bias is not None:
                module.bias.data.zero_()


class ModelCustomHeadWithClassOutput(nn.Module):
    def __init__(
        self,
        settings: Dict[str, Any],
        model: PreTrainedModel,
        class_model: ModelCustomHeadClassification
    ) -> None:
        super().__init__()
        self.settings = settings
        self.model = model
        self.class_model = class_model
        self.num_add_classes = class_model.num_classes
        self.num_classes = settings['model']['num_classes']
        self.num_reinit_layers = settings['model'].get('num_reinit_layers', 0)
        self.head_types = settings['model']['custom_head_types']
        self.num_use_layers = self.settings['model']['num_use_layers_as_output']
        output_layers = {}

        if 'attn' in self.head_types:
            self.hidden_dim = self.settings['model']['head_hidden_dim']
            self.intermediate_dim = self.settings['model'].get('head_intermediate_dim', self.hidden_dim)
            self.attn_head = AttentionHead(self.hidden_dim, self.intermediate_dim)
        if 'layers_sum' in self.head_types:
            self.layer_weight = nn.Parameter(torch.tensor([1] * self.num_use_layers, dtype=torch.float))

        for head in self.head_types:
            if 'concat' in head:
                output_layers[head] = nn.Linear(self.model.config.hidden_size * self.num_use_layers, self.num_classes)
            else:
                output_layers[head] = nn.Linear(self.model.config.hidden_size + self.num_add_classes, self.num_classes)
        self.output_layers = nn.ModuleDict(output_layers)

        self.dropout = nn.Dropout(settings['model']['dropout_rate'])
        self.ensemble_type = settings['model']['custom_head_ensemble']
        if self.ensemble_type == 'weight':
            self.ensemble_weight = nn.Linear(len(self.head_types), 1, bias=False)
        self.output_head_features = settings['model'].get('output_head_features', False)

        self.initialize()

    def forward(self, inputs):
        outputs = self.model(**inputs, output_hidden_states=True)
        with torch.no_grad():  # class model のパラメータは触らない
            class_outputs = self.class_model(inputs)[0]  # [batch_size, num_classes]
            if self.settings['model'].get('use_log_for_class_output', False):
                class_outputs = torch.log(class_outputs)
        head_features = []
        features = []
        if 'cls' in self.head_types:
            cls_state = outputs.last_hidden_state[:, 0, :]
            cls_state = torch.cat([cls_state, class_outputs], -1)
            feature = self.output_layers['cls'](self.dropout(cls_state))
            head_features.append(cls_state)
            features.append(feature)
        if 'avg' in self.head_types:
            avg_pool = torch.mean(outputs.last_hidden_state, 1)
            avg_pool = torch.cat([avg_pool, class_outputs], -1)
            feature = self.output_layers['avg'](self.dropout(avg_pool))
            head_features.append(avg_pool)
            features.append(feature)
        if 'max' in self.head_types:
            max_pool = torch.max(outputs.last_hidden_state, 1)[0]
            max_pool = torch.cat([max_pool, class_outputs], -1)
            feature = self.output_layers['max'](self.dropout(max_pool))
            head_features.append(max_pool)
            features.append(feature)
        if 'attn' in self.head_types:
            attn_state = self.attn_head(outputs.last_hidden_state)
            attn_state = torch.cat([attn_state, class_outputs], -1)
            feature = self.output_layers['attn'](self.dropout(attn_state))
            head_features.append(attn_state)
            features.append(feature)
        if 'layers_sum' in self.head_types:
            hidden_states = torch.stack(outputs.hidden_states[-self.num_use_layers:], -1)[:, 0, :, :]
            weight = self.layer_weight[None, None, :] / self.layer_weight.sum()
            weighted_sum_feature = torch.sum(hidden_states * weight, -1)
            weighted_sum_feature = torch.cat([weighted_sum_feature, class_outputs], -1)
            feature = self.output_layers['layers_sum'](self.dropout(weighted_sum_feature))
            head_features.append(weighted_sum_feature)
            features.append(feature)

        outputs = torch.stack(features, -1)
        if len(self.head_types) > 1:
            if self.ensemble_type == 'avg':
                outputs = torch.mean(outputs, -1)
            elif self.ensemble_type == 'weight':
                if self.settings['training']['trainer'] == 'multi':
                    outputs = outputs.detach()
                weight = self.ensemble_weight.weight / torch.sum(self.ensemble_weight.weight)
                outputs = torch.sum(weight * outputs, -1)
        outputs = outputs.reshape((inputs['input_ids'].shape[0]))
        if self.output_head_features:
            return outputs, features, head_features
        return outputs

    def initialize(self):
        if self.ensemble_type == 'weight':
            torch.nn.init.constant_(self.ensemble_weight.weight, 1.0)
        self.output_layers.apply(self._init_weight)
        for i in range(self.num_reinit_layers):
            self.model.encoder.layer[-(1 + i)].apply(self._init_weight)

    def _init_weight(self, module):
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean=0.0, std=self.model.config.initializer_range)
            if module.bias is not None:
                module.bias.data.zero_()


class ModelCustomHeadAuxNetwork(nn.Module):
    def __init__(
        self,
        settings: Dict[str, Any],
        config,
        model: PreTrainedModel
    ) -> None:
        super().__init__()
        self.settings = settings
        self.model = model
        self.num_reinit_layers = settings['model'].get('num_reinit_layers', 0)
        self.num_reinit_aux_layers = settings['model'].get('num_reinit_aux_layers', 0)
        self.head_types = settings['model']['custom_head_types']
        self.num_use_layers = self.settings['model']['num_use_layers_as_output']
        output_layers = {}

        if 'attn' in self.head_types:
            self.hidden_dim = self.settings['model']['head_hidden_dim']
            self.intermediate_dim = self.settings['model'].get('head_intermediate_dim', self.hidden_dim)
            self.attn_head = AttentionHead(self.hidden_dim, self.intermediate_dim)
        if 'layers_sum' in self.head_types:
            self.layer_weight = nn.Parameter(torch.tensor([1] * self.num_use_layers, dtype=torch.float))

        aux_config = copy.deepcopy(config)
        aux_config.num_hidden_layers = settings['model']['aux_num_hidden_layers']
        aux_config.hidden_size = settings['model']['aux_hidden_size']
        aux_config.num_attention_heads = aux_config.hidden_size // 8
        aux_config.intermediate_size = aux_config.hidden_size * 4
        self.aux_model = DebertaModel(aux_config)

        for head in self.head_types:
            if 'concat' in head:
                output_layers[head] = nn.Linear(
                    self.model.config.hidden_size * self.num_use_layers + aux_config.hidden_size, 1
                )
            elif head == 'conv':
                continue
            else:
                output_layers[head] = nn.Linear(self.model.config.hidden_size + aux_config.hidden_size, 1)
        self.output_layers = nn.ModuleDict(output_layers)

        self.dropout = nn.Dropout(settings['model']['dropout_rate'])
        self.ensemble_type = settings['model']['custom_head_ensemble']
        if self.ensemble_type == 'weight':
            self.ensemble_weight = nn.Linear(len(self.head_types), 1, bias=False)
        self.output_head_features = settings['model'].get('output_head_features', False)

        self.initialize()

    def forward(self, inputs):
        outputs = self.model(**inputs, output_hidden_states=True)
        aux_outputs = self.aux_model(**inputs).last_hidden_state[:, 0, :]
        head_features = []
        features = []
        if 'cls' in self.head_types:
            cls_state = outputs.last_hidden_state[:, 0, :]
            feature = self.output_layers['cls'](self.dropout(torch.cat([cls_state, aux_outputs], axis=-1)))
            head_features.append(cls_state)
            features.append(feature)
        if 'avg' in self.head_types:
            avg_pool = torch.mean(outputs.last_hidden_state, 1)
            feature = self.output_layers['avg'](self.dropout(torch.cat([avg_pool, aux_outputs], axis=-1)))
            head_features.append(avg_pool)
            features.append(feature)
        if 'max' in self.head_types:
            max_pool = torch.max(outputs.last_hidden_state, 1)[0]
            feature = self.output_layers['max'](self.dropout(torch.cat([max_pool, aux_outputs], axis=-1)))
            head_features.append(max_pool)
            features.append(feature)
        if 'attn' in self.head_types:
            attn_state = self.attn_head(outputs.last_hidden_state)
            feature = self.output_layers['attn'](self.dropout(torch.cat([attn_state, aux_outputs], axis=-1)))
            head_features.append(attn_state)
            features.append(feature)
        if 'layers_concat' in self.head_types:
            hidden_states = outputs.hidden_states[-self.num_use_layers:]
            cat_feature = torch.cat([state[:, 0, :] for state in hidden_states], -1)
            feature = self.output_layers['layers_concat'](self.dropout(torch.cat([cat_feature, aux_outputs], axis=-1)))
            head_features.append(cat_feature)
            features.append(feature)
        if 'layers_sum' in self.head_types:
            hidden_states = torch.stack(outputs.hidden_states[-self.num_use_layers:], -1)[:, 0, :, :]
            weight = self.layer_weight[None, None, :] / self.layer_weight.sum()
            weighted_sum_feature = torch.sum(hidden_states * weight, -1)
            feature = self.output_layers['layers_sum'](
                self.dropout(torch.cat([weighted_sum_feature, aux_outputs], axis=-1))
            )
            head_features.append(weighted_sum_feature)
            features.append(feature)

        outputs = torch.cat(features, -1)
        if len(self.head_types) > 1:
            if self.ensemble_type == 'avg':
                outputs = torch.mean(outputs, -1)
            elif self.ensemble_type == 'weight':
                if self.settings['training']['trainer'] == 'multi':
                    outputs = outputs.detach()
                weight = self.ensemble_weight.weight / torch.sum(self.ensemble_weight.weight)
                outputs = torch.sum(weight * outputs, -1)
        outputs = outputs.reshape(inputs['input_ids'].shape[0])
        if self.output_head_features:
            features = [f.reshape(inputs['input_ids'].shape[0]) for f in features]
            return outputs, features, head_features
        return outputs

    def initialize(self):
        if self.ensemble_type == 'weight':
            torch.nn.init.constant_(self.ensemble_weight.weight, 1.0)
        self.output_layers.apply(self._init_weight)
        for i in range(self.num_reinit_layers):
            self.model.encoder.layer[-(1 + i)].apply(self._init_weight)
        for i in range(self.num_reinit_aux_layers):
            self.aux_model.encoder.layer[-(1 + i)].apply(self._init_weight)

    def _init_weight(self, module):
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean=0.0, std=self.model.config.initializer_range)
            if module.bias is not None:
                module.bias.data.zero_()


class ModelAuxNetwork(nn.Module):

    def __init__(
        self,
        settings: Dict[str, Any],
        config,
        model: PreTrainedModel,
    ) -> None:
        super().__init__()
        self.settings = settings
        self.config = config
        self.model = model
        self.num_reinit_layers = settings['model'].get('num_reinit_layers', 0)
        aux_config = copy.deepcopy(config)
        aux_config.num_hidden_layers = settings['model']['aux_num_hidden_layers']
        aux_config.hidden_size = settings['model']['aux_hidden_size']
        aux_config.num_attention_heads = aux_config.hidden_size // 8
        aux_config.intermediate_size = aux_config.hidden_size * 4

        self.aux_model = DebertaModel(aux_config)
        self.output_layer = nn.Linear(self.model.config.hidden_size + aux_config.hidden_size, 1)
        self.dropout = nn.Dropout(settings['model']['dropout_rate'])

    def forward(self, inputs):
        outputs = self.model(**inputs)
        cls_state = outputs.last_hidden_state[:, 0, :]

        aux_outputs = self.aux_model(**inputs)
        aux_outputs = aux_outputs.last_hidden_state[:, 0, :]
        outputs = self.output_layer(self.dropout(torch.cat([cls_state, aux_outputs], -1)))
        outputs = outputs.reshape(inputs['input_ids'].shape[0])
        return outputs

    def initialize(self):
        self.output_layer.apply(self._init_weight)
        for i in range(self.num_reinit_layers):
            self.model.encoder.layer[-(1 + i)].apply(self._init_weight)

    def _init_weight(self, module):
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean=0.0, std=self.model.config.initializer_range)
            if module.bias is not None:
                module.bias.data.zero_()


class AttentionHead(nn.Module):
    def __init__(self, in_features, hidden_dim):
        super().__init__()
        self.in_features = in_features
        self.middle_features = hidden_dim

        self.W = nn.Linear(in_features, hidden_dim)
        self.V = nn.Linear(hidden_dim, 1)
        self.out_features = hidden_dim

    def forward(self, features):
        att = torch.tanh(self.W(features))
        score = self.V(att)
        attention_weights = torch.softmax(score, dim=1)

        context_vector = attention_weights * features
        context_vector = torch.sum(context_vector, dim=1)

        return context_vector


def get_model(
    settings: Dict[str, Any],
    fold: Optional[int] = None,
    is_kaggle: bool = False,
    pretrained: bool = True,
    hg_model: Optional[PreTrainedModel] = None,
    *args, **kwargs
) -> nn.Module:
    mst = settings['model']
    fold = fold or 0
    if hg_model is None:
        hg_model = get_transformers_model(
            settings['model'], mst['model_name'], pretrained, ckptdir=settings['ckptdir']
        )
    if fold == 0:
        LOGGER.info(hg_model.config)
    model_type = settings['model'].get('model_type', 'cls_base')
    model: nn.Module
    if model_type == 'custom_head':
        model = ModelCustomHeadEnsemble(settings, hg_model)
    elif model_type == 'custom_head_class':
        model = ModelCustomHeadClassification(settings, hg_model)
    elif model_type == 'custom_head_aux':
        model = ModelCustomHeadAuxNetwork(settings, hg_model.config, hg_model)
    elif model_type == 'custom_head_with_class':
        class_model = kwargs['class_model']
        model = ModelCustomHeadWithClassOutput(settings, hg_model, class_model)
    elif model_type == 'aux_network':
        model = ModelAuxNetwork(settings, hg_model.config, hg_model)
    else:
        model = ModelCustomHeadEnsemble(settings, hg_model)

    return model


def predict(
    model: nn.Module,
    df: pd.DataFrame,
    dataloader: DataLoader,
    batch_size: int,
    num_classes: int,
    use_amp: bool = True
) -> np.ndarray:
    if num_classes > 1:
        preds = np.zeros((len(df), num_classes), dtype=np.float32)
    else:
        preds = np.zeros((len(df)), dtype=np.float32)
    device = torch.device('cuda')
    model.to(device)
    model.eval()
    for i, batch in enumerate(dataloader):
        if isinstance(batch, (list, tuple)):
            inputs = batch[0]
        else:
            inputs = batch
        if 'label' in inputs.keys():
            inputs.pop('label')
        if 'comparison_label' in inputs.keys():
            inputs.pop('comparison_label')
        for key in inputs.keys():
            inputs[key] = inputs[key].to(device).long()
        with autocast(enabled=use_amp):
            with torch.no_grad():
                outputs = model(inputs)
                if isinstance(outputs, (list, tuple)):
                    outputs = outputs[0]
        preds[i * batch_size:(i + 1) * batch_size] = outputs.detach().cpu().numpy()
    return preds


def test(
    settings: Dict[str, Any],
    model: nn.Module,
    dataloader: DataLoader,
    df_test: pd.DataFrame,
) -> np.ndarray:
    batch_size = settings['training']['test_batch_size']
    use_amp = settings['training']['use_amp']
    num_classes = settings['training']['num_classes']
    preds = predict(model, df_test, dataloader, batch_size, num_classes, use_amp)
    return preds


def ensemble_inference(
    models: List[str],
    df: pd.DataFrame,
    num_classes: int,
    ckptdir: Path,
    datadir: Path,
    is_kaggle: bool,
    ensemble_type: str = 'avg',
    weights: Optional[np.ndarray] = None
) -> np.ndarray:
    device = torch.device('cuda')

    whole_preds = np.zeros((len(models), len(df)))
    for i, model_name in enumerate(models):
        gc.collect()
        torch.cuda.empty_cache()
        LOGGER.info(f'inference by {model_name} start.')
        model_ckptdir = ckptdir / model_name
        model_settings = yaml.safe_load(open(model_ckptdir / 'settings.yml', 'r'))
        model_settings['ckptdir'] = model_ckptdir
        model_settings['training']['ckptdir'] = model_ckptdir
        model_settings['model']['ckpt_from'] = None
        tokenizer = AutoTokenizer.from_pretrained(model_ckptdir)
        for fold in range(model_settings['training']['num_folds']):
            gc.collect()
            torch.cuda.empty_cache()
            ds = MyDataset(model_settings['training'], df, datadir, tokenizer, Mode.TEST, fold)
            dataloader = get_dataloader(model_settings['training'], ds, Mode.TEST, fold)
            if model_settings['name'] == '081_de_ft961_classmodel':
                class_model_settings = yaml.safe_load(open(model_ckptdir / '049_settings.yml', 'r'))
                class_model_settings['ckptdir'] = str(model_ckptdir)
                class_model = get_model(class_model_settings, fold=fold, is_kaggle=IS_KAGGLE, pretrained=False).cuda()
                model = get_model(
                    model_settings, fold=fold, is_kaggle=IS_KAGGLE, pretrained=False, class_model=class_model
                )
            else:
                class_model = None
                model = get_model(model_settings, fold=fold, is_kaggle=is_kaggle, pretrained=False)
            model.load_state_dict(torch.load(model_ckptdir / f'model_{fold}.pt'))
            model.to(device)
            preds = test(model_settings, model, dataloader, df)
            whole_preds[i] += preds / model_settings['training']['num_folds']
            del model, class_model
    if ensemble_type == 'avg':
        whole_preds = np.mean(whole_preds, axis=0)
    else:
        if weights is not None:
            whole_preds = np.sum(whole_preds * weights[:, np.newaxis], axis=0)
        else:
            whole_preds = np.mean(whole_preds, axis=0)

    return whole_preds


def do_head_stacking(
    settings: Dict[str, Any],
    df_train: pd.DataFrame,
    df_test: pd.DataFrame
) -> Tuple[np.ndarray, np.ndarray]:
    pass


def do_ensemble(
    settings: Dict[str, Any],
    df_train: pd.DataFrame,
    df_test: pd.DataFrame,
    fold: Optional[int] = None
) -> Tuple[np.ndarray, np.ndarray]:
    models = settings['ensemble']['models']
    best_ensemble_weights = np.array([
        0.02372718, 0.08941032, 0.04026971, 0.04065091, 0.11687406,
        0.069342, 0.12299674, 0.01315429, 0.05438842, 0.04086522, 0.16173649, 0.11391101, 0.12956535
    ])

    preds = ensemble_inference(
        models, df_test, settings['model']['num_classes'],
        CKPTDIR.parent, TESTDIR, IS_KAGGLE,
        settings['ensemble']['ensemble_type'], best_ensemble_weights
    )
    return preds


def do_stacking(
    settings: Dict[str, Any],
    df_train: pd.DataFrame,
    df_test: pd.DataFrame
) -> Tuple[np.ndarray, np.ndarray]:
    pass


def do_training(settings: Dict[str, Any], df: pd.DataFrame):
    pass


def do_inference(settings: Dict[str, Any], df: pd.DataFrame):
    tst = settings['training']
    tokenizer = AutoTokenizer.from_pretrained(CKPTDIR)
    whole_preds = np.zeros((len(df)))
    if not settings['is_full_training']:
        LOGGER.info('inference is skipped since is_full_training is False')
        return whole_preds
    for fold in range(settings['training']['num_folds']):
        ds = MyDataset(tst, df, TESTDIR, tokenizer, Mode.TEST, fold)
        dataloader = get_dataloader(tst, ds, Mode.TEST, fold)
        model = get_model(settings, fold=fold, is_kaggle=IS_KAGGLE, pretrained=False)
        model.load_state_dict(torch.load(CKPTDIR / f'model_{fold}.pt'))
        preds = test(settings, model, dataloader, df)
        if 'class' in tst['trainer']:
            centers = pickle.load(open(CKPTDIR / 'centers.pkl', 'rb'))
            preds = np.sum(preds * centers[None, :], -1)
        whole_preds += preds / settings['training']['num_folds']
    return whole_preds


def submit(settings: Dict[str, Any], df_sub: pd.DataFrame, preds: np.ndarray, suffix: str = '') -> None:
    df_sub['target'] = preds
    df_sub.to_csv(OUTPUTDIR / f'submission{suffix}.csv', index=False)
    print(df_sub)


def main(settings: Dict[str, Any]):
    mode = settings.get('mode', 'training')
    if mode == 'training' and IS_KAGGLE:
        mode = 'inference'
    LOGGER.info(f'start {mode}')
    if mode == 'ensemble':
        fold = settings['ensemble'].get('fold_only')
        if fold is not None:
            preds = do_ensemble(settings, df_train, df_test, fold=fold)
            submit(settings, df_sub, preds, str(fold))
        else:
            preds = do_ensemble(settings, df_train, df_test)
            submit(settings, df_sub, preds)
    elif mode == 'stacking':
        preds = do_stacking(settings, df_train, df_test)
        submit(settings, df_sub, preds)
    elif mode == 'inference':
        preds = do_inference(settings, df_test)
        submit(settings, df_sub, preds)
    elif mode == 'headstack':
        preds = do_head_stacking(settings, df_train, df_test)
        submit(settings, df_sub, preds)
    elif mode == 'other_training':
        do_training(settings, df_train)
    else:
        do_training(settings, df_train)
        preds = do_inference(settings, df_test)
        submit(settings, df_sub, preds)


if __name__ == '__main__':
    if IS_KAGGLE:
        LOGGER.info('starting in kaggle environment')
        fix_seed(SETTINGS['seed'])
        main(SETTINGS)
    else:
        args = parse()
        LOGGER.info(f'starting with args: {args}')
        os.environ['CUDA_VISIBLE_DEVICES'] = args.gpus
        os.environ['TOKENIZERS_PARALLELISM'] = 'true'
        fix_seed(SETTINGS['seed'])
        main(SETTINGS)
