In [145]:
!pip install transformers==3.3.0
!pip install datasets



# New Section

In [146]:
import logging
import os
import argparse

import numpy as np
from transformers import (
    AutoConfig,
    AutoModelForSequenceClassification,
    AutoTokenizer,
    Trainer,
    TrainingArguments,
    DataCollatorForLanguageModeling,
)

In [147]:
from sfda.models import sfdaTargetRobertaNegation
from sfda.trainer import sfdaTrainer
from sfda.DataProcessor import NegationDataset

In [148]:
logger = logging.getLogger(__name__)

In [149]:

data_file, output_dir = "practice_text/train.tsv", "/outputs/negation/model/" 

In [150]:
model_name = "tmills/roberta_sfda_sharpseed"# Base Model
config = AutoConfig.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name, config=config)
 

In [151]:
model = sfdaTargetRobertaNegation.from_pretrained_source(model_name,config=config)

Some weights of the model checkpoint at tmills/roberta_sfda_sharpseed were not used when initializing sfdaTargetRobertaNegation: ['roberta.pooler.dense.weight', 'roberta.pooler.dense.bias']



In [152]:
from typing import Callable, Dict, Optional, List, Union
from transformers.data.metrics import acc_and_f1
from transformers import EvalPrediction
def build_compute_metrics_fn() -> Callable[[EvalPrediction], Dict]:
    def compute_metrics_fn(p: EvalPrediction):
        preds = np.argmax(p.predictions, axis=1)
        return acc_and_f1(preds, p.label_ids)

    return compute_metrics_fn

In [153]:
from datasets import load_dataset
dataset = load_dataset("text", data_files=data_file)

Using custom data configuration default
Reusing dataset text (/root/.cache/huggingface/datasets/text/default-b355bebfede2ba65/0.0.0/daf90a707a433ac193b369c8cc1772139bb6cca21a9c7fe83bdd16aad9b9b6ab)


In [154]:
def tokenize_function(examples):
    return tokenizer(examples["text"], return_special_tokens_mask=True)


In [155]:
tokenized_dataset = dataset.map(
            tokenize_function,
            batched=True,
            num_proc=None,
#             remove_columns=[text_column_name],
#             load_from_cache_file=not data_args.overwrite_cache,
        )

Loading cached processed dataset at /root/.cache/huggingface/datasets/text/default-b355bebfede2ba65/0.0.0/daf90a707a433ac193b369c8cc1772139bb6cca21a9c7fe83bdd16aad9b9b6ab/cache-a7cb13c62ec18b77.arrow


In [156]:
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=0.3)
trainer_mlm = Trainer(
    model=model,
    args=TrainingArguments(output_dir = "save_run/",learning_rate = 5e-6),
    compute_metrics=build_compute_metrics_fn(),
    train_dataset = tokenized_dataset["train"],
    data_collator = data_collator
)

In [157]:
trainer_mlm.train()

HBox(children=(FloatProgress(value=0.0, description='Epoch', max=3.0, style=ProgressStyle(description_width='i…

HBox(children=(FloatProgress(value=0.0, description='Iteration', max=361.0, style=ProgressStyle(description_wi…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=361.0, style=ProgressStyle(description_wi…

{'loss': 9.5276591796875, 'learning_rate': 2.6915974145891046e-06, 'epoch': 1.3850415512465375, 'total_flos': 269090336422188, 'step': 500}



HBox(children=(FloatProgress(value=0.0, description='Iteration', max=361.0, style=ProgressStyle(description_wi…

{'loss': 8.3948779296875, 'learning_rate': 3.831948291782087e-07, 'epoch': 2.770083102493075, 'total_flos': 540047943878400, 'step': 1000}




TrainOutput(global_step=1083, training_loss=8.907064252943213)

In [136]:
import dataclasses
import logging

import os
import argparse
from os.path import basename, dirname
import sys
from dataclasses import dataclass, field
import numpy as np
from transformers import (
    AutoConfig,
    AutoModelForSequenceClassification,
    AutoTokenizer,
    Trainer,
    TrainingArguments,
    set_seed,
    HfArgumentParser,
    EvalPrediction,
)
from typing import Callable, Dict, Optional, List, Union
logger = logging.getLogger(__name__)
from transformers.data.metrics import acc_and_f1

from enum import Enum
from sfda.models import sfdaTargetRobertaNegation
from sfda.trainer import sfdaTrainer
from sfda.DataProcessor import sfdaNegationDataset
from sfda.DataProcessor import NegationDataset

In [158]:
@dataclass
class sfdaTrainingArguments:
    APM_Strategy: str = field(
        default="top_k", metadata={"help": "APM update strategy, use top_k for updating APM with top_k from each label and thresh for specifying it with a threshold score."}
    )
    top_k: int = field(
        default=100, metadata={"help": "[For top_k APM update strategy], the number of prototypes extracted for each label"}
    )
    cf_ratio: float = field(
        default=100.0, metadata={"help": "The minimum ratio of min similarity  of  the closest class to the max similarity point of the farthest class to be eligible for consideration as High Confidence point"}
    )
    update_freq: int = field(
        default = 100,
        metadata={"help": "The number of global steps after which  APM prototypes are updated "}
    )
    alpha_routine: str = field(
        default="exp", metadata={"help": "The alpha update startegy. Choose from \"exp\" : Exponential routine, \"sqr\" : Square routine , \"lin\": Linear routine,, \"cube\": Cube routine "}
    )


In [159]:
@dataclass
class DataTrainingArguments:
    """
    Arguments pertaining to what data we are going to input our model for training and eval.

    Using `HfArgumentParser` we can turn this class
    into argparse arguments to be able to specify them on
    the command line.
    """

    # Only allowed task is Negation, don't need this field from Glue
    #task_name: str = field(metadata={"help": "The name of the task to train on: " + ", ".join(glue_processors.keys())})
    train_file: str = field(
        default = "practice_text/train.tsv", 
        metadata={"help": "The input train file"}
    )
    train_pred: str = field(
        default = "practice_text/train_pred.tsv",
        metadata={"help": "A file containing the generated pseudo labels for the train file "}
    )
    eval_file:str = field(
        default = "practice_text/dev.tsv",
        metadata={"help": "A file to evaluate on."}
    )
    eval_pred:str = field(
        default = "practice_text/dev_labels.txt",
        metadata={"help": "A file to evaluate on."}
    )
    max_seq_length: int = field(
        default=128,
        metadata={
            "help": "The maximum total input sequence length after tokenization. Sequences longer "
            "than this will be truncated, sequences shorter will be padded."
        },
    )
    overwrite_cache: bool = field(
        default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
    )


In [160]:
@dataclass
class ModelArguments:
    """
    Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
    """

    src_model_name_or_pth: str = field(
       default= "tmills/roberta_sfda_sharpseed", metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
    )
    model_save_path : str = field(
        default=output_dir, metadata={"help": "Save path for model"}
    )
    config_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
    )
    tokenizer_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
    )
    cache_dir: Optional[str] = field(
        default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
    )
    

In [140]:
model_args, data_args, training_args,  sfda_args = ModelArguments(), DataTrainingArguments() , TrainingArguments(output_dir = output_dir,learning_rate = 5e-5) , sfdaTrainingArguments()

In [141]:
from transformers.trainer_utils import nested_concat,nested_numpify,Any
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataset import Dataset
import torch.nn as nn
import torch
from transformers.file_utils import is_torch_tpu_available
from typing import NamedTuple, Union,Tuple,Optional,Dict
import numpy as np
import logging
from transformers import Trainer ,EvalPrediction
from sfda.APM import APM_update
from tqdm.auto import tqdm, trange

logger = logging.getLogger(__name__)
class sfdaPredictionOutput(NamedTuple):
    predictions: Union[np.ndarray, Tuple[np.ndarray]]
    label_ids: Optional[np.ndarray]
    metrics: Optional[Dict[str, float]]
    feat_matrix: Optional[np.ndarray]

class sfdaTrainer(Trainer):
        def __init__(
        self,
        sfda_args = None,
        **kwargs,
    ):
            super(sfdaTrainer,self).__init__(**kwargs)
            
            if sfda_args is not None:
                self.prototype_p,self.prototype_f =  None, None
                self.update_freq  = sfda_args.update_freq
                self.last_update_epoch = 0
                self.alpha = np.float(0)
                self.APM_Strategy = sfda_args.APM_Strategy
                self.top_k = sfda_args.top_k
                self.cf_ratio = sfda_args.cf_ratio
                if sfda_args.alpha_routine.lower() == "exp":
                    self._update_alpha = self._update_alpha_exp
                elif sfda_args.alpha_routine.lower() == "sqr":
                    self._update_alpha = self._update_alpha_sqr
                elif sfda_args.alpha_routine.lower() == "lin":
                    self._update_alpha = self._update_alpha_lin
                elif sfda_args.alpha_routine.lower() == "cube":
                    self._update_alpha = self._update_alpha_cube
                elif sfda_args.alpha_routine.lower() == "sin":
                    self._update_alpha = self._update_alpha_sin
                else:
                    raise F"Invalid alpha routine {sfda_args.alpha_routine}"   
            else:
                logger.warning("sfda_args not initialised : Only classifier_t will be used for training and inference!!!")
                

        def _update_prototypes(self):
            self.prototype_p,self.prototype_f,_ = APM_update(self.prediction_loop(self.get_train_dataloader(),description = F"APM Update @Global step {self.global_step}",ret_feats  =True), flag = self.APM_Strategy,k = self.top_k,cf_ratio = self.cf_ratio )
        
        def _update_alpha_exp(self):
            self.alpha = np.float(2.0 / (1.0 + np.exp(-10 * self.global_step / float( (self.args.num_train_epochs*len(self.train_dataset)//self.args.train_batch_size + 1)//2))) - 1.0)
        def _update_alpha_sin(self):
            self.alpha = np.sin(0.5*np.pi*float(self.global_step / float( (self.args.num_train_epochs*len(self.train_dataset)//self.args.train_batch_size + 1)//2)))
        def _update_alpha_sqr(self):
            self.alpha = np.float((self.global_step / float(self.args.num_train_epochs*len(self.train_dataset)//self.args.train_batch_size))**2)
        def _update_alpha_lin(self):
            self.alpha = np.float((self.global_step / float(self.args.num_train_epochs*len(self.train_dataset)//self.args.train_batch_size)))
        def _update_alpha_cube(self):
            self.alpha = np.float((self.global_step / float(self.args.num_train_epochs*len(self.train_dataset)//self.args.train_batch_size))**3)
               
        def predict(self, test_dataset: Dataset, ret_feats: Optional[bool] = None) -> sfdaPredictionOutput:
            """
            Run prediction and returns predictions and potential metrics.

            Depending on the dataset and your use case, your test dataset may contain labels.
            In that case, this method will also return metrics, like in :obj:`evaluate()`.

            Args:
                test_dataset (:obj:`Dataset`):
                    Dataset to run the predictions on. If it is an :obj:`datasets.Dataset`, columns not accepted by the
                    ``model.forward()`` method are automatically removed.

            Returns:
                `NamedTuple`:
                predictions (:obj:`np.ndarray`):
                    The predictions on :obj:`test_dataset`.
                label_ids (:obj:`np.ndarray`, `optional`):
                    The labels (if the dataset contained some).
                metrics (:obj:`Dict[str, float]`, `optional`):
                    The potential dictionary of metrics (if the dataset contained labels).
            """
            test_dataloader = self.get_test_dataloader(test_dataset)

            return self.prediction_loop(test_dataloader, description="Prediction",ret_feats = ret_feats)


        def prediction_loop(
        self, dataloader: DataLoader, description: str, prediction_loss_only: Optional[bool] = None , ret_feats: Optional[bool] = None,
        ) -> sfdaPredictionOutput:
            """
            Prediction/evaluation loop, shared by :obj:`Trainer.evaluate()` and :obj:`Trainer.predict()`.

            Works both with or without labels.
            """
            if hasattr(self, "_prediction_loop"):
                warnings.warn(
                    "The `_prediction_loop` method is deprecated and won't be called in a future version, define `prediction_loop` in your subclass.",
                    FutureWarning,
                )
                return self._prediction_loop(dataloader, description, prediction_loss_only=prediction_loss_only)

            prediction_loss_only = (
                prediction_loss_only if prediction_loss_only is not None else self.args.prediction_loss_only
            )

            assert not getattr(
                self.model.config, "output_attentions", False
            ), "The prediction loop does not work with `output_attentions=True`."
            assert not getattr(
                self.model.config, "output_hidden_states", False
            ), "The prediction loop does not work with `output_hidden_states=True`."

            model = self.model
            # multi-gpu eval
            if self.args.n_gpu > 1:
                model = torch.nn.DataParallel(model)
            else:
                model = self.model
            # Note: in torch.distributed mode, there's no point in wrapping the model
            # inside a DistributedDataParallel as we'll be under `no_grad` anyways.

            batch_size = dataloader.batch_size
            logger.info("***** Running %s *****", description)
            logger.info("  Num examples = %d", self.num_examples(dataloader))
            logger.info("  Batch size = %d", batch_size)
            eval_losses: List[float] = []
            preds: torch.Tensor = None
            label_ids: torch.Tensor = None
            feat_mat: torch.Tensor = None
            model.eval()

            if is_torch_tpu_available():
                dataloader = pl.ParallelLoader(dataloader, [self.args.device]).per_device_loader(self.args.device)

            if self.args.past_index >= 0:
                self._past = None

            disable_tqdm = not self.is_local_process_zero() or self.args.disable_tqdm
            for inputs in tqdm(dataloader, desc=description, disable=disable_tqdm):
                loss, logits, labels,feats = self.prediction_step(model, inputs, prediction_loss_only,ret_feats = ret_feats)
                batch_size = inputs[list(inputs.keys())[0]].shape[0]
                if loss is not None:
                    eval_losses.extend([loss] * batch_size)
                if logits is not None:
                    preds = logits if preds is None else nested_concat(preds, logits, dim=0)
                if labels is not None:
                    label_ids = labels if label_ids is None else nested_concat(label_ids, labels, dim=0)
                if feats is not None:
                    feat_mat = feats if feat_mat is None else nested_concat(feat_mat,feats)

            if self.args.past_index and hasattr(self, "_past"):
                # Clean the state at the end of the evaluation loop
                delattr(self, "_past")

            if self.args.local_rank != -1:
                # In distributed mode, concatenate all results from all nodes:
                if preds is not None:
                    preds = distributed_concat(preds, num_total_examples=self.num_examples(dataloader))
                if label_ids is not None:
                    label_ids = distributed_concat(label_ids, num_total_examples=self.num_examples(dataloader))
                if feat_mat is not None:
                    feat_mat = distributed_concat(feat_mat, num_total_examples=self.num_examples(dataloader))
            
            elif is_torch_tpu_available():
                # tpu-comment: Get all predictions and labels from all worker shards of eval dataset
                if preds is not None:
                    preds = nested_xla_mesh_reduce(preds, "eval_preds")
                if label_ids is not None:
                    label_ids = nested_xla_mesh_reduce(label_ids, "eval_label_ids")
                if feat_mat is not None:
                    feat_mat = nested_xla_mesh_reduce(feat_mat, "eval_feat_mat")
                if eval_losses is not None:
                    eval_losses = xm.mesh_reduce("eval_losses", torch.tensor(eval_losses), torch.cat).tolist()

            # Finally, turn the aggregated tensors into numpy arrays.
            if preds is not None:
                preds = nested_numpify(preds)
            if label_ids is not None:
                label_ids = nested_numpify(label_ids)
            if feat_mat is not None:
                feat_mat = nested_numpify(feat_mat)

            if self.compute_metrics is not None and preds is not None and label_ids is not None:
                metrics = self.compute_metrics(EvalPrediction(predictions=preds, label_ids=label_ids))
            else:
                metrics = {}
            if len(eval_losses) > 0:
                if self.args.local_rank != -1:
                    metrics["eval_loss"] = (
                        distributed_broadcast_scalars(eval_losses, num_total_examples=self.num_examples(dataloader))
                        .mean()
                        .item()
                    )
                else:
                    metrics["eval_loss"] = np.mean(eval_losses)

            # Prefix all keys with eval_
            for key in list(metrics.keys()):
                if not key.startswith("eval_"):
                    metrics[f"eval_{key}"] = metrics.pop(key)

            return sfdaPredictionOutput(predictions=preds, label_ids=label_ids, metrics=metrics,feat_matrix = feat_mat)


        def prediction_step(
        self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], prediction_loss_only: bool, ret_feats: bool,
    ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor],Optional[torch.Tensor]]:
            
            has_labels = all(inputs.get(k) is not None for k in self.args.label_names)
            inputs = self._prepare_inputs(inputs)

            with torch.no_grad():
                outputs = model(**inputs, train_mode = "sfda")
    #                     print(outputs)
                loss = outputs.loss
                logits = outputs.logits
    #                     print(outputs.last_hidden_state.shape)
                feats = outputs.last_hidden_state[:,0,:].detach()
                labels = None
                if has_labels:
                    # The .mean() is to reduce in case of distributed training
                    loss = loss.mean().item()
                    labels = tuple(inputs.get(name).detach() for name in self.args.label_names)
                    if len(labels) == 1:
                        labels = labels[0]
                return (loss, logits, labels, feats)
        def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:
            
            self._update_alpha()
            if (self.global_step + self.update_freq)%self.update_freq == 0:
                self._update_prototypes()
                self.last_update_epoch = self.epoch
            model.train()
            inputs = self._prepare_inputs(inputs)

            if self.args.fp16 and _use_native_amp:
                with autocast():
                    loss = self.compute_loss(model, inputs)
            else:
                loss = self.compute_loss(model, inputs)

            if self.args.n_gpu > 1:
                loss = loss.mean()  # mean() to average on multi-gpu parallel training

            if self.args.gradient_accumulation_steps > 1:
                loss = loss / self.args.gradient_accumulation_steps

            if self.args.fp16 and _use_native_amp:
                self.scaler.scale(loss).backward()
            elif self.args.fp16 and _use_apex:
                with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            return loss.detach()
        def compute_loss(self, model, inputs):
            """
            How the loss is computed by Trainer. By default, all models return the loss in the first element.

            Subclass and override for custom behavior.
            """
            prototype_p = torch.Tensor(self.prototype_p).to(self.args.device)
            prototype_f = torch.Tensor(self.prototype_f).to(self.args.device)
            outputs = model(**inputs,prototype_p = prototype_p  ,prototype_f = prototype_f,cf_ratio = self.cf_ratio , train_mode = "sfda")
            # Save past state if it exists
            if self.args.past_index >= 0:
                self._past = outputs[self.args.past_index]
            # We don't use .loss here since the model may return tuples instead of ModelOutput.
            [s2t_loss, t_loss] = outputs.loss
            return (1-self.alpha)*s2t_loss +self.alpha*t_loss



In [142]:
train_dataset = sfdaNegationDataset.from_tsv(data_args.train_file, data_args.train_pred,tokenizer)
eval_dataset = sfdaNegationDataset.from_tsv(data_args.eval_file, data_args.eval_pred,tokenizer)
trainer = sfdaTrainer(
    model=model,
    args=training_args,
    sfda_args = sfda_args,
    compute_metrics=build_compute_metrics_fn(),
    train_dataset = train_dataset,
    eval_dataset = eval_dataset,
)
trainer.train(model_path=model_args.src_model_name_or_pth if os.path.isdir(model_args.src_model_name_or_pth) else None
    )

HBox(children=(FloatProgress(value=0.0, description='Epoch', max=3.0, style=ProgressStyle(description_width='i…

HBox(children=(FloatProgress(value=0.0, description='Iteration', max=361.0, style=ProgressStyle(description_wi…

HBox(children=(FloatProgress(value=0.0, description='APM Update @Global step 0', max=361.0, style=ProgressStyl…


Conf_mask 35.0 / 2886
tensor([False, False, False,  ..., False, False, False])


HBox(children=(FloatProgress(value=0.0, description='APM Update @Global step 100', max=361.0, style=ProgressSt…


Conf_mask 2860.0 / 2886
tensor([True, True, True,  ..., True, True, True])


HBox(children=(FloatProgress(value=0.0, description='APM Update @Global step 200', max=361.0, style=ProgressSt…


Conf_mask 2879.0 / 2886
tensor([True, True, True,  ..., True, True, True])


HBox(children=(FloatProgress(value=0.0, description='APM Update @Global step 300', max=361.0, style=ProgressSt…


Conf_mask 2869.0 / 2886
tensor([True, True, True,  ..., True, True, True])



HBox(children=(FloatProgress(value=0.0, description='Iteration', max=361.0, style=ProgressStyle(description_wi…

HBox(children=(FloatProgress(value=0.0, description='APM Update @Global step 400', max=361.0, style=ProgressSt…


Conf_mask 2883.0 / 2886
tensor([True, True, True,  ..., True, True, True])
{'loss': 0.03313962936401367, 'learning_rate': 2.6915974145891044e-05, 'epoch': 1.3850415512465375, 'total_flos': 505055363655168, 'step': 500}


HBox(children=(FloatProgress(value=0.0, description='APM Update @Global step 500', max=361.0, style=ProgressSt…


Conf_mask 2886.0 / 2886
tensor([True, True, True,  ..., True, True, True])


HBox(children=(FloatProgress(value=0.0, description='APM Update @Global step 600', max=361.0, style=ProgressSt…


Conf_mask 2885.0 / 2886
tensor([True, True, True,  ..., True, True, True])


HBox(children=(FloatProgress(value=0.0, description='APM Update @Global step 700', max=361.0, style=ProgressSt…


Conf_mask 2884.0 / 2886
tensor([True, True, True,  ..., True, True, True])



HBox(children=(FloatProgress(value=0.0, description='Iteration', max=361.0, style=ProgressStyle(description_wi…

HBox(children=(FloatProgress(value=0.0, description='APM Update @Global step 800', max=361.0, style=ProgressSt…


Conf_mask 2886.0 / 2886
tensor([True, True, True,  ..., True, True, True])


HBox(children=(FloatProgress(value=0.0, description='APM Update @Global step 900', max=361.0, style=ProgressSt…


Conf_mask 2886.0 / 2886
tensor([True, True, True,  ..., True, True, True])
{'loss': 0.00021408843994140625, 'learning_rate': 3.831948291782087e-06, 'epoch': 2.770083102493075, 'total_flos': 1010110727310336, 'step': 1000}


HBox(children=(FloatProgress(value=0.0, description='APM Update @Global step 1000', max=361.0, style=ProgressS…


Conf_mask 2886.0 / 2886
tensor([True, True, True,  ..., True, True, True])




TrainOutput(global_step=1083, training_loss=0.015402439008239014)

In [143]:
eval_result = trainer.evaluate(eval_dataset)
output_eval_file = os.path.join(
    training_args.output_dir, f"eval_results.txt"
)
if trainer.is_world_process_zero():
    with open(output_eval_file, "w") as writer:
        logger.info("***** Eval results *****")
        for key, value in eval_result.items():
            logger.info("  %s = %s", key, value)
            writer.write("%s = %s\n" % (key, value))
trainer.save_model()
predictions = trainer.predict(eval_dataset).predictions
predictions = np.argmax(predictions, axis=1)


HBox(children=(FloatProgress(value=0.0, description='Evaluation', max=694.0, style=ProgressStyle(description_w…


{'eval_loss': 5.051169447043224, 'eval_acc': 0.27989179440937784, 'eval_f1': 0.3554479418886199, 'eval_acc_and_f1': 0.3176698681489989, 'epoch': 3.0, 'total_flos': 1093739204233728, 'step': 1083}


HBox(children=(FloatProgress(value=0.0, description='Prediction', max=694.0, style=ProgressStyle(description_w…




In [144]:
save_path = os.path.join(training_args.output_dir,F"dev_pred_sfda_{sfda_args.top_k}.csv")
with open(save_path, "w") as writer:
    logger.info("***** Test results *****")
    for index, item in enumerate(predictions):
        item = train_dataset.get_labels()[item]
        writer.write("%s\n" % (item))