In [1]:
import gc
import os

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import sklearn
import torch
from PIL import Image
from sklearn.model_selection import GroupKFold
from sklearn.preprocessing import LabelEncoder
from timm import create_model, list_models
from timm.data import create_transform
from torch.cuda.amp import GradScaler, autocast
from torch.utils.data import Sampler
from tqdm import tqdm
import random 
import wandb
import glob
import timm

plt.rcParams['figure.figsize'] = (20, 5)
pd.set_option('display.max_rows', 100)
pd.set_option('display.max_columns', 1000)

torch.set_float32_matmul_precision('high')
# torch._dynamo.config.suppress_errors = True


In [2]:
# ====================================================
# Directory settings
# ====================================================
import os

data_dir =  '/home/rashmi/Documents/kaggle/aiornot/' 
EXP_NAME = 'exp25'
OUTPUT_DIR = f'{data_dir}/src/models_' + EXP_NAME + "/"

if not os.path.exists(OUTPUT_DIR):
    os.makedirs(OUTPUT_DIR)

In [3]:
import timm
# timm.list_models("convnext*")
# timm.list_models()

In [4]:
# Config

class CFG:
    one_cycle = True
    one_cycle_pct_start = 0.1
    adamw = False
    adamw_decay = 0.024
    one_cycle_max_lr = float(os.environ.get('lr', '1e-5')) 
    epochs = int(os.environ.get('epochs', 10)) 
    model_type = os.environ.get('model', 'timm/convnext_large_mlp.clip_laion2b_augreg_ft_in1k_384')
                                # 'timm/convnext_large_mlp.clip_laion2b_augreg_ft_in1k')
                                # 'convnextv2_huge.fcmae_ft_in22k_in1k_512') # 'convnextv2_large')
    # 'convnext_large_384_in22ft1k') # 'tf_efficientnetv2_m') # 'seresnext50_32x4d') convnext_xlarge_in22k
    dropout = float(os.environ.get('dropout', 0.2))
    AUG = False 
    batch_size = 8

    # Switchers
    debug = False
    wandb_set = False
    wandb_sweep = False
    train = True
    oof = True

    seed = 42
    n_folds = 5
    folds = np.array(os.environ.get('FOLDS', '0,1,2,3,4').split(',')).astype(int)

    TRAIN_IMAGES_PATH = f'{data_dir}input/train/'
    TEST_IMAGES_PATH =  f'{data_dir}input/test/'
    TRAIN_CSV =  f'{data_dir}input/train.csv'
    TEST_CSV =  f'{data_dir}input/sample_submission.csv'
    
    models = ['model-f0', 'model-f1', 'model-f2', 'model-f3' , 'model-f4']
    models = [f'{OUTPUT_DIR}{m}' for m in models]

    predict_max_batches = 1e9
    max_eval_batches = 4000
    num_workers = os.cpu_count()
    img_size = (768,768) #(640,640)


wandb_run_name = f'{CFG.model_type}_lr{CFG.one_cycle_max_lr}_ep{CFG.epochs}_bs{CFG.batch_size}_{"adamw" if CFG.adamw else "adam"}_{"aug" if CFG.AUG else "noaug"}_drop{CFG.dropout}'
print('run', wandb_run_name, 'folds', CFG.folds)


DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

run timm/convnext_large_mlp.clip_laion2b_augreg_ft_in1k_384_lr1e-05_ep10_bs8_adam_noaug_drop0.2 folds [0 1 2 3 4]


In [5]:
df_train = pd.read_csv(CFG.TRAIN_CSV) #.head(10000)
display(df_train.head())
df_train.label.value_counts()


Unnamed: 0,id,label
0,0.jpg,1
1,1.jpg,1
2,2.jpg,1
3,3.jpg,0
4,4.jpg,1


1    10330
0     8288
Name: label, dtype: int64

In [6]:
from sklearn.model_selection import StratifiedKFold, KFold
import numpy as np
skf = StratifiedKFold(n_splits=5,shuffle=True,random_state=42)
X,y = df_train['id'], df_train['label']
df_train['kfold'] = -1
for fold, (train, test) in enumerate(skf.split(X,y)):
    df_train.loc[test,'kfold'] = fold
    
df_train.to_csv('train_folds.csv',index=False)
display(pd.crosstab(df_train.kfold,df_train.label))
df_train.kfold.value_counts()

label,0,1
kfold,Unnamed: 1_level_1,Unnamed: 2_level_1
0,1658,2066
1,1658,2066
2,1658,2066
3,1657,2066
4,1657,2066


2    3724
0    3724
1    3724
4    3723
3    3723
Name: kfold, dtype: int64

In [7]:
import torchvision

def get_transforms(aug=False):
    """
    
    """
    #todo: normalize using OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)
    # OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)
    
    def transforms(img):
        img = img.convert('RGB')#.resize((512, 512))
        if aug:
            tfm = [
                torchvision.transforms.RandomHorizontalFlip(0.5),
                torchvision.transforms.RandomRotation(degrees=(-5, 5)), 
                torchvision.transforms.RandomResizedCrop(CFG.img_size, scale=(0.8, 1), ratio=(0.45, 0.55)) 
            ]
        else:
            tfm = [
                # torchvision.transforms.RandomHorizontalFlip(0.5),
                torchvision.transforms.Resize(CFG.img_size)
            ]
        img = torchvision.transforms.Compose(tfm + [            
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]),
            
        ])(img)
        return img

    return lambda img: transforms(img)

if CFG.debug:
    tfm = get_transforms(aug=True)
    img = Image.open(f"{CFG.TRAIN_IMAGES_PATH}10000.jpg")
    print(img.size)
    plt.imshow(np.array(img), cmap='gray')
    plt.show()

    plt.figure(figsize=(20, 20))
    for i in range(8):
        v = tfm(img).permute(1, 2, 0)
        v -= v.min()
        v /= v.max()
        # plt.imshow(v)
        # break
        plt.subplot(2, 4, i + 1).imshow(v)
    plt.tight_layout()

In [8]:
class AIOrNotDataset(torch.utils.data.Dataset):
    def __init__(self, df, path, transforms=None):
        super().__init__()
        self.df = df
        self.path = path
        self.transforms = transforms
        
    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        img_path = f'{self.path}{self.df.loc[idx, "id"]}'
        image = Image.open(img_path)
        label = self.df.loc[idx, "label"] 
        
        if self.transforms is not None:
            image = self.transforms(image)
        
        return {"img_path": img_path, "image": image, "label": label}
    
if CFG.debug:
    fold = 4
    ds_train = AIOrNotDataset(df_train[df_train.kfold!=fold].reset_index(drop=True), CFG.TRAIN_IMAGES_PATH, get_transforms(True))
    print(df_train[df_train.kfold!=fold].shape)
    data = ds_train[1]
    print(data['img_path'], data['image'].shape, data['label'])
    dl_train = torch.utils.data.DataLoader(ds_train, 
                                            batch_size=CFG.batch_size, 
                                            shuffle=False, 
                                            num_workers=1, #CFG.num_workers, 
                                            pin_memory=True, 
                                            drop_last=True)
    print(len(ds_train), len(dl_train))
    for idx, batch in enumerate(dl_train):
        print(idx, len(batch),len(batch['img_path']), batch['image'].shape, batch['label'].shape)
        if idx>3:
            break

In [9]:
class AIdentifierModel(torch.nn.Module):
    def __init__(self,model_type):
        super().__init__()
        self.model = create_model(model_type, pretrained=True, num_classes=1)
        self.backbone_dim = self.model(torch.randn(1, 3, 512, 512)).shape[-1]

        self.model.fc = torch.nn.Sequential(
            torch.nn.Linear(self.backbone_dim, 1),
        )
        
    def forward(self, x):
        # returns logits
        x = self.model(x)
        x = self.model.fc(x)
        return x
    
    def predict(self, x):
        x = self.forward(x)
        return torch.sigmoid(x)  
    

if CFG.debug:
    with torch.no_grad():
        model = AIdentifierModel(model_type=CFG.model_type) #'efficientnet_b2') #'seresnext50_32x4d')
        pred = model.predict(torch.randn(2, 3, 512, 512))
        print('efficientnet_b2', pred.shape,pred,)

    del model

In [10]:
def save_model(name, model,  model_type):
    torch.save({'model': model.state_dict(),  'model_type': model_type}, f'{name}')
    
def load_model(name, dir='.', model=None):
    data = torch.load(os.path.join(dir, f'{name}'), map_location=DEVICE)
    if model is None:
        model = BreastCancerModel(AUX_TARGET_NCLASSES, data['model_type'])
    model.load_state_dict(data['model'])
    return model,  data['model_type']


if CFG.debug:
    # quick test
    model = torch.nn.Linear(2, 1)
    save_model(f'{OUTPUT_DIR}testmodel', model, thres=0.123, model_type='abc')

    model1, thres, model_type = load_model(f'{OUTPUT_DIR}testmodel', model=torch.nn.Linear(2, 1))
    assert torch.all(
        next(iter(model1.parameters())) == next(iter(model.parameters()))
    ).item(), "Loading/saving is inconsistent!"
    print(thres, model_type)

    

In [11]:

from sklearn.metrics import accuracy_score, f1_score, log_loss

def optimal_metric(truth, preds):
    f1 = f1_score(truth, preds>0.5, average="macro")
    acc = accuracy_score(truth, preds>0.5)
    logloss = log_loss(truth, preds)
    return {"accuracy": acc, "f1": f1, "log_loss": logloss}



In [12]:
"""# Random seed"""

def seed_everything(seed, use_cuda = True):
    np.random.seed(seed) # cpu vars
    torch.manual_seed(seed) # cpu  vars
    random.seed(seed) # Python
    os.environ['PYTHONHASHSEED'] = str(seed) # Python hash building
    if use_cuda:
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed) # gpu vars
        torch.backends.cudnn.deterministic = True  #needed
        torch.backends.cudnn.benchmark = False


def get_logger(filename=OUTPUT_DIR+'train'):
    from logging import getLogger, INFO, StreamHandler, FileHandler, Formatter
    logger = getLogger(__name__)
    logger.setLevel(INFO)
    handler1 = StreamHandler()
    handler1.setFormatter(Formatter("%(message)s"))
    handler2 = FileHandler(filename=f"{filename}.log")
    handler2.setFormatter(Formatter("%(message)s"))
    logger.addHandler(handler1)
    logger.addHandler(handler2)
    return logger

LOGGER = get_logger()

seed_everything(CFG.seed)

In [13]:
def evaluate_model(model: AIdentifierModel, ds, max_batches=CFG.predict_max_batches, shuffle=False, config=CFG):
    torch.manual_seed(42)
    model = model.to(DEVICE)
    dl_test = torch.utils.data.DataLoader(ds, batch_size=8, shuffle=shuffle, num_workers=CFG.num_workers, pin_memory=False)
    pred_aiornot = []
    
    with torch.no_grad():        
        model.eval()
        losses = []
        targets = []
        with tqdm(dl_test, desc='Eval', mininterval=30) as progress:
            for i, X in enumerate(progress):
                with autocast(enabled=True):
                    X_img = X['image'].to(DEVICE)
                    y_true = X['label']
                    y_pred = model.forward(X_img).squeeze()
                
                    loss = torch.nn.functional.binary_cross_entropy_with_logits(
                        y_pred, 
                        y_true.to(float).to(DEVICE),
                    ).item()
                    pred_aiornot.append(y_pred)
                    losses.append(loss)
                    targets.append(y_true.cpu().numpy())
                if i >= max_batches:
                    break
        targets = np.concatenate(targets)
        pred = torch.nan_to_num(torch.concat(pred_aiornot)).cpu().numpy()
        metric_dict = optimal_metric(targets, pred) #{"accuracy": acc, "f1": f1, "log_loss": logloss}
        return np.mean(losses), pred, metric_dict['accuracy'], metric_dict['f1'], metric_dict['log_loss']
        
# quick test
if CFG.debug:
    m = AIdentifierModel(model_type=CFG.model_type)
    closs, pred, acc, f1, logloss = evaluate_model(m, ds_train, max_batches=2)
    del m

In [14]:
def train_model(ds_train, ds_eval, logger, name, config=CFG, do_save_model=True):
    # torch.manual_seed(42)

    seed_everything(CFG.seed)
    
    dl_train = torch.utils.data.DataLoader(ds_train, 
                                           batch_size=CFG.batch_size,
                                           shuffle=False, #True, 
                                           num_workers=CFG.num_workers, 
                                           pin_memory=True, 
                                           drop_last=True)

    model = AIdentifierModel(CFG.model_type).to(DEVICE)
    # model = torch.compile(model)

    if CFG.adamw:
        optim = torch.optim.AdamW(add_weight_decay(model, weight_decay=CFG.adamw, skip_list=['bias']), lr=CFG.one_cycle_max_lr, \
            betas=(0.9, 0.999), weight_decay=CFG.adamw_decay)
    else:
        optim = torch.optim.Adam(model.parameters())


    scheduler = None
    if CFG.one_cycle:
        scheduler = torch.optim.lr_scheduler.OneCycleLR(optim, max_lr=CFG.one_cycle_max_lr, epochs=CFG.epochs,
                                                                steps_per_epoch=len(dl_train), 
                                                                pct_start=CFG.one_cycle_pct_start) ##s_per_epoch, # 1385 len(dl_train)*2,
    
    scaler = GradScaler()
    best_eval_score = 10000
    for epoch in tqdm(range(CFG.epochs), desc='Epoch'):

        model.train()
        with tqdm(dl_train, desc='Train', mininterval=30) as progress:
            for batch_idx, X in enumerate(progress):
                optim.zero_grad()
                # Using mixed precision training
                with autocast():
                    X_img = X['image'].to(DEVICE)
                    y_true = X['label']
                    y_pred = model.forward(X_img).squeeze()
                    
                    loss = torch.nn.functional.binary_cross_entropy_with_logits(
                        y_pred,
                        y_true.to(float).to(DEVICE)
                    )

                # scaler is needed to prevent "gradient underflow"
                scaler.scale(loss).backward()
                scaler.step(optim)
                if scheduler is not None:
                    scheduler.step()
                    
                scaler.update()

                lr = scheduler.get_last_lr()[0] if scheduler else CFG.one_cycle_max_lr
                if CFG.wandb_set :
                    logger.log({'loss': (loss.item()),
                                'lr': lr,
                                'epoch': epoch})
                else:
                     if batch_idx%100==0:
                        logger.info({'loss': (loss.item()),
                                'lr': lr,
                                'batch':batch_idx,
                                'epoch': epoch})

        if ds_eval is not None and  CFG.max_eval_batches > 0:
            closs, pred, acc, f1, logloss_v = evaluate_model(
                model, ds_eval, max_batches=CFG.max_eval_batches, shuffle=False, config=config)
            
            if logloss_v < best_eval_score:
                best_eval_score = logloss_v
                if do_save_model:
                    save_model(name, model,  CFG.model_type)
                    
            logger.info(
            {
                'eval_f1': f1,
                'max_eval_f1': best_eval_score, #this is actually log loss
                'eval_loss': loss.item(),
                'epoch': epoch
            }
            )

    return model


# N-fold models. Can be used to estimate accurate CV score and in ensembled submissions.
if CFG.train:
    for fold in CFG.folds:
        name = f'{wandb_run_name}-f{fold}'

        LOGGER.info(f"========== fold: {fold} training ==========")
        gc.collect()
        ds_train = AIOrNotDataset(df_train.query('kfold != @fold').reset_index(drop=True), CFG.TRAIN_IMAGES_PATH, get_transforms(aug=CFG.AUG))
        ds_eval = AIOrNotDataset(df_train.query('kfold == @fold').reset_index(drop=True), CFG.TRAIN_IMAGES_PATH, get_transforms(aug=False))
        train_model(ds_train, ds_eval, LOGGER, f'{OUTPUT_DIR}model-f{fold}') 
        torch.cuda.empty_cache()
        gc.collect()

{'loss': 0.6692510209977627, 'lr': 4.000068467582758e-07, 'batch': 0, 'epoch': 0}
{'loss': 0.711988378316164, 'lr': 4.696745815798406e-07, 'batch': 100, 'epoch': 0}
{'loss': 0.41743691358715296, 'lr': 6.739693142648165e-07, 'batch': 200, 'epoch': 0}
{'loss': 0.5230465875938535, 'lr': 1.0070767330887727e-06, 'batch': 300, 'epoch': 0}
{'loss': 0.1586799727519974, 'lr': 1.4595164645476667e-06, 'batch': 400, 'epoch': 0}
{'loss': 0.10851336817722768, 'lr': 2.0184118886687477e-06, 'batch': 500, 'epoch': 0}
{'loss': 0.07327374203305226, 'lr': 2.6678566128862896e-06, 'batch': 600, 'epoch': 0}
{'loss': 0.09383594582322985, 'lr': 3.3893671744856208e-06, 'batch': 700, 'epoch': 0}
{'loss': 0.10379178395669442, 'lr': 4.1624090875185584e-06, 'batch': 800, 'epoch': 0}
{'loss': 0.041370889739482664, 'lr': 4.964981262679721e-06, 'batch': 900, 'epoch': 0}
{'loss': 0.08407772698410554, 'lr': 5.774242167314729e-06, 'batch': 1000, 'epoch': 0}
{'loss': 0.0035717234895855654, 'lr': 6.567159904790218e-06, 'ba

In [15]:
def evaluate_model(model: AIdentifierModel, ds, max_batches=CFG.predict_max_batches, shuffle=False, config=CFG):
    torch.manual_seed(42)
    model = model.to(DEVICE)
    dl_test = torch.utils.data.DataLoader(ds, batch_size=8, shuffle=shuffle, num_workers=CFG.num_workers, pin_memory=False)
    pred_aiornot = []
    
    with torch.no_grad():        
        model.eval()
        losses = []
        targets = []
        with tqdm(dl_test, desc='Eval', mininterval=30) as progress:
            for i, X in enumerate(progress):
                with autocast(enabled=True):
                    X_img = X['image'].to(DEVICE)
                    y_true = X['label']
                    y_pred = model.predict(X_img).squeeze()
                
                    loss = torch.nn.functional.binary_cross_entropy_with_logits(
                        y_pred, 
                        y_true.to(float).to(DEVICE),
                    ).item()
                    pred_aiornot.append(y_pred)
                    losses.append(loss)
                    targets.append(y_true.cpu().numpy())
                if i >= max_batches:
                    break
        targets = np.concatenate(targets)
        pred = torch.nan_to_num(torch.concat(pred_aiornot)).cpu().numpy()
        metric_dict = optimal_metric(targets, pred) #{"accuracy": acc, "f1": f1, "log_loss": logloss}
        return np.mean(losses), pred, metric_dict['accuracy'], metric_dict['f1'], metric_dict['log_loss']

In [16]:
### Generate OOF data and score 
def gen_predictions(models, df_train):
    df_train_predictions = []
    with tqdm(enumerate(models), total=len(models), desc='Folds') as progress:
        for fold, model in progress:
            if model is not None:
                ds_eval = AIOrNotDataset(df_train.query('kfold == @fold').reset_index(drop=True), CFG.TRAIN_IMAGES_PATH, get_transforms(aug=False))
                closs, pred, acc, f1, logloss = evaluate_model(
                model, ds_eval, max_batches=CFG.max_eval_batches, shuffle=False, config=CFG)

                progress.set_description(f'Eval fold:{fold} pF1:{f1:.02f}')
                df_pred = pd.DataFrame(data=pred,
                                              columns=['aiornot_pred_proba'])
                
                df = pd.concat(
                        [df_train.query('kfold == @fold').reset_index(drop=True), df_pred],
                        axis=1
                    )
                df_train_predictions.append(df)
    df_train_predictions = pd.concat(df_train_predictions)
    return df_train_predictions



if CFG.oof:
    models = [load_model(model, '', AIdentifierModel(CFG.model_type))[0] for model in CFG.models]
    df_pred = gen_predictions(models, df_train)
    df_pred.to_csv(f'{OUTPUT_DIR}oof_df.csv', index=False)
    display(df_pred)
    print('logloss:', log_loss(df_pred.label, df_pred.aiornot_pred_proba))
    print('f1_score:',f1_score(df_pred.label, df_pred.aiornot_pred_proba>0.5))
    print('accuracy:',accuracy_score(df_pred.label, df_pred.aiornot_pred_proba>0.5))

Eval: 100%|██████████| 466/466 [01:05<00:00,  7.11it/s]
Eval: 100%|██████████| 466/466 [01:06<00:00,  7.05it/s], 65.85s/it]
Eval: 100%|██████████| 466/466 [01:06<00:00,  7.05it/s], 66.18s/it]
Eval: 100%|██████████| 466/466 [01:06<00:00,  7.03it/s], 66.27s/it]
Eval: 100%|██████████| 466/466 [01:07<00:00,  6.88it/s], 66.39s/it]
Eval fold:4 pF1:1.00: 100%|██████████| 5/5 [05:33<00:00, 66.65s/it]


Unnamed: 0,id,label,kfold,aiornot_pred_proba
0,2.jpg,1,0,1.000000e+00
1,5.jpg,1,0,1.000000e+00
2,8.jpg,1,0,1.000000e+00
3,9.jpg,1,0,1.000000e+00
4,16.jpg,1,0,1.000000e+00
...,...,...,...,...
3718,18579.jpg,1,4,1.000000e+00
3719,18581.jpg,0,4,9.536743e-07
3720,18583.jpg,1,4,1.000000e+00
3721,18601.jpg,1,4,1.000000e+00


logloss: 0.0153137135536813
f1_score: 0.9957960860111138
accuracy: 0.9953271028037384


### Inference

In [17]:
df_test = pd.read_csv(CFG.TEST_CSV)

In [18]:
import torchvision
from PIL import Image

def get_transforms(aug=False):
    """
    # old transforms
    create_transform(
        (1024, 512), 
        mean=0.53, #(0.53, 0.53, 0.53),
        std=0.23, #(0.23, 0.23, 0.23),
        is_training=is_training, 
        auto_augment=f'rand-m{config.AUTO_AUG_M}-n{config.AUTO_AUG_N}'
    )
    """
    def transforms(img):
        img = img.convert('RGB')#.resize((512, 512))
        if aug:
            tfm = [
                torchvision.transforms.RandomHorizontalFlip(0.5),
                torchvision.transforms.RandomRotation(degrees=(-5, 5)), 
                torchvision.transforms.RandomResizedCrop(CFG.img_size, scale=(0.8, 1), ratio=(0.45, 0.55)) 
            ]
        else:
            tfm = [
                # torchvision.transforms.RandomHorizontalFlip(0.5),
                torchvision.transforms.Resize(CFG.img_size)
            ]
        img = torchvision.transforms.Compose(tfm + [            
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]),
            
        ])(img)
        return img

    return lambda img: transforms(img)

In [19]:
class AIOrNotDataset(torch.utils.data.Dataset):
    def __init__(self, df, path, transforms=None):
        super().__init__()
        self.df = df
        self.path = path
        self.transforms = transforms
        
    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        img_path = f'{self.path}{self.df.loc[idx, "id"]}'
        image = Image.open(img_path)
        
        if self.transforms is not None:
            image = self.transforms(image)
        
        return {"img_path": img_path, "image": image}

In [20]:
class AIdentifierModel(torch.nn.Module):
    def __init__(self,model_type):
        super().__init__()
        self.model = create_model(model_type, pretrained=True, num_classes=1)
        self.backbone_dim = self.model(torch.randn(1, 3, 512, 512)).shape[-1]

        self.model.fc = torch.nn.Sequential(
            torch.nn.Linear(self.backbone_dim, 1),
        )

    def forward(self, x):
        # returns logits
        x = self.model(x)
        x = self.model.fc(x)
        return x
    
    def predict(self, x):
        x = self.forward(x)
        return torch.sigmoid(x) #.squeeze(-1)
    

    

In [21]:
MODELS_PATH = glob.glob(f'{OUTPUT_DIR}/model*') 
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
BATCH_SIZE = 64
def load_model(model_path):
    state_dict = torch.load(model_path, map_location=DEVICE)
    
    model = AIdentifierModel(state_dict['model_type'])
    model.load_state_dict(state_dict['model'])
    return model,  state_dict['model_type']

In [22]:
models = []
for fname in tqdm(MODELS_PATH):
    model, model_type = load_model(fname)
    model = model.to(DEVICE)
    models.append(model)

100%|██████████| 5/5 [00:14<00:00,  2.82s/it]


In [25]:
def models_predict(models, ds, max_batches=1e9):
    dl_test = torch.utils.data.DataLoader(ds, batch_size=24, shuffle=False, num_workers=os.cpu_count())
    for m in models:
        m.eval()

    with torch.no_grad():
        predictions = []
        for idx, X in enumerate(tqdm(dl_test, mininterval=30)):
            pred = torch.zeros(X['image'].shape[0], len(models))
            for idx, m in enumerate(models):
                preds = m.predict(X['image'].to(DEVICE)).squeeze()
                pred[:, idx] = preds.cpu()
            predictions.append(pred.mean(dim=-1))
            
            if idx >= max_batches:
                break
        return torch.concat(predictions).numpy()

In [None]:
ds_test = AIOrNotDataset(df_test,CFG.TEST_IMAGES_PATH,  get_transforms(False))
models_pred = models_predict(models, ds_test)

In [None]:
print(df_test.shape, models_pred.shape)
df_test['label'] = models_pred
df_test.to_csv(f'{OUTPUT_DIR}submission_{pd.Timestamp.now().strftime("%Y-%m-%d %X")}.csv',index=False)

(43442, 2) (43442,)
