In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import pandas as pd
from torch.utils.data import DataLoader, Dataset
import albumentations
from transformers import *
import geffnet
import cv2
from tqdm import tqdm

In [3]:
from utils import *

In [4]:
train = pd.read_csv(PATH/'train.csv')
tmp = train.groupby('label_group').posting_id.agg('unique').to_dict()
train['target'] = train.label_group.map(tmp)
print('train shape is', train.shape )
train.head()

train shape is (34250, 6)


Unnamed: 0,posting_id,image,image_phash,title,label_group,target
0,train_129225211,0000a68812bc7e98c42888dfb1c07da0.jpg,94974f937d4c2433,Paper Bag Victoria Secret,249114794,"[train_129225211, train_2278313361]"
1,train_3386243561,00039780dfc94d01db8676fe789ecd05.jpg,af3f9460c2838f0f,"Double Tape 3M VHB 12 mm x 4,5 m ORIGINAL / DOUBLE FOAM TAPE",2937985045,"[train_3386243561, train_3423213080]"
2,train_2288590299,000a190fdd715a2a36faed16e2c65df7.jpg,b94cb00ed3e50f78,Maling TTS Canned Pork Luncheon Meat 397 gr,2395904891,"[train_2288590299, train_3803689425]"
3,train_2406599165,00117e4fc239b1b641ff08340b429633.jpg,8514fc58eafea283,Daster Batik Lengan pendek - Motif Acak / Campur - Leher Kancing (DPT001-00) Batik karakter Alhadi,4093212188,"[train_2406599165, train_3342059966]"
4,train_3369186413,00136d1cf4edede0203f32f05f660588.jpg,a6f319f924ad708c,Nescafe \xc3\x89clair Latte 220ml,3648931069,"[train_3369186413, train_921438619]"


In [5]:

def f1score(row, col):
    n = len( np.intersect1d(row.target,row[col]) )
    return 2*n / (len(row.target)+len(row[col]))


## B0+Bert model

In [6]:
def get_transforms(img_size=256):
    return  albumentations.Compose([
                albumentations.Resize(img_size, img_size),
                albumentations.Normalize()
            ])

class ImageTextDataset(Dataset):
    def __init__(self, csv, transforms=get_transforms(img_size=256), tokenizer=None):

        self.csv = csv.reset_index()
        self.transform = transforms
        self.tokenizer = tokenizer

    def __len__(self):
        return self.csv.shape[0]

    def __getitem__(self, index):
        row = self.csv.iloc[index]
        
        text = row.title
        
        image = cv2.imread(row.filepath)
        image = image[:, :, ::-1]
        
        res0 = self.transform(image=image)
        image0 = res0['image'].astype(np.float32)
        image = image0.transpose(2, 0, 1)        

        text = self.tokenizer(text, padding='max_length', truncation=True, max_length=16, return_tensors="pt")
        input_ids = text['input_ids'][0]
        attention_mask = text['attention_mask'][0]

        return torch.tensor(image), input_ids, attention_mask
        

In [7]:
tokenizer = AutoTokenizer.from_pretrained(PATH/'bert-base-uncased')

In [8]:
df_test = train.copy()
df_test['filepath'] = df_test['image'].apply(lambda x: str(PATH/'train_images'/x))
dataset_test = ImageTextDataset(df_test, transforms=get_transforms(img_size=256), tokenizer=tokenizer)
test_loader = DataLoader(dataset_test, batch_size=64, num_workers=16)

print(len(dataset_test),dataset_test[0])

34250 (tensor([[[ 0.6563,  0.4508,  0.4679,  ...,  0.3309,  0.3652,  0.4508],
         [ 0.0741,  0.0227,  0.0227,  ...,  0.3652,  0.3994,  0.3994],
         [ 0.4337,  0.5193,  0.5022,  ..., -0.1657,  0.1083,  0.3309],
         ...,
         [ 1.1529,  0.6049,  0.7419,  ...,  1.1700,  1.1872,  1.1872],
         [ 0.7077,  0.8104,  0.7762,  ...,  1.1187,  1.3242,  1.1187],
         [ 0.5193,  0.8789,  0.8104,  ...,  1.0844,  1.3242,  1.1015]],

        [[-0.4251, -0.6352, -0.6176,  ..., -0.3901, -0.5126, -0.5476],
         [-1.0203, -1.0728, -1.0728,  ..., -0.2500, -0.3725, -0.4951],
         [-0.6527, -0.5651, -0.5826,  ..., -0.4776, -0.3725, -0.2850],
         ...,
         [ 0.2577, -0.4951, -0.2850,  ..., -0.0399,  0.3102,  0.1176],
         [-0.3025, -0.2500, -0.1625,  ..., -0.1450,  0.3452,  0.0301],
         [-0.4601, -0.0924, -0.1275,  ..., -0.1975,  0.2752,  0.0126]],

        [[-0.0267, -0.2358, -0.2184,  ...,  0.0431, -0.0615, -0.1138],
         [-0.6193, -0.6715, -0.6715,  

In [9]:
class ArcMarginProduct_subcenter(nn.Module):
    def __init__(self, in_features, out_features, k=3):
        super().__init__()
        self.weight = nn.Parameter(torch.FloatTensor(out_features*k, in_features))
        self.k = k
        self.out_features = out_features

    def forward(self, features):
        cosine_all = F.linear(F.normalize(features), F.normalize(self.weight))
        cosine_all = cosine_all.view(-1, self.out_features, self.k)
        cosine, _ = torch.max(cosine_all, dim=2)
        return cosine 
    
sigmoid = torch.nn.Sigmoid()

class Swish(torch.autograd.Function):
    @staticmethod
    def forward(ctx, i):
        result = i * sigmoid(i)
        ctx.save_for_backward(i)
        return result
    @staticmethod
    def backward(ctx, grad_output):
        i = ctx.saved_variables[0]
        sigmoid_i = sigmoid(i)
        return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i)))

class Swish_module(nn.Module):
    def forward(self, x):
        return Swish.apply(x)

    
 
    
class enet_arcface_FINAL(nn.Module):

    def __init__(self, enet_type, out_dim):
        super(enet_arcface_FINAL, self).__init__()
        self.bert = AutoModel.from_pretrained(PATH/'bert-base-uncased')
        self.enet = geffnet.create_model(enet_type.replace('-', '_'), pretrained=None)
        self.feat = nn.Linear(self.enet.classifier.in_features+self.bert.config.hidden_size, 512)
        self.swish = Swish_module()
        self.dropout = nn.Dropout(0.5)
        self.metric_classify = ArcMarginProduct_subcenter(512, out_dim)
        self.enet.classifier = nn.Identity()
 
    def forward(self, x,input_ids, attention_mask):
        x = self.enet(x)
        text = self.bert(input_ids=input_ids, attention_mask=attention_mask)[1]
        x = torch.cat([x, text], 1)
        x = self.swish(self.feat(x))
        return F.normalize(x), self.metric_classify(x)
    
def load_model(model, model_file):
    state_dict = torch.load(model_file)
    if "model_state_dict" in state_dict.keys():
        state_dict = state_dict["model_state_dict"]
    state_dict = {k[7:] if k.startswith('module.') else k: state_dict[k] for k in state_dict.keys()}
#     del state_dict['metric_classify.weight']
    model.load_state_dict(state_dict, strict=True)
    print(f"loaded {model_file}")
    model.eval()    
    return model

In [10]:
WGT = PATH/'b0ns_256_bert_20ep_fold0_epoch27.pth'

In [11]:
model = enet_arcface_FINAL('tf_efficientnet_b0_ns', out_dim=11014).cuda()
model = load_model(model, WGT)


loaded /home/slex/data/shopee/b0ns_256_bert_20ep_fold0_epoch27.pth


In [12]:
%%time
embeds = []

with torch.no_grad():
    for img, input_ids, attention_mask in tqdm(test_loader): 
        img, input_ids, attention_mask = img.cuda(), input_ids.cuda(), attention_mask.cuda()
        feat, _ = model(img, input_ids, attention_mask)
        image_embeddings = feat.half()
        embeds.append(image_embeddings)

100%|██████████| 536/536 [00:34<00:00, 15.54it/s]

CPU times: user 27.5 s, sys: 4.83 s, total: 32.4 s
Wall time: 34.5 s





In [13]:
image_embeddings = torch.cat(embeds)
print('image embeddings shape',image_embeddings.shape)

image embeddings shape torch.Size([34250, 512])


In [14]:
%%time
preds=[[] for _ in range(len(df_test))]

CHUNK=10000
for start in range(0, len(df_test), CHUNK):
    cos_sim = image_embeddings[start:start+CHUNK] @ image_embeddings.T
    idxa, idxb =torch.where(cos_sim>.5)
    dfb=df_test.iloc[idxb.cpu()].posting_id.values
    for a,b in zip(idxa, dfb):
        preds[a+start].append(b)

CPU times: user 3.31 s, sys: 20.1 ms, total: 3.33 s
Wall time: 3.33 s


In [15]:
df_test['preds_b0bert']=preds

df_test['b0bert_score'] = df_test.apply(functools.partial(f1score, col='preds_b0bert'),axis=1)
print('CV score for baseline =',df_test.b0bert_score.mean())

CV score for baseline = 0.9088021070935636


In [27]:
preds = [' '.join(p) for p in preds]
preds[:5]

['train_129225211 train_2278313361',
 'train_3386243561 train_3423213080',
 'train_2288590299 train_3803689425',
 'train_2406599165',
 'train_3369186413 train_921438619']

In [26]:
' '.join(preds[0])

'train_129225211 train_2278313361'