In [None]:
from prepare_data import create_or_load
from collator import T2TDataCollator
from transformers import AdamW, get_scheduler, Trainer, TrainingArguments
from transformers import T5Tokenizer
from model import T5PromptTuningLM

In [None]:
tokenizer = T5Tokenizer.from_pretrained("t5-base")
train_dataset, valid_dataset = create_or_load(tokenizer)

In [None]:
# if you want to train
model = T5PromptTuningLM.from_pretrained(
    "t5-base",
    n_tokens=args.n_prompt_tokens,
    initialize_from_vocab=args.init_from_vocab)

# if you want to use an existing prompt to do inference
# model = T5PromptTuningLM.from_pretrained('t5-base', 
#                                           return_dict=False, 
#                                           soft_prompt_path='soft_prompts/soft_prompt.model')

In [None]:
# Set up training arguments, optimizers, etc

class Config:
    # Prompt-tuning
    n_prompt_tokens = 10
    init_from_vocab = True
    # random_range = 0.5
args = Config()

optimizer_grouped_parameters = [
    {
        "params": [p for n, p in model.named_parameters() if n == "soft_prompt.weight"],
    }
]
optimizer = AdamW(optimizer_grouped_parameters)
lr_scheduler = get_scheduler(
    name='cosine',
    num_warmup_steps=0,
    optimizer=optimizer,
    num_training_steps=3,
)

training_args = TrainingArguments(
    output_dir='./results',          # output directory
    per_device_train_batch_size=8,  # batch size per device during training
    per_device_eval_batch_size=32,   # batch size for evaluation
    logging_dir='./logs',            # directory for storing logs
    logging_steps=100,
    save_steps=3000,
    report_to='tensorboard',
    prediction_loss_only=True,
)

In [None]:
# Initialize trainer

trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=valid_dataset,
        data_collator=T2TDataCollator(),
        optimizers=(optimizer, lr_scheduler),
    )

In [None]:
# start training

trainer.train()

In [None]:
# start evaluate

trainer.evaluate()

In [None]:
# making predictions

# question = 'In what country is Normandy located?'
question = 'When were the Normans in Normandy?'
context = 'The Normans (Norman: Nourmands; French: Normands; Latin: Normanni) were the people who in the 10th and 11th centuries gave their name to Normandy, a region in France. They were descended from Norse ("Norman" comes from "Norseman") raiders and pirates from Denmark, Iceland and Norway who, under their leader Rollo, agreed to swear fealty to King Charles III of West Francia. Through generations of assimilation and mixing with the native Frankish and Roman-Gaulish populations, their descendants would gradually merge with the Carolingian-based cultures of West Francia. The distinct cultural and ethnic identity of the Normans emerged initially in the first half of the 10th century, and it continued to evolve over the succeeding centuries.'

# question = 'Who was the duke in the battle of Hastings?'
# context = 'The Norman dynasty had a major political, cultural and military impact on medieval Europe and even the Near East. The Normans were famed for their martial spirit and eventually for their Christian piety, becoming exponents of the Catholic orthodoxy into which they assimilated. They adopted the Gallo-Romance language of the Frankish land they settled, their dialect becoming known as Norman, Normaund or Norman French, an important literary language. The Duchy of Normandy, which they formed by treaty with the French crown, was a great fief of medieval France, and under Richard I of Normandy was forged into a cohesive and formidable principality in feudal tenure. The Normans are noted both for their culture, such as their unique Romanesque architecture and musical traditions, and for their significant military accomplishments and innovations. Norman adventurers founded the Kingdom of Sicily under Roger II after conquering southern Italy on the Saracens and Byzantines, and an expedition on behalf of their duke, William the Conqueror, led to the Norman conquest of England at the Battle of Hastings in 1066. Norman cultural and military influence spread from these new European centres to the Crusader states of the Near East, where their prince Bohemond I founded the Principality of Antioch in the Levant, to Scotland and Wales in Great Britain, to Ireland, and to the coasts of north Africa and the Canary Islands.'

# question = "When was the Latin version of the word Norman first recorded?"
# context = "The English name 'Normans' comes from the French words Normans/Normanz, plural of Normant, modern French normand, which is itself borrowed from Old Low Franconian Nortmann 'Northman' or directly from Old Norse Norðmaðr, Latinized variously as Nortmannus, Normannus, or Nordmannus (recorded in Medieval Latin, 9th century) to mean 'Norseman, Viking'."

# question = 'When was the Duchy of Normandy founded?'
# context = 'In the course of the 10th century, the initially destructive incursions of Norse war bands into the rivers of France evolved into more permanent encampments that included local women and personal property. The Duchy of Normandy, which began in 911 as a fiefdom, was established by the treaty of Saint-Clair-sur-Epte between King Charles III of West Francia and the famed Viking ruler Rollo, and was situated in the former Frankish kingdom of Neustria. The treaty offered Rollo and his men the French lands between the river Epte and the Atlantic coast in exchange for their protection against further Viking incursions. The area corresponded to the northern part of present-day Upper Normandy down to the river Seine, but the Duchy would eventually extend west beyond the Seine. The territory was roughly equivalent to the old province of Rouen, and reproduced the Roman administrative structure of Gallia Lugdunensis II (part of the former Gallia Lugdunensis).'

input_ids = tokenizer.encode('question: %s  context: %s </s>' % (question, context), 
                             return_tensors='pt')
# input_ids = input_ids.cuda()

encoder_outputs = model.forward(input_ids, to_encoder_only=True)
indices = model.generate(
    encoder_outputs=encoder_outputs, 
#     min_length=10, 
#     repetition_penalty=2.,
                        )
[tokenizer.decode(idx) for idx in indices]