Skip to content

Commit

Permalink
Merge branch 'main' of github.com:princeton-nlp/SimCSE
Browse files Browse the repository at this point in the history
  • Loading branch information
gaotianyu1350 committed Oct 10, 2022
2 parents 8f3ef2f + d868602 commit 511c99d
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 7 deletions.
29 changes: 29 additions & 0 deletions .github/workflows/stale.yml
@@ -0,0 +1,29 @@
# This workflow warns and then closes issues and PRs that have had no activity for a specified amount of time.
#
# You can adjust the behavior by modifying this file.
# For more information, see:
# https://github.com/actions/stale
name: Mark stale issues and pull requests

on:
schedule:
- cron: '18 9 * * *'

jobs:
stale:

runs-on: ubuntu-latest
permissions:
issues: write
pull-requests: write

steps:
- uses: actions/stale@v5
with:
repo-token: ${{ secrets.GITHUB_TOKEN }}
stale-issue-message: 'Stale issue message'
stale-pr-message: 'Stale pull request message'
stale-issue-label: 'no-issue-activity'
stale-pr-label: 'no-pr-activity'
days-before-stale: 30
days-before-close: 5
12 changes: 6 additions & 6 deletions requirements.txt
@@ -1,9 +1,9 @@
transformers==4.2.1
scipy==1.5.4
datasets==1.2.1
pandas==1.1.5
scikit-learn==0.24.0
prettytable==2.1.0
scipy
datasets
pandas
scikit-learn
prettytable
gradio
torch
setuptools==49.3.0
setuptools
27 changes: 26 additions & 1 deletion simcse/tool.py
Expand Up @@ -176,6 +176,31 @@ def build_index(self, sentences_or_file_path: Union[str, List[str]],
self.is_faiss_index = False
self.index["index"] = index
logger.info("Finished")

def add_to_index(self, sentences_or_file_path: Union[str, List[str]],
device: str = None,
batch_size: int = 64):

# if the input sentence is a string, we assume it's the path of file that stores various sentences
if isinstance(sentences_or_file_path, str):
sentences = []
with open(sentences_or_file_path, "r") as f:
logging.info("Loading sentences from %s ..." % (sentences_or_file_path))
for line in tqdm(f):
sentences.append(line.rstrip())
sentences_or_file_path = sentences

logger.info("Encoding embeddings for sentences...")
embeddings = self.encode(sentences_or_file_path, device=device, batch_size=batch_size, normalize_to_unit=True, return_numpy=True)

if self.is_faiss_index:
self.index["index"].add(embeddings.astype(np.float32))
else:
self.index["index"] = np.concatenate((self.index["index"], embeddings))
self.index["sentences"] += sentences_or_file_path
logger.info("Finished")



def search(self, queries: Union[str, List[str]],
device: str = None,
Expand All @@ -186,7 +211,7 @@ def search(self, queries: Union[str, List[str]],
if isinstance(queries, list):
combined_results = []
for query in queries:
results = self.search(query, device)
results = self.search(query, device, threshold, top_k)
combined_results.append(results)
return combined_results

Expand Down
Binary file added slides/emnlp2021_slides.pdf
Binary file not shown.
1 change: 1 addition & 0 deletions train.py
Expand Up @@ -500,6 +500,7 @@ def mask_tokens(
"""
Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
"""
inputs = inputs.clone()
labels = inputs.clone()
# We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)
probability_matrix = torch.full(labels.shape, self.mlm_probability)
Expand Down

0 comments on commit 511c99d

Please sign in to comment.