I tried Graph neural network (RGCN) approaches using Deep Graph Libarry DGL.  
Please upvote if this notebook is useful.

In [None]:
!conda install -c dglteam dgl-cuda11.0 -y
!conda install -c conda-forge swifter -y

In [None]:
import dgl
import dgl.nn as dglnn

import torch
import torch.nn as nn
import torch.nn.functional as F

import pandas as pd
import numpy as np
import tqdm
import joblib
from annoy import AnnoyIndex
import swifter
#from scipy.spatial import cKDTree
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import recall_score, roc_auc_score

In [None]:
class Config:
    transaction_path = "../input/h-and-m-personalized-fashion-recommendations/transactions_train.csv"
    transaction_2020_path = "../input/h-and-m-split-dataset-by-year/transactions_train_2020.csv"
    transaction_2019_path = "../input/h-and-m-split-dataset-by-year/transactions_train_2019.csv"
    customer_path = "../input/h-and-m-personalized-fashion-recommendations/customers.csv"
    article_path = "../input/h-and-m-personalized-fashion-recommendations/articles.csv"
    image_feat_path = "../input/h-and-m-swint-image-embedding/swin_tiny_patch4_window7_224_emb.csv.gz"
    sample_submission_path = "../input/h-and-m-personalized-fashion-recommendations/sample_submission.csv"

    output_dir = "../output/"
    #start_date = '2020-08-01'
    start_date = '2020-09-01'

    image_feat_dim = 768
    text_feat_dim = 384
    
    # train
    #n_fold = 2
    n_fold = 5
    #epoch = 50
    epoch = 100

    seed = 2022
    
    # graph 
    customer_node = "customer"
    article_node = "article"
    buy_edge = "buy"
    bought_by_edge = "bought_by"

    buy_store_edge = "buy_store"
    bought_by_store_edge = "bought_by_store"
    buy_online_edge = "buy_online"
    bought_by_online_edge = "bought_by_online"
    age_same = "age"
    age_same_by = "age_by"
    
    
    in_feat_dim = 200
    hidden_feat_dim = 500
    out_feat_dim = 200

    #device=torch.device("cpu")
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

    model_path = "model.pth"

In [None]:
df_trans = pd.read_csv(Config.transaction_path, dtype={'article_id': 'str'})
df_trans = df_trans[df_trans.t_dat >= Config.start_date]
df_trans.head()

In [None]:
df_trans.shape

In [None]:
df_trans.t_dat.head()

In [None]:
df_ranking = df_trans[["article_id", "customer_id"]].groupby("article_id").count().reset_index().sort_values("customer_id", ascending=False)
df_ranking = df_ranking[df_ranking["customer_id"] > 10]
df_ranking.head()

In [None]:
# th回以上transactionがあるやつに絞る。

def transaction_count_filter(df, th=10):
    df_ranking = df_trans[["article_id", "customer_id"]].groupby("article_id").count().reset_index().sort_values("customer_id", ascending=False)
    df_ranking = df_ranking[df_ranking["customer_id"] >= th]
    df = df.merge(df_ranking[["article_id"]], on="article_id", how="inner").reset_index(drop=True)
    df = df.drop_duplicates().reset_index(drop=True)
    return df


In [None]:
df_trans.shape

In [None]:
transaction_count_filter(df_trans).shape

In [None]:
df_trans = transaction_count_filter(df_trans)

In [None]:
df_trans.head()

In [None]:
df_submission = pd.read_csv(Config.sample_submission_path)
df_submission.head()

In [None]:
df_customer = pd.read_csv(Config.customer_path)
df_customer.head()

In [None]:
users = df_trans["customer_id"].unique().tolist()
df_customer_node = pd.DataFrame( {"customer_id": users,
                                  "customer_node_id": [i for i in range(len(users))]})

In [None]:
article = df_trans["article_id"].unique().tolist()
df_article_node = pd.DataFrame( {"article_id": article, "article_node_id": [i for i in range(len(article))]})

In [None]:
df_article = pd.read_csv(Config.article_path, dtype={'article_id': 'str'})
df_article = df_article.merge(df_article_node[["article_id"]], on="article_id", how="inner")


In [None]:
df_trans = df_trans.merge(df_customer_node, how='inner', on="customer_id")
df_trans = df_trans.merge(df_article_node, how="inner", on="article_id")

In [None]:
df_trans.head()

In [None]:
# cond_store = df_trans["sales_channel_id"] == 1
# cond_online = df_trans["sales_channel_id"] == 2

# graph = dgl.heterograph({
#     (Config.customer_node, Config.buy_store_edge, Config.article_node): 
#                (df_trans[cond_store].loc[:, "customer_node_id"].tolist(), df_trans[cond_store].loc[:, "article_node_id"].tolist()),
#     (Config.article_node, Config.bought_by_store_edge, Config.customer_node): 
#               (df_trans[cond_store].loc[:, "article_node_id"].tolist(), df_trans[cond_store].loc[:, "customer_node_id"].tolist() ),
#     (Config.customer_node, Config.buy_online_edge, Config.article_node): 
#                (df_trans[cond_online].loc[:, "customer_node_id"].tolist(), df_trans[cond_online].loc[:, "article_node_id"].tolist()),
#     (Config.article_node, Config.bought_by_online_edge, Config.customer_node): 
#               (df_trans[cond_online].loc[:, "article_node_id"].tolist(), df_trans[cond_online].loc[:, "customer_node_id"].tolist())
# })

graph = dgl.heterograph({
    (Config.customer_node, Config.buy_edge, Config.article_node): 
               (df_trans.loc[:, "customer_node_id"].tolist(), df_trans.loc[:, "article_node_id"].tolist()),
    (Config.article_node, Config.bought_by_edge, Config.customer_node): 
              (df_trans.loc[:, "article_node_id"].tolist(), df_trans.loc[:, "customer_node_id"].tolist() ),  
})

In [None]:
graph

In [None]:
class RGCN(nn.Module):
    def __init__(self, in_feat, hidden_feat, out_feat, rel_names):
        super().__init__()
        self.conv1 = dglnn.HeteroGraphConv({
                rel : dglnn.GraphConv(in_feat, hidden_feat, norm='right')
                for rel in rel_names
            })
        self.conv2 = dglnn.HeteroGraphConv({
                rel : dglnn.GraphConv(hidden_feat, out_feat, norm='right')
                for rel in rel_names
            })

    def forward(self, blocks, x):
        x = self.conv1(blocks[0], x)
        x = {key:F.relu(val)  for key, val in x.items()}
        x = self.conv2(blocks[1], x)
        return x

In [None]:
# class Dense(nn.Module):
    
#     def __init__(self, col, image_mask, text_mask, image_dim, text_dim):
#         super().__init__()
#         self.col = col
#         self.image_mask
#         self.text_mask
#         self.image_dence = nn.Linear(image_dim, 250)
#         self.text_dence = nn.Linear(text_dim, 250)
#         self.other_dence = nn.Liner()
        
#     def forward(self, x):
#         _image = self.image_dence(x[self.col][:, self.image_mask])
#         _text = self.text_dence(x[self.col][:, self.text_mask])
#         _x = torch.cat([_image, _text])
#         x[self.col] = _x
        
#         return _x

In [None]:
class ScorePredictor(nn.Module):
    def forward(self, edge_subgraph, x):
        with edge_subgraph.local_scope():
            edge_subgraph.ndata['h'] = x
            for etype in edge_subgraph.canonical_etypes:
                edge_subgraph.apply_edges(
                    dgl.function.u_dot_v('h', 'h', 'score'), etype=etype)
            return edge_subgraph.edata['score']

class Model(nn.Module):
    def __init__(self, 
                 in_features, 
                 hidden_features, 
                 out_features,
                 item_dim,
                 user_dim,                 
                 etypes,
                 item_col=Config.article_node,
                 user_col=Config.customer_node):
        
        super().__init__()
        
        self.item_dence = nn.Linear(item_dim, in_features)
        self.user_dence = nn.Linear(user_dim, in_features)
        self.item_col = item_col
        self.user_col = user_col
        
        self.hidden_featuers = hidden_features
        self.out_featuers = out_features

        self.rgcn = RGCN(in_features, hidden_features, out_features, etypes)
        
        self.score = ScorePredictor()

    def forward(self, blocks, x):
        x[self.user_col] = F.relu(self.user_dence(x[self.user_col]))
        x[self.item_col] = F.relu(self.item_dence(x[self.item_col]))
        x = self.rgcn(blocks, x)        
        return x    
    

In [None]:
def compute_loss(pos_score, neg_score, canonical_etypes):
    # Margin loss
    all_losses = []
    for given_type in canonical_etypes:
        n_edges = pos_score[given_type].shape[0]
        if n_edges == 0:
            continue
        all_losses.append((1 - neg_score[given_type].view(n_edges, -1) + pos_score[given_type].unsqueeze(1)).clamp(min=0).mean())
    return torch.stack(all_losses, dim=0).mean()

def compute_loss_bce(pos_score, neg_score, canonical_etypes):
    all_losses = []
    criterion = torch.nn.BCEWithLogitsLoss()
    for given_type in canonical_etypes:
        _pos_score = pos_score[given_type].squeeze(1)
        _neg_score = neg_score[given_type].squeeze(1)
        
        pred = torch.cat([_pos_score, _neg_score])
        
        label = torch.cat([torch.ones(len(_pos_score)), torch.zeros(len(_neg_score))]).to(Config.device)
        loss = criterion(pred, label)
        all_losses.append(loss)
        
    return torch.stack(all_losses ,dim=0).mean() 

def compute_auc(pos_score, neg_score, canonical_etypes):
    aucs = []
    for etype in canonical_etypes:
        _pos_score = pos_score[etype].squeeze(1).to("cpu").detach()
        _neg_score = neg_score[etype].squeeze(1).to("cpu").detach()
        pred = torch.cat([_pos_score, _neg_score])
        label = torch.cat([torch.ones(len(_pos_score)), torch.zeros(len(_neg_score))])
        
        roc_auc = roc_auc_score(label.numpy(), pred.numpy())
        aucs.append(roc_auc)
        
    return np.mean(aucs)
        
        

In [None]:
def train(train_graph, X_dic, train_dataloader, model):    
    
    model = model.to(Config.device)
    opt = torch.optim.Adam(model.parameters())
    
    model.train()
    for i in range(Config.epoch):
        
        # train loop
        for input_nodes, positive_graph, negative_graph, blocks in train_dataloader:
            
            blocks = [b.to(Config.device) for b in blocks]
            positive_graph = positive_graph.to(Config.device)
            negative_graph = negative_graph.to(Config.device)

            feature = {
                ntype: X_dic[ntype][input_nodes[ntype]].to(Config.device) for ntype in train_graph.ntypes
            }            

            emb_dict = model(blocks, feature)
            pos_score = model.score(positive_graph, emb_dict)
            neg_score = model.score(negative_graph, emb_dict)

            loss = compute_loss_bce(pos_score, neg_score, train_graph.canonical_etypes)
            opt.zero_grad()
            loss.backward()
            opt.step()
            auc = compute_auc(pos_score, neg_score, train_graph.canonical_etypes)
    
        print(f"epoch: {i} | train loss:{loss.item()} | AUC {auc}")

        #evaluate()
    
    return model


def evaluate(train_graph, graph, X_dic, model, valid_eid_dict):
    
    emb = inference(train_graph, X_dic, model)
    score_list = []
    for etype in graph.canonical_etypes:
        src, dst = graph.find_edges(valid_eid_dict[etype], etype=etype)
        score = (emb[etype][src] * emb[etype][dst]).sum(1)
        score_list.append(score)

    print(score)
        


def inference(graph, X_dic, model):
    model = model.to(Config.device)
    model.eval()

    dataloader = dgl.dataloading.NodeDataLoader(
                graph,
                {
                    Config.article_node: torch.arange(graph.number_of_nodes(ntype=Config.article_node)),
                    Config.customer_node: torch.arange(graph.number_of_nodes(ntype=Config.customer_node))
                },
                dgl.dataloading.MultiLayerFullNeighborSampler(1),
                batch_size=1024,
                shuffle=True,
                drop_last=False,
            )


    with torch.no_grad():
        for n_layer in range(2):
            if n_layer == 0:
                y = {ntype: torch.zeros(graph.number_of_nodes(ntype), model.hidden_featuers) 
                    for ntype in graph.ntypes}
                 
            else:
                y = {ntype: torch.zeros(graph.number_of_nodes(ntype), model.out_featuers) 
                    for ntype in graph.ntypes}


            for input_nodes, output_nodes, blocks in dataloader:
                block = blocks[0].to(Config.device)            

                x = {
                    ntype: X_dic[ntype][input_nodes[ntype]].to(Config.device) for ntype in graph.ntypes
                }                

                if n_layer == 0:
                    x[model.user_col] = F.relu(model.user_dence(x[model.user_col]))
                    x[model.item_col] = F.relu(model.item_dence(x[model.item_col]))
                    h = model.rgcn.conv1(block, x)  
                    h = {key:F.relu(val) for key, val in h.items()}                    

                else:
                    h = model.rgcn.conv2(block, x)
    
                for ntype in graph.ntypes:
                    y[ntype][output_nodes[ntype]] = h[ntype].cpu()                         
                    
            X_dic = y

    return y
    
    
def validation(train_graph, graph, valid_eid_dict, x_dict, model, batch_size, fanout, num_workers):
    scores = []
    for src_ntype, etype, dst_ntype in graph.canonical_etypes:
        label = torch.ones(len(valid_eid_dict[etype]))
        src, dst = graph.find_edges(valid_eid_dict[etype], etype=etype)
        src_emb = emb[src_ntype][src]
        dst_emb = emb[dst_ntype][dst]
        score = torch.sigmoid(score)
        score = score > 0.5
        recall = recall_score(label, score)
        scores.append(recall)
        
    return np.mean(scores)

In [None]:
df_customer_node.head()

In [None]:
def create_customer_feat(df, df_node, df_trans):

    customer_drop_cols = ["postal_code"]
    customer_dummy_cols = ["club_member_status", "fashion_news_frequency"]


    df = df.drop(customer_drop_cols, axis=1)
    df.loc[:, "FN"] = df["FN"].fillna(0)
    df.loc[:, "Active"] = df["Active"].fillna(0)
    df.loc[:, "club_member_status"] = df["club_member_status"].fillna("NONE")
    df.loc[:, "fashion_news_frequency"] = df["fashion_news_frequency"].fillna("NONE")
    df.loc[:, "age"] = df["age"].fillna(0)
    df.loc[:, "age"] = np.log1p(df["age"])

    df = pd.get_dummies(df, columns=customer_dummy_cols)
    
    
    # price_mean
    df_price_mean = df_trans[['customer_id', 'price']].groupby("customer_id").mean().reset_index()
    df_price_mean["price"] = np.log(df_price_mean["price"] * 1000000)
    df = df.merge(df_price_mean, on="customer_id", how="left")
    
    # number of transaction per customer
    cond_store = df_trans["sales_channel_id"] == 1
    cond_online = df_trans["sales_channel_id"] == 2
    
    df_trans_count_offline = df_trans[cond_store].groupby("customer_id").count().reset_index()[["customer_id", "t_dat"]].rename(columns={"t_dat": "count_offline"})
    df_trans_count_online = df_trans[cond_online].groupby("customer_id").count().reset_index()[["customer_id", "t_dat"]].rename(columns={"t_dat": "count_online"})
    df_trans_count_both = df_trans.groupby("customer_id").count().reset_index()[["customer_id", "t_dat"]].rename(columns={"t_dat": "count_both"})
    
    df = df.merge(df_trans_count_offline, on="customer_id", how="left").fillna(0)
    df = df.merge(df_trans_count_online, on="customer_id", how="left").fillna(0)
    df = df.merge(df_trans_count_both, on="customer_id", how="left").fillna(0)
    
    
    df = df.merge(df_node, on="customer_id", how="inner")
    df = df.sort_values("customer_node_id").reset_index(drop=True)
    df = df.drop(["customer_id", "customer_node_id"], axis=1)


    return df


def get_article_table_feat(df):
    #article_id_cols = ["product_code", "product_type_no", "graphical_appearance_no", "colour_group_code",
    #         "perceived_colour_value_id", "perceived_colour_master_id", "department_no", "index_group_no",
    #           "section_no", "garment_group_no"]

    article_dummy_cols = ["product_type_name", "product_group_name", "graphical_appearance_name", "colour_group_name",
                         "perceived_colour_value_name", "perceived_colour_master_name",
                         #"department_name",
                         "index_name", "index_group_name", "section_name", "garment_group_name"]

    article_drop_cols = ["index_code", "prod_name", "detail_desc", "department_name"]

    df = df.drop(article_drop_cols, axis=1)
    df = pd.get_dummies(df, columns=article_dummy_cols)
    return df

def get_article_image_feat(df):
    pass

def get_article_text_feat(df):
    pass

def create_article_feat(df, df_node, df_trans):
    df_table_feat = get_article_table_feat(df)
    
    # price_mean
    df_price_mean = df_trans[['article_id', 'price']].groupby("article_id").mean().reset_index()
    df_price_mean["price"] = np.log(df_price_mean["price"] * 1000000)
    df_table_feat = df_table_feat.merge(df_price_mean, on="article_id", how="left")
    
    # number of transaction per customer
    cond_store = df_trans["sales_channel_id"] == 1
    cond_online = df_trans["sales_channel_id"] == 2
    
    df_trans_count_offline = df_trans[cond_store].groupby("article_id").count().reset_index()[["article_id", "t_dat"]].rename(columns={"t_dat": "count_offline"})
    df_trans_count_online = df_trans[cond_online].groupby("article_id").count().reset_index()[["article_id", "t_dat"]].rename(columns={"t_dat": "count_online"})
    df_trans_count_both = df_trans.groupby("article_id").count().reset_index()[["article_id", "t_dat"]].rename(columns={"t_dat": "count_both"})
    
    df_table_feat = df_table_feat.merge(df_trans_count_offline, on="article_id", how="left").fillna(0)
    df_table_feat = df_table_feat.merge(df_trans_count_online, on="article_id", how="left").fillna(0)
    df_table_feat = df_table_feat.merge(df_trans_count_both, on="article_id", how="left").fillna(0)
    
    df = df_table_feat
    df = df.merge(df_node, on="article_id", how="inner")
    df = df.sort_values("article_node_id").reset_index()
    df = df.drop(["article_id", "article_node_id"], axis=1)

    return df

def create_graph_feature(df_article, df_article_node, df_customer, df_customer_node, df_trans):    
    X_dic = {}
    df_customer_feat = create_customer_feat(df_customer, df_customer_node, df_trans)
    print(df_customer_feat.isna().any())
    df_article_feat = create_article_feat(df_article, df_article_node, df_trans)
    print(df_article_feat.isna().any())

    scaler = MinMaxScaler()

    X_dic[Config.customer_node] = torch.Tensor(scaler.fit_transform(df_customer_feat.values))
    X_dic[Config.article_node] = torch.Tensor(scaler.fit_transform(df_article_feat.values))
    
    return X_dic


In [None]:
df_trans.head()

In [None]:
X_dic = create_graph_feature(df_article, df_article_node, df_customer, df_customer_node, df_trans)

In [None]:
X_dic["customer"].shape

In [None]:
X_dic["article"].shape

In [None]:
X_dic["customer"].shape, X_dic["article"].shape

In [None]:
args = {
    "in_features": Config.in_feat_dim,
    "hidden_features": Config.hidden_feat_dim ,
    "out_features": Config.out_feat_dim,
    "item_dim": X_dic[Config.article_node].shape[1],
    "user_dim": X_dic[Config.customer_node].shape[1],
    "etypes": [Config.buy_edge, Config.bought_by_edge]
    #[Config.buy_online_edge, Config.bought_by_online_edge, Config.buy_store_edge, Config.bought_by_store_edge],
}

model = Model(**args)

In [None]:
# https://zqfang.github.io/2021-08-12-graph-linkpredict/


# train/ validation split

# https://github.com/dglai/WWW20-Hands-on-Tutorial/blob/master/_legacy/basic_apps/BasicTasks_pytorch.ipynb

# def train_valid_split(graph, train_rate=0.8):
#     _train_eid_dict = {}
#     _valid_eid_dict = {}

#     for etype in graph.canonical_etypes:
#         eids = np.random.permutation(graph.num_edges(etype))    
#         train_eids = eids[:int(len(eids) * train_rate)]
#         valid_eids = eids[int(len(eids) * train_rate):]

#         _train_eid_dict[etype] = train_eids
#         _valid_eid_dict[etype] = valid_eids
        
#     train_graph = graph.edge_subgraph(_train_eid_dict, relabel_nodes=False, store_ids=True)
#     valid_graph = graph.edge_subgraph(_valid_eid_dict, relabel_nodes=False, store_ids=True)

#     return train_graph, valid_graph, _train_eid_dict, _valid_eid_dict

# train_graph, valid_graph, _train_eid_dict, _valid_eid_dict = train_valid_split(graph)

In [None]:
#_valid_eid_dict

In [None]:
#graph.find_edges(126293, etype=('article','bought_by_online','customer'))

In [None]:
train_eid_dict = {
    #Config.buy_store_edge: torch.arange(graph.num_edges(Config.buy_store_edge)),
    #Config.bought_by_store_edge: torch.arange(graph.num_edges(Config.bought_by_store_edge)),
    #Config.buy_online_edge: torch.arange(graph.num_edges(Config.buy_online_edge)),
    #Config.bought_by_online_edge: torch.arange(graph.num_edges(Config.bought_by_online_edge)),
    Config.buy_edge: torch.arange(graph.num_edges(Config.buy_edge)),
    Config.bought_by_edge: torch.arange(graph.num_edges(Config.bought_by_edge)),
   
}

reverse_types = {
    #Config.buy_store_edge: Config.bought_by_store_edge,
    #Config.bought_by_store_edge: Config.buy_store_edge,
    #Config.buy_online_edge: Config.bought_by_online_edge,
    #Config.bought_by_online_edge: Config.buy_online_edge
    Config.buy_edge: Config.bought_by_edge,
    Config.bought_by_edge: Config.buy_edge,
}


sampler = dgl.dataloading.MultiLayerFullNeighborSampler(2)

sampler = dgl.dataloading.as_edge_prediction_sampler(
    sampler,
    exclude='reverse_types',
    reverse_etypes=reverse_types,
    negative_sampler=dgl.dataloading.negative_sampler.Uniform(1),
    
)

train_dataloader = dgl.dataloading.DataLoader(
    graph, 
    train_eid_dict, 
    sampler,
    batch_size=1024,
    shuffle=True,
    drop_last=False,
    num_workers=2
)

In [None]:
model = train(graph, X_dic, train_dataloader, model)

In [None]:
torch.save(model.to("cpu").state_dict(), Config.model_path)

In [None]:
model.load_state_dict(torch.load(Config.model_path))

In [None]:
emb = inference(graph, X_dic, model)

In [None]:
X_dic[Config.article_node].shape

In [None]:
emb[Config.article_node].shape

In [None]:
graph.num_nodes(Config.article_node)

In [None]:
df_article_node.shape

In [None]:
X_dic[Config.customer_node].shape

In [None]:
graph.num_nodes(Config.customer_node)

In [None]:
emb[Config.customer_node].shape

In [None]:
df_customer_node.shape

In [None]:
def create_emb_dataframe(emb, df_article_node, df_customer_node):
    df_article_emb = pd.DataFrame(emb[Config.article_node].numpy())
    df_article_emb = pd.concat([df_article_node, df_article_emb], axis=1)

    df_customer_emb = pd.DataFrame(emb[Config.customer_node].numpy())
    df_customer_emb = pd.concat([df_customer_node, df_customer_emb], axis=1)

    return df_article_emb, df_customer_emb

In [None]:
df_article_emb, df_customer_emb = create_emb_dataframe(emb, df_article_node, df_customer_node)

df_article_emb.to_pickle("article_emb.pkl")
df_customer_emb.to_pickle("customer_emb.pkl")

In [None]:
df_article_emb.head()

In [None]:
df_customer_emb.head()

In [None]:
df_customer_emb.isna().any()

In [None]:
# class NearestNeighborSearch:
    
#     def __init__(self, n_dim, seed):
#         self.t = AnnoyIndex(n_dim, 'angular')  
#         self.t.set_seed(seed)

#     def create_nearest_neighbor_search(self, df_emb, emb_col, n_trees):                
#         self.df_emb = df_emb

#         for i, v in tqdm.tqdm(enumerate(df_emb[emb_col].values), total=len(df_emb)):
#             self.t.add_item(i, v)

#         self.t.build(n_trees)

#     def get_nerest_negihbor(self, x: np.array, n: int, target_col="article_id"):
#         nn_index_list = self.t.get_nns_by_vector(x, n)
#         return self.df_emb.iloc[nn_index_list, :][target_col].tolist()




# def create_submission(df_submission, df_article_emb, df_customer_emb, n_trees=10):
#     nns = NearestNeighborSearch(Config.out_feat_dim, seed=Config.seed)
#     nns.create_nearest_neighbor_search(df_article_emb, list(range(Config.out_feat_dim)), n_trees)

#     for customer_id, emb in tqdm.tqdm(zip(df_customer_emb["customer_id"], df_customer_emb.loc[:, list(range(Config.out_feat_dim))].values), total=len(df_customer_emb)):        
#         nn_article_list = nns.get_nerest_negihbor(emb, 12)        
#         df_submission.loc[df_submission["customer_id"] == customer_id, "prediciton"] = " ".join([str(x) for x in nn_article_list])
    
#     return df_submission


# use ckdtree for multi processing        

#from joblib import wrap_non_picklable_objects

#@wrap_non_picklable_objects
def _task(customer_id, emb, df_article_id, annoy_path, n_dim):
    u = AnnoyIndex(n_dim, 'dot')
    u.load(annoy_path)     
    nn_index_list = u.get_nns_by_vector(emb, 12)
    nn_article_list = df_article_id.iloc[nn_index_list, :]["article_id"].tolist()
    return (customer_id, " ".join([str(x) for x in nn_article_list]))

def create_submission_mp(df_article_emb, df_customer_emb, annoy_path="h_and_m.ann", n_trees=10):
    
    t = AnnoyIndex(Config.out_feat_dim, 'dot')  
    t.set_seed(Config.seed)

    emb_col = list(range(Config.out_feat_dim))    
    for i, v in tqdm.tqdm(enumerate(df_article_emb[emb_col].values), total=len(df_article_emb)):
         t.add_item(i, v)

    t.build(n_trees, n_jobs=1)
    t.save(annoy_path)
    
    df_article_id = df_article_emb[["article_id"]]

    #result_list = [_task(customer_id, emb, df_article_id, annoy_path, Config.out_feat_dim) 
    #                for customer_id, emb in zip(df_customer_emb.head(10)["customer_id"].tolist(), df_customer_emb.head(10).loc[:, emb_col].values)]
    
    
    # https://github.com/spotify/annoy/issues/499
    # https://github.com/pavlin-policar/openTSNE/blob/872e8df89d7700bc650e1c2b40a41c0c5a9c1a54/openTSNE/nearest_neighbors.py#L264-L268
    result_list = joblib.Parallel(n_jobs=-1, verbose=1, require="sharedmem")(
        joblib.delayed(_task)(customer_id, emb, df_article_id, annoy_path, Config.out_feat_dim) 
        for customer_id, emb in zip(df_customer_emb["customer_id"].tolist(), df_customer_emb.loc[:, emb_col].values)
    )

    customer_id_list, prediction_list = [], []
    for customer_id, prediction in tqdm.tqdm(result_list):
        customer_id_list.append(customer_id)
        prediction_list.append(prediction)

    df_pred = pd.DataFrame(
        {"customer_id": customer_id_list, "prediction_2": prediction_list}
    )

    return df_pred

#def create_submission()

In [None]:

df_article_emb = pd.read_pickle("article_emb.pkl")
df_customer_emb = pd.read_pickle("customer_emb.pkl")

In [None]:
df_customer_emb.shape

In [None]:
df_prediction = create_submission_mp(df_article_emb, df_customer_emb)

In [None]:
df_prediction.head()

In [None]:
df_prediction.head(1).T

In [None]:
_df_submission = df_submission.merge(df_prediction, on="customer_id", how="left")

In [None]:
_df_submission.head()

In [None]:
_df_submission[_df_submission.prediction_2.isna()].shape, _df_submission.shape

In [None]:
_df_submission["prediction"] = _df_submission.swifter.apply(lambda row: row[2] if row[2] is not np.NaN else row[1], axis=1)
_df_submission.head()

In [None]:
_df_submission = _df_submission[["customer_id", "prediction"]]
_df_submission.head()

In [None]:
_df_submission.to_csv("submission.csv",index=None)