In [36]:
# load sample list
import os
import sys
sys.path.append("..")
os.environ["CUDA_VISIBLE_DEVICES"]="1"
from openTSNE import TSNE
import torch

from train_router_mdeberta import RouterDataset, RouterModule
from torch.utils.data import Dataset, DataLoader, ConcatDataset
from transformers import T5EncoderModel, T5Tokenizer, AutoTokenizer, DebertaV2Model

dataset_paths = ["../datasets/split2_model7/mmlu_train.json","../datasets/split2_model7/gsm8k-train.json", "../datasets/split2_model7/cmmlu_train.json", "../datasets/split2_model7/arc_challenge_train.json", "../datasets/split2_model7/humaneval_train.json",]

data_types = [ "multi_attempt", "probability", "probability", "probability", "multi_attempt"]

tokenizer = AutoTokenizer.from_pretrained("microsoft/mdeberta-v3-base", truncation_side='left', padding=True)
encoder_model = DebertaV2Model.from_pretrained("microsoft/mdeberta-v3-base").to("cuda")


number_per_dataset = 2000

router_datasets = [RouterDataset(data_path, data_type=data_types[i], dataset_id=i, size=number_per_dataset) for i, data_path in enumerate(dataset_paths)]
for router_dataset in router_datasets:
    router_dataset.register_tokenizer(tokenizer)
router_dataset = ConcatDataset(router_datasets)
router_dataloader = DataLoader(router_dataset, batch_size=64)


router_model = RouterModule(encoder_model, hidden_state_dim=768, node_size=len(router_datasets[0].router_node), similarity_function="cos").to("cpu")
router_model.to('cuda')


# get predicted label 
all_hidden_states = []
dataset_set_ids = []
cluster_ids = []
predicts = []
with torch.no_grad():
    for i, batch in enumerate(router_dataloader):
        input, _, dataset_id, cluster_id = batch
        input.to("cuda")
        predict, hidden_states = router_model(**input)
        dataset_set_ids.append(dataset_id)
        cluster_ids.append(cluster_id)
        predicts.append(predict)
        all_hidden_states.append(hidden_states)


all_hidden_states = torch.concat(all_hidden_states)
predicts = torch.concat(predicts)
cluster_ids = torch.concat(cluster_ids).numpy() 
_, max_index = torch.max(predicts, dim=1)
dataset_set_ids = torch.concat(dataset_set_ids).numpy()



In [43]:
from MulticoreTSNE import MulticoreTSNE as M_TSNE
from openTSNE import TSNE
np_hidden_states = all_hidden_states.cpu().numpy()
tsne_result = M_TSNE(n_components=5, n_jobs=12).fit_transform(np_hidden_states)

In [44]:
from sklearn.cluster import KMeans, DBSCAN
import numpy as np
import random as random
import json

from matplotlib import pyplot as plt
from matplotlib.patches import Patch
from matplotlib.colors import Normalize

n_clusters_list = [8]

seed = 42
random.seed(seed)
np.random.seed(seed)

for n_clusters in n_clusters_list:
    x = tsne_result
    kmeans = KMeans(n_clusters=n_clusters, max_iter=1000)

    kmeans.fit(x)

    kmeans_labels = kmeans.labels_.tolist()
    # kmeans_labels = np.array(kmeans_labels)

    labels_split = [kmeans_labels[i*number_per_dataset: (i+1)*number_per_dataset] for i in range(len(dataset_paths))]
    base_output_path = f"./datasets/split2_model7_cluster"
    os.makedirs(base_output_path, exist_ok=True)


    for i, data_path in enumerate(dataset_paths) :
        cluster_ids = labels_split[i]
        
        with open(data_path, 'r') as f:
            if data_path.endswith('.json'):
                sample_list = json.load(f)
        new_sample_list = []
        for j, sample in enumerate(sample_list):
            if j >= 2000:
                break
            new_sample = sample 
            new_sample['cluster_id'] = cluster_ids[j]
            new_sample_list.append(new_sample)
        with open(os.path.join(base_output_path, data_path.split('/')[-1]), "w" ) as f:
            json.dump(new_sample_list ,f)