In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
import torch
import esm
import numpy as np
import pandas as pd
from tqdm import tqdm
from sklearn.metrics import mean_squared_error, r2_score
from sklearn.model_selection import train_test_split
from transformers import AutoTokenizer
from evaluate import load

In [None]:
pcp = pd.read_csv("../data/pcp_res_all.csv")
pcp['pcp'] = pcp['pcp'].astype(float)

In [None]:
model_checkpoint = "facebook/esm2_t6_8M_UR50D"
sequences = list(pcp.seq)
labels = list(pcp.pcp)
train_sequences, test_sequences, train_labels, test_labels = train_test_split(sequences, labels, 
                                                                              test_size=0.3, shuffle=True,
                                                                              random_state=20230421)

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint,cache_dir="./data")
train_tokenized = tokenizer(train_sequences)
test_tokenized = tokenizer(test_sequences)
from datasets import Dataset
train_dataset = Dataset.from_dict(train_tokenized)
test_dataset = Dataset.from_dict(test_tokenized)

train_dataset = train_dataset.add_column("labels", train_labels)
test_dataset = test_dataset.add_column("labels", test_labels)

In [None]:
from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer
model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint, num_labels=1, problem_type = "regression")

batch_size = 64
args = TrainingArguments(
    output_dir = "pcp_tunning5",
    evaluation_strategy = "epoch",
    save_strategy = "epoch",
    save_total_limit = 2,
    learning_rate = 2e-4,
    per_device_train_batch_size = batch_size,
    per_device_eval_batch_size = batch_size,
    greater_is_better = False,
    num_train_epochs = 15,
    weight_decay = 0.001,
    load_best_model_at_end = True,
    metric_for_best_model = "loss",
    push_to_hub = False,
    optim = "adamw_torch",
    logging_steps=300,
    dataloader_num_workers = 30
)

In [None]:
def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = (2.0**(predictions))
    labels = (2.0**(labels))
    r2 = r2_score(labels, predictions)
    return {"r2": r2}

In [None]:
trainer = Trainer(
    model,
    args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics 
)
trainer.train()