# Transformer préentrainés et clustering

Clustering des embeddings CLS d'un modèle BERT léger avec PyTorch et Hugging Face

In [None]:
import torch
from transformers import AutoTokenizer, AutoModel
from sklearn.cluster import KMeans

# Pour gérer les données textuelles
import numpy as np


In [None]:

model_name = "distilbert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)
model.eval()  # Mettre le modèle en mode évaluation



In [None]:
### 3. Préparer les données textuelles

texts = [
    "I love machine learning.",
    "Deep learning is a subset of machine learning.",
    "I enjoy hiking and outdoor adventures.",
    "The mountains are beautiful this time of year.",
    "Artificial intelligence is fascinating.",
    "Nature provides peace and tranquility."
]



In [None]:

### 4. Extraire les embeddings CLS

def get_cls_embeddings(texts):
    embeddings = []
    with torch.no_grad():
        for text in texts:
            inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=128)
            outputs = model(**inputs)
            cls_embedding = outputs.last_hidden_state[:, 0, :].squeeze(0)  # Le token CLS est à l'indice 0
            # stocker les représentations
            # <CORRECTION>
            embeddings.append(cls_embedding.numpy())
            # </CORRECTION>
    return np.array(embeddings)

cls_embeddings = get_cls_embeddings(texts)



In [None]:

### 5. Appliquer le clustering avec KMeans

num_clusters = 2  # Choisissez le nombre de clusters souhaité
# <CORRECTION>
kmeans = KMeans(n_clusters=num_clusters, random_state=42)
# </CORRECTION>

clusters = kmeans.fit_predict(cls_embeddings)


In [None]:

### 6. Afficher les résultats

for text, cluster in zip(texts, clusters):
    print(f"Cluster {cluster}: {text}")



# Construction du sujet à partir de la correction

In [1]:
### <CORRECTION> ###
import re
# transformation de cet énoncé en version étudiante

fname = "5_3_transf_pretrained-kmeans-corr.ipynb" # ce fichier
fout  = fname.replace("-corr","")

# print("Fichier de sortie: ", fout )

f = open(fname, "r")
txt = f.read()
 
f.close()

f2 = open(fout, "w")
f2.write(re.sub("<CORRECTION>.*?(</CORRECTION>)"," TODO ",\
    txt, flags=re.DOTALL))
f2.close()

### </CORRECTION> ###