In [None]:
import sys
import os  
os.environ["TOKENIZERS_PARALLELISM"] = "false"
import traceback

!ln -s ../input/feedback ./src
if os.path.exists('/kaggle'):
  sys.path.append('/kaggle/input/pikachu/utils')
  sys.path.append('/kaggle/input/pikachu/third')
  sys.path.append('.')
!ls ../input

In [None]:
!pip install -q icecream --no-index --find-links=file:///kaggle/input/icecream/ 

In [None]:
!pip install -q pymp-pypi --no-index --find-links=file:///kaggle/input/pymp-pypi/pymp-pypi-0.4.5/dist

In [None]:
# The following is necessary if you want to use the fast tokenizer for deberta v2 or v3
import shutil
from pathlib import Path

transformers_path = Path("/opt/conda/lib/python3.7/site-packages/transformers")

input_dir = Path("../input/deberta-v2-3-fast-tokenizer")

convert_file = input_dir / "convert_slow_tokenizer.py"
conversion_path = transformers_path/convert_file.name

if conversion_path.exists():
  conversion_path.unlink()

shutil.copy(convert_file, transformers_path)
deberta_v2_path = transformers_path / "models" / "deberta_v2"

for filename in ['tokenization_deberta_v2.py', 'tokenization_deberta_v2_fast.py']:
  filepath = deberta_v2_path/filename
  if filepath.exists():
    filepath.unlink()

  shutil.copy(input_dir/filename, filepath)

In [None]:
from IPython.display import display
import tensorflow as tf
import torch
from absl import flags
FLAGS = flags.FLAGS
from transformers import AutoTokenizer
from datasets import Dataset
from src import config
from src.util import *
from src.get_preds import *
from src.eval import *
import melt as mt
import numpy as np
import glob
import gc
from numba import cuda
from gezi import tqdm
import gezi
import husky
import lele

In [None]:
gpus = tf.config.experimental.list_physical_devices(device_type='GPU')
for gpu in gpus:
  tf.config.experimental.set_memory_growth(gpu, True)

In [None]:
num_test_ids = 1000
folds = pd.read_csv('../input/feedback/folds.csv')
test_ids = folds[folds.kfold==0].id.values
test_ids.sort()
test_ids = test_ids[:num_test_ids]
len(test_ids)
# test_ids

In [None]:
gezi.init_flags()
model_root = '../input'
model_dirs = [x for x in glob.glob(f'{model_root}/feedback-model*') if os.path.isdir(x)]
model_dirs = [f'../input/feedback-model{i}' for i in range(len(model_dirs))]
model_dir = model_dirs[0]
# model_dirs

In [None]:
# Make sure tf models first, then cuda.close() to release gpu then infer torch models
tf_models = []
first_models = []
ic(first_models)
model_dirs = gezi.unique_list([*tf_models, *first_models, *model_dirs])

m = model_dirs
used_model_indexes = list(range(len(model_dirs)))
# used_model_indexes = list(range(15))
# used_model_indexes = [9]
# 8,9,10,11,12,13
# used_model_indexes = [0,1,2,3]
# used_model_indexes = [0,1,2,3,8,9,10,11,12,13,14]
# used_model_indexes = [3, 4]
used_model_indexes = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 17]
# used_model_indexes = [30,31,32,33,34,35]
# used_model_indexes += [36,37,38,39,40]
used_model_indexes = [30,31,32]
# used_model_indexes = [30]
model_dirs = [m[i] for i in used_model_indexes]
used_tf_models = [x for x in model_dirs if x in tf_models]
num_tf_models = len(used_tf_models)


In [None]:
def get_batch_size(backbone, model_dir=''):
  if 'longformer' in backbone:
    return 32
  if 'bird' in backbone:
    return 32
  if 'bart' in backbone:
    return 64
  if 'deberta' in backbone and 'xlarge' in backbone:
    return 32
  if 'deberta' in backbone and any(x in model_dir for x in ['len1024']):
    return 16
  if 'deberta' in backbone and any(x in model_dir for x in ['len1280']):
    return 16
  if 'deberta' in backbone and any(x in model_dir for x in ['len1536']):
    return 8
  if 'deberta' in backbone and any(x in model_dir for x in ['len1600']):
    return 8
  if 'deberta' in backbone:
    return 64
  if 'funnel' in backbone:
    return 8
  return 128

In [None]:
def rename(backbone):
  m = {
    'roberta-large-squad2': 'roberta-large',
    'large': 'funnel-large',
  }
  return m.get(backbone, backbone)

In [None]:
mns = []
for i, model_dir in tqdm(enumerate(model_dirs), total=len(model_dirs)):
#   ic(model_dir)
  gezi.restore_configs(model_dir)
#   ic(FLAGS.model_dir)
  mns.append(os.path.basename(FLAGS.model_dir))
#   assert 'online' in FLAGS.model_dir
mns

In [None]:
MIN_WEIGHT = 1
weights_dict = {}
weights_dict =  {
#   'bart.start.run2': 9, 
#   'roberta.start.nwemb-0': 9,
#   'deberta.start': 8,
 
#   'deberta-xlarge.start': 9, 
#   'deberta-xlarge.end': 9, 

#   'deberta-v3.start.len1024.stride-256.seq_encoder-0': 10, 
#   'deberta-v3.start.len1024.stride-256': 6,

#   'deberta-v3.start.len1536': 7, 
#   'deberta-v3.start.len1024.rnn_bi': 8, 
#   'deberta-v3.end.len1024.seq_encoder-0': 10,
#   'deberta-v3.mid.len1024': 8,

#   'deberta-v3.start.stride-256.seq_encoder-0': 7, 
#   'deberta-v3.start.nwemb-0.mark_end-0': 10, 
#   'deberta-v3.se': 10,
#   'deberta-v3.se2': 10,

#   'longformer.start.len1536': 6,
#   'longformer.start.len1600': 6,
#   'funnel.start.len1536.bs-8': 6,
# #   'deberta-v3.start.len1280': 6,
# #   'deberta-v3.start.len1280.rnn_layers-2': 6,
    
#   'deberta-v3.start.len1024.stride-512': 4,
#   'electra.start.nwemb-0.run2': 7,
}
weights_dict0 = {'bart.start.run2': 6,
                        'deberta-v3.end.len1024.seq_encoder-0': 6,
                        'deberta-v3.mid.len1024': 4,
                        'deberta-v3.se': 6,
                        'deberta-v3.se2': 1,
                        'deberta-v3.start.len1024.rnn_bi': 5,
                        'deberta-v3.start.len1024.stride-256': 6,
                        'deberta-v3.start.len1024.stride-256.seq_encoder-0': 10,
                        'deberta-v3.start.len1536': 4,
                        # 'deberta-v3.start.len1536': 2,
                        # 'deberta-v3.start.mui-end-mid': 1,
                        # 'deberta-v3.start.len1536.rnn_type-GRU': 4,
                        'deberta-v3.start.nwemb-0.mark_end-0': 8,
                        'deberta-v3.start.stride-256.seq_encoder-0': 9,
                        'deberta-xlarge.end': 0,
                        'deberta-xlarge.start': 6,
                        'deberta.start': 6,
                        'longformer.start.len1536': 9,
                        'roberta.start.nwemb-0': 6}
weights_dict1 = {'bart.start.run2': 7,
                        'deberta-v3.end.len1024.seq_encoder-0': 10,
                        'deberta-v3.mid.len1024': 6,
                        'deberta-v3.se': 2,
                        'deberta-v3.se2': 7,
                        'deberta-v3.start.len1024.rnn_bi': 8,
                        'deberta-v3.start.len1024.stride-256': 10,
                        'deberta-v3.start.len1024.stride-256.seq_encoder-0': 7,
                        'deberta-v3.start.len1536': 8,
                        # 'deberta-v3.start.len1536': 4,
                        # 'deberta-v3.start.mui-end-mid': 1,
                        # 'deberta-v3.start.len1536.rnn_type-GRU': 8,
                        'deberta-v3.start.nwemb-0.mark_end-0': 8,
                        'deberta-v3.start.stride-256.seq_encoder-0': 7,
                        'deberta-xlarge.end': 7,
                        'deberta-xlarge.start': 10,
                        'deberta.start': 6,
                        'longformer.start.len1536': 8,
                        'roberta.start.nwemb-0': 5}
weights_dict2 = {'bart.start.run2': 6,
                        'deberta-v3.end.len1024.seq_encoder-0': 7,
                        'deberta-v3.mid.len1024': 5,
                        'deberta-v3.se': 9,
                        'deberta-v3.se2': 5,
                        'deberta-v3.start.len1024.rnn_bi': 9,
                        'deberta-v3.start.len1024.stride-256': 6,
                        'deberta-v3.start.len1024.stride-256.seq_encoder-0': 10,
                        'deberta-v3.start.len1536': 6,
                        # 'deberta-v3.start.len1536': 1.5,
                        # 'deberta-v3.start.mui-end-mid': 0,
                        # 'deberta-v3.start.len1536.rnn_type-GRU': 3,
                        'deberta-v3.start.nwemb-0.mark_end-0': 0,
                        'deberta-v3.start.stride-256.seq_encoder-0': 5,
                        'deberta-xlarge.end': 8,
                        'deberta-xlarge.start': 4,
                        'deberta.start': 8,
                        'longformer.start.len1536': 9,
                        'roberta.start.nwemb-0': 6}
weights_dicts = [weights_dict0, weights_dict1, weights_dict2]
ic(gezi.sort_byval(weights_dict))
len(weights_dict)

In [None]:
# def get_weight(x):
#   assert x in weights_dict
#   if weights_dict:
#     return weights_dict[x]
#   return 1
def get_weight(x, idx=0):
  weight = 1
  # return 1
  if x in weights_dict:
    return weights_dicts[idx][x]
    # return weights_dict[x]
  return max(weight, 1)

# if weights_dict:
#   mns_ori = mns.copy()
#   indxes = [i for i in range(len(mns)) if mns[i] in weights_dict and weights_dict[mns[i]] >= MIN_WEIGHT]
#   mns = [mns[i] for i in range(len(mns_ori)) if i in indxes]
#   model_dirs = [model_dirs[i] for i in range(len(mns_ori)) if i in indxes]
    
weights = [get_weight(x) for x in mns]
weights2 = [get_weight(x, 1) for x in mns]
weights3 = [get_weight(x, 2) for x in mns]
ic(list(zip(range(len(model_dirs)), model_dirs, mns, weights)), len(model_dirs))

In [None]:
verify = False
# verify = True
if verify:
    mode = 'train'
    ensembler = Ensembler(need_sort=True)
    for i, model_dir in tqdm(enumerate(model_dirs), total=len(model_dirs)):
      ic(model_dir)
      gezi.init_flags()
      FLAGS.multi_inputs = False
      FLAGS.seq_encoder = False
      FLAGS.merge_tokens = False
      FLAGS.split_punct = False
      FLAGS.custom_tokenize = False
      FLAGS.ori_deberta_v2_tokenizer = True
      FLAGS.max_len_valid = None
      FLAGS.lower = False
      FLAGS.stride = None
      FLAGS.scatter_method = 0
      gezi.restore_configs(model_dir)
#       if FLAGS.stride is None:
#         FLAGS.max_len = 2048
#       FLAGS.fake_infer = True
      ic(FLAGS.merge_tokens)
      FLAGS.pymp = False
      FLAGS.eval_len = True
      ic(FLAGS.model_dir)
#       assert 'online' in FLAGS.model_dir
      model_dir_ = FLAGS.model_dir
      FLAGS.model_dir = model_dir
      backbone = rename(FLAGS.backbone.split('/')[-1]).replace('_', '-')
      FLAGS.backbone = '../input/' + backbone
      model = get_model()
      model.eval()
      gezi.load_weights(model, model_dir)
      d = pd.read_csv(f'{model_dir}/metrics.csv')
      display(d[['f1/Overall']])
      assert d['f1/Overall'].values[-1] > 0.6
      ic(FLAGS.adjacent_rule)
      double_times = 0
      inputs = get_inputs(FLAGS.backbone, sort=True, mode=mode, double_times=double_times, test_ids=test_ids)
      ic(len(inputs['id']))
      ic(gezi.get_mem_gb())
      batch_size = get_batch_size(backbone, model_dir_)
      ic(FLAGS.backbone, FLAGS.max_len, FLAGS.lower, FLAGS.multi_inputs, FLAGS.multi_inputs_srcs, 
         FLAGS.merge_tokens, FLAGS.seq_encoder, FLAGS.use_relative_positions, FLAGS.stride,
         FLAGS.mask_inside, FLAGS.label_inside, FLAGS.word_combiner, FLAGS.scatter_method,
         FLAGS.custom_tokenize, FLAGS.ori_deberta_v2_tokenizer, FLAGS.split_punct, 
         FLAGS.block_size, FLAGS.n_blocks, batch_size)
     
#       p = gezi.predict(model, inputs, batch_size, dynamic_keys=['input_ids', 'word_ids'], mask_key='attention_mask')
      p = gezi.predict(model, inputs, batch_size)
      p.update({
        'id': inputs['id'],
        'word_ids': inputs['word_ids'],
        'num_words': inputs['num_words']
      })
      convert_res(p)
      df_gt = pd.read_csv('../input/feedback-prize-2021/train.csv')
      df_gt = df_gt[df_gt.id.isin(test_ids)]
      df_gt = df_gt.sort_values('id')
      df_gt['num_words'] = df_gt.id.apply(lambda id: len(open(f'../input/feedback-prize-2021/train/{id}.txt').read().split()))
#       res = get_metrics(df_gt, p)
#       ic(res)
      #ensembler.add(p, weights[i])
      ensembler.add(p, weights=[weights[i], weights2[i], weights3[i]])
#       df = get_preds(p)
#       display(df)
     
      if FLAGS.torch:
        torch.cuda.empty_cache()
      else:
        # only the last tf model should cuda close
        if i + 1 == num_tf_models:
          cuda.select_device(0)
          cuda.close()
      del model
      del inputs
      gc.collect()

    ic(gezi.get_mem_gb())
    p = ensembler.finalize()
    df_gt = pd.read_csv('../input/feedback-prize-2021/train.csv')
    df_gt = df_gt[df_gt.id.isin(test_ids)]
    df_gt = df_gt.sort_values('id')
    df_gt['num_words'] = df_gt.id.apply(lambda id: len(open(f'../input/feedback-prize-2021/train/{id}.txt').read().split()))
    df = get_preds(p)
    display(df)
    res = get_metrics(df_gt, p)
    ic(res)

In [None]:
ic(P)

In [None]:
mode = 'test'
# mode = 'train'
ensembler = Ensembler(need_sort=True)
for i, model_dir in tqdm(enumerate(model_dirs), total=len(model_dirs)):
  ic(model_dir)
  gezi.init_flags()
  FLAGS.multi_inputs = False
  FLAGS.seq_encoder = False
  FLAGS.merge_tokens = False
  FLAGS.split_punct = False
  FLAGS.custom_tokenize = False
  FLAGS.ori_deberta_v2_tokenizer = True
  FLAGS.max_len_valid = None
  FLAGS.lower = False
  FLAGS.stride = None
  FLAGS.scatter_method = 0
  gezi.restore_configs(model_dir)
  if mode == 'train':
    FLAGS.fake_infer = True
  FLAGS.pymp = False
  ic(FLAGS.model_dir)
#   assert 'online' in FLAGS.model_dir
  model_dir_ = FLAGS.model_dir
  FLAGS.model_dir = model_dir
  backbone = rename(FLAGS.backbone.split('/')[-1]).replace('_', '-')
  FLAGS.backbone = '../input/' + backbone
  model = get_model()
#   model = Model()
  gezi.load_weights(model, model_dir)
  d = pd.read_csv(f'{model_dir}/metrics.csv')
  display(d[['f1/Overall']])
  assert d['f1/Overall'].values[-1] > 0.6
  ic(FLAGS.adjacent_rule)
  # double_times change to 6(so 320 testids) to test if gpu mem is ok for new model
  double_times = 0
#   double_times = 6

  inputs = get_inputs(FLAGS.backbone, sort=True, mode=mode, double_times=double_times)
  batch_size = get_batch_size(backbone, model_dir_)
  ic(gezi.get_mem_gb())
  ic(FLAGS.backbone, FLAGS.max_len, FLAGS.lower, FLAGS.multi_inputs, FLAGS.multi_inputs_srcs, 
     FLAGS.merge_tokens, FLAGS.seq_encoder, FLAGS.use_relative_positions, FLAGS.stride,
     FLAGS.mask_inside, FLAGS.label_inside, FLAGS.word_combiner, FLAGS.scatter_method,
     FLAGS.custom_tokenize, FLAGS.ori_deberta_v2_tokenizer, FLAGS.split_punct, 
     FLAGS.num_words_emb, FLAGS.mark_end, batch_size)

  p = gezi.predict(model, inputs, batch_size, dynamic_keys=['input_ids', 'word_ids'], mask_key='attention_mask')
  p.update({
    'id': inputs['id'],
    'word_ids': inputs['word_ids'],
    'num_words': inputs['num_words']
  })
  convert_res(p)
  ensembler.add(p, weights[i])
#   ensembler.add(p, weights=[weights[i], weights2[i], weights3[i]])
  if len(inputs['id']) < 1000:
    df = get_preds(p)
    display(df)
    
  if FLAGS.torch:
    torch.cuda.empty_cache()
  else:
    # only the last tf model should cuda close
    if i + 1 == num_tf_models:
      cuda.select_device(0)
      cuda.close()
  del model
  del inputs
  gc.collect()

ic(gezi.get_mem_gb())
p = ensembler.finalize()

In [None]:
df = get_preds(p)
display(df)

In [None]:
df.to_csv('submission.csv',index=False)