In [0]:
PRETRAINED_MODEL_DIR = 'My Drive/MLLU/pretrained_models/base'

In [0]:
DATA_DIR = 'My Drive/MLLU/data'
MODEL_DIR = 'My Drive/MLLU/model'

In [0]:
!pip install -qU t5 

In [0]:
import functools
import t5
import torch
import transformers
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
model = t5.models.HfPyTorchModel("t5-base", MODEL_DIR, device)

In [0]:
import tensorflow_datasets as tfds

ds = tfds.load(
    "glue/mrpc",
    data_dir=DATA_DIR,
    # Download data locally for preprocessing to avoid using GCS space.
    download_and_prepare_kwargs={"download_dir": "./downloads"})
print("A few raw validation examples...")
for ex in tfds.as_numpy(ds["validation"].take(20)):
  print(ex)

In [0]:
def label_preprocessor(ds):
  
  def to_inputs_and_targets(ex):
    return {
        "inputs": ex["input"],
        "targets": ex["output"],
        "idx": ex["idx"]
    }
  ds_ = ds.map(to_inputs_and_targets,
               num_parallel_calls=tf.data.experimental.AUTOTUNE)
  return ds_


In [0]:
def mrpc_extract(ds):
  def extract_io(ex):
    return {
        "input": "mrpc sentence 1: " + ex["sentence1"] + "  sentence 2:" + ex["sentence2"],
        "output": "unequal" if ex["label"] == 0 else "equal", 
        "idx": ex["idx"]
    }
  return ds.map(extract_io, num_parallel_calls=tf.data.experimental.AUTOTUNE)

t5.data.TaskRegistry.remove("mrpc")
t5.data.TaskRegistry.add(
    "mrpc",
    # A TfdsTask takes in a TFDS name instead of a tf.data.Dataset function.
    t5.data.TfdsTask,
    tfds_name="glue/mrpc:1.0.0",
    tfds_data_dir=DATA_DIR,
    sentencepiece_model_path=t5.data.DEFAULT_SPM_PATH,
    text_preprocessor=[mrpc_extract, label_preprocessor],
    postprocess_fn=t5.data.postprocessors.lower_text,
    metric_fns=[t5.evaluation.metrics.accuracy]
)

# Load and print a few examples.
import tensorflow.compat.v1 as tf


In [0]:
mrpc_task = t5.data.TaskRegistry.get("mrpc")
ds = mrpc_task.get_dataset(split="validation", sequence_length={"inputs": 128, "targets": 4})
print("A few preprocessed validation examples...")
max_ = 0
for ex in tfds.as_numpy(ds):
    print(ex)
    max_=max(max_,len(ex["inputs"]))
print(max_)

In [0]:
# Evaluate the pre-trained checkpoint, before further fine-tuning
model.eval(
    "mrpc",
    sequence_length={"inputs": 128, "targets": 4},
    batch_size=128,
)

In [0]:
# Run 1000 steps of fine-tuning
model.train(
    mixture_or_task_name="mrpc",
    steps=2000,
    save_steps=200,
    sequence_length={"inputs": 128, "targets": 4},
    split="train",
    batch_size=32,
    optimizer=functools.partial(transformers.AdamW, lr=1e-4),
)

In [0]:
# Evaluate after fine-tuning
model.eval(
    "mrpc",
    checkpoint_steps="all",
    sequence_length={"inputs": 128, "targets": 4},
    batch_size=128,
)