In [1]:
# !apt install -y unzip imagemagick libopencv-dev
# !pip install pandas scikit-learn opencv-python wand wandb accelerate transformers timm bitsandbytes

In [2]:
import os, gc, time
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F

from glob import glob
from tqdm.auto import tqdm
from sklearn.metrics import f1_score
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import StratifiedKFold
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import v2
import cv2
from PIL import Image
from accelerate import Accelerator

# half 에서 메모리사용량과 관련이 있을지?
torch.set_float32_matmul_precision('high')  # or 'medium' | 'high'

os.environ['WANDB_MODE']='online'
os.environ['WANDB_PROJECT']='basslibrary240210'
os.environ['WANDB_MODE']='offline'

######## logger ########
import sys, logging, IPython
logger = logging.getLogger()
logging.basicConfig( handlers=[ logging.StreamHandler(stream=sys.stdout), logging.handlers.RotatingFileHandler(filename='run.log', mode='a', maxBytes=512000, backupCount=4) ] )
logging_fomatter = logging.Formatter( '%(asctime)s [%(levelname)-4.4s] %(message)s', datefmt='%m/%d %H:%M:%S' )
_ = [ h.setFormatter(logging_fomatter) for h in logger.handlers ]
logger.setLevel(logging.INFO)
def showtraceback(self, *args, **kwargs):
    logger.exception('-------Exception----------')
IPython.core.interactiveshell.InteractiveShell.showtraceback = showtraceback
logger.info('program started')

05/30 12:47:54 [INFO] program started


In [3]:
CFG = {}
CFG['OVERSAMPLING'] = True
CFG['GRADIENT_CHECKPOINT'] = False  ## gradient_checkpoint 사용여부
CFG['ACCUMULATION_STEPS'] = 1  ## 숫자를 늘리면, weights가 메모리에 여러벌 존재해야함.
CFG['MAX_CLASSES'] = 7 ## len(train_df.label.unique())
CFG['SEED'] = 42
CFG['N_SPLIT'] = 5
CFG['LABEL_SMOOTHING'] = 0.05
CFG['RESIZE_TO_FIT'] = False
CFG['OPTIMIZER'] = 'AdamW'
CFG['INTERPOLATION'] = 'robidouxsharp' ## best..
CFG['PRECISION'] = '16' # '16'
# #----------------------------------
CFG['MODEL_NAME'] = "timm/eva02_base_patch14_448.mim_in22k_ft_in22k_in1k"
CFG['IMG_SIZE'] = 448
CFG['BATCH_SIZE'] = 48 if CFG['GRADIENT_CHECKPOINT'] else 16 # 20(10GB)
CFG['LR'] = [ 1e-5 / float(8) * CFG['BATCH_SIZE'] * CFG['ACCUMULATION_STEPS'], 1e-6  / float(8) * CFG['BATCH_SIZE'] * CFG['ACCUMULATION_STEPS']] # 1e-6 => 1e-5
######################################
if 'IMG_TRAIN_SIZE' not in CFG:
    CFG['IMG_TRAIN_SIZE'] = CFG['IMG_SIZE']
logger.info(CFG)

05/30 12:47:54 [INFO] {'OVERSAMPLING': True, 'GRADIENT_CHECKPOINT': False, 'ACCUMULATION_STEPS': 1, 'MAX_CLASSES': 7, 'SEED': 42, 'N_SPLIT': 5, 'LABEL_SMOOTHING': 0.05, 'RESIZE_TO_FIT': False, 'OPTIMIZER': 'AdamW', 'INTERPOLATION': 'robidouxsharp', 'PRECISION': '16', 'MODEL_NAME': 'timm/eva02_base_patch14_448.mim_in22k_ft_in22k_in1k', 'IMG_SIZE': 448, 'BATCH_SIZE': 16, 'LR': [2e-05, 2e-06], 'IMG_TRAIN_SIZE': 448}


In [4]:
assert torch.cuda.is_available()
device = 'cuda' if torch.cuda.is_available() else 'cpu'
logger.info(f'{device=}')
logger.info(f'dtype={torch.get_default_dtype()}')

05/30 12:47:54 [INFO] device='cuda'
05/30 12:47:54 [INFO] dtype=torch.float32


In [5]:
def seed_everything(seed):
    logger.info(f'seed_everything : {seed=}')

    import random, os
    import numpy as np
    import torch
    
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

seed_everything(CFG['SEED'])

05/30 12:47:54 [INFO] seed_everything : seed=42


In [6]:
if not os.path.exists('train.csv'):
    train_list = glob('train/*/*.jpg')
    train_df = pd.DataFrame(train_list)
    train_df['rock_type'] = train_df[0].apply(lambda x: x.split(os.path.sep)[1])
    train_df['img_path'] = train_df[0].apply(lambda x: x.replace('open' + os.path.sep,'.' + os.path.sep) )
    train_df['ID'] = train_df.index
    train_df[['ID','img_path','rock_type']].to_csv('train.csv', index=False)
    logger.info('train.csv saved.')

In [7]:
image_size = CFG['IMG_SIZE']

train_transform_list = [
    v2.TrivialAugmentWide(interpolation=v2.InterpolationMode.BICUBIC), 
    v2.RandomHorizontalFlip(), ## 성능이 좋아짐.
    v2.RandomVerticalFlip(), ## ??
    v2.ToImage(), v2.ToDtype( torch.float32, scale=True),
    v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
train_transform = v2.Compose(train_transform_list )
test_transform = v2.Compose( [
    v2.ToImage(), v2.ToDtype( torch.float32, scale=True),
    v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [8]:
class CustomDataset(Dataset):
    def __init__(self, img_path_list, label_list, load_img_size, shuffle=False, transforms=None, interpolation='robidouxsharp' ):
        self.df = pd.DataFrame({'img_path_list': img_path_list})
        self.interpolation = interpolation
        self.load_img_size = load_img_size
        logger.info(f'load_img_size={load_img_size}')
        if label_list is not None:
            self.df['label_list'] = label_list
        if shuffle:
            self.df = self.df.sample(frac=1.0).reset_index(drop=True)
        self.transforms = transforms

    # numpy or PIL Image => PIL Image
    def get_interpolated_image(self, img, new_image_size):
        if isinstance(new_image_size, int):
            new_image_size = [ new_image_size, new_image_size ]
        if self.interpolation == 'pil_lanczos':
            if isinstance(img, np.ndarray ):
                img = Image.fromarray(img)
            return img.resize( new_image_size, Image.LANCZOS )
        elif self.interpolation == 'cv2_lanczos4':
            if not isinstance(img, np.ndarray ):
                img = np.array(img)
            import cv2
            img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
            img = cv2.resize(img, new_image_size, interpolation=cv2.INTER_LANCZOS4) # 픽셀 크기 지정
            img = cv2.cvtColor(np.array(img), cv2.COLOR_BGR2RGB)
            return Image.fromarray(img)
        else:
            if not isinstance(img, np.ndarray ):
                img = np.array(img)
            from wand import image
            with image.Image.from_array(img) as src:
                src.resize( *new_image_size, filter=self.interpolation )
                return Image.fromarray(np.array(src))

    def numpy_sqaure_array(self, img):
        h,w,c = np.array(img).shape
        if h>w:
            sq_img = np.full((h, h, c), fill_value=0, dtype=np.uint8)
            sq_img[:,(h-w)//2:(h-w)//2+w] = img
        else:
            sq_img = np.full((w, w, c), fill_value=0, dtype=np.uint8)
            sq_img[(w-h)//2:(w-h)//2+h,:] = img
        return sq_img
        
    def get_image_from_index(self, index, img_size, exact = False ):
        img_path = self.df.img_path_list[index]
        img = Image.open(img_path)
        W, H = img.size
        if ('label_list' in self.df.columns) and ( index % 2 == 0 ):
            img = img.crop((W//4,H//4,W//4+W//2,H//4+H//2)) 
            W, H = img.size

        scale = max(int((img_size+W-1)//W), int((img_size+W-1)//H))
        if (exact == True) or (W <= img_size) or (H <= img_size):
            img = self.get_interpolated_image(img, img_size )
        return img
    
    def __getitem__(self, index):
        image = self.get_image_from_index( index, self.load_img_size, exact = True)
        if self.transforms is not None:
            image = self.transforms(image)
        if 'label_list' in self.df.columns:
            label = self.df.label_list[index]
            return { 'pixel_values': image, 'label': label }
        else:
            return { 'pixel_values': image }
    
    def __len__(self):
        return len(self.df)

In [9]:
## ref: https://github.com/katsura-jp/pytorch-cosine-annealing-with-warmup/blob/master/cosine_annealing_warmup/scheduler.py
import math
import torch
from torch.optim.lr_scheduler import _LRScheduler

class CosineAnnealingWarmupRestarts(_LRScheduler):
    """
        optimizer (Optimizer): Wrapped optimizer.
        first_cycle_steps (int): First cycle step size.
        cycle_mult(float): Cycle steps magnification. Default: -1.
        max_lr(float): First cycle's max learning rate. Default: 0.1.
        min_lr(float): Min learning rate. Default: 0.001.
        warmup_steps(int): Linear warmup step size. Default: 0.
        gamma(float): Decrease rate of max learning rate by cycle. Default: 1.
        last_epoch (int): The index of last epoch. Default: -1.
    """
    
    def __init__(self,
                 optimizer : torch.optim.Optimizer,
                 first_cycle_steps : int,
                 cycle_mult : float = 1.,
                 max_lr : float = 1e-5,
                 min_lr : float = 1e-10,
                 warmup_steps : int = 0,
                 gamma : float = 1.,
                 last_epoch : int = -1
        ):
        assert warmup_steps < first_cycle_steps
        
        self.first_cycle_steps = first_cycle_steps # first cycle step size
        self.cycle_mult = cycle_mult # cycle steps magnification
        self.base_max_lr = max_lr # first max learning rate
        self.max_lr = max_lr # max learning rate in the current cycle
        self.min_lr = min_lr # min learning rate
        self.warmup_steps = warmup_steps # warmup step size
        self.gamma = gamma # decrease rate of max learning rate by cycle
        
        self.cur_cycle_steps = first_cycle_steps # first cycle step size
        self.cycle = 0 # cycle count
        self.step_in_cycle = last_epoch # step size of the current cycle
        
        super(CosineAnnealingWarmupRestarts, self).__init__(optimizer, last_epoch)
        
        # set learning rate min_lr
        self.init_lr()
    
    def init_lr(self):
        self.base_lrs = []
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = self.min_lr
            self.base_lrs.append(self.min_lr)
    
    def get_lr(self):
        if self.step_in_cycle == -1:
            return self.base_lrs
        elif self.step_in_cycle < self.warmup_steps:
            return [(self.max_lr - base_lr)*self.step_in_cycle / self.warmup_steps + base_lr for base_lr in self.base_lrs]
        else:
            return [base_lr + (self.max_lr - base_lr) \
                    * (1 + math.cos(math.pi * (self.step_in_cycle-self.warmup_steps) \
                                    / (self.cur_cycle_steps - self.warmup_steps))) / 2
                    for base_lr in self.base_lrs]

    def step(self, epoch=None):
        if epoch is None:
            epoch = self.last_epoch + 1
            self.step_in_cycle = self.step_in_cycle + 1
            if self.step_in_cycle >= self.cur_cycle_steps:
                self.cycle += 1
                self.step_in_cycle = self.step_in_cycle - self.cur_cycle_steps
                self.cur_cycle_steps = int((self.cur_cycle_steps - self.warmup_steps) * self.cycle_mult) + self.warmup_steps
        else:
            if epoch >= self.first_cycle_steps:
                if self.cycle_mult == 1.:
                    self.step_in_cycle = epoch % self.first_cycle_steps
                    self.cycle = epoch // self.first_cycle_steps
                else:
                    n = int(math.log((epoch / self.first_cycle_steps * (self.cycle_mult - 1) + 1), self.cycle_mult))
                    self.cycle = n
                    self.step_in_cycle = epoch - int(self.first_cycle_steps * (self.cycle_mult ** n - 1) / (self.cycle_mult - 1))
                    self.cur_cycle_steps = self.first_cycle_steps * self.cycle_mult ** (n)
            else:
                self.cur_cycle_steps = self.first_cycle_steps
                self.step_in_cycle = epoch
        
        self.max_lr = (self.base_max_lr - self.min_lr) * (self.gamma**self.cycle) + self.min_lr ## min_lr 보다 작아지지 않게 수정
        self.last_epoch = math.floor(epoch)
        for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
            param_group['lr'] = lr

In [10]:
train_df = pd.read_csv('./train.csv')
le = LabelEncoder()
train_df['rock_type_int'] = le.fit_transform(train_df['rock_type'])
train_df.drop(columns='rock_type', inplace=True)
# # 109826	109826	train/Etc/TRAIN_14569.jpg	Etc	28	23
# # 117048	117048	train/Etc/TRAIN_03346.jpg	Etc	34	32
# # 117545	117545	train/Etc/TRAIN_14567.jpg	Etc	28	23
# # 187913	187913	train/Mud_Sandstone/TRAIN_68767.jpg	Mud_Sandstone	32	1
# train_df = train_df[ ~train_df.index.isin((109826, 117048, 117545, 187913)) ].reset_index(drop=True).copy()
train_df = train_df[ ~train_df.index.isin((187913,)) ].reset_index(drop=True).copy()

In [11]:
skf = StratifiedKFold(n_splits=CFG['N_SPLIT'], random_state=CFG['SEED'], shuffle=True)

In [12]:
def collate_fn_cuda(batch):
    inputs = torch.stack([item['pixel_values'] for item in batch]).to(device)
    labels = torch.tensor([item['label'] for item in batch]).to(device)
    return {'pixel_values': inputs, 'label': labels}

In [13]:
def oversampling(df):
    max_items = df['rock_type_int'].value_counts().max()
    selected_dfs = []
    for type_int in df['rock_type_int'].unique():
        selected_df = df[ df.rock_type_int == type_int ]
        if max_items == len(selected_df):
            logger.info(f'already max rock_type : {type_int}')
            continue
        selected_dfs.append( selected_df.sample(max_items - len(selected_df), replace=True ) )
    df = pd.concat( [ df, *selected_dfs ], axis=0 ).sample(frac=1).reset_index(drop=True).copy()
    return df

In [14]:
def train(model, optimizer, intput_train_fold_df, valid_fold_df, device, validation_steps = 0.25, logging_steps = 10, use_amp=True, filename=''):
    
    accelerator = None
    # accelerator = Accelerator(cpu=(device=='cpu'), mixed_precision='fp16' if use_amp else None, gradient_accumulation_steps=CFG['ACCUMULATION_STEPS'] )
    scaler = torch.GradScaler(device, enabled=use_amp)
    
    logger.info(f'{use_amp=}')

    best_score = 0
    best_loss  = 1000
    best_model = None
    MAX_PATIENCE = 5
    best_patience = MAX_PATIENCE

    # ---------
    if CFG['OVERSAMPLING']:
        train_fold_df = oversampling(intput_train_fold_df.copy())
    else:
        train_fold_df = intput_train_fold_df.copy()
    # # ---------------------------------------
    train_dataset = CustomDataset( train_fold_df['img_path'].values, train_fold_df['rock_type_int'].values, 
        interpolation=CFG['INTERPOLATION'], load_img_size=CFG['IMG_TRAIN_SIZE'], shuffle=True, transforms=train_transform)
    val_dataset = CustomDataset( valid_fold_df['img_path'].values, valid_fold_df['rock_type_int'].values,
        interpolation=CFG['INTERPOLATION'], load_img_size=CFG['IMG_SIZE'], shuffle=False, transforms=test_transform)
    # # ---------------------------------------
    scheduler = None
    scheduler = CosineAnnealingWarmupRestarts(
        optimizer,
        first_cycle_steps=int( (len(train_fold_df) + CFG['BATCH_SIZE'] - 1)//CFG['BATCH_SIZE'] ) // 4,
        cycle_mult=1.0, max_lr=CFG['LR'][0] * 2, 
        min_lr=CFG['LR'][1],
        warmup_steps=0, 
        gamma=0.93,  ## 2024.05.02
    )
    # ---------------------------------------
    train_class_weight = torch.FloatTensor( len(train_fold_df) / CFG["MAX_CLASSES"] / train_fold_df.rock_type_int.value_counts().sort_index() )
    valid_class_weight = torch.FloatTensor( len(valid_fold_df) / CFG["MAX_CLASSES"] / valid_fold_df.rock_type_int.value_counts().sort_index() )
    logger.info(f'{le.classes_=}')
    logger.info(f'{train_class_weight=}')
    logger.info(f'{valid_class_weight=}')
    loss_fn = nn.CrossEntropyLoss( weight=train_class_weight, label_smoothing=CFG['LABEL_SMOOTHING'], reduction='mean' )#.to(accelerator.device)
    checkpoint_filenames = []

    # ema 모델은 모델의 weight 한벌을 가지고 있어, 메모리 사용량도 확인해야 함..
    ema_model = None  ## 의미가 없을 듯..
    ema_decay = np.power(np.e, np.log(0.5)/(validation_steps*MAX_PATIENCE))
    ema_model = torch.optim.swa_utils.AveragedModel(model, multi_avg_fn=torch.optim.swa_utils.get_ema_multi_avg_fn(ema_decay))#.to(accelerator.device)
        
    # loss_fn = accelerator.prepare( loss_fn )
    # model, ema_model = accelerator.prepare( model, ema_model )
    # optimizer, scheduler = accelerator.prepare( optimizer, scheduler )
    # train_loader, val_loader = accelerator.prepare( train_loader, val_loader )
    model = model.to(device)
    # model.compile()
    ema_model = ema_model.to(device)
    # ema_model.compile()
    loss_fn = loss_fn.to(device)


    for epoch in range(1, 1000):
        model.train()
        train_loss = []
        pbar_postfix = {}
        train_labels = []
        train_preds = []
        lastlog_tm = time.time()

        # # ----------- OVER SAMPLING
        if CFG['OVERSAMPLING']: # and (epoch > 1):
            train_fold_df = oversampling(intput_train_fold_df.copy())
            train_dataset = CustomDataset( train_fold_df['img_path'].values, train_fold_df['rock_type_int'].values, 
                interpolation=CFG['INTERPOLATION'], load_img_size=CFG['IMG_TRAIN_SIZE'], shuffle=True, transforms=train_transform)
            # train_loader = DataLoader(train_dataset, batch_size = CFG['BATCH_SIZE'], shuffle=True, num_workers=4)
            # train_loader = accelerator.prepare( train_loader )
            # 교집합이 없어야 함.
            assert (set(train_fold_df['img_path']) & set(valid_fold_df['img_path'])) == set()
        ## 데이터 로딩이 문제가 있을거 염려됨..
        train_loader = DataLoader(train_dataset, batch_size = CFG['BATCH_SIZE'], shuffle=True, num_workers=4)
        val_loader = DataLoader(val_dataset, batch_size=CFG['BATCH_SIZE']*2, shuffle=False, num_workers=4)
        gc.collect()
        torch.cuda.empty_cache()
    
        # # ----------- ema_testing..
        # ema_model.update_parameters(model)
        # torch.optim.swa_utils.update_bn(train_loader, ema_model, device if not accelerator else accelerator.device )
        # validation(accelerator, ema_model, loss_fn, val_loader, use_amp, desc='Ema validation:' )
        # # torch.save( {"model": accelerator.unwrap_model(ema_model).state_dict() if accelerator else ema_model.state_dict() }, 'save_filename' )
        # # ----------- validation_testing..
        # validation(accelerator, model, loss_fn, val_loader, use_amp, desc='Model validation:')

        # -----------------------
        max_steps = len(train_loader)
        if not isinstance(validation_steps, int):
            validation_steps = int(max_steps * validation_steps)  ## 절사..
            max_steps = (max_steps//validation_steps)*validation_steps
            
        pbar = tqdm(train_loader, desc=f'Epoch {epoch}')
        for i, batch in enumerate(pbar):
            if i >= max_steps:
                continue
            steps = i+1
            
            with torch.autocast(device_type=device, dtype=torch.float16, enabled=use_amp):
                output = model( batch['pixel_values'].to(device) )
                loss = loss_fn(output, batch['label'].to(device) )
                
            scaler.scale(loss).backward()

            scaler.unscale_(optimizer)
            # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.2)
            
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()
            
            if scheduler is not None:
                scheduler.step()
            
            # with accelerator.accumulate(model):
            #     optimizer.zero_grad()
            #     output = model( batch['pixel_values'] )
            #     loss = loss_fn(output, batch['label'] )
            #     accelerator.backward(loss)

            #     # LR 값이 큰경우, 아래 항목이 있어야 함.
            #     if accelerator.sync_gradients:
            #         accelerator.clip_grad_norm_(model.parameters(), max_norm=0.2)
                    
            #     train_labels += batch['label'].detach().cpu().numpy().tolist()
            #     train_preds += output.detach().argmax(1).cpu().numpy().tolist()
                
            #     optimizer.step()
            #     if scheduler is not None:
            #         scheduler.step()

            # train 변수 초기화
            if (steps-1) % validation_steps == 0:
                train_loss = []
                train_labels = []
                train_preds = []
                
            # 변수값 업데이트..
            train_loss.append(loss.item())
            train_labels += batch['label'].detach().cpu().numpy().tolist()
            train_preds += output.detach().argmax(1).cpu().numpy().tolist()
            # # 사용안하는 메모리 초기화..
            # loss = None
            # output = None
            # batch = None
            
            if ema_model is not None:
                if accelerator is not None:
                    if accelerator.sync_gradients:
                        ema_model.update_parameters(model)
                else:
                    # with torch.autocast(device_type=device, dtype=torch.float16, enabled=use_amp):
                    ema_model.update_parameters(model)
                    

            if steps % logging_steps == 0:
                _train_loss = float(np.mean(train_loss))
                _train_f1_all = f1_score(train_labels, train_preds, average=None).tolist()
                
                pbar_postfix.update({
                    't_loss0': train_loss[-1], 
                    'lr': optimizer.param_groups[0]["lr"],
                    'score' : float(np.mean(_train_f1_all)),
                    # **{ f'f1_{i}':v for i, v in enumerate(_train_f1_all) }
                } )
                pbar.set_postfix( pbar_postfix )
                # 수정되어야 함..
                if ( time.time() - lastlog_tm ) > 60:
                    t_f1_raw_str = ', '.join([ f'{float(v):.3f}' for v in _train_f1_all ])
                    logger.info(f'eps={epoch:d}, steps={steps}/{max_steps}, lr={optimizer.param_groups[0]["lr"]:.3g}, t_loss={_train_loss:.4f}, t_f1r=[{t_f1_raw_str}], t_f1={float(np.mean(_train_f1_all)):.3f}' )
                    lastlog_tm = time.time()
                
                run.log({
                    "epoch": epoch, 
                    "step": steps,
                    "train":{
                        "loss": train_loss[-1],
                        "f1": np.mean(_train_f1_all),
                        **{ f'f1.{i}({le.classes_[i][:3].lower()})': v for i, v in enumerate(_train_f1_all) }
                    }, 
                    "lr": optimizer.param_groups[0]["lr"]
                }, step=(epoch-1)*max_steps+steps)
                
            if steps % validation_steps == 0:
                _val_loss, _val_score, _val_score_all = validation(accelerator, model, loss_fn, val_loader, use_amp)
                _val_score_all = _val_score_all.tolist()
                _val_score = float(_val_score)
                _train_loss = float(np.mean(train_loss))
                
                best_score_mark = '*' if best_score < _val_score else ' '
                best_loss_mark = '*' if best_loss > _val_loss else ' '
                pbar_postfix.update({
                    'lr': optimizer.param_groups[0]["lr"], 
                    't_loss': _train_loss,
                    'v_loss': _val_loss, 
                    'v_f1': _val_score,
                })
                pbar.set_postfix( pbar_postfix )
                v_f1_raw_str = ', '.join([ f'{float(v):.3f}' for v in _val_score_all ])
                logger.info(f'eps={epoch:d}, steps={steps}/{max_steps}, lr={optimizer.param_groups[0]["lr"]:.3g}, t_loss={_train_loss:.4f}, v_f1r=[{v_f1_raw_str}], v_loss={_val_loss:.4f}{best_loss_mark}, v_f1={float(_val_score):.4f}{best_score_mark}' )
                             
                run.log({
                    "epoch": epoch, "step": steps,
                    "train": { 
                        "avg_loss": _train_loss
                    }, 
                    "valid": { 
                        "avg_loss": _val_loss,
                        "score": float(np.mean(_val_score_all)),
                        **{ f'f1.{i}({le.classes_[i][:3].lower()})': v for i, v in enumerate(_val_score_all) }
                    },
                    "lr": optimizer.param_groups[0]["lr"] 
                }, step=(epoch-1)*max_steps+steps)
                
                if ( best_score < _val_score ):
                    if ( _val_score - best_score ) > 0.001:
                        best_score = _val_score
                        best_patience = MAX_PATIENCE
                    best_model = accelerator.unwrap_model( model ) if accelerator is not None else model
                    
                    ## saving..
                    if filename is not None and len(filename) != 0:
                        checkpoint_filenames.append(
                            filename.format(epoch=epoch, val_loss=_val_loss, val_score=_val_score) + '.ckpt' )
                        if best_score > 0.89:
                            os.makedirs(os.path.dirname(checkpoint_filenames[-1]), exist_ok=True)
                            torch.save( {"model": best_model.state_dict() }, checkpoint_filenames[-1] )
                            logger.info( f'{checkpoint_filenames[-1]} : saved.' )
                            _ = [ os.path.exists(fname) and os.remove(fname) for fname in checkpoint_filenames[:-1] ]
                            checkpoint_filenames = checkpoint_filenames[-1:]
                    
                    ## 추가적으로 비교함..
                    if best_loss > _val_loss:
                        best_loss = _val_loss
                elif ( best_loss > _val_loss ) and (( best_loss - _val_loss ) > 0.001 ):
                    best_loss = _val_loss
                    best_patience = MAX_PATIENCE
                elif best_patience > 0:
                    best_patience -= 1
                else:
                    logger.info(f'NO_MORE_TRAINING, {best_score=:.4f}')
                    if ema_model is not None:
                        # ## EMA --------------------
                        torch.optim.swa_utils.update_bn(train_loader, ema_model, device )
                        ema_val_loss, ema_val_score, _ = validation(accelerator, ema_model, loss_fn, val_loader, use_amp)
                        logger.info(f'EMA ::: ema_v_loss={ema_val_loss:.4f}, ema_v_f1={ema_val_score:.4f}')
                        run.log({'ema_v_loss': ema_val_loss, 'ema_v_f1': ema_val_score })
                        
                        save_filename = filename.format(epoch=epoch, val_loss=ema_val_loss, val_score=ema_val_score) + '-ema.ckpt'
                        torch.save( {"model": accelerator.unwrap_model(ema_model).state_dict() if accelerator is not None else ema_model.state_dict() }, save_filename )
                        logger.info( f'{save_filename} : (ema) saved.' )
                        # ##========================
                    if not os.path.exists(checkpoint_filenames[-1]):
                        os.makedirs(os.path.dirname(checkpoint_filenames[-1]), exist_ok=True)
                        torch.save( {"model": best_model.state_dict() }, checkpoint_filenames[-1] )
                        logger.info( f'{checkpoint_filenames[-1]} : saved.' )
                        _ = [ os.path.exists(fname) and os.remove(fname) for fname in checkpoint_filenames[:-1] ]
                        checkpoint_filenames = checkpoint_filenames[-1:]
                    return best_model

In [15]:
def validation(accelerator, model, loss_fn, val_loader, use_amp=False, desc=None):
    save_training = model.training
    model.eval()
    
    val_loss = []
    preds, true_labels = [], []

    last_log_tm = time.time()
    with torch.no_grad():
        max_steps = len(val_loader)
        for i, batch in enumerate(tqdm(val_loader, desc=desc)):
            if accelerator is not None:
                pred = model(batch['pixel_values'])
                loss = loss_fn(pred, batch['label'])
            else:
                with torch.autocast(device_type=device, dtype=torch.float16, enabled=use_amp):
                    pred = model(batch['pixel_values'].to(device) )
                    loss = loss_fn(pred, batch['label'].to(device) )
                
            true_labels += batch['label'].detach().cpu().numpy().tolist()
            preds += pred.detach().argmax(1).cpu().numpy().tolist()
            val_loss.append(loss.cpu().item())

            if (time.time() - last_log_tm) > 60:
                _val_f1_all = f1_score(true_labels, preds, average=None ).tolist()

                v_f1_raw_str = ', '.join([ f'{float(v):.3f}' for v in _val_f1_all ])
                logger.info(f'steps={i+1}/{max_steps}, v_loss={float(np.mean(val_loss)):.4f}, v_f1r=[{v_f1_raw_str}], v_f1={float(np.mean(_val_f1_all)):.3f}' )
                last_log_tm = time.time()
            
        _val_loss = np.mean(val_loss)
        _val_scores = f1_score(true_labels, preds, average=None ) #'macro')
    ## return_to_train..
    if save_training:
        model.train()
    return _val_loss, np.mean(_val_scores), _val_scores

In [16]:
def prediction(model, test_loader, device):
    model = model.to(device)
    save_training = model.training
    model.eval()
    preds = []

    with torch.no_grad():
        for batch in tqdm(test_loader):
            pred = model(batch['pixel_values'].to(device, dtype=torch.float32 ))  ## F.softmax(output) ## 의미는 없을 듯.
            preds += pred.detach().cpu().numpy().tolist()
    if save_training:
        model.train()
    return preds

In [17]:
class CustomModel(nn.Module):
    def __init__(self, model, model_shape):
        super(CustomModel, self).__init__()
        self.model = model
        self.gelu = nn.GELU()
        self.clf = nn.Linear(model_shape[1], CFG['MAX_CLASSES'])
        # 초기화, 가장 좋은 결과가 있었던 kaiming으로 구성함.
        torch.nn.init.kaiming_uniform_(self.clf.weight.data)
        self.clf.bias.data.fill_(0)
        logger.info(f"clf layer: nn.Linear({model_shape[1]}, {CFG['MAX_CLASSES']}) added and initialize.")
    def forward(self, x):
        x = self.model(x)
        if not isinstance(x, torch.Tensor):
            x = x.pooler_output
        x = self.gelu(x)
        x = self.clf(x)
        return x        

In [18]:
def create_model(model_name):
    import timm
    from transformers import AutoModel, AutoModelForImageClassification, AutoConfig

    logger.info(f'create_model: {model_name}')
    if '/' not in model_name:
        model_name = 'timm/' + model_name
    #        
    if model_name.startswith('./'):
        import nextvit
        model = timm.create_model('nextvit_large', pretrained=True, checkpoint_path=model_name)
    elif model_name.startswith('facebook/hiera_'):
        from hiera import Hiera  ## pip install hiera-transformer
        model = Hiera.from_pretrained(model_name)
    elif model_name.startswith('timm/'):
        # model = CustomModel( timm.create_model( model_name, pretrained=True ) )
        try:
            model = timm.create_model( model_name, pretrained=True, img_size=CFG['IMG_SIZE'], num_classes=CFG['MAX_CLASSES'] )
        except:
            model = timm.create_model( model_name, pretrained=True, num_classes=CFG['MAX_CLASSES'] )

        ## initlialize model weights
        ## normal_ : 0.7334, unifiom_ : 0.7862, xavier_uniform_ : 0.7971, kaiming_uniform_: 0.7982
        try:
            torch.nn.init.kaiming_uniform_(model.head.weight.data)
            model.head.bias.data.fill_(0)
        except:
            torch.nn.init.kaiming_uniform_(model.head.fc.weight.data)
            model.head.fc.bias.data.fill_(0)
            # logger.info('initlialize model weights ( model.head.weight.data, model.head.bias.data )' )
        return model
    else:
        model = AutoModel.from_pretrained(model_name)
    model.eval()
    # if CFG['PRECISION'] == '16':
    #     model.half()
    output = model( torch.randn((1,3,CFG['IMG_SIZE'],CFG['IMG_SIZE'])) ) ## model_output check..
    if not isinstance(output, torch.Tensor):
        output = output.pooler_output
    assert (len(output.shape) == 2) and (output.shape[1] >= CFG['MAX_CLASSES'])
    if output.shape[1] != CFG['MAX_CLASSES']:
        model = CustomModel( model, output.shape )
    model.train()
    return model

In [None]:
from datetime import datetime
dt_str = datetime.now().strftime('%m%d%H%M')

for fold_idx, (train_index, valid_index) in enumerate(skf.split(train_df, train_df['rock_type_int'])):
    gc.collect()
    torch.cuda.empty_cache()
    
    logger.info(f'{fold_idx=} started')
    import wandb
    run = wandb.init(
        name=f'fold{fold_idx+1}_{CFG["MODEL_NAME"].split("/")[1].split("-")[0]}_{dt_str}',
        config=CFG,
        reinit=True)
    
    train_fold_df = train_df.iloc[train_index,:].copy().reset_index(drop=True)
    valid_fold_df = train_df.iloc[valid_index,:].copy().reset_index(drop=True)

    model = create_model(CFG['MODEL_NAME'])
    if (CFG['GRADIENT_CHECKPOINT']==True) and ('set_grad_checkpointing' in dir(model)):
        model.set_grad_checkpointing() ## for memory_efficient..
        logger.info('grad_checkpointing : True')
    else:
        model.compile() #
        logger.info('model compiled')
    
    ## wrap model
    optimizer_class_ = getattr(torch.optim, CFG['OPTIMIZER'] )
    logger.info(f'create optimizer : {optimizer_class_.__module__}')
    optimizer = optimizer_class_(
        model.parameters(),
        lr=CFG['LR'][0],
        weight_decay=0.001,  ## default는 0.01이며, 논문은 0.001임.
    )
    import bitsandbytes as bnb
    optimizer = bnb.optim.Adam(
        model.parameters(),
        lr=CFG['LR'][0],
        weight_decay=0.001,  ## default는 0.01이며, 논문은 0.001임.
        optim_bits=8,
    )
    
    model = train( 
        model, optimizer,
        train_fold_df,
        valid_fold_df, 
        # val_loader,
        # scheduler, 
        device,
        use_amp=(CFG['PRECISION'] == '16'),
        filename = f'./ckpt/{CFG["MODEL_NAME"].split("/")[1].split("-")[0]}-fold_idx={fold_idx}-' + 'epoch={epoch:02d}-val_loss={val_loss:.4f}-val_score={val_score:.4f}',
    )
    
    model = None
    gc.collect()
    torch.cuda.empty_cache()
    logger.info(f'{fold_idx=} finished')
    run.finish()
    
    try:
        # !python ~/send_telegram.py 'fold_idx={fold_idx} finished'
        last_chpt_info = !ls -t ./ckpt/ | head -n1
        last_chpt_info = ','.join( last_chpt_info[0][:-5].split('-')[1:] )
        !python ~/send_telegram.py {last_chpt_info}
    except:
        pass
    # 1pass만..
    break

05/30 12:47:55 [INFO] fold_idx=0 started


[34m[1mwandb[0m: [32m[41mERROR[0m Failed to detect the name of this notebook. You can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


05/30 12:47:56 [INFO] create_model: timm/eva02_base_patch14_448.mim_in22k_ft_in22k_in1k
05/30 12:47:57 [INFO] Loading pretrained weights from Hugging Face hub (timm/eva02_base_patch14_448.mim_in22k_ft_in22k_in1k)
05/30 12:47:58 [INFO] [timm/eva02_base_patch14_448.mim_in22k_ft_in22k_in1k] Safe alternative available for 'pytorch_model.bin' (as 'model.safetensors'). Loading weights using safetensors.
05/30 12:47:58 [INFO] Missing keys (head.weight, head.bias) discovered while loading pretrained weights. This is expected if model is being adapted.
05/30 12:47:59 [INFO] model compiled
05/30 12:47:59 [INFO] create optimizer : torch.optim.adamw
05/30 12:47:59 [INFO] use_amp=True
05/30 12:47:59 [INFO] already max rock_type : 4
05/30 12:47:59 [INFO] load_img_size=448
05/30 12:47:59 [INFO] load_img_size=448
05/30 12:47:59 [INFO] le.classes_=array(['Andesite', 'Basalt', 'Etc', 'Gneiss', 'Granite', 'Mud_Sandstone',
       'Weathered_Rock'], dtype=object)
05/30 12:47:59 [INFO] train_class_weight=te

Epoch 1:   0%|          | 0/32524 [00:00<?, ?it/s]

05/30 12:49:02 [INFO] eps=1, steps=220/32524, lr=3.99e-05, t_loss=1.7288, t_f1r=[0.287, 0.520, 0.259, 0.370, 0.572, 0.366, 0.476], t_f1=0.407
05/30 12:50:03 [INFO] eps=1, steps=450/32524, lr=3.97e-05, t_loss=1.4532, t_f1r=[0.387, 0.658, 0.312, 0.451, 0.684, 0.434, 0.577], t_f1=0.500
05/30 12:51:04 [INFO] eps=1, steps=680/32524, lr=3.93e-05, t_loss=1.3452, t_f1r=[0.436, 0.709, 0.341, 0.506, 0.719, 0.448, 0.616], t_f1=0.539
05/30 12:52:05 [INFO] eps=1, steps=910/32524, lr=3.88e-05, t_loss=1.2819, t_f1r=[0.471, 0.742, 0.353, 0.534, 0.738, 0.467, 0.633], t_f1=0.563
05/30 12:53:06 [INFO] eps=1, steps=1140/32524, lr=3.82e-05, t_loss=1.2390, t_f1r=[0.487, 0.758, 0.369, 0.548, 0.749, 0.484, 0.652], t_f1=0.578
05/30 12:54:07 [INFO] eps=1, steps=1370/32524, lr=3.74e-05, t_loss=1.2061, t_f1r=[0.502, 0.774, 0.386, 0.566, 0.760, 0.491, 0.662], t_f1=0.591
05/30 12:55:09 [INFO] eps=1, steps=1600/32524, lr=3.65e-05, t_loss=1.1816, t_f1r=[0.518, 0.787, 0.399, 0.577, 0.764, 0.497, 0.666], t_f1=0.601
05/

  0%|          | 0/2376 [00:00<?, ?it/s]

05/30 13:25:30 [INFO] steps=307/2376, v_loss=0.6666, v_f1r=[0.905, 0.000, 0.000, 0.000, 0.000, 0.000, 0.802], v_f1=0.244
05/30 13:26:30 [INFO] steps=637/2376, v_loss=0.6096, v_f1r=[0.896, 0.962, 0.000, 0.000, 0.000, 0.000, 0.880], v_f1=0.391
05/30 13:27:30 [INFO] steps=968/2376, v_loss=0.6756, v_f1r=[0.868, 0.956, 0.633, 0.000, 0.000, 0.793, 0.851], v_f1=0.586
05/30 13:28:30 [INFO] steps=1299/2376, v_loss=0.6960, v_f1r=[0.828, 0.939, 0.589, 0.000, 0.000, 0.843, 0.830], v_f1=0.576
05/30 13:29:31 [INFO] steps=1630/2376, v_loss=0.6581, v_f1r=[0.823, 0.938, 0.554, 0.000, 0.927, 0.845, 0.811], v_f1=0.700
05/30 13:30:31 [INFO] steps=1961/2376, v_loss=0.6365, v_f1r=[0.821, 0.937, 0.516, 0.526, 0.936, 0.844, 0.792], v_f1=0.768
05/30 13:31:31 [INFO] steps=2292/2376, v_loss=0.6560, v_f1r=[0.814, 0.935, 0.457, 0.811, 0.927, 0.837, 0.770], v_f1=0.793
05/30 13:31:46 [INFO] eps=1, steps=8131/32524, lr=3.73e-05, t_loss=0.9036, v_f1r=[0.813, 0.934, 0.443, 0.823, 0.924, 0.835, 0.764], v_loss=0.6599*, v

  0%|          | 0/2376 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f99dfc4c680>
Traceback (most recent call last):
  File "/home/notebook/miniconda3/envs/py312/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1665, in __del__
    Exception ignored in: Exception ignored in: self._shutdown_workers()<function _MultiProcessingDataLoaderIter.__del__ at 0x7f99dfc4c680>
<function _MultiProcessingDataLoaderIter.__del__ at 0x7f99dfc4c680>

Traceback (most recent call last):
  File "/home/notebook/miniconda3/envs/py312/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1665, in __del__
Traceback (most recent call last):
    self._shutdown_workers()
  File "/home/notebook/miniconda3/envs/py312/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1648, in _shutdown_workers
  File "/home/notebook/miniconda3/envs/py312/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1665, in __del__
  File "/home/notebook/miniconda3/envs/py

05/30 14:09:16 [INFO] steps=324/2376, v_loss=0.5911, v_f1r=[0.929, 0.000, 0.000, 0.000, 0.000, 0.000, 0.845], v_f1=0.253
05/30 14:10:16 [INFO] steps=655/2376, v_loss=0.5674, v_f1r=[0.920, 0.968, 0.000, 0.000, 0.000, 0.000, 0.877], v_f1=0.395
05/30 14:11:16 [INFO] steps=986/2376, v_loss=0.6336, v_f1r=[0.889, 0.961, 0.684, 0.000, 0.000, 0.826, 0.851], v_f1=0.601
05/30 14:12:16 [INFO] steps=1317/2376, v_loss=0.6495, v_f1r=[0.854, 0.948, 0.642, 0.000, 0.000, 0.865, 0.832], v_f1=0.592
05/30 14:13:16 [INFO] steps=1648/2376, v_loss=0.6152, v_f1r=[0.852, 0.947, 0.603, 0.000, 0.935, 0.865, 0.819], v_f1=0.717
05/30 14:14:17 [INFO] steps=1978/2376, v_loss=0.5990, v_f1r=[0.850, 0.947, 0.560, 0.610, 0.943, 0.864, 0.805], v_f1=0.797
05/30 14:15:17 [INFO] steps=2308/2376, v_loss=0.6120, v_f1r=[0.844, 0.945, 0.498, 0.835, 0.937, 0.858, 0.788], v_f1=0.815
05/30 14:15:29 [INFO] eps=1, steps=16262/32524, lr=3.49e-05, t_loss=0.7757, v_f1r=[0.842, 0.944, 0.487, 0.843, 0.935, 0.856, 0.784], v_loss=0.6146*, 

  0%|          | 0/2376 [00:00<?, ?it/s]

05/30 14:53:21 [INFO] steps=325/2376, v_loss=0.5870, v_f1r=[0.931, 0.000, 0.000, 0.000, 0.000, 0.000, 0.853], v_f1=0.255
05/30 14:54:21 [INFO] steps=657/2376, v_loss=0.5597, v_f1r=[0.922, 0.972, 0.000, 0.000, 0.000, 0.000, 0.878], v_f1=0.396
05/30 14:55:21 [INFO] steps=988/2376, v_loss=0.6164, v_f1r=[0.898, 0.965, 0.695, 0.000, 0.000, 0.845, 0.854], v_f1=0.608
05/30 14:56:21 [INFO] steps=1318/2376, v_loss=0.6312, v_f1r=[0.871, 0.954, 0.650, 0.000, 0.000, 0.877, 0.837], v_f1=0.598
05/30 14:57:22 [INFO] steps=1649/2376, v_loss=0.5945, v_f1r=[0.869, 0.953, 0.619, 0.000, 0.937, 0.878, 0.824], v_f1=0.726
05/30 14:58:22 [INFO] steps=1979/2376, v_loss=0.5765, v_f1r=[0.868, 0.953, 0.578, 0.618, 0.946, 0.877, 0.811], v_f1=0.807
05/30 14:59:22 [INFO] steps=2309/2376, v_loss=0.5855, v_f1r=[0.864, 0.951, 0.515, 0.845, 0.940, 0.871, 0.796], v_f1=0.826
05/30 14:59:34 [INFO] eps=1, steps=24393/32524, lr=3.26e-05, t_loss=0.7210, v_f1r=[0.863, 0.951, 0.505, 0.854, 0.939, 0.870, 0.793], v_loss=0.5873*, 

W0530 15:36:46.757000 128577 site-packages/torch/_inductor/utils.py:1408] [0/2] Not enough SMs to use max_autotune_gemm mode


  0%|          | 0/2376 [00:00<?, ?it/s]

05/30 15:38:09 [INFO] steps=324/2376, v_loss=0.5450, v_f1r=[0.939, 0.000, 0.000, 0.000, 0.000, 0.000, 0.853], v_f1=0.256
05/30 15:39:09 [INFO] steps=655/2376, v_loss=0.5293, v_f1r=[0.931, 0.973, 0.000, 0.000, 0.000, 0.000, 0.888], v_f1=0.399
05/30 15:40:09 [INFO] steps=986/2376, v_loss=0.5872, v_f1r=[0.910, 0.969, 0.703, 0.000, 0.000, 0.856, 0.861], v_f1=0.614
05/30 15:41:09 [INFO] steps=1316/2376, v_loss=0.6002, v_f1r=[0.882, 0.961, 0.657, 0.000, 0.000, 0.886, 0.841], v_f1=0.604
05/30 15:42:09 [INFO] steps=1647/2376, v_loss=0.5721, v_f1r=[0.880, 0.961, 0.619, 0.000, 0.939, 0.886, 0.828], v_f1=0.730
05/30 15:43:09 [INFO] steps=1976/2376, v_loss=0.5585, v_f1r=[0.879, 0.960, 0.579, 0.646, 0.946, 0.885, 0.814], v_f1=0.816
05/30 15:44:09 [INFO] steps=2306/2376, v_loss=0.5667, v_f1r=[0.875, 0.959, 0.528, 0.860, 0.942, 0.879, 0.797], v_f1=0.834
05/30 15:44:22 [INFO] eps=1, steps=32524/32524, lr=3.04e-05, t_loss=0.6764, v_f1r=[0.874, 0.958, 0.518, 0.868, 0.941, 0.878, 0.793], v_loss=0.5680*, 

Epoch 2:   0%|          | 0/32524 [00:00<?, ?it/s]

05/30 15:45:24 [INFO] eps=2, steps=150/32524, lr=3.04e-05, t_loss=0.6968, t_f1r=[0.814, 0.940, 0.724, 0.779, 0.902, 0.764, 0.798], t_f1=0.817
05/30 15:46:26 [INFO] eps=2, steps=380/32524, lr=3.03e-05, t_loss=0.6933, t_f1r=[0.835, 0.946, 0.716, 0.779, 0.889, 0.767, 0.790], t_f1=0.817
05/30 15:47:27 [INFO] eps=2, steps=610/32524, lr=3e-05, t_loss=0.7065, t_f1r=[0.822, 0.941, 0.705, 0.770, 0.883, 0.773, 0.791], t_f1=0.812
05/30 15:48:28 [INFO] eps=2, steps=840/32524, lr=2.97e-05, t_loss=0.7075, t_f1r=[0.820, 0.941, 0.691, 0.777, 0.886, 0.771, 0.786], t_f1=0.810
05/30 15:49:29 [INFO] eps=2, steps=1070/32524, lr=2.92e-05, t_loss=0.7042, t_f1r=[0.819, 0.943, 0.693, 0.778, 0.884, 0.770, 0.794], t_f1=0.812
05/30 15:50:31 [INFO] eps=2, steps=1300/32524, lr=2.87e-05, t_loss=0.7049, t_f1r=[0.815, 0.943, 0.689, 0.777, 0.885, 0.767, 0.794], t_f1=0.810
05/30 15:51:32 [INFO] eps=2, steps=1530/32524, lr=2.8e-05, t_loss=0.7004, t_f1r=[0.821, 0.942, 0.691, 0.781, 0.883, 0.772, 0.798], t_f1=0.812
05/30 1

  0%|          | 0/2376 [00:00<?, ?it/s]

05/30 16:22:40 [INFO] steps=324/2376, v_loss=0.5166, v_f1r=[0.946, 0.000, 0.000, 0.000, 0.000, 0.000, 0.863], v_f1=0.258
05/30 16:23:40 [INFO] steps=656/2376, v_loss=0.5090, v_f1r=[0.938, 0.976, 0.000, 0.000, 0.000, 0.000, 0.892], v_f1=0.401
05/30 16:24:40 [INFO] steps=988/2376, v_loss=0.5688, v_f1r=[0.916, 0.970, 0.732, 0.000, 0.000, 0.865, 0.868], v_f1=0.622
05/30 16:25:41 [INFO] steps=1319/2376, v_loss=0.5795, v_f1r=[0.889, 0.959, 0.693, 0.000, 0.000, 0.893, 0.851], v_f1=0.612
05/30 16:26:41 [INFO] steps=1650/2376, v_loss=0.5495, v_f1r=[0.887, 0.959, 0.662, 0.000, 0.944, 0.894, 0.839], v_f1=0.741
05/30 16:27:41 [INFO] steps=1981/2376, v_loss=0.5348, v_f1r=[0.885, 0.958, 0.624, 0.668, 0.952, 0.893, 0.827], v_f1=0.829
05/30 16:28:41 [INFO] steps=2312/2376, v_loss=0.5447, v_f1r=[0.882, 0.956, 0.563, 0.864, 0.947, 0.888, 0.810], v_f1=0.844
05/30 16:28:53 [INFO] eps=2, steps=8131/32524, lr=2.84e-05, t_loss=0.6409, v_f1r=[0.881, 0.956, 0.553, 0.871, 0.946, 0.887, 0.807], v_loss=0.5463*, v

  0%|          | 0/2376 [00:00<?, ?it/s]

05/30 17:06:53 [INFO] steps=323/2376, v_loss=0.5040, v_f1r=[0.949, 0.000, 0.000, 0.000, 0.000, 0.000, 0.868], v_f1=0.260
05/30 17:07:53 [INFO] steps=655/2376, v_loss=0.4976, v_f1r=[0.941, 0.976, 0.000, 0.000, 0.000, 0.000, 0.897], v_f1=0.402
05/30 17:08:54 [INFO] steps=987/2376, v_loss=0.5528, v_f1r=[0.921, 0.971, 0.742, 0.000, 0.000, 0.872, 0.874], v_f1=0.626
05/30 17:09:54 [INFO] steps=1318/2376, v_loss=0.5617, v_f1r=[0.895, 0.962, 0.704, 0.000, 0.000, 0.899, 0.856], v_f1=0.617
05/30 17:10:54 [INFO] steps=1650/2376, v_loss=0.5386, v_f1r=[0.893, 0.961, 0.668, 0.000, 0.943, 0.900, 0.842], v_f1=0.744
05/30 17:11:54 [INFO] steps=1980/2376, v_loss=0.5280, v_f1r=[0.892, 0.961, 0.627, 0.670, 0.949, 0.898, 0.828], v_f1=0.832
05/30 17:12:54 [INFO] steps=2311/2376, v_loss=0.5364, v_f1r=[0.888, 0.960, 0.571, 0.868, 0.945, 0.892, 0.811], v_f1=0.848
05/30 17:13:06 [INFO] eps=2, steps=16262/32524, lr=2.66e-05, t_loss=0.6135, v_f1r=[0.888, 0.959, 0.562, 0.875, 0.944, 0.892, 0.808], v_loss=0.5377*, 

  0%|          | 0/2376 [00:00<?, ?it/s]

05/30 17:51:07 [INFO] steps=323/2376, v_loss=0.5095, v_f1r=[0.948, 0.000, 0.000, 0.000, 0.000, 0.000, 0.865], v_f1=0.259
05/30 17:52:07 [INFO] steps=655/2376, v_loss=0.5059, v_f1r=[0.940, 0.977, 0.000, 0.000, 0.000, 0.000, 0.895], v_f1=0.402
05/30 17:53:07 [INFO] steps=987/2376, v_loss=0.5535, v_f1r=[0.923, 0.973, 0.743, 0.000, 0.000, 0.876, 0.873], v_f1=0.627
05/30 17:54:07 [INFO] steps=1319/2376, v_loss=0.5566, v_f1r=[0.901, 0.967, 0.701, 0.000, 0.000, 0.905, 0.858], v_f1=0.619
05/30 17:55:07 [INFO] steps=1650/2376, v_loss=0.5299, v_f1r=[0.900, 0.966, 0.670, 0.000, 0.946, 0.905, 0.846], v_f1=0.748
05/30 17:56:08 [INFO] steps=1982/2376, v_loss=0.5160, v_f1r=[0.899, 0.966, 0.633, 0.677, 0.953, 0.904, 0.835], v_f1=0.838
05/30 17:57:08 [INFO] steps=2313/2376, v_loss=0.5213, v_f1r=[0.896, 0.965, 0.584, 0.876, 0.948, 0.899, 0.821], v_f1=0.856
05/30 17:57:19 [INFO] eps=2, steps=24393/32524, lr=2.49e-05, t_loss=0.5857, v_f1r=[0.896, 0.965, 0.576, 0.884, 0.948, 0.898, 0.818], v_loss=0.5219*, 

  0%|          | 0/2376 [00:00<?, ?it/s]

05/30 18:35:20 [INFO] steps=325/2376, v_loss=0.4806, v_f1r=[0.952, 0.000, 0.000, 0.000, 0.000, 0.000, 0.874], v_f1=0.261
05/30 18:36:20 [INFO] steps=657/2376, v_loss=0.4813, v_f1r=[0.945, 0.977, 0.000, 0.000, 0.000, 0.000, 0.904], v_f1=0.404
05/30 18:37:20 [INFO] steps=988/2376, v_loss=0.5378, v_f1r=[0.927, 0.974, 0.759, 0.000, 0.000, 0.883, 0.880], v_f1=0.632
05/30 18:38:20 [INFO] steps=1319/2376, v_loss=0.5402, v_f1r=[0.908, 0.968, 0.726, 0.000, 0.000, 0.913, 0.863], v_f1=0.625
05/30 18:39:20 [INFO] steps=1650/2376, v_loss=0.5147, v_f1r=[0.906, 0.968, 0.698, 0.000, 0.948, 0.913, 0.851], v_f1=0.755
05/30 18:40:21 [INFO] steps=1981/2376, v_loss=0.5022, v_f1r=[0.905, 0.967, 0.664, 0.690, 0.955, 0.911, 0.839], v_f1=0.847
05/30 18:41:21 [INFO] steps=2312/2376, v_loss=0.5081, v_f1r=[0.902, 0.967, 0.615, 0.881, 0.951, 0.906, 0.823], v_f1=0.864
05/30 18:41:32 [INFO] eps=2, steps=32524/32524, lr=2.33e-05, t_loss=0.5655, v_f1r=[0.902, 0.967, 0.607, 0.889, 0.950, 0.904, 0.821], v_loss=0.5088*, 

Epoch 3:   0%|          | 0/32524 [00:00<?, ?it/s]

05/30 18:42:35 [INFO] eps=3, steps=150/32524, lr=2.32e-05, t_loss=0.5527, t_f1r=[0.887, 0.954, 0.828, 0.860, 0.927, 0.830, 0.841], t_f1=0.875
05/30 18:43:36 [INFO] eps=3, steps=380/32524, lr=2.31e-05, t_loss=0.5705, t_f1r=[0.887, 0.956, 0.792, 0.828, 0.918, 0.823, 0.840], t_f1=0.863
05/30 18:44:37 [INFO] eps=3, steps=610/32524, lr=2.3e-05, t_loss=0.5699, t_f1r=[0.882, 0.956, 0.802, 0.835, 0.917, 0.835, 0.839], t_f1=0.867
05/30 18:45:39 [INFO] eps=3, steps=840/32524, lr=2.27e-05, t_loss=0.5702, t_f1r=[0.884, 0.960, 0.800, 0.834, 0.916, 0.839, 0.836], t_f1=0.867
05/30 18:46:40 [INFO] eps=3, steps=1070/32524, lr=2.24e-05, t_loss=0.5749, t_f1r=[0.879, 0.959, 0.788, 0.833, 0.916, 0.834, 0.840], t_f1=0.864
05/30 18:47:41 [INFO] eps=3, steps=1300/32524, lr=2.2e-05, t_loss=0.5759, t_f1r=[0.880, 0.960, 0.788, 0.830, 0.917, 0.832, 0.839], t_f1=0.864
05/30 18:48:43 [INFO] eps=3, steps=1530/32524, lr=2.15e-05, t_loss=0.5760, t_f1r=[0.878, 0.961, 0.788, 0.829, 0.918, 0.831, 0.844], t_f1=0.864
05/30

  0%|          | 0/2376 [00:00<?, ?it/s]

05/30 19:19:50 [INFO] steps=323/2376, v_loss=0.4788, v_f1r=[0.957, 0.000, 0.000, 0.000, 0.000, 0.000, 0.877], v_f1=0.262
05/30 19:20:50 [INFO] steps=655/2376, v_loss=0.4880, v_f1r=[0.949, 0.977, 0.000, 0.000, 0.000, 0.000, 0.899], v_f1=0.404
05/30 19:21:50 [INFO] steps=986/2376, v_loss=0.5404, v_f1r=[0.932, 0.974, 0.760, 0.000, 0.000, 0.889, 0.878], v_f1=0.633
05/30 19:22:51 [INFO] steps=1318/2376, v_loss=0.5411, v_f1r=[0.911, 0.968, 0.726, 0.000, 0.000, 0.915, 0.863], v_f1=0.626
05/30 19:23:51 [INFO] steps=1649/2376, v_loss=0.5138, v_f1r=[0.909, 0.968, 0.702, 0.000, 0.950, 0.915, 0.853], v_f1=0.757
05/30 19:24:51 [INFO] steps=1980/2376, v_loss=0.4999, v_f1r=[0.908, 0.967, 0.670, 0.690, 0.957, 0.914, 0.841], v_f1=0.849
05/30 19:25:51 [INFO] steps=2312/2376, v_loss=0.5040, v_f1r=[0.905, 0.966, 0.625, 0.886, 0.952, 0.908, 0.827], v_f1=0.867
05/30 19:26:03 [INFO] eps=3, steps=8131/32524, lr=2.18e-05, t_loss=0.5373, v_f1r=[0.904, 0.966, 0.618, 0.893, 0.951, 0.907, 0.825], v_loss=0.5045*, v

  0%|          | 0/2376 [00:00<?, ?it/s]

05/30 20:04:01 [INFO] steps=323/2376, v_loss=0.4846, v_f1r=[0.953, 0.000, 0.000, 0.000, 0.000, 0.000, 0.875], v_f1=0.261
05/30 20:05:01 [INFO] steps=654/2376, v_loss=0.4739, v_f1r=[0.947, 0.979, 0.000, 0.000, 0.000, 0.000, 0.911], v_f1=0.405


# 모델 앙상블 및 추론

In [None]:
train_fold_df = None
valid_fold_df = None
train_dataset = None
train_loader = None
val_dataset = None
val_loader = None

optimizer = None
scheduler = None
model = None

run = None
    
gc.collect()
torch.cuda.empty_cache()

In [None]:
test_df = pd.read_csv('./test.csv')


In [None]:
import re
ckpt_df = pd.DataFrame({'fname':glob('./ckpt/*.ckpt')})
ckpt_df['mtime'] = ckpt_df.fname.apply(lambda x: int(os.stat(x).st_mtime))
ckpt_df['model_name'] = ckpt_df.fname.apply(lambda x: re.search(r'./ckpt/(.*?)-fold',x)[1])
ckpt_df['img_size'] = ckpt_df.fname.apply(lambda x: int(re.search(r'patch[0-9]+_([0-9]+)', x + 'patch0_0')[1]) )
ckpt_df['is_ema'] = ckpt_df.fname.str.endswith('ema.ckpt').astype(int)
ckpt_df['fold_idx'] = ckpt_df.fname.apply(lambda x: int(re.search(r'fold_idx=([0-9])-',x)[1]))
ckpt_df['val_loss'] = ckpt_df.fname.apply(lambda x: float(re.search(r'val_loss=(0\.[0-9]+)', x)[1]) )
ckpt_df['val_score'] = ckpt_df.fname.apply(lambda x: float(re.search(r'val_score=(0\.[0-9]+)', x)[1]) )

In [None]:
ckpt_df = ckpt_df[ckpt_df.img_size != 0] # [ckpt_df.is_ema == 0]
ckpt_df = ckpt_df.sort_values('mtime',ascending=False).reset_index(drop=True)
MAX_SIZE = 1 ## 상위 4개..
ckpt_indexes = ckpt_df[ ckpt_df.fold_idx==ckpt_df.fold_idx.max() ].index[:MAX_SIZE]  

In [None]:
preds = []
preds_score = []

for ckpt_start_index in ckpt_indexes:
    # 메모리 초기화..
    gc.collect()
    torch.cuda.empty_cache()
    
    logger.info(f'{ckpt_df.fname[ckpt_start_index]} loading')
    ## imagesize
    CFG['IMG_SIZE'] = int(ckpt_df.img_size[ckpt_start_index])
    assert CFG['IMG_SIZE'] in ( 196, 224, 448 )
    logger.info(CFG['IMG_SIZE'])

    test_dataset = CustomDataset(
        test_df['img_path'].values, None, 
        interpolation=CFG['INTERPOLATION'], load_img_size=CFG['IMG_SIZE'],
        shuffle=False, transforms=test_transform)
    test_loader = DataLoader(test_dataset, batch_size=CFG['BATCH_SIZE']*2, shuffle=False, num_workers=4)

    model_name = ckpt_df.model_name[ckpt_start_index]
    model = create_model(model_name)
    # 관련이 없을 듯 한데, 그래도 영향이 발생하지 않도록 
    if (CFG['GRADIENT_CHECKPOINT']==True) and ('set_grad_checkpointing' in dir(model)):
        model.set_grad_checkpointing() ## for memory_efficient..
    #     # logger.info('grad_checkpointing : True')
    else:
        model.compile() #
    #     # logger.info('model compiled')
    
    if ckpt_df.is_ema[ckpt_start_index]:
        model = torch.optim.swa_utils.AveragedModel(model)
    #-----------------------------
    for i in range(ckpt_start_index, ckpt_start_index + ckpt_df.fold_idx.max() + 1 ):
        checkpoint_path = ckpt_df.fname[i]
        logger.info(f'{checkpoint_path} loading')
        model.load_state_dict( torch.load(checkpoint_path)['model'] )
        
        preds_score.append( ckpt_df.val_score[i] )
        preds.append( prediction(model, test_loader, device) )
    
preds = np.array(preds)
preds_score = np.array(preds_score)

In [None]:
# # ### 가중치 평균값..
preds_labels = le.inverse_transform(preds.sum(0).argmax(-1))
print(preds_labels)

In [None]:
submit = pd.read_csv('./sample_submission.csv')
submit['rock_type'] = preds_labels
from datetime import datetime
dt_str = datetime.now().strftime('%Y%m%d_%H%M')
submit.to_csv(f'./basslibrary_submit_{dt_str}.csv', index=False)
logger.info(f'./basslibrary_submit_{dt_str}.csv saved')

In [None]:
submit_value_counts = str(submit['rock_type'].value_counts())
!python ~/send_telegram.py f'submit created: {submit_value_counts}'