In [2]:
import os
import time
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import pandas as pd
import numpy as np
from collections import OrderedDict

from model import ISLRModelV6, ISLRModelArcFaceCE
from dataset import ISLRDataSetV2, collate_func

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "2"

In [31]:
cur_fold = 4
val_idx = np.load(os.path.join('/sources/dataset', "cv", f"val_idx_f{cur_fold}.npy"))
dataset = ISLRDataSetV2(
            max_len=64,
            ver='v0_93',
            indicies=val_idx,
            random_noise=False,
            flip_x=False,
            flip_x_v2=False,
            rotate=False,
            drop_lm=False,
            interpolate=False,
        )
dataloader = torch.utils.data.DataLoader(
            dataset,
            batch_size=64,
            shuffle=False,
            num_workers=4,
            pin_memory=False,
            collate_fn=collate_func,
        )


In [32]:
model = ISLRModelArcFaceCE(
                embed_dim=256,
                n_head=4,
                ff_dim=256,
                dropout=0.2,
                cls_dropout=0.2,
                max_len=64,
                n_layers=5,
                input_dim=1194,
                s=32.0,
                m=0.2,
                k=3)

trained_state_dict = torch.load(f'/sources/ckpts/test-arcface-ce-sd-ab1002-m02/2023-04-28T00-44-27/{cur_fold}/best.pth.tar')["state_dict"]
new_state_dict = OrderedDict()
for k, v in trained_state_dict.items():
    name = k
    if "module" in name:
        name = name.replace("module.", "")
    new_state_dict[name] = v
model.load_state_dict(new_state_dict)
model.eval()

try:
    model = nn.DataParallel(model).cuda()
except:
    model = model.cuda()
cudnn.benchmark = True


In [33]:
model.eval()
preds = []
labels = []

with torch.no_grad():
    for idx, batch in tqdm(enumerate(dataloader)):
        y = batch['label']
        output = model(batch)
        logit = 0.84 * output[0] + 0.16 * output[1]
        output = F.softmax(logit, dim=-1)

        preds.extend(output.detach().tolist())
        labels.extend(y.tolist())

280it [01:16,  3.64it/s]


In [34]:
len(preds)

17884

In [35]:
np.save(f'/sources/dataset/stack/arcface/X_fold{cur_fold}', preds)
np.save(f'/sources/dataset/stack/arcface/y_fold{cur_fold}', labels)

In [36]:
len(labels)

17884

In [38]:
len(preds[0])

250