# Step 1: fine-tune LLM using top result from (fixed) ranker

In [None]:
def launch():
    import os
    import StepOne
    import torch

    os.environ['MODEL_TYPE'] = 'base'
    augment = int(os.environ.get('AUGMENT', '1'))
    os.environ['BATCH_SIZE'] = '16' # '32'
    
    world_size = torch.cuda.device_count()
    torch.multiprocessing.spawn(StepOne.step_one,
                                args=(world_size,),
                                nprocs=world_size,
                                join=True)
    
launch()

# Step 2: learn ranker using (fixed pre-finetuned) task LLM

In [None]:
def learn_ranker(*, max_iteration, k):
    import evaluate
    import os
    from RewardPredictor import RewardPredictor
    from TaskLLM import TaskLLM
    from PersonalizedNews import train_loader, dev_loader
    from ProgressPrinter import ProgressPrinter
    from SimpleRegret import SimpleRegretHypercubeSampler
    from peft import IA3Config, TaskType
    from transformers import T5ForConditionalGeneration
    import torch
    from Util import interleave, set_directory, GPUMonitor
    
    torch.manual_seed(8675309)

    os.environ['TOKENIZERS_PARALLELISM'] = 'true'
    step1_iter = os.environ.get('STEP1_ITER', '0_augment8')
    augment = int(os.environ.get('AUGMENT', '1'))
    gamma = float(os.environ.get('GAMMA', '1'))
    output_dir = os.environ.get('AMLT_OUTPUT_DIR', '.')

    train = train_loader(batch_size=8 * torch.cuda.device_count(), augment=augment)
    dev = dev_loader(batch_size=16 * torch.cuda.device_count())

    t5 = T5ForConditionalGeneration.from_pretrained('google/flan-t5-base')
    t5.load_adapter(f'User_keq{k}_t5base_step1_iter{step1_iter}', 'taskllm')

    rhat_config = IA3Config(task_type=TaskType.SEQ_2_SEQ_LM)
    t5.add_adapter(rhat_config, "rhat")
    t5.enable_adapters()
    
    taskllm = TaskLLM(t5=t5, adapter_name="taskllm")
    rewardpredictor = RewardPredictor(t5=t5, adapter_name="rhat")
    rouge_metric = evaluate.load('rouge')

    gumbel = torch.distributions.gumbel.Gumbel(0,1)
    def randomized_similarity(embeddings, nsamples):
        scores = embeddings[0,:] @ embeddings[1:,:].T
        temperature = scores[0].item() - scores[min(scores.shape[0], 4)].item()
        gumbel_shape = torch.Size([nsamples, scores.shape[0]])
        gumbels = temperature * gumbel.sample(gumbel_shape).to(scores.device)
        return torch.unique(torch.topk(scores.unsqueeze(0) + gumbels, dim=1, k=k).indices, sorted=False, dim=0)

    def inner_batch(func, inner_batch_size, inputs):
        from more_itertools import chunked
        return [ func(*ib) for ib in zip(*[ chunked(g, inner_batch_size) for g in inputs ]) ]

    monitor = GPUMonitor(delay=60, maxcount=5)

    print(f'************ augment = {augment} gamma = {gamma} step1_iter = {step1_iter} *************')
    with ProgressPrinter('iter', f'{k} loss', f'{k} rouge1', f'{k} rouge1 (dev)', 'samps') as printer:
        cumsum = lambda z, acc=0: [0] + [ acc := acc + v for v in z ]

        for iteration in range(max_iteration):
            for istrain, (examples, labels) in interleave(train, dev, sequential=True):
                with torch.no_grad():
                    texts_to_embed = [ [ ex['article'] ] + 
                                       [ v['text']
                                         for v in ex['profile']
                                       ]
                                       for ex in examples
                                     ]
                    embeddings = torch.cat(inner_batch(func = dev.embed,
                                                       inner_batch_size = 128 * torch.cuda.device_count(),
                                                       inputs = (sum(texts_to_embed, []),)
                                                      ),
                                           dim=0)
                    splits = cumsum(map(len, texts_to_embed))
                    randos = [ randomized_similarity(embeddings[a:b,:], 64) for a, b in zip(splits, splits[1:]) ]
                    prompts = [ [ dev.prepend_to_prompt(ex, [ ex['profile'][ind] for ind in indices ]) 
                                  for indices in rando.to('cpu').tolist() 
                                ]
                                for ex, rando in zip(examples, randos) 
                              ]
                    rhats = torch.cat(inner_batch(func = rewardpredictor.predict,
                                                  inner_batch_size = 128 * torch.cuda.device_count(),
                                                  inputs = (sum(prompts, []),)
                                                 ),
                                      dim=0)
                    splits = cumsum(map(len, prompts))
                    samples = [ SimpleRegretHypercubeSampler(rhats[a:b].view(1, -1), gamma=gamma) for a, b in zip(splits, splits[1:]) ]
                    actionind = [ [ exploit.item() ] + [ n for n, observed in enumerate(explore) if observed > 0 ]
                                  for exploit, exploreraw in samples
                                  for explore in (exploreraw[0].tolist() if istrain else [], )
                                ]
                    nsamps = [ len(aind) for aind in actionind ]
                    guessprompts = [ [ prompt[a] for a in aind ] for prompt, aind in zip(prompts, actionind) ]
                    guesses = sum(inner_batch(func = taskllm.generate,
                                              inner_batch_size = 64 * torch.cuda.device_count(),
                                              inputs = (sum(guessprompts, []),)
                                             ),
                                  [])
                    splits = cumsum(map(len, guessprompts))
                    rewards = sum( ( rouge_metric.compute(predictions=guesses[a:b], 
                                                          references=[label]*(b-a),
                                                          use_aggregator=False)['rouge1']
                                     for a, b, label in zip(splits, splits[1:], labels)
                                  ),
                                  [])
                    greedyrewards = rouge_metric.compute(predictions=[guesses[a] for a in splits[:-1]],
                                                         references = labels,
                                                         use_aggregator=False)['rouge1']

                if istrain:
                    predlosses = inner_batch(func = lambda a, b: (len(a), rewardpredictor.learn(a, torch.Tensor([ [ r ] for r in b ]))),
                                             inner_batch_size = 32 * torch.cuda.device_count(),
                                             inputs = (sum(guessprompts, []), rewards)
                                            )
                    predloss = sum(n * v for n, v in predlosses) / sum(n for n, v in predlosses)
                else:
                    predloss = None

                greedyreward = torch.Tensor(greedyrewards, device='cpu').float().mean().item()
                nsamps = torch.Tensor(nsamps, device='cpu').float().mean().item() if istrain else None

                printer.addobs(iteration, predloss, greedyreward if istrain else None, greedyreward if not istrain else None, nsamps)

            printer.print()
            printer.autoprint = False
            with set_directory(output_dir):
                taskllm.save_pretrained(f'User_keq{k}_t5base_step2_iter{iteration}_augment{augment}')

learn_ranker(k=4, max_iteration=5)