# Step 1: fine-tune LLM using top result from (fixed) ranker

In [None]:
def step_one(*, k, max_iteration):
    import evaluate
    from TaskLLM import TaskLLM
    from PersonalizedNews import train_loader, dev_loader
    from ProgressPrinter import ProgressPrinter
    from peft import IA3Config, TaskType
    from transformers import T5ForConditionalGeneration
    import torch
    from Util import interleave
    
    device = 'cuda'
    torch.set_default_device(device)
    torch.manual_seed(2112)

    train = train_loader(batch_size=1, augment=True)
    dev = dev_loader(batch_size=2)

    t5 = T5ForConditionalGeneration.from_pretrained('google/flan-t5-base')
    taskllm_config = IA3Config(task_type=TaskType.SEQ_2_SEQ_LM)
    t5.add_adapter(taskllm_config, "taskllm")
    t5.enable_adapters()

    taskllm = TaskLLM(t5=t5, adapter_name="taskllm")

    rouge_metric = evaluate.load('rouge')
    
    with ProgressPrinter('iter', f'{k} loss', f'{k} rouge1', f'{k} rouge1 (dev)') as printer:
        for iteration in range(max_iteration):
            for istrain, (examples, labels) in interleave(train, dev, sequential=True):
                with torch.no_grad():
                    prompts = []
    
                    for ex in examples:
                        embeddings = dev.embed( [ ex['article'] ] + 
                                                [ v['text']
                                                  for v in ex['profile']
                                                ])
                        index = torch.topk(embeddings[0,:] @ embeddings[1:,:].T, dim=0, k=k).indices.to('cpu').tolist()
                        profile_examples = [ ex['profile'][ind] for ind in index ]
                        prompt = dev.prepend_to_prompt(ex, profile_examples)
                        prompts.append(prompt)

                    guesses = taskllm.generate(prompts)
                    scores = rouge_metric.compute(predictions=guesses, references=labels)['rouge1']
    
                loss = taskllm.learn(prompts, labels) if istrain else None
                printer.addobs(iteration, loss, scores if istrain else None, scores if not istrain else None)

            printer.print()
            printer.autoprint = False
            taskllm.save_pretrained(f'User_keq{k}_t5base_step1_iter{iteration}')

from Fork import SubProcess
from Util import BadPipe
with BadPipe(), SubProcess() as process: process.parent or step_one(k=1, max_iteration=5)

# Step 2: learn ranker using (fixed pre-finetuned) task LLM