# Libraries

In [None]:
#defaul libraries
#https://docs.python.org/ja/
import os
import sys
import io

import math
import random
import pprint
import time
import datetime
import typing
import json
import glob
import requests
import warnings
import gc
from pprint import pprint
import re

import numpy as np #https://numpy.org/
import pandas as pd #https://pandas.pydata.org/
import sklearn #https://scikit-learn.org/stable/

import matplotlib.pyplot as plt #https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.html
%matplotlib inline

import seaborn as sns
sns.set()

from tqdm import tqdm #https://tqdm.github.io/

import torch #https://pytorch.org/
import transformers #https://huggingface.co/transformers/

# Configuration

In [None]:
class CFG():
    
    data_path="path/to/datasets/"
    save_path=''
    debug=False
    seed=0

    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    batch_size=128
    epochs=20
    learning_rate=1#0.001
    kFold=1
    amp=True

    #高速化関連
    #https://qiita.com/sugulu_Ogawa_ISID/items/62f5f7adee083d96a587

    #GPU 遅くなるらしい↓
    torch.backends.cudnn.deterministic = True

    #イテレーションごとのnnの順伝搬および誤差関数の 計算手法がある程度一定であれば、torch.backends.cudnn.benchmark = Trueで GPU での計算が高速化
    torch.backends.cudnn.benchmark = False

def set_seed(seed=0):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    #tf.random.set_seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

torch.cuda.set_device(CFG.device)
set_seed(CFG.seed)
print(CFG.device)
print(torch.cuda.current_device())
print(torch.cuda.get_device_name())
print(torch.cuda.device_count())

# Color

In [None]:
def color(string,fg='DEFAULT',bg='DEFAULT',fg_rgb=None,bg_rgb=None,style='END'):
    colors=['BLACK','RED','GREEN','YELLOW','BLUE','PURPLE','CYAN','WHITE','8','DEFAULT']
    styles=['END','BOLD','2','3','UNDERLINE','5','6','REVERSE','INVISIBLE','9']

    fg=f'\033[3{colors.index(fg)}m'
    bg=f'\033[4{colors.index(bg)}m'
    style=f'\033[0{styles.index(style)}m'

    if fg_rgb:fg=f"\033[38;2;{fg_rgb[0]};{fg_rgb[1]};{fg_rgb[2]}m"
    if bg_rgb:bg=f"\033[48;2;{bg_rgb[0]};{bg_rgb[1]};{bg_rgb[2]}m"

    return style+fg+bg+str(string)+'\033[0m'

# Dataset

## Read Data

In [None]:
data={
    'train':pd.DataFrame(columns=['feature','target']),
    'val':pd.DataFrame({'feature':val.feature,'target':val['target']}),
    'test':pd.DataFrame(test),
}

#データの読み込み
#data=pd.read_pickle('../dataset')

#sample
if CFG.debug:
    data=data.sample(frac=0.01,random_state=CFG.seed)
    data.reset_index(drop=True,inplace=True)
data

## split dataset

In [None]:
data={'train':None,'val':None}
from sklearn.model_selection import train_test_split
data['train'], data['val'] = sklearn.model_selection.train_test_split(
    data_,
    test_size=0.2,
    random_state=CFG.seed, 
    stratify=data_["target"],
)
data['train']=data['train'].reset_index(drop=True)
data['val']=data['val'].reset_index(drop=True)

## Dataset

In [None]:
class Dataset(torch.utils.data.Dataset):
    def __init__(self,features,targets=None,phase='train',transform=None):
        self.features=features
        self.targets=targets
        self.phase=phase #train/val/test
        self.transform=transform
        
    def __len__(self):
        return len(self.features)

    def __getitem__(self,idx):

        if self.transform:datum=self.transform(self.features[idx])
        
        datum={
            'train':{
                'feature':self.features[idx],
                'target':self.targets[idx],
            },
            'val':{
                'feature':self.features[idx],
                'target':self.targets[idx],
            },
            'test':{
                'feature':self.features[idx],
                'target':self.targets[idx],
            }
        }
        return datum[phase].features,datum[phase].targets

In [None]:
dataset={
    'train':Dataset(data['train'].feature,data['train'].target),
    'val'  :Dataset(data['val'].feature,  data['val'].target  ),
    'test' :Dataset(data['test'].feature, data['test'].target ),
}

## dataloader

In [None]:
# def collate_fn(batch):
#     features,targets = zip(*batch)
#     return features,torch.tensor(targets).float()

dataloader={
    'train':
    torch.utils.data.DataLoader(
        dataset['train'],
        #collate_fn=sentences_collate_fn,
        batch_size=CFG.batch_size,
        shuffle=True,
        num_workers=os.cpu_count(),
        pin_memory=True
    ),
    'val':
    torch.utils.data.DataLoader(
        dataset['val'],
        #collate_fn=sentences_collate_fn,
        batch_size=CFG.batch_size,
        shuffle=False,
        num_workers=os.cpu_count(),
        pin_memory=True
    ),
    'test':
    torch.utils.data.DataLoader(
        dataset['test'],
        #collate_fn=sentences_collate_fn,
        batch_size=CFG.batch_size,
        shuffle=False,
        num_workers=os.cpu_count(),
        pin_memory=True
    )
}

In [None]:
# 動作確認
inputs, labels = next(iter(dataloader["train"]))  # 1番目の要素を取り出す
print(inputs.size())
print(labels)

# Model

In [None]:
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.conv1 = torch.nn.Conv2d(1, 20, 5)
        self.conv2 = torch.nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = torch.F.relu(self.conv1(x))
        return torch.F.relu(self.conv2(x))
    

In [None]:
model=Model()

model.to(CFG.device)

model.require_grad=True

# Training

## Trainer

In [None]:
class Trainer():

    def __init__(self):
        self.loss_fun=torch.nn.BCEWithLogitsLoss()
        self.score_fun=lambda a,b:torch.eq(a>0.5,b>0.5)

        self.optimizer=None
        self.lr_scheduler=None
        self.scaler=torch.cuda.amp.GradScaler(enabled=CFG.amp) 

    
    def train_val_test(self,model,dataloader,phase,epoch=-1):

        model.train() if phase=='train' else model.eval()   # モデルのモード
        #model.to(CFG.device)

        predictions=[]
        losses=[]
        scores=[]

        # データローダーからミニバッチを取り出すループ

        tqdm_bar=io.StringIO()
        for features,targets in tqdm(dataloader,file=tqdm_bar,desc=f'model description\n{epoch=} {phase=}'):
            
            # optimizerを初期化
            if phase=='train':self.optimizer.zero_grad()

            # 順伝搬（forward）計算
            with torch.set_grad_enabled(phase=='train'):
 
                features = features.to(CFG.device,non_blocking=True)
                targets = targets.to(CFG.device,non_blocking=True)

                with torch.cuda.amp.autocast(enabled=CFG.amp):

                    preds=model(features)#,targets)

                    if phase!='test':
                        loss = self.loss_fun(preds, targets)  # 損失を計算
                        losses.append(loss.item())
 
                        score = self.score_fun(preds, targets).cpu().numpy()  # 正誤判定
                        scores.extend(score)

                # 訓練時はバックプロパゲーション
                if phase == 'train':
                    if CFG.amp:
                        #scalerの場合
                        self.scaler.scale(loss).backward() # ロスのバックワード
                        self.scaler.step(self.optimizer) # オプティマイザーの更新
                        self.scaler.update() # スケーラーの更新
                    else:
                        loss.backward()
                        self.optimizer.step()

            predictions.extend(preds.detach().cpu().numpy())

            del preds
            if phase!='test':del loss,score
            torch.cuda.empty_cache()
            gc.collect()

            slack.update(
                '1635484096.026100',
                slack.textblock(
                    tqdm_bar.getvalue().split('\r')[-1]+
                    f"\nscore={np.mean(scores):.3f}"
                )
            )
            
        return predictions,losses,scores

    def cross_validation(self,model,dataloader):

        start_time=datetime.datetime.utcnow() + datetime.timedelta(hours=9)

        self.optimizer=transformers.AdamW(model.parameters(), CFG.learning_rate,betas=(0.9, 0.999), weight_decay=1e-2)
        self.lr_scheduler=torch.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lambda epoch:(1e-4)*(0.5**(epoch//5)),verbose=True)
        
        for fold in range(CFG.kFold):
            print('fold',fold)
            
            losses={'train':[],'val':[]}
            scores={'train':[],'val':[]}
            
            #self.initialize(CFG.seed,fold)
            #dataset,dataloader,model,optimizer,scheduler,scaler=initialize(CFG.seed,fold)
            
            bestscore=0
            
            for epoch in tqdm(range(CFG.epochs)):
                print("epoch=",epoch)

                # 未学習時の検証性能を確かめるため、epoch=0の訓練は省略
                # if (epoch == 0) and (phase == 'train'):train_continue

                _,loss,score=self.train_val_test(model,dataloader['train'],'train',epoch=epoch)
                print(color("train score",bg='CYAN')+' :',color(np.mean(score),'CYAN'))
                scores['train'].append(np.mean(score))
                losses['train'].append(np.mean(loss))
                
                #plt.scatter(dataset['train'].targets,preds,color='blue',s=5)
                
                _,loss,score=self.train_val_test(model,dataloader['val'],'val',epoch=epoch)
                print(np.mean(loss))
                print(color("val score",bg='RED')+' :',color(np.mean(score),'RED'))
                
                scores['val'].append(np.mean(score))
                losses['val'].append(np.mean(loss))

                #print(scores[-1])

                if bestscore < scores['val'][-1]:
                    bestscore = scores['val'][-1]
                    print(color("BEST SCORE",bg='YELLOW')+' :',color(bestscore,'YELLOW'))

                    bestmodel={
                        'state_dict': model.state_dict(),
                        'optimizer_dict': self.optimizer.state_dict(),
                        'bestscore':bestscore,
                        'seed':CFG.seed
                    }

                #print(preds)

                self.lr_scheduler.step() # 学習率の更新 
                
            plt.plot(range(CFG.epochs),losses['train'],color = "blue",label='train')
            plt.plot(range(CFG.epochs),losses['val'],color = "red",label='val')
            plt.legend()
            plt.title('Loss')
            plt.show()

            plt.plot(range(CFG.epochs),scores['train'],color = "blue",label='train')
            plt.plot(range(CFG.epochs),scores['val'],color = "red",label='val')
            plt.plot(range(CFG.epochs),[0.644]*CFG.epochs,color = "green",linestyle = "dotted",label='Egawa model')
            plt.legend()
            plt.title('Score')
            plt.show()



        exe_time=datetime.datetime.utcnow() + datetime.timedelta(hours=9)-start_time
        print(exe_time)
        slack.update(
            '1635484096.026100',
            slack.textblock(
                f"学習終了\n実行時間：{exe_time}\n{bestscore=:.3f}"
            )
        )

        torch.save(bestmodel,"egawa_model+BERT_NLI_v3:"+str(bestscore)+".pth")
            

#             plt.plot(losses['train'],color='blue')
#             plt.plot(losses['valid'],color='red')
#             plt.show()

trainer=Trainer()

## Train

In [None]:
trainer.cross_validation(model,dataloader['CMV'])