# Import Packages

In [None]:
import sys
sys.path.append('../input/timm-pytorch-image-models/pytorch-image-models-master')

In [None]:
import os
import cv2
import math
import random
import numpy as np
import pandas as pd
from tqdm import tqdm

import torch
import timm
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.data import Dataset,DataLoader

import gc
import matplotlib.pyplot as plt
from sklearn.feature_extraction.text import TfidfVectorizer,CountVectorizer
from sklearn.preprocessing import normalize
from sklearn.preprocessing import LabelEncoder

In [None]:
import transformers
from transformers import (BertTokenizer, BertModel,
                          DistilBertTokenizer, DistilBertModel)
from transformers import AutoTokenizer, AutoModel, BertTokenizer
from collections import OrderedDict
from typing import Tuple, Union, Dict

# Config

In [None]:
TEXT_MODEL = "../input/distilbert-base-indonesian"
MAX_LEN = 32 # Maximum length of text
EMBED_DIM = 768

In [None]:
DATA_DIR = '../input/shopee-product-matching/train_images'
TRAIN_CSV = '../input/shopee-product-matching/train.csv'
MODEL_PATH = './'

class CFG:
    debug = False
    cv = True
    divide_fold = False
    epochs = 6
    seed = 54
    batch_size = 20
    classes = 11014 
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    text_model_path = '../input/shopee-pytorch-models/arcface_distilbert_model_512.pt'
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
df = pd.read_csv(TRAIN_CSV)
# df['length'] = df['title'].apply(lambda x: len(x.split()))
# labelencoder= LabelEncoder()
# df['label_group'] = labelencoder.fit_transform(df['label_group'])
print(df.shape)
df.head()

In [None]:
if CFG.debug:
#     CFG.batch_size = 5
    df = df[:CFG.batch_size*5+3]
    CFG.epochs = 1    

# Utils

In [None]:
def seed_torch(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    
seed_torch(CFG.seed)

# Text Model

In [None]:
class TextModel(nn.Module):
    def __init__(self):
        
        super().__init__()
        self.bert_model = DistilBertModel.from_pretrained(TEXT_MODEL)
    
    def get_bert_features(self, batch):
        output = self.bert_model(input_ids=batch['input_ids'], attention_mask=batch['attention_mask'])
        last_hidden_state = output.last_hidden_state # shape: (batch_size, seq_length, bert_hidden_dim)
#         CLS_token_state = last_hidden_state[:, 0, :] # obtaining CLS token state which is the first token.
        return last_hidden_state
    
    def forward(self, batch):
        CLS_hidden_state = self.get_bert_features(batch)
        return CLS_hidden_state

# Data

In [None]:
class ShopeeTextDataset(torch.utils.data.Dataset):

    def __init__(self, df, tokenizer=DistilBertTokenizer.from_pretrained(TEXT_MODEL), max_length=MAX_LEN):
        self.df = df 
        self.tokenizer = tokenizer
        self.max_length = max_length
        texts = list(df['title'].apply(lambda o: str(o)).values)
        self.encodings = tokenizer(texts, 
                                   padding=True, 
                                   truncation=True, 
                                   max_length=max_length)
        del texts

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):

        row = self.df.iloc[idx]
        
        item = {key: torch.tensor(values[idx]) for key, values in self.encodings.items()}  
        
        return item

# Training

In [None]:
model = TextModel()
model.eval()
model.load_state_dict(torch.load(CFG.text_model_path, map_location=torch.device('cpu')), strict=False)
model = model.to(CFG.device)

text_dataset = ShopeeTextDataset(df)
text_loader = torch.utils.data.DataLoader(
    text_dataset,
    batch_size=CFG.batch_size,
    pin_memory=True,
    drop_last=False,
    num_workers=2
)

embeds = []
with torch.no_grad():
    for data in tqdm(text_loader):
        for k,v in data.items():
            data[k] = v.to(device)
        features = model.get_bert_features(data)
        embeds.append(features.half())

# del model
text_embeddings = torch.cat(embeds, dim=0)
# del embeds
print(f'Our text embeddings shape is {text_embeddings.shape}')
torch.save(text_embeddings, f'text_embeddings.pt')