In [1]:
from fastai.vision.all import *

from tqdm.notebook import tqdm


In [2]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline

In [3]:
from shopee_utils import *
from train_utils import *

In [4]:
import debugpy
debugpy.listen(5678)

('127.0.0.1', 5678)

In [5]:
tokenizer = AutoTokenizer.from_pretrained('Helsinki-NLP/opus-mt-id-en')

In [6]:
txt_model = AutoModelForSeq2SeqLM.from_pretrained('Helsinki-NLP/opus-mt-id-en')

In [7]:
txt_model=txt_model.cuda().eval()


In [4]:
PATH = Path('/home/slex/data/shopee')
train_df = pd.read_csv(PATH/'train_split.csv')

In [9]:
trans_texts = []
CHUNK = 50

print('translating texts')
CTS = len(train_df)//CHUNK
if len(train_df)%CHUNK!=0: CTS += 1
for i,j in tqdm(enumerate(range(CTS)), total=CTS):
    a = j*CHUNK
    b = (j+1)*CHUNK
    b = min(b,len(train_df))
    input_ids = tokenizer(list(train_df.iloc[a:b].title.values), return_tensors="pt", truncation=True, padding=True).input_ids.cuda()
    outputs = txt_model.generate(input_ids=input_ids, num_return_sequences=1)    
    val = tokenizer.batch_decode(outputs, skip_special_tokens=True)
    trans_texts.extend(val)

translating texts


  0%|          | 0/685 [00:00<?, ?it/s]

In [11]:
len(trans_texts)

34250

In [13]:
train_df['en_title']=trans_texts

In [9]:
train_df.to_csv(PATH/'train_trans.csv', index=False)

In [19]:

trans_df =pd.read_csv(PATH/'train_trans.csv')
trans_df[['title', 'en_title']]

Unnamed: 0,title,en_title
0,Paper Bag Victoria Secret,Paper Bag Victoria Secret
1,"Double Tape 3M VHB 12 mm x 4,5 m ORIGINAL / DOUBLE FOAM TAPE",Double Tape 3M VHB 12 mm x 4.5 m ORIGINAL / DOUBLE FOAM TAPE
2,Maling TTS Canned Pork Luncheon Meat 397 gr,TTS Thief Canned Pork Luncheon Meat 397 gr
3,Daster Batik Lengan pendek - Motif Acak / Campur - Leher Kancing (DPT001-00) Batik karakter Alhadi,Batik Short Arm - Random Motive / Mixed - Snap Neck (DPT001-00) Batik character Alhadi
4,Nescafe \xc3\x89clair Latte 220ml,Nescafe \xc3\x89clair Latte 220ml
...,...,...
34245,Masker Bahan Kain Spunbond Non Woven 75 gsm 3 ply lapis Bisa Dicuci,Spunbond Material Mask Non Woven 75 gsm 3 ply layer washable
34246,MamyPoko Pants Royal Soft - S 70 - Popok Celana,MamyPoko Pants Royal Soft - S 70 - Popok Pants
34247,KHANZAACC Robot RE101S 1.2mm Subwoofer Bass Metal Wired Headset,KHANZAACC Robot RE101S 1.2mm Subwoofer Bass Metal Wired Headset
34248,"Kaldu NON MSG HALAL Mama Kamu Ayam Kampung , Sapi Lokal, Jamur (Bkn Alsultan / Biocell)","Your Mama's NON MSG Hall Chicken Village, Local Cows, Mushrooms."


In [10]:
input_ids = tokenizer(list(train_df.iloc[0:4].title.values), return_tensors="pt", truncation=True, padding=True).input_ids.cuda()

In [11]:
input_ids.shape

torch.Size([4, 32])

In [17]:
outputs = txt_model.generate(input_ids, output_hidden_states=True, num_beam_groups=1, num_beams=1,return_dict_in_generate=True, max_length=2)

In [25]:
len(outputs.encoder_hidden_states)

7

In [37]:
enc_state = torch.cat([t.mean(dim=1) for t in outputs.encoder_hidden_states], dim=1).shape


In [41]:
torch.cat(outputs.decoder_hidden_states[0], dim=2).squeeze().shape

torch.Size([4, 3584])

In [5]:
class ArcFaceClassifier(nn.Module):
    def __init__(self, in_features, output_classes):
        super().__init__()
        self.initial_layers=nn.Sequential(

            nn.BatchNorm1d(in_features),
            nn.Dropout(.25))
        self.W = nn.Parameter(torch.Tensor(in_features, output_classes))
        nn.init.kaiming_uniform_(self.W)
    def forward(self, x):
        x = self.initial_layers(x)
        x_norm = F.normalize(x)
        W_norm = F.normalize(self.W, dim=0)
        return x_norm @ W_norm
    


In [6]:
class TitleTransform(Transform):
    def __init__(self):
        super().__init__()
        self.tokenizer = AutoTokenizer.from_pretrained('Helsinki-NLP/opus-mt-id-en')
        
        
    def encodes(self, row):
        text = row.title
        encodings = self.tokenizer(text, padding = 'max_length', max_length=50, truncation=True,return_tensors='pt')
        return encodings['input_ids'].squeeze(), encodings['attention_mask'].squeeze(), torch.tensor(54795).view(1)

In [7]:
tfm = TitleTransform()

data_block = DataBlock(
    blocks = (TransformBlock(type_tfms=tfm), 
              CategoryBlock(vocab=train_df.label_group.to_list())),
    splitter=ColSplitter(),
    get_y=ColReader('label_group'),
    )
dls = data_block.dataloaders(train_df, bs=256,num_workers=16)


In [40]:
class SillyModel(nn.Module):
    def __init__(self, text_model):
        super().__init__()
        self.text_model = text_model
        embs_dim1, embs_dim2 = 4096, 1024
        self.pooler = nn.Sequential(
            nn.Linear(embs_dim1, embs_dim2),
            nn.BatchNorm1d(embs_dim2),
            nn.Dropout())
        self.classifier = ArcFaceClassifier(embs_dim2, dls.c)
        self.outputEmbs = False
    def forward(self, x):
        input_ids, attention_mask, decoder_input_ids = x
        outputs = self.text_model(input_ids=input_ids, attention_mask=attention_mask,
                                  decoder_input_ids=decoder_input_ids, output_hidden_states=True)
        #enc_states = torch.mean(outputs['encoder_last_hidden_state'], dim=1)
        enc_states = outputs['encoder_last_hidden_state'][:,0,:]
        dec_states = torch.cat(outputs.decoder_hidden_states, dim=2).squeeze()
        embeddings = torch.cat([enc_states, dec_states], dim=1)
        embeddings = self.pooler(embeddings)
        if self.outputEmbs:
            return embeddings
        return self.classifier(embeddings)

In [9]:
bx, by = dls.one_batch()

In [41]:
def new_model():
    txt_model = AutoModelForSeq2SeqLM.from_pretrained('Helsinki-NLP/opus-mt-id-en')
    return SillyModel(txt_model).cuda()

In [42]:
def split_2way(model):
    return L(params(model.text_model),
            params(model.classifier)+params(model.pooler))

In [43]:
loss_func=functools.partial(arcface_loss,m=.5)
learn = Learner(dls,new_model(),  splitter=split_2way, loss_func=loss_func,  cbs = F1FromEmbs, metrics=FakeMetric(), train_bn=True)

In [44]:
learn.fine_tune(20, 1e-2,lr_mult=10, freeze_epochs=2)

epoch,train_loss,valid_loss,F1 embeddings,time
0,22.795553,,0.565353,00:13
1,20.683207,,0.647583,00:13


epoch,train_loss,valid_loss,F1 embeddings,time
0,18.402208,,0.681869,00:20
1,17.481159,,0.697676,00:20
2,16.569492,,0.720971,00:21
3,15.624162,,0.73355,00:20
4,14.886375,,0.748251,00:20
5,14.215281,,0.755205,00:21
6,13.530007,,0.760552,00:20
7,12.908889,,0.761603,00:20
8,12.341012,,0.761122,00:20
9,11.792464,,0.765182,00:21
