Skip to content

pris-nlp/NAACL2022-Reassigned_Contrastive_Learning_OOD

Repository files navigation

Revisit Overconfidence for OOD Detection: Reassigned Contrastive Learning with Adaptive Class-dependent Threshold

This repository is the official implementation of Revisit Overconfidence for OOD Detection: Reassigned Contrastive Learning with Adaptive Class-dependent Thresholde (NAACL2022) by Yanan Wu, Keqing He, Yuanmeng Yan, Qixiang Gao, Zhiyuan Zeng, Fujia Zheng, Lulu Zhao, Huixing Jiang, Wei Wu, Weiran Xu.

Introduction

An OOD detection model based on reassigned contrastive learning with adaptive class-dependent threshold.

The architecture of the proposed model:

An example of confused label pair:

Dependencies

We use anaconda to create python environment:

conda create --name python=3.6

Install all required libraries:

pip install -r requirements.txt

How to run

  1. Baseline : Cross-entropy Loss
    sh run.sh 0

  2. Baseline : Supervised contrastive loss
    sh run.sh 1

  3. Reassigned Contrastive Learning
    sh run.sh 2

Parameters

The parameters that must be specified:

  • dataset, required, The dataset to use, CLINC_OOD_full, ATIS, SNIPS...
  • mode, optional, specify running mode, options:
    • train: only train model
    • test: only predict the model with trained model
    • both: train model, and use the trained model to predict
  • detect_method, required when the mode is 'test' or 'both', specify the settings to detect ood samples, options:
    • energylocal: using energy for predicting with adaptive class-dependent threshold
  • model_dir, required when the mode is 'test', specify the directory contains model file
  • unseen_classes, oos
  • confused_pre_epoches, default=5
  • global_pre_epoches, default=25
  • ce_pre_epoches, default=30

The parameters that have default values (In general, it can stay fixed):

  • gpu_device, default = 1
  • output_dir, default = ./outputs
  • embedding_file, default = /glove/glove.6B.300d.txt
  • embedding_dim, default = 300
  • max_seq_len, default = None,
  • max_num_words, default = 10000
  • max_epoches, default = 200
  • patience, default = 20
  • batch_size, default = 200
  • seed, default = 2022
  • T, default = 1
  • train_class_num, default = n

Code Introduction by block

1. model.py

  • BiLSTM_LMCL main model
  • LMCL Loss
  • SCL Loss
  • RCL Loss

2. train.py

Model main experiment, including:

  • Process Data
  • Select OOD type randomly or specify OOD type
  • Train model(BiLSTM + LSTM)
    • CE
    • SCL
    • RCL
  • Predict model
    • energylocal

3. utils.py

  • get_score, get f1, f1_seen, f1_unseen result according to the confusion matrix
  • confidence, calculate mahalanobis or euclidean distance based on the confidence of each category
  • get_test_info, get predicting results including text,label,softmax probability, softmax prediction,softmax confidence,(if use lof) lof prediction result,(if use gda) gda mahalanobis distance, (if use gda) the gda confidence
  • log_pred_results, Read'. / output_dir... /result.txt' file, print the result in the specified format, and convert it to the 'results.json' file.

Results

CLINC-Full Snips
Global Threshold Local Threshold(ours) Global Threshold Local Threshold(ours)
OOD IND OOD IND OOD IND OOD IND
F1 Recall F1 Recall F1 Recall F1 Recall F1 Recall F1 Recall
MSP CE 54.70 44.50 86.28 64.35 62.00 87.06 74.39 80.19 88.37 78.03 82.46 90.21
MSP SCL 56.98 46.34 87.94 65.88 64.31 88.00 80.57 89.60 79.25 78.03 82.57 91.30
MSP RCL(ours) 61.71 53.90 88.45 67.43 64.92 88.76 81.00 81.52 91.71 83.53 82.94 93.28
GDA CE 65.79 64.14 87.90 68.06 73.50 87.95 77.33 79.23 90.08 81.96 84.52 90.11
GDA SCL 68.04 66.92 88.60 70.85 73.40 88.63 80.27 82.46 91.19 83.45 87.20 92.58
GDA RCL(ours) 72.61 70.00 88.98 73.88 75.10 89.03 85.24 86.95 93.89 87.91 88.57 94.65
Energy CE 68.87 66.30 88.02 71.67 72.50 88.78 78.75 79.27 91.00 82.65 84.70 92.58
Energy SCL 71.12 71.01 88.59 73.15 73.20 88.98 81.72 81.99 91.27 85.04 85.83 95.42
Energy RCL 74.30 72.03 89.56 75.32 78.60 89.67 86.41 87.16 94.40 89.21 89.45 95.42

Citation

@inproceedings{wu-etal-2022-revisit, 
    title = "Revisit Overconfidence for {OOD} Detection: Reassigned Contrastive Learning with Adaptive Class-dependent Threshold", 
    author = "Wu, Yanan  and
      He, Keqing  and
      Yan, Yuanmeng  and
      Gao, QiXiang  and
      Zeng, Zhiyuan  and
      Zheng, Fujia  and
      Zhao, Lulu  and
      Jiang, Huixing  and
      Wu, Wei  and
      Xu, Weiran",
    booktitle = "Proceedings of the 2022 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies",
    month = jul,
    year = "2022",
    address = "Seattle, United States",
    publisher = "Association for Computational Linguistics",
    url = "https://aclanthology.org/2022.naacl-main.307",
    doi = "10.18653/v1/2022.naacl-main.307",
    pages = "4165--4179",
    abstract = "Detecting Out-of-Domain (OOD) or unknown intents from user queries is essential in a task-oriented dialog system. A key challenge of OOD detection is the overconfidence of neural models. In this paper, we comprehensively analyze overconfidence and classify it into two perspectives: over-confident OOD and in-domain (IND). Then according to intrinsic reasons, we respectively propose a novel reassigned contrastive learning (RCL) to discriminate IND intents for over-confident OOD and an adaptive class-dependent local threshold mechanism to separate similar IND and OOD intents for over-confident IND. Experiments and analyses show the effectiveness of our proposed method for both aspects of overconfidence issues.",
}

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published