<a href="https://colab.research.google.com/github/yoichi1484/subspace/blob/main/notebook/subspace.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Subspace Representations for Soft Set Operations and Sentence Similarities
Yoichi Ishibashi, Sho Yokoi, Katsuhito Sudoh, Satoshi Nakamura: [Subspace Representations for Soft Set Operations and Sentence Similarities](https://arxiv.org/abs/2210.13034) (NAACL, 2024)


## Setup

In [None]:
!git clone https://github.com/yoichi1484/subspace.git

fatal: destination path 'subspace' already exists and is not an empty directory.


In [None]:
cd subspace

/content/subspace


In [None]:
!pip install -r requirements.txt



## Set similarity
Our subspace-based sentence (set of words) similarity can be easily computed as follows.

In [None]:
from subspace.tool import SubspaceBERTScore

scorer = SubspaceBERTScore(device='cpu', model_name_or_path='bert-base-uncased')

sentences_a = ["A man with a hard hat is dancing.", "A young child is riding a horse."]
sentences_b = ["A man wearing a hard hat is dancing.", "A child is riding a horse."]

scorer(sentences_a, sentences_b)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


(tensor([0.9848, 0.9338]), tensor([0.9838, 0.9249]), tensor([0.9843, 0.9293]))

### STS task
Evaluation experiments on the STS task can be conducted with ```SentEval```.
The first step is to download the evaluation data.

In [None]:
cd ./SentEval/data/downstream/

/content/subspace/SentEval/data/downstream


In [None]:
!bash download_dataset.sh

--2024-05-31 09:23:41--  https://huggingface.co/datasets/princeton-nlp/datasets-for-simcse/resolve/main/senteval.tar
Resolving huggingface.co (huggingface.co)... 13.33.30.49, 13.33.30.76, 13.33.30.23, ...
Connecting to huggingface.co (huggingface.co)|13.33.30.49|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://cdn-lfs.huggingface.co/datasets/princeton-nlp/datasets-for-simcse/bc43c148f7be97471c78fc4255399d3158cb99dfe8f2221999c918338b138c38?response-content-disposition=attachment%3B+filename*%3DUTF-8%27%27senteval.tar%3B+filename%3D%22senteval.tar%22%3B&response-content-type=application%2Fx-tar&Expires=1717406622&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTcxNzQwNjYyMn19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5odWdnaW5nZmFjZS5jby9kYXRhc2V0cy9wcmluY2V0b24tbmxwL2RhdGFzZXRzLWZvci1zaW1jc2UvYmM0M2MxNDhmN2JlOTc0NzFjNzhmYzQyNTUzOTlkMzE1OGNiOTlkZmU4ZjIyMjE5OTljOTE4MzM4YjEzOGMzOD9yZXNwb25zZS1jb250ZW50LWRpc3Bvc2l0aW9u

The evaluation scripts and the calculation of correlation coefficients are based on the code of [Gao & Yao](https://github.com/princeton-nlp/SimCSE).
Here is how to run the script:

In [None]:
cd ../../../

/content/subspace


In [None]:
!bash run_sts.sh

2024-05-31 09:23:53,986 : NumExpr defaulting to 2 threads.
Pooler and similarity:  hidden_states_subspace_bert_score_F
2024-05-31 09:23:54,287 : Starting new HTTPS connection (1): huggingface.co:443
2024-05-31 09:23:54,599 : https://huggingface.co:443 "HEAD /bert-base-uncased/resolve/main/config.json HTTP/1.1" 200 0
2024-05-31 09:23:55,078 : Starting new HTTPS connection (1): huggingface.co:443
2024-05-31 09:23:55,390 : https://huggingface.co:443 "HEAD /bert-base-uncased/resolve/main/pytorch_model.bin HTTP/1.1" 302 0
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model t


## Other set operations
Other subspace-based set operations such as union, intersection, orthogonal complement, and soft membership can be computed as follows using torch.


In [None]:
import torch
from subspace.operations import *

torch.manual_seed(0)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
A = torch.rand((50, 300), device=device) # 50 stacked 300-dimensional word vectors
B = torch.rand((80, 300), device=device) # 80 stacked 300-dimensional word vectors

Compute bases of the subspace

In [None]:
SA = subspace(A)
SA.shape # torch.Size([50, 300])

torch.Size([50, 300])

Compute bases of the orthogonal complement

In [None]:
A_NOT = orthogonal_complement(A)
A_NOT.shape # torch.Size([250, 300])

torch.Size([250, 300])

Compute bases of the intersection

In [None]:
A_AND_B = intersection(A, B)
A_AND_B.shape # torch.Size([1, 300])


torch.Size([1, 300])

Compute bases of the sum space



In [None]:
A_OR_B = sum_space(A, B)
A_OR_B.shape # torch.Size([130, 300])

torch.Size([130, 300])

Compute soft membership degree

In [None]:
v = torch.rand(300, device=device)
soft_membership(A, v) # tensor(0.8875)

tensor(0.8875, device='cuda:0')

# Exploring Word Embeddings using Subspaces

This experiment aims to explore the relationships between word embeddings using subspaces. It demonstrates how to create subspaces for color words and fruit words, and then analyzes the intersection of these subspaces. The code also provides functions to sample random vectors from a subspace, find similar words within a subspace, and calculate the soft membership of evaluation words to a given subspace.


In [None]:
import torch
import gensim.downloader as api
from gensim.utils import simple_preprocess

# Use GPU if available, otherwise use CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Color word sets (split into spanning and evaluation sets)
color_words_span = ['red', 'yellow', 'amber', 'ochre', 'gold', 'brown',
                    'sienna', 'bronze', 'copper', 'rust', 'crimson', 'scarlet',
                    'vermilion', 'tangerine', 'apricot', 'peach']

# Fruit word sets (split into spanning and evaluation sets)
fruit_words_span = ['tangerine', 'clementine', 'mandarin', 'citrus',
                    'grapefruit', 'kumquat', 'persimmon', 'mango', 'papaya',
                    'cantaloupe', 'apricot', 'nectarine', 'peach', 'plum',
                    'cherry', 'pomegranate']

# Load word2vec model
#w2v_model = api.load('word2vec-google-news-300')

# Function to represent a set of words as a matrix using word2vec
def words_to_matrix(words):
    matrix = []
    for word in words:
        if word in w2v_model.key_to_index:
            matrix.append(w2v_model[word])
    return torch.tensor(matrix).to(device)

# Represent color word sets as matrices
color_matrix_span = words_to_matrix(color_words_span)

# Represent fruit word sets as matrices
fruit_matrix_span = words_to_matrix(fruit_words_span)

# Calculate intersections
color_subspace = subspace(color_matrix_span)
fruit_subspace = subspace(fruit_matrix_span)
# Tip: Adjust the threshold to allow more overlap between subspaces
color_and_fruit_subspace = intersection(color_matrix_span, fruit_matrix_span, threshold=0.5)
assert color_and_fruit_subspace.shape[0] > 0, "Increase the threshold to allow more overlap between subspaces"
color_and_fruit_subspace.shape

torch.Size([3, 300])


Tips:
- If the subspaces do not overlap, it could be due to the small size of the word sets. Consider increasing the size of the word sets to improve the chances of overlap.
- Adjust the threshold for the intersection subspace to allow more overlap between subspaces when the intersection is likely to be empty.

## Evaluation
Soft membership is a way to measure how much a word belongs to a set of words. It gives a value between 0 and 1, where:

- 1 means the word is strongly related to the set
- 0 means the word is weakly related to the set

Soft membership helps us:

- See how closely a word matches a concept or category
- Find words that don't fit well with the set


In [None]:
# Function to sample random vectors from a subspace
def sample_vector_from_subspace(subspace, num_samples=1):
    dim = subspace.shape[1]
    coefficients = torch.randn(num_samples, subspace.shape[0], device=device)
    vectors = coefficients @ subspace
    return vectors

# Function to calculate and display soft membership of evaluation words to a subspace
def evaluate_soft_membership(subspace, eval_words, subspace_name):
    for word in eval_words:
        if word in w2v_model.key_to_index:
            word_vector = torch.tensor(w2v_model[word]).to(device)
            membership = soft_membership(subspace, word_vector)
            print(f"Soft membership of '{word}' to the {subspace_name} subspace: {membership:.4f}")
        else:
            print(f"The word '{word}' is not found in the word2vec model.")

## Example 1
Words that have characteristics of both color and fruit word sets tend to have high soft membership values. Words like "orange," which possess features of both sets, will have high values in both.

In [None]:
color_words_eval = ['orange', 'coral', 'salmon', 'persimmon']
fruit_words_eval = ['orange', 'lemon', 'lime', 'citron']

# Experiment with color subspace
print("Color Subspace:")
evaluate_soft_membership(color_subspace, color_words_eval, "color")

# Experiment with fruit subspace
print("\nFruit Subspace:")
evaluate_soft_membership(fruit_subspace, fruit_words_eval, "fruit")

# Experiment with intersection subspace
print("\nIntersection Subspace:")
evaluate_soft_membership(color_and_fruit_subspace, color_words_eval + fruit_words_eval, "intersection")

Color Subspace:
Soft membership of 'orange' to the color subspace: 0.7759
Soft membership of 'coral' to the color subspace: 0.4909
Soft membership of 'salmon' to the color subspace: 0.3981
Soft membership of 'persimmon' to the color subspace: 0.6431

Fruit Subspace:
Soft membership of 'orange' to the fruit subspace: 0.6002
Soft membership of 'lemon' to the fruit subspace: 0.7296
Soft membership of 'lime' to the fruit subspace: 0.6405
Soft membership of 'citron' to the fruit subspace: 0.7488

Intersection Subspace:
Soft membership of 'orange' to the intersection subspace: 0.5879
Soft membership of 'coral' to the intersection subspace: 0.3895
Soft membership of 'salmon' to the intersection subspace: 0.3088
Soft membership of 'persimmon' to the intersection subspace: 0.6148
Soft membership of 'orange' to the intersection subspace: 0.5879
Soft membership of 'lemon' to the intersection subspace: 0.6702
Soft membership of 'lime' to the intersection subspace: 0.5715
Soft membership of 'citron

## Example 2
These word sets are not particularly related to colors or fruits. Such words will have low soft membership values.

In [None]:
# Unrelated word set for evaluation
unrelated_words = ['computer', 'book', 'guitar', 'coffee', 'dog', 'rain', 'castle', 'forest', 'ocean', 'moon']

In [None]:
# Experiment with color subspace
print("Color Subspace:")
evaluate_soft_membership(color_subspace, unrelated_words, "color")

# Experiment with fruit subspace
print("\nFruit Subspace:")
evaluate_soft_membership(fruit_subspace, unrelated_words, "fruit")

# Experiment with intersection subspace
print("\nIntersection Subspace:")
evaluate_soft_membership(color_and_fruit_subspace, unrelated_words, "intersection")

Color Subspace:
Soft membership of 'computer' to the color subspace: 0.2263
Soft membership of 'book' to the color subspace: 0.2674
Soft membership of 'guitar' to the color subspace: 0.2656
Soft membership of 'coffee' to the color subspace: 0.4121
Soft membership of 'dog' to the color subspace: 0.3692
Soft membership of 'rain' to the color subspace: 0.3345
Soft membership of 'castle' to the color subspace: 0.2753
Soft membership of 'forest' to the color subspace: 0.3059
Soft membership of 'ocean' to the color subspace: 0.2824
Soft membership of 'moon' to the color subspace: 0.3550

Fruit Subspace:
Soft membership of 'computer' to the fruit subspace: 0.2205
Soft membership of 'book' to the fruit subspace: 0.3211
Soft membership of 'guitar' to the fruit subspace: 0.2340
Soft membership of 'coffee' to the fruit subspace: 0.4622
Soft membership of 'dog' to the fruit subspace: 0.2811
Soft membership of 'rain' to the fruit subspace: 0.3184
Soft membership of 'castle' to the fruit subspace: 0