In [1]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '7'

In [42]:
import random

import torch as th
from torch.utils.data import Dataset, Sampler, DataLoader
from transformers import BertTokenizer, BertModel

from sentence_transformers import SentenceTransformer

import numpy as np
import scipy.linalg as linalg

from tqdm.notebook import tqdm
import matplotlib as mp
import matplotlib.pyplot as plt
import matplotlib.colors as colors

# import umap
from sklearn.manifold import TSNE, MDS, Isomap
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from sklearn.cluster import KMeans, MiniBatchKMeans, AgglomerativeClustering, DBSCAN
from sklearn.pipeline import Pipeline
import fastcluster
import pandas as pd
import seaborn as sns

import faiss

from examples.speech_to_text.data_utils import load_df_from_tsv

In [3]:
root = '/mnt/raid0/siqi/datasets/covost2'
langs = ["fr", "de", "es", "fa", "it", "ru", "pt", "zh-CN", "tr", "ar", "et", "mn", "nl", "sv-SE", "lv", "sl", "ta", "ja", "id"]
os.makedirs('resources', exist_ok=True)
device='cuda'

In [37]:
all_features = []
n_labels = []
for i in tqdm(range(len(langs))):
    features = np.load('resources/train_{}.npy'.format(langs[i]))
    n_labels.append(features.shape[0])
    all_features.append(features)
all_features = np.concatenate(all_features, axis=0)

  0%|          | 0/19 [00:00<?, ?it/s]

In [38]:
tgt_texts_per_lang = []
for lang in langs:
    df = load_df_from_tsv(os.path.join(root, lang, 'train_st_{}_en.tsv'.format(lang)))
    tgt_texts = df['tgt_text'].tolist()
    tgt_texts_per_lang.append(tgt_texts)

In [62]:
low_lang = 'sv-SE'
low_idx = langs.index(low_lang)
n_low = n_labels[low_idx]

high_lang = 'de'
high_idx = langs.index(high_lang)
n_high = n_labels[high_idx]

high_start = sum(n_labels[:high_idx])
high_end = sum(n_labels[:high_idx + 1])

low_start = sum(n_labels[:low_idx])
low_end = sum(n_labels[:low_idx + 1])

tgt_texts_high = np.array(tgt_texts_per_lang[high_idx])
tgt_texts_low = np.array(tgt_texts_per_lang[low_idx])

high_features = all_features[high_start : high_end]
low_features = all_features[low_start : low_end]

In [51]:
index = faiss.IndexFlatIP(all_features.shape[1])
index.add(high_features)
high_neighbors = index.search(low_features, 5)[1]

In [63]:
high_neighbors = np.unique(high_neighbors.flatten())
extracted_high_features = high_features[high_neighbors]

In [123]:
# n_cluster = 200
max_dist = 1
clusterer = AgglomerativeClustering(n_clusters=None, linkage='complete', distance_threshold=max_dist)
labels = clusterer.fit_predict(low_features)

In [109]:
extracted_high_mask = np.ones(features.shape[0], dtype=bool)
extracted_high_mask[extracted_high_features.shape[0] : ] = False
(labels == -1).sum(), extracted_high_mask.sum(), (labels[~extracted_high_mask] == -1).sum(), low_features.shape[0]

(0, 7367, 0, 2160)

In [114]:
(labels == 0).sum()

(6, 10)

In [111]:
n_cluster = labels.max() + 1
n_cluster

2277

In [112]:
tgt_texts = np.concatenate([tgt_texts_high[high_neighbors], tgt_texts_low])
tgt_texts_per_cluster = []
for idx in range(n_cluster):
    tgt_texts_per_cluster.append(tgt_texts[labels == idx])

In [116]:
cluster_sizes = np.array([len(tgt_texts_per_cluster[idx]) for idx in range(n_cluster)])
# sns.displot(cluster_sizes)
(cluster_sizes == 1).sum()

684