In [1]:
#imports
from datasets import Dataset
from transformers import AutoTokenizer, Trainer, TrainingArguments, EsmForMaskedLM, DataCollatorForLanguageModeling
import torch
import esm
from torch.optim import AdamW
import plotly.express as px
import random
import pandas as pd
from Bio import SeqIO
import numpy as np
import scipy
import math
import datetime

In [2]:
model_type = "facebook/esm2_t33_650M_UR50D"

tokenizer = AutoTokenizer.from_pretrained(model_type)

model = EsmForMaskedLM.from_pretrained(model_type)

In [3]:
def tokenize_function(row): #tokenizer and params including special_tokens_mask required for mlm
    return tokenizer(
        row['seq'],
        padding='max_length',
        truncation=True,
        max_length=566,
        return_special_tokens_mask=True)

In [4]:
def reset_weights(module):
    for layer in module.children():
        if hasattr(layer, 'reset_parameters'):
            layer.reset_parameters()
        else:
            reset_weights(layer)  # Recursively go deeper


In [5]:
device = torch.device("cuda:0")
if torch.cuda.is_available():
    model =  model.to(device)
    print("Transferred model to GPU")

Transferred model to GPU


In [7]:
#reset weights
reset_weights(model)

In [8]:
model

EsmForMaskedLM(
  (esm): EsmModel(
    (embeddings): EsmEmbeddings(
      (word_embeddings): Embedding(33, 1280, padding_idx=1)
      (dropout): Dropout(p=0.0, inplace=False)
      (position_embeddings): Embedding(1026, 1280, padding_idx=1)
    )
    (encoder): EsmEncoder(
      (layer): ModuleList(
        (0-32): 33 x EsmLayer(
          (attention): EsmAttention(
            (self): EsmSelfAttention(
              (query): Linear(in_features=1280, out_features=1280, bias=True)
              (key): Linear(in_features=1280, out_features=1280, bias=True)
              (value): Linear(in_features=1280, out_features=1280, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
              (rotary_embeddings): RotaryEmbedding()
            )
            (output): EsmSelfOutput(
              (dense): Linear(in_features=1280, out_features=1280, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (LayerNorm): LayerNorm((1280,), eps=1e-05, 

In [9]:
#save reset model alone

model.save_pretrained("/media/spyros/HD-ADU3/spyros/model_weights_topublish/reset_weights_ESM2/esm2_reset")


In [2]:
haall_train_seqs = SeqIO.to_dict(SeqIO.parse('ncbiflu_HA_all_110424_noX_clu99_filt_HA-all.fas', 'fasta'))
len(haall_train_seqs)

9890

In [13]:
seqdat = [str(haall_train_seqs[x].seq) for x in list(haall_train_seqs)]

valseqs = random.sample(seqdat, round(0.2*len(seqdat))) 
trainseqs = list(set(seqdat) - set(valseqs))
# trainseqs = random.sample(seqdat, round(0.008*len(seqdat)))

tempdf = pd.DataFrame({'seq':trainseqs})
tempdf

templ = tempdf.seq.tolist()
templ
train_dataset = Dataset.from_pandas(tempdf[['seq']])
print(train_dataset)


tempdf = pd.DataFrame({'seq':valseqs})
tempdf

templ = tempdf.seq.tolist()
templ
val_dataset = Dataset.from_pandas(tempdf[['seq']])
print(val_dataset)

train_dataset = train_dataset.map( #apply the tokenizer to the dataset
    tokenize_function,
    batched=True,
    num_proc=4,
    remove_columns=train_dataset.column_names,
)

val_dataset = val_dataset.map(
    tokenize_function,
    batched=True,
    num_proc=4,
    remove_columns=val_dataset.column_names,
)

Dataset({
    features: ['seq'],
    num_rows: 7912
})
Dataset({
    features: ['seq'],
    num_rows: 1978
})


Map (num_proc=4):   0%|          | 0/7912 [00:00<?, ? examples/s]

Map (num_proc=4):   0%|          | 0/1978 [00:00<?, ? examples/s]

In [14]:
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer,return_tensors='pt',mlm_probability=0.15) #provide random masking and return tensors during training per-batch

In [15]:
training_args = TrainingArguments(
    output_dir="models",
    overwrite_output_dir=True,
    num_train_epochs = 10,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=8,
    evaluation_strategy="epoch",
    save_total_limit=2,
    seed=13,
#     fp16=True,
    dataloader_num_workers=4,
    disable_tqdm=False
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    data_collator=data_collator,
    #tokenizer=tokenizer, 
)



In [16]:
trainer.train()



Epoch,Training Loss,Validation Loss
1,No log,0.212838
2,1.169100,0.145361
3,0.170100,0.125289
4,0.135200,0.112351
5,0.118800,0.108128
6,0.109800,0.103779
7,0.102700,0.09549
8,0.095900,0.093172
9,0.089500,0.088045
10,0.085000,0.085678




TrainOutput(global_step=4950, training_loss=0.21688221941090594, metrics={'train_runtime': 5128.9481, 'train_samples_per_second': 15.426, 'train_steps_per_second': 0.965, 'total_flos': 1.7491845197320128e+17, 'train_loss': 0.21688221941090594, 'epoch': 10.0})

In [17]:
trainer.save_model("models/esm2_t33-weight_reset-HA_all_110424_clu99_e10")

In [6]:
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer,return_tensors='pt',mlm_probability=0.15) #provide random masking and return tensors during training per-batch

In [7]:
#clu99 individual serotypes
for sero in ['H7', 'H5', 'H1', 'H3']:

    model = EsmForMaskedLM.from_pretrained(model_type)
    device = torch.device("cuda:0")
    if torch.cuda.is_available():
        model =  model.to(device)
        print("Transferred model to GPU")

    #reset weights
    reset_weights(model)

    # print("\nPreparing models/esm2_t33-%s_110424_clu99_e10_071024"%sero)

    #load protein sequences from fasta file into a list
    seqdat = [str(rec.seq) for rec in SeqIO.parse('../evotuning/ncbiflu_HA_all_110424_noX_clu99_filt_%s.fas'%sero, 'fasta')]

    #separate 80-20 training-validation dataset
    valseqs = random.sample(seqdat, round(0.2*len(seqdat))) 
    trainseqs = list(set(seqdat) - set(valseqs))
    
    tempdf = pd.DataFrame({'seq':trainseqs})
    templ = tempdf.seq.tolist()
    templ
    train_dataset = Dataset.from_pandas(tempdf[['seq']])
    print(train_dataset)
    
    tempdf = pd.DataFrame({'seq':valseqs})
    tempdf
    templ = tempdf.seq.tolist()
    templ
    val_dataset = Dataset.from_pandas(tempdf[['seq']])
    print(val_dataset)

    #apply the tokenizer to the dataset
    train_dataset = train_dataset.map( 
        tokenize_function,
        batched=True,
        num_proc=4,
        remove_columns=train_dataset.column_names,
    )
    
    val_dataset = val_dataset.map(
        tokenize_function,
        batched=True,
        num_proc=4,
        remove_columns=val_dataset.column_names,
    )
    
    training_args = TrainingArguments(
        output_dir="/media/spyros/HD-ADU3/spyros/model_weights_topublish/reset_weights_ESM2_%s/"%sero,
        overwrite_output_dir=True,
        num_train_epochs = 10,
        # logging_steps=100,
        logging_strategy='epoch',
        save_strategy='epoch',
        per_device_train_batch_size=4,
        per_device_eval_batch_size=8,
        evaluation_strategy="epoch",
        save_total_limit=5,
        seed=13,
    #     fp16=True,
        dataloader_num_workers=4,
        disable_tqdm=False
    )
    
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        data_collator=data_collator,
        #tokenizer=tokenizer, 
    )
    
    train_result = trainer.train()
    # trainer.save_model("/data2/spyros/flu_LLM_evol_data/ncbiflu_models_071024/esm2_t33-%s_110424_clu99_e10_071024"%sero)

    metrics = train_result.metrics
    
    # save train results
    trainer.log_metrics("/media/spyros/HD-ADU3/spyros/model_weights_topublish/reset_weights_ESM2_%s/ereset_weights_ESM2_%s_metrics"%(sero, sero), metrics)
    trainer.save_metrics("/media/spyros/HD-ADU3/spyros/model_weights_topublish/reset_weights_ESM2_%s/ereset_weights_ESM2_%s_metrics"%(sero, sero), metrics)
    with open("/media/spyros/HD-ADU3/spyros/model_weights_topublish/reset_weights_ESM2_%s/ereset_weights_ESM2_%s.log"%(sero, sero), 'w') as out:
        out.write(str(trainer.state.log_history))

Transferred model to GPU
Dataset({
    features: ['seq'],
    num_rows: 244
})
Dataset({
    features: ['seq'],
    num_rows: 61
})


Map (num_proc=4):   0%|          | 0/244 [00:00<?, ? examples/s]

Map (num_proc=4):   0%|          | 0/61 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss
1,3.1303,2.918915
2,2.9172,2.908961
3,2.9129,2.905982
4,2.9057,2.908129
5,2.9046,2.895317
6,2.8841,2.844552
7,3.2138,2.898546
8,2.8544,2.815137
9,2.785,2.751624
10,2.7262,2.701269




***** /media/spyros/HD-ADU3/spyros/model_weights_topublish/reset_weights_ESM2_H7/ereset_weights_ESM2_H7_metrics metrics *****
  epoch                    =       10.0
  total_flos               =  5023880GF
  train_loss               =     2.9234
  train_runtime            = 0:09:42.27
  train_samples_per_second =       4.19
  train_steps_per_second   =      0.275
Transferred model to GPU
Dataset({
    features: ['seq'],
    num_rows: 782
})
Dataset({
    features: ['seq'],
    num_rows: 196
})


Map (num_proc=4):   0%|          | 0/782 [00:00<?, ? examples/s]

Map (num_proc=4):   0%|          | 0/196 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss
1,2.9774,2.926667
2,2.9885,2.827787
3,2.599,2.161012
4,1.2257,0.524831
5,0.3368,0.242247
6,0.2114,0.197506
7,0.1692,0.161766
8,0.1469,0.14602
9,0.1328,0.128923
10,0.1208,0.134946




***** /media/spyros/HD-ADU3/spyros/model_weights_topublish/reset_weights_ESM2_H5/ereset_weights_ESM2_H5_metrics metrics *****
  epoch                    =       10.0
  total_flos               = 16101125GF
  train_loss               =     1.0909
  train_runtime            = 0:15:20.33
  train_samples_per_second =      8.497
  train_steps_per_second   =      0.532
Transferred model to GPU
Dataset({
    features: ['seq'],
    num_rows: 1829
})
Dataset({
    features: ['seq'],
    num_rows: 457
})


Map (num_proc=4):   0%|          | 0/1829 [00:00<?, ? examples/s]

Map (num_proc=4):   0%|          | 0/457 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss
1,2.89,2.5004
2,0.9536,0.279377
3,0.2168,0.185655
4,0.167,0.154044
5,0.1458,0.135566
6,0.1295,0.121268
7,0.1194,0.121911
8,0.114,0.1092
9,0.1072,0.105588
10,0.1003,0.103732




***** /media/spyros/HD-ADU3/spyros/model_weights_topublish/reset_weights_ESM2_H1/ereset_weights_ESM2_H1_metrics metrics *****
  epoch                    =       10.0
  total_flos               = 37658514GF
  train_loss               =     0.4943
  train_runtime            = 0:26:24.32
  train_samples_per_second =     11.544
  train_steps_per_second   =      0.726
Transferred model to GPU
Dataset({
    features: ['seq'],
    num_rows: 3393
})
Dataset({
    features: ['seq'],
    num_rows: 848
})


Map (num_proc=4):   0%|          | 0/3393 [00:00<?, ? examples/s]

Map (num_proc=4):   0%|          | 0/848 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss
1,1.9258,0.264197
2,0.1736,0.142358
3,0.1217,0.115795
4,0.1033,0.105411
5,0.094,0.096424
6,0.0874,0.096016
7,0.0834,0.086758
8,0.0792,0.087518
9,0.0766,0.086929
10,0.07,0.078468




***** /media/spyros/HD-ADU3/spyros/model_weights_topublish/reset_weights_ESM2_H3/ereset_weights_ESM2_H3_metrics metrics *****
  epoch                    =       10.0
  total_flos               = 69860764GF
  train_loss               =     0.2815
  train_runtime            = 0:42:50.51
  train_samples_per_second =       13.2
  train_steps_per_second   =      0.829
