In [20]:
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, EsmForSequenceClassification, Trainer, TrainingArguments

# Exemple de séquences de protéines et leurs labels correspondants pour l'entraînement
train_sequences = [
    "MALWMRLLPLLALLALWGPGPGLSGLALLLAVAP",  # Mitochondrion
    "MGLSDGEWQLVLNVWGKVEADIPGHGQEVLIRLFK",  # Cytoplasm
    # Ajouter d'autres séquences et leurs labels
]
train_labels = [
    [1, 0],  # LABEL_0, par exemple Mitochondrion
    [0, 1],  # LABEL_1, par exemple Cytoplasm
    # Ajouter d'autres labels correspondant aux séquences, en one-hot encoding
]

# Exemple de séquences de protéines et leurs labels correspondants pour l'évaluation
eval_sequences = [
    "MLAKKKPQKPLLPLTPEELPAELTDLT",  # Mitochondrion
    "MDDIAALVVDNGSGMCKAGFAGDDAPR",  # Cytoplasm
    # Ajouter d'autres séquences et leurs labels
]
eval_labels = [
    [1, 0],  # LABEL_0, par exemple Mitochondrion
    [0, 1],  # LABEL_1, par exemple Cytoplasm
    # Ajouter d'autres labels correspondant aux séquences, en one-hot encoding
]

class ProteinDataset(Dataset):
    def __init__(self, sequences, labels, tokenizer):
        self.sequences = sequences
        self.labels = labels
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        sequence = self.sequences[idx]
        label = self.labels[idx]
        inputs = self.tokenizer(sequence, return_tensors="pt", padding=True, truncation=True)
        item = {key: val.squeeze(0) for key, val in inputs.items()}
        item["labels"] = torch.tensor(label, dtype=torch.float)  # Convertir les labels en type Float
        return item

tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")
train_dataset = ProteinDataset(train_sequences, train_labels, tokenizer)
eval_dataset = ProteinDataset(eval_sequences, eval_labels, tokenizer)

def collate_fn(batch):
    input_ids = torch.nn.utils.rnn.pad_sequence([item['input_ids'] for item in batch], batch_first=True)
    attention_mask = torch.nn.utils.rnn.pad_sequence([item['attention_mask'] for item in batch], batch_first=True)
    labels = torch.stack([item['labels'] for item in batch])
    return {'input_ids': input_ids, 'attention_mask': attention_mask, 'labels': labels}


In [21]:
training_args = TrainingArguments(
    output_dir="./results",
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    num_train_epochs=3,
    weight_decay=0.01,
    logging_dir="./logs",
)


In [22]:
model = EsmForSequenceClassification.from_pretrained("facebook/esm2_t6_8M_UR50D", num_labels=2, problem_type="multi_label_classification")

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,  # Fournir le dataset d'évaluation
    data_collator=collate_fn
)

trainer.train()


Some weights of EsmForSequenceClassification were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.

                                     

[A[A                               
  0%|          | 0/3 [02:01<?, ?it/s]        
[A
[A

{'eval_loss': 0.6912202835083008, 'eval_runtime': 0.041, 'eval_samples_per_second': 48.812, 'eval_steps_per_second': 24.406, 'epoch': 1.0}



                                     

[A[A                               
  0%|          | 0/3 [02:02<?, ?it/s]        
[A
[A

{'eval_loss': 0.6912007331848145, 'eval_runtime': 0.029, 'eval_samples_per_second': 68.873, 'eval_steps_per_second': 34.436, 'epoch': 2.0}



                                     

[A[A                               
  0%|          | 0/3 [02:02<?, ?it/s]        
[A
                                     
100%|██████████| 3/3 [00:00<00:00,  4.56it/s]

{'eval_loss': 0.6911584734916687, 'eval_runtime': 0.026, 'eval_samples_per_second': 76.939, 'eval_steps_per_second': 38.469, 'epoch': 3.0}
{'train_runtime': 0.6543, 'train_samples_per_second': 9.171, 'train_steps_per_second': 4.585, 'train_loss': 0.6931606928507487, 'epoch': 3.0}





TrainOutput(global_step=3, training_loss=0.6931606928507487, metrics={'train_runtime': 0.6543, 'train_samples_per_second': 9.171, 'train_steps_per_second': 4.585, 'total_flos': 9992508156.0, 'train_loss': 0.6931606928507487, 'epoch': 3.0})

In [23]:
# Exemple de séquence de protéine pour les tests
test_sequence = "MALWMRLLPLLALLALWGPGPGLSGLALLLAVAP"

# Tokeniser la séquence de protéine
inputs = tokenizer(test_sequence, return_tensors="pt")

# Effectuer une prédiction sans gradients
with torch.no_grad():
    logits = model(**inputs).logits

# Calculer les classes prédites
predicted_class_ids = torch.arange(0, logits.shape[-1])[torch.sigmoid(logits).squeeze(dim=0) > 0.5]

# Afficher les résultats
print("Logits:", logits)
print("Predicted Class IDs:", predicted_class_ids)

# Si vous voulez les labels correspondants
labels = [model.config.id2label[class_id.item()] for class_id in predicted_class_ids]
print("Predicted Labels:", labels)


Logits: tensor([[0.0322, 0.0120]])
Predicted Class IDs: tensor([0, 1])
Predicted Labels: ['LABEL_0', 'LABEL_1']


In [25]:
from transformers import AutoConfig

# Charger la configuration du modèle
config = AutoConfig.from_pretrained("facebook/esm2_t6_8M_UR50D")

# Vérifier les mappings des labels
id2label = config.id2label
label2id = config.label2id

print("ID to Label mapping:", id2label)
print("Label to ID mapping:", label2id)

# Afficher les détails de la configuration
print(config)


ID to Label mapping: {0: 'LABEL_0', 1: 'LABEL_1'}
Label to ID mapping: {'LABEL_0': 0, 'LABEL_1': 1}
EsmConfig {
  "_name_or_path": "facebook/esm2_t6_8M_UR50D",
  "architectures": [
    "EsmForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.0,
  "classifier_dropout": null,
  "emb_layer_norm_before": false,
  "esmfold_config": null,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.0,
  "hidden_size": 320,
  "initializer_range": 0.02,
  "intermediate_size": 1280,
  "is_folding_model": false,
  "layer_norm_eps": 1e-05,
  "mask_token_id": 32,
  "max_position_embeddings": 1026,
  "model_type": "esm",
  "num_attention_heads": 20,
  "num_hidden_layers": 6,
  "pad_token_id": 1,
  "position_embedding_type": "rotary",
  "token_dropout": true,
  "torch_dtype": "float32",
  "transformers_version": "4.43.2",
  "use_cache": true,
  "vocab_list": null,
  "vocab_size": 33
}

