Skip to content

天池阿里灵杰问天引擎电商搜索算法赛 pytorch baseline,非官方。附从 0.05 到 0.26 分的 trick。

Notifications You must be signed in to change notification settings

yuanjie-ai/E-commerce-Search-Recall

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

11 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

电商搜索召回

一个毫无 NLP 经验的人的比赛(挖坑填坑)之旅。

  1. 实现 DSSM baseline,直接优化距离结果很差,得分 0.057
  2. 实现 CoSENT,余弦距离得分 0.159
  3. 实现 SimCSE,得分 0.227

tools 里面是精度转换和结果文件检查。

Trick

  1. 在 model.py 中使用 first-last-avg 融合大概从 0.22 提升到 0.245 左右。

    Details
    def forward(self, input_ids, attention_mask, token_type_ids):
        out = self.extractor(input_ids,
                             attention_mask=attention_mask,
                             token_type_ids=token_type_ids,
                             output_hidden_states=True)
    
        first = out.hidden_states[1].transpose(1, 2)
        last = out.hidden_states[-1].transpose(1, 2)
        first_avg = torch.avg_pool1d(
            first, kernel_size=last.shape[-1]).squeeze(-1)  # [batch, 768]
        last_avg = torch.avg_pool1d(last, kernel_size=last.shape[-1]).squeeze(
            -1)  # [batch, 768]
        avg = torch.cat((first_avg.unsqueeze(1), last_avg.unsqueeze(1)),
                        dim=1)  # [batch, 2, 768]
        out = torch.avg_pool1d(avg.transpose(1, 2), kernel_size=2).squeeze(-1)
        x = self.fc(out)
        x = F.normalize(x, p=2, dim=-1)
        return x
  2. unilm 文件夹下,进行 UniLM 预训练,大概 0.265 左右,损失在 1.3x 左右。预训练模型下载: YunwenTechnology/Unilm

参考

致谢

本仓库中的工作得到西安电子科技大学高性能计算校级公共平台的支持. Supported by High-performance Computing Platform of XiDian University.

About

天池阿里灵杰问天引擎电商搜索算法赛 pytorch baseline,非官方。附从 0.05 到 0.26 分的 trick。

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 100.0%