In [None]:
import os
import gc
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision

from glob import glob
from tqdm.auto import tqdm
from sklearn.metrics import f1_score
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import StratifiedKFold
from torchvision.transforms import v2
import albumentations as A
import cv2
from albumentations.pytorch import ToTensorV2

from torch.utils.data import Dataset, DataLoader
from transformers import Swinv2Model, ConvNextV2Model, AutoModel
import timm
from PIL import Image

torch.set_float32_matmul_precision('high')  # or 'medium' | 'high'
# os.environ['WANDB_API_KEY']='xxxxx'
# os.environ['WANDB_MODE']='online'
# os.environ['WANDB_PROJECT']='basslibrary240210'
os.environ['WANDB_MODE']='offline'

######## logger ########
import sys, logging, IPython
logger = logging.getLogger()
logging.basicConfig( handlers=[ logging.StreamHandler(stream=sys.stdout), logging.handlers.RotatingFileHandler(filename='run.log', mode='a', maxBytes=512000, backupCount=4) ] )
logging_fomatter = logging.Formatter( '%(asctime)s [%(levelname)-4.4s] %(message)s', datefmt='%m/%d %H:%M:%S' )
_ = [ h.setFormatter(logging_fomatter) for h in logger.handlers ]
logger.setLevel(logging.INFO)
def showtraceback(self, *args, **kwargs):
    logger.exception('-------Exception----------')
IPython.core.interactiveshell.InteractiveShell.showtraceback = showtraceback
logger.info('program started')

In [None]:
CFG = {}
CFG['SEED'] = 42
CFG['N_SPLIT'] = 5
CFG['LABEL_SMOOTHING'] = 0.05
CFG['OPTIMIZER'] = 'AdamW'
CFG['INTERPOLATION'] = 'robidouxsharp'
CFG['PRECISION'] = '16'
#----------------------------------
CFG['MODEL_NAME'] = "timm/beitv2_large_patch16_224.in1k_ft_in22k_in1k"
CFG['IMG_SIZE'] = 224
CFG['BATCH_SIZE'] = 48 ## 48//16G(ema), 14//8G memory..
CFG['LR'] = [ 0.25e-5 * np.sqrt(CFG['BATCH_SIZE']), 1e-6 ]
#----------------------------------

######################################
if 'IMG_TRAIN_SIZE' not in CFG:
    CFG['IMG_TRAIN_SIZE'] = CFG['IMG_SIZE']
logger.info(CFG)

In [None]:
assert torch.cuda.is_available()
device = 'cuda' if torch.cuda.is_available() else 'cpu'
torch.set_default_device(device)
logger.info(device)

In [None]:
def seed_everything(seed):
    logger.info(f'seed_everything : {seed}')

    import random, os
    import numpy as np
    import torch
    
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

seed_everything(CFG['SEED'])

In [None]:
class CustomDataset(Dataset):
    def __init__(self, img_path_list, label_list, load_img_size, shuffle=False, transforms=None, interpolation='robidouxsharp' ):
        self.df = pd.DataFrame({'img_path_list': img_path_list})
        self.interpolation = interpolation
        self.load_img_size = load_img_size
        logger.info(f'load_img_size={load_img_size}')
        if label_list is not None:
            self.df['label_list'] = label_list
        if shuffle:
            self.df = self.df.sample(frac=1.0).reset_index(drop=True)
        self.transforms = transforms

    # numpy or PIL Image => PIL Image
    def get_interpolated_image(self, img, new_image_size):
        if self.interpolation == 'pil_lanczos':
            if isinstance(img, np.ndarray ):
                img = Image.fromarray(img)
            return img.resize( (new_image_size, new_image_size), Image.LANCZOS )
        elif self.interpolation == 'cv2_lanczos4':
            if not isinstance(img, np.ndarray ):
                img = np.array(img)
            import cv2
            img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
            img = cv2.resize(src, (new_image_size, new_image_size), interpolation=cv2.INTER_LANCZOS4) # 픽셀 크기 지정
            img = cv2.cvtColor(np.array(img), cv2.COLOR_BGR2RGB)
            return Image.fromarray(img)
        else:
            if not isinstance(img, np.ndarray ):
                img = np.array(img)
            from wand import image
            with image.Image.from_array(img) as src:
                src.resize( new_image_size, new_image_size, filter=self.interpolation )
                return Image.fromarray(np.array(src))
                
    # path => PIL Image
    def get_image_from_index(self, index, img_size ):
        img_path = self.df.img_path_list[index]
        fname = img_path.replace('./','').split('.')[0] + '.png'
        full_fname = f'img_cached/{img_size}_{self.interpolation}/{fname}'
        if os.path.exists(full_fname):
            img = Image.open(full_fname)
        else:            
            fname_path = '/'.join(full_fname.split('/')[:-1])
            os.makedirs(fname_path, exist_ok = True)
            img = self.get_interpolated_image(Image.open(img_path), img_size )
            img.save( full_fname )
        return img
    
    def __getitem__(self, index):
        image = self.get_image_from_index( index, self.load_img_size )
        if self.transforms is not None:
            image = self.transforms(image)
        if 'label_list' in self.df.columns:
            label = self.df.label_list[index]
            return { 'pixel_values': image, 'label': label }
        else:
            return { 'pixel_values': image }
    
    def __len__(self):
        return len(self.df)

In [None]:
class CustomModel(nn.Module):
    def __init__(self, model):
        super(CustomModel, self).__init__()
        self.model = model
        self.clf = nn.LazyLinear(25)
        
    def forward(self, x):
        x = self.model(x)
        if not isinstance(x, torch.Tensor):
            x = x.pooler_output
        if self.clf:
            x = self.clf(x)
        return x

In [None]:
image_size = CFG['IMG_SIZE']

train_transform_list = [
    v2.TrivialAugmentWide(interpolation=v2.InterpolationMode.BICUBIC), 
    v2.RandomErasing(),
    v2.Resize(size=(image_size, image_size), interpolation=v2.InterpolationMode.LANCZOS, antialias=True),
    v2.ToImage(), v2.ToDtype( torch.float32, scale=True),
    v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
if CFG['IMG_SIZE'] == CFG['IMG_TRAIN_SIZE']:
    train_transform_list = [ a for a in train_transform_list if not isinstance(a, v2.Resize) ]
train_transform = v2.Compose(train_transform_list )
test_transform = v2.Compose( [
    v2.ToImage(), v2.ToDtype( torch.float32, scale=True),
    v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [None]:
def prediction(model, test_loader, device):
    model = model.to(device)
    save_training = model.training
    model.eval()
    preds = []

    with torch.no_grad():
        for batch in tqdm(test_loader):
            pixel_values = batch['pixel_values'].to(device)            
            pred = model(pixel_values)  ## F.softmax(output) ## 의미는 없을 듯.
            preds += pred.detach().cpu().numpy().tolist()
    if save_training:
        model.train()
    return preds

In [None]:
def create_model(model_name):
    import timm
    from transformers import AutoModel, AutoModelForImageClassification, AutoConfig
    
    logger.info(f'create_model: {model_name}')
    if '/' not in model_name:
        model_name = 'timm/' + model_name
    if model_name.startswith('./'):
        import nextvit
        model = CustomModel( timm.create_model('nextvit_large', pretrained=True, checkpoint_path=model_name) )
    elif model_name.startswith('facebook/hiera_'):
        from hiera import Hiera  ## pip install hiera-transformer
        model = CustomModel( Hiera.from_pretrained(model_name) )
    elif model_name.startswith('timm/'):
        model = CustomModel( timm.create_model( model_name, pretrained=True ) )
    else:
        model = CustomModel( AutoModel.from_pretrained(model_name) )
        
    model.eval()
    model( torch.rand((1,3,CFG['IMG_SIZE'],CFG['IMG_SIZE'])).type(torch.float32) ) ## initalize_lazyLinear..
    return model

In [None]:
train_df = pd.read_csv('./train.csv')
le = LabelEncoder()
train_df['class'] = le.fit_transform(train_df['label'])

In [None]:
test_df = pd.read_csv('./test.csv')

In [None]:
import re
ckpt_df = pd.DataFrame({'fname':glob('./ckpt/*.ckpt')})
ckpt_df['mtime'] = ckpt_df.fname.apply(lambda x: int(os.stat(x).st_mtime))
ckpt_df['model_name'] = ckpt_df.fname.apply(lambda x: re.search(r'./ckpt/(.*?)-fold',x)[1])
ckpt_df['img_size'] = ckpt_df.fname.apply(lambda x: int(re.search(r'patch[0-9]+_([0-9]+)', x + 'patch0_0')[1]) )
ckpt_df['is_ema'] = ckpt_df.fname.str.endswith('ema.ckpt').astype(int)
ckpt_df['fold_idx'] = ckpt_df.fname.apply(lambda x: int(re.search(r'fold_idx=([0-9])-',x)[1]))
ckpt_df['val_loss'] = ckpt_df.fname.apply(lambda x: float(re.search(r'val_loss=(0\.[0-9]+)', x)[1]) )
ckpt_df['val_score'] = ckpt_df.fname.apply(lambda x: float(re.search(r'val_score=(0\.[0-9]+)', x)[1]) )

In [None]:
ckpt_df.loc[ckpt_df.model_name == 'swinv2', ['model_name', 'img_size']] = ['microsoft/swinv2-large-patch4-window12-192-22k', 192]

In [None]:
ckpt_df = ckpt_df[ (ckpt_df.img_size != 0) & (ckpt_df.is_ema == 0) ]
ckpt_df = ckpt_df.sort_values('mtime',ascending=False).reset_index(drop=True)
display(ckpt_df)
ckpt_indexes = ckpt_df[ ckpt_df.fold_idx==ckpt_df.fold_idx.max() ].index[:4]
display(ckpt_indexes)

In [None]:
preds = []
preds_score = []

for ckpt_start_index in ckpt_indexes:
    logger.info(f'{ckpt_df.fname[ckpt_start_index]} loading')
    ## imagesize
    CFG['IMG_SIZE'] = ckpt_df.img_size[ckpt_start_index]
    assert CFG['IMG_SIZE'] in ( 192, 196, 224, )
    logger.info(CFG['IMG_SIZE'])

    test_dataset = CustomDataset(
        test_df['img_path'].values, None, 
        interpolation=CFG['INTERPOLATION'], load_img_size=CFG['IMG_SIZE'],
        shuffle=False, transforms=test_transform)
    test_loader = DataLoader(test_dataset, batch_size=CFG['BATCH_SIZE']*2, shuffle=False, num_workers=0)

    model_name = ckpt_df.model_name[ckpt_start_index]
    model = create_model(model_name)
    if ckpt_df.is_ema[ckpt_start_index]:
        model = torch.optim.swa_utils.AveragedModel(model)
    #-----------------------------
    for i in range(ckpt_start_index, ckpt_start_index + ckpt_df.fold_idx.max() + 1 ):
        checkpoint_path = ckpt_df.fname[i]
        logger.info(f'{checkpoint_path} loading')
        model.load_state_dict( torch.load(checkpoint_path)['model'] )

        preds_score.append( ckpt_df.val_score[i] )
        preds.append( prediction(model, test_loader, device) )
    
preds = np.array(preds)
preds_score = np.array(preds_score)

In [None]:
# ### 가중치 평균값..
preds_error = (1-preds_score)  ## L1 ACC 오차인경우
preds_error = 1-preds_error/preds_error.sum()
preds_coef = preds_error/preds_error.sum()

logger.info(f'{preds_score=}')
logger.info(f'{preds_coef=}')
preds2 = np.array( [ coef * preds[i] for i, coef in enumerate( preds_coef ) ] )
preds_labels = le.inverse_transform(preds2.sum(0).argmax(-1))
print(preds_labels)

In [None]:
submit = pd.read_csv('./sample_submission.csv')
submit['label'] = preds_labels
from datetime import datetime
dt_str = datetime.now().strftime('%Y%m%d_%H%M')
submit.to_csv(f'./basslibrary_submit_{dt_str}.csv', index=False)
logger.info(f'./basslibrary_submit_{dt_str}.csv saved')

In [None]:
submit.label.value_counts()

In [None]:
# !python ~/send_telegram.py 'basslibrary_submit_{dt_str}.csv saved'