In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from fastai.vision.all import *
import sklearn.metrics as skm
from tqdm.notebook import tqdm
import sklearn.feature_extraction.text
from transformers import (BertTokenizer, BertModel,
                          AutoConfig, AutoModel)
import gc
import codecs
import sklearn.feature_extraction.text

In [None]:
import sys
sys.path.append('../input/timm-pytorch-image-models/pytorch-image-models-master')
import timm

In [None]:
from shopee_utils import *

In [None]:
OUTPUT_CLASSES=11014
RECIPROCAL_PER_ROW = 4.2
CLUSTER_BREAK = .29

In [None]:
TRIAL_RUN=False

In [None]:
def eca_nfnet_l0(pretrained): return timm.create_model("eca_nfnet_l0", pretrained = pretrained)
def eca_nfnet_l1(pretrained): return timm.create_model("eca_nfnet_l1", pretrained = pretrained)

In [None]:
text_model_name='bert_large_novalid.pth'
img_model_names=['nfnetl0_336_noval525.pth', 'nfnetl0_336_noval422.pth']
img_model_archs = [eca_nfnet_l1, eca_nfnet_l0]
img_models_dir=Path('../input/shopee-image-models')

PATH = Path('../input/shopee-product-matching')
text_models_path = Path('../input/shopee-models')
BERT_PATH = '../input/bertlarge-config'


image_model_files = [img_models_dir/model_name for model_name in img_model_names]
text_model_file = text_models_path / text_model_name

In [None]:

train_df = pd.read_csv(PATH/'train.csv')
train_df = add_splits(train_df)

In [None]:
def embs_from_models(models, dl):
    all_embs = [[] for _ in range(len(models))]
    all_ys=[] 
    for batch in tqdm(dl):
        if len(batch) ==2:
            bx,by=batch
        else:
            bx,=batch
            by=torch.zeros(1)
        with torch.no_grad():
            for i in range(len(models)):
                embs = models[i](bx.half())
                embs = F.normalize(embs)
                all_embs[i].append(embs.cpu())
        all_ys.append(by)
    for i in range(len(models)):
        all_embs[i] = torch.cat(all_embs[i])
    return all_embs, torch.cat(all_ys)

## TEXT

In [None]:
class EmbsModel(nn.Module):
    def __init__(self, bert_model):
        super().__init__()
        self.bert_model = bert_model
    def forward(self, x):
        output = self.bert_model(*x)
        return output.last_hidden_state[:,0,:]

In [None]:
def string_escape(s, encoding='utf-8'):
    return (
        s.encode('latin1')  # To bytes, required by 'unicode-escape'
        .decode('unicode-escape')  # Perform the actual octal-escaping decode
        .encode('latin1')  # 1:1 mapping back to bytes
        .decode(encoding)
    )  # Decode original encoding
class TitleTransform(Transform):
    def __init__(self):
        super().__init__()
        self.tokenizer = BertTokenizer.from_pretrained(BERT_PATH)
        
        
    def encodes(self, row):
        text = row.title
        text=string_escape(text)
        encodings = self.tokenizer(text, padding = 'max_length', max_length=100, truncation=True,return_tensors='pt')
        keys =['input_ids', 'attention_mask', 'token_type_ids'] 
        return tuple(encodings[key].squeeze() for key in keys)

In [None]:
def get_text_dls():
    tfm = TitleTransform()

    data_block = DataBlock(
        blocks = (TransformBlock(type_tfms=tfm), 
                  CategoryBlock(vocab=train_df.label_group.to_list())),
        splitter=ColSplitter(),
        get_y=ColReader('label_group'),
        )
    return  data_block.dataloaders(train_df, bs=256)


In [None]:
def generate_text_embs(dl):
    model = AutoModel.from_config(AutoConfig.from_pretrained(BERT_PATH))
    state = torch.load(text_model_file)
    model.load_state_dict(state)
    model = EmbsModel(model).cuda().eval()
    embs, ys = embs_from_model(model, dl)
    return embs.cpu(), ys

def generate_text_pairs(dl):
    embs, ys = generate_text_embs(dl)
    target_matrix = ys[:,None]==ys[None,:]
    groups = [torch.where(t)[0].tolist() for t in target_matrix]
    dists, inds = get_nearest(embs, do_chunk(embs))
    pairs = sorted_pairs(dists, inds)[:len(embs)*10]
    return pairs, groups

In [None]:
# pairs, groups = generate_text_pairs(text_dls.valid)
# _=build_from_pairs(pairs, groups, True)

## TFID


In [None]:
def csr_matrix_to_tensor(csr):
    coo = csr.tocoo()
    t = torch.sparse_coo_tensor([coo.row, coo.col], coo.data, csr.shape).cuda()
    return t

def get_tfid_embs(data, idxs):
    sk_model = sklearn.feature_extraction.text.TfidfVectorizer(stop_words='english', binary=True, max_features=25_000)
    text_embeddings =sk_model.fit_transform(data.title)
    text_embeddings =text_embeddings[idxs]
    return text_embeddings

def generate_tfid_D(text_embeddings, out):
    emb_size = text_embeddings.shape[0]
    
    sparse_embs = csr_matrix_to_tensor(text_embeddings)
    step = 100
    for chunk_start in range(0, emb_size, step):
        chunk_end = min(chunk_start+step, emb_size)
        chunk = text_embeddings[chunk_start:chunk_end]
        chunk = csr_matrix_to_tensor(chunk).to_dense()
        tmp = sparse_embs @ chunk.T
        tmp.clip_(0,1)
        out[chunk_start:chunk_end]=tmp.half().T

## IMAGE


In [None]:
class ResnetArcFace(nn.Module):
    def __init__(self, arch):
        super().__init__()
        self.body = create_body(arch, cut=-2, pretrained=False)
        nf = num_features_model(nn.Sequential(*self.body.children()))
        self.after_conv=nn.Sequential(
            AdaptiveConcatPool2d(),
            Flatten(),
            nn.BatchNorm1d(nf*2),
#             nn.Dropout(.25),
#             nn.Linear(nf*2,512),
#             nn.BatchNorm1d(512)
        )
        self.classifier = None#ArcFaceClassifier(512, OUTPUT_CLASSES)
        self.outputEmbs = False
    def forward(self, x):
        x = self.body(x)
        embeddings = self.after_conv(x)
        if self.outputEmbs:
            return embeddings
        return self.classifier(embeddings)

In [None]:
def get_img_file(row):
    img =row.image
    fn  = PATH/'train_images'/img
    if not fn.is_file():
        fn = PATH/'test_images'/img
    return fn

In [None]:
def get_image_dls(size, bs):
    data_block = DataBlock(blocks = (ImageBlock(), CategoryBlock(vocab=train_df.label_group.to_list())),
                 splitter=ColSplitter(),
                 get_y=ColReader('label_group'),
                 get_x=get_img_file,
                 item_tfms=Resize(size*2, resamples=(Image.BICUBIC,Image.BICUBIC)),
                 
                 batch_tfms=aug_transforms(size=size, min_scale=0.75)+[Normalize.from_stats(*imagenet_stats)],
                 )
    return data_block.dataloaders(train_df, bs=bs)


In [None]:
def load_image_model(fname, arch):
    state_dict = torch.load(fname)
    model = ResnetArcFace(arch)
    model.classifier=None
    model.load_state_dict(state_dict)
    model = model.eval().cuda().half()
    model.outputEmbs=True
    return model

In [None]:
def generate_image_embs(dl):
    models = [load_image_model(image_model_file, arch) for image_model_file, arch in zip(image_model_files, img_model_archs)]
    embs, ys = embs_from_models(models, dl)
    return embs, ys

In [None]:
# pairs, groups = generate_image_pairs(dls_image.valid)
# _=build_from_pairs(pairs, groups, True)

## Distances for combined top-K indices

In [None]:
def gen_sim_and(embs_list, res):
    emb_size = len(embs_list[0])
    step = 100
    cache = torch.empty((step, emb_size), device = 'cuda', dtype=embs_list[0].dtype)
    print('starting')
    for embs in embs_list:
        embs = embs.cuda()
        for chunk_start in range(0, emb_size, step):
            chunk_end = min(chunk_start+step, emb_size)
            chunk=embs[chunk_start: chunk_end]
            tmp = cache[:chunk_end-chunk_start]
            torch.matmul(chunk, embs.T, out = tmp)
            tmp.clip_(0,1)
            tmp.mul_(-1)
            tmp.add_(1)
            res[chunk_start:chunk_end].mul_(tmp)
    res.mul_(-1)
    res.add_(1)

## Helper code


In [None]:
def find_threshold(D):
    k=D.numel()-int(RECIPROCAL_PER_ROW*len(D))
    threshold=D.view(-1).kthvalue(k).values
    return threshold
def rerank(D):
    if len(D)<4: threshold =0
    else: 
        threshold = find_threshold(D[:20000])
        print("threshold", threshold)
    for i in range(len(D)):
        D[i]=reciprocal_probs(D, i,threshold)

In [None]:
def score_cluster(D, clust):
    c_idx = torch.nonzero(clust)
    return D[c_idx.view(-1,1),c_idx.view(1,-1)].mean()

def clusters_shape(clusters):
    bin_count = clusters.bincount()
    bin_count =bin_count[bin_count.nonzero()]
    hist=torch.histc(bin_count.float(), bins=51, min=1, max=51)
    hist/=hist.sum()
    return hist

def dist_to_edges(dist):
    res = []
    K = min(51, len(dist))
    for x in range(len(dist)):
        vals, ys = dist[x].topk(K)
        for v,y in zip(vals.tolist(),ys.tolist()):
            if x!= y: res.append((x,y,v))
    return sorted(res, key=lambda x: -x[2])

def clusters_to_groups(C):
    res =[]
    for i in range(len(C)):
        cc,=torch.where(C==C[i])
        res.append(cc.tolist())
    return res

def reciprocal_cluster(D, C,x, scaled=False, include_x=False):
    neighb = torch.where(C==C[x])
    if scaled:
        probs =D[x,neighb[0]]
        DP = probs[:,None] * D[neighb]
    else:
        DP = D[neighb]
    if include_x:
        DP = (DP.sum(dim=0) + D[x]) / (len(neighb[0])+1)
    else:
        DP = DP.mean(dim=0)
    return DP

In [None]:
def generate_clusters(DD, edges, cluster_min_score):
    clusters = torch.arange(0,len(DD)).cuda()
    for x, y, p in edges[:10*len(DD)]:
        A,B = clusters[x], clusters[y]
        if A != B:
            a_idxs = clusters==A
            b_idxs = clusters==B
            combined = a_idxs.logical_or(b_idxs)
            if combined.sum() > 51: continue
            c_s = score_cluster(DD, combined)
            if c_s < cluster_min_score: 
                continue

            clusters[b_idxs] = A.clone()
    c_shape = clusters_shape(clusters)
    print((clusters_shape(clusters)[2:]).sum())
    return clusters, (clusters_shape(clusters)[2:]).sum()


In [None]:
def bin_search_clusters(DD,edges, target_shape):
    a,b = 0.3,1
    for _ in range(9):
        c = (a+b)/2
        clusters, clusters_sum = generate_clusters(DD, edges, c)
        if clusters_sum < target_shape: b=c
        else: a=c
    return clusters, c

In [None]:
def get_group_probs(groups, D):
    group_probs=[]
    new_groups=[]
    for x in range(len(groups)):
        gr = groups[x]
        gr_probs = D[x][gr]
        with_prob =sorted(list(zip(gr, gr_probs)), key=lambda x: -x[1])
        new_groups.append([wp[0] for wp in with_prob])
        group_probs.append([wp[1]for wp in with_prob])
    return new_groups, group_probs

In [None]:
def edges_to_groups(edges, N):
    groups = [[] for i in range(N)]
    groups_p = [[] for _ in range(N)]
    for x,y,v in edges:
        if len(groups[x])>=51 or x==y: continue
        groups[x].append(y)
        groups_p[x].append(v)
    return groups, groups_p

## Check on validation set

In [None]:
def show_groups(groups, targets):
    groups_lens = [len(g)for g in groups]
    targets_lens = [len(g) for g in targets]
    plt.figure(figsize=(8,8)) 
    plt.hist((groups_lens,targets_lens) ,bins=list(range(1,52)), label=['preds', 'targets'])
    plt.legend()
    plt.title(f'score: {score_all_groups(groups, targets):.3f}')
    plt.show()   

In [None]:
def pipeline(D):
    rerank(D)
    edges = dist_to_edges(D)
    groups, groups_p = edges_to_groups(edges, len(D))
    

    for pos, size_pct in get_targets_shape(train_df):
        if pos==2: 
            chisel2(groups, groups_p, int(size_pct * len(groups)))
        else:
            chisel(groups, groups_p, pos-1, int(size_pct * len(groups)))
    groups = [g+[i] for i,g in enumerate(groups)]
    show_groups(groups, targets)
    return groups

In [None]:
def top3(i):
    return torch.tensor([i]+groups[i][:2])

def trip_score(i, groups, groups_p):
    second = groups[i][0]
#     trip1 = top3(i)
#     trip2 = top3(second)
#     p1 = D[trip1[:,None], trip1[None,:]].min()
#     p2 = D[trip2[:,None], trip2[None,:]].min()
#     return p1+p2-p1*p2

    p1 = groups_p[i][1]
    if groups[second][0] == i:
        p2 = groups_p[second][1]
    else:
        p2= groups_p[second][0]
    return p1+p2-p1*p2

def chisel2(groups, groups_p, target_count):
    ts = torch.tensor([trip_score(i, groups, groups_p) for i in range(len(groups))])
    _, cands = (-ts.cuda()).topk(target_count)
    found = 0
    for i in cands.tolist():
        if len(groups[i]) <= 1: continue
        groups[i]=groups[i][:1]
        groups_p[i]=groups_p[i][:1]
        found +=1
        second = groups[i][0]
        if len(groups[second]) > 1:
            groups[second]=groups[second][:1]
            groups_p[second]=groups_p[second][:1]
            found +=1
        if found >= target_count: break

In [None]:
if TRIAL_RUN:
    img_embs,ys = generate_image_embs(get_image_dls(336, 128).valid)

    text_embs,ys = generate_text_embs(get_text_dls().valid)

    target_matrix = ys[:,None]==ys[None,:]
    targets = [torch.where(t)[0].tolist() for t in target_matrix]


    emb_size = img_embs[0].shape[0]
    D = torch.empty((emb_size, emb_size), device = 'cuda', dtype=torch.float16)

    tfid_embs=  get_tfid_embs(train_df, train_df[train_df.is_valid].index.tolist())
    generate_tfid_D(tfid_embs,D)

    D.mul_(-1)
    D.add_(1)
    gen_sim_and([text_embs]+img_embs,D)

    print(score_distances(D,targets))

    pipeline(D)

## Run on the test set

In [None]:
test_df = pd.read_csv(PATH/'test.csv')

In [None]:
if TRIAL_RUN:
    fake_test_df = train_df[['posting_id', 'image', 'image_phash', 'title', 'label_group']].copy()
    fake_test_df = pd.concat([fake_test_df, fake_test_df])
    fake_test_df = add_target_groups(fake_test_df)
    test_df = fake_test_df

In [None]:
#img_embs = [F.normalize(torch.rand((68500, 4200), device='cuda', dtype=torch.float16)).cpu()]*2

In [None]:
#text_embs = F.normalize(torch.rand((68500, 1024), device='cuda', dtype=torch.float16)).cpu()

In [None]:
text_embs,_ = generate_text_embs(get_text_dls().test_dl(test_df))

In [None]:
img_embs,_ = generate_image_embs(get_image_dls(336, 64).test_dl(test_df))

In [None]:
gc.collect()
torch.cuda.synchronize()
torch.cuda.empty_cache()

In [None]:
tfid_embs=  get_tfid_embs(pd.concat([test_df,train_df]), range(len(test_df)))


In [None]:
emb_size = len(test_df)
D = torch.empty((emb_size, emb_size), device = 'cuda', dtype=torch.float16)

In [None]:
generate_tfid_D(tfid_embs,D)
del tfid_embs

In [None]:
D.mul_(-1)
D.add_(1)
gen_sim_and([text_embs]+img_embs,D)

In [None]:
del img_embs, text_embs
gc.collect()
torch.cuda.synchronize()
torch.cuda.empty_cache()

In [None]:
rerank(D)

In [None]:
%%time
edges = dist_to_edges(D)

In [None]:
groups, groups_p = edges_to_groups(edges, len(D))
    

for pos, size_pct in get_targets_shape(train_df):
    if pos==2: 
        chisel2(groups, groups_p, int(size_pct * len(groups)))
    else:
        chisel(groups, groups_p, pos-1, int(size_pct * len(groups)))
groups = [g+[i] for i,g in enumerate(groups)]

In [None]:
# if 'target' in test_df.columns.to_list():
#     print(score_all_groups(groups, test_df.target.to_list()))

In [None]:
#matches =test_df.posting_id.to_list()

In [None]:
matches = [' '.join(test_df.iloc[g].posting_id.to_list()) for g in groups]
test_df['matches'] = matches

test_df[['posting_id','matches']].to_csv('submission.csv',index=False)

In [None]:
pd.read_csv('submission.csv').head()