# Step 1: fine-tune LLM using top result from (fixed) ranker

In [None]:
def step_one(*, k, max_iteration):
    from TaskLLM import TaskLLM
    from PersonalizedProductRating import train_loader, dev_loader
    from ProgressPrinter import ProgressPrinter
    from peft import LoraConfig, TaskType
    from transformers import T5ForConditionalGeneration
    import torch
    from Util import interleave
    
    device = 'cuda'
    torch.set_default_device(device)
    torch.manual_seed(2112)

    train = train_loader(batch_size=8, augment=True)
    dev = dev_loader(batch_size=16)

    t5 = T5ForConditionalGeneration.from_pretrained('google/flan-t5-base')
    taskllm_config = LoraConfig(r=5, task_type=TaskType.SEQ_2_SEQ_LM)
    t5.add_adapter(taskllm_config, "taskllm")
    t5.enable_adapters()

    taskllm = TaskLLM(t5=t5, adapter_name="taskllm")

    with ProgressPrinter('iter', f'{k} loss', f'{k} MAE', f'{k} MAE (dev)') as printer:
        for iteration in range(max_iteration):
            for istrain, (examples, labels) in interleave(train, dev, sequential=True):
                with torch.no_grad():
                    prompts, targets = [], []
    
                    for ex, label in zip(examples, labels):
                        embeddings = dev.embed( [ text[:256]
                                                  for text in (' '.join(ex['review'].split()), ) 
                                                ] + 
                                                [ text[:256]
                                                  for v in ex['profile']
                                                  for text in (' '.join(v['text'].split()), )
                                                ])
                        index = torch.topk(embeddings[0,:] @ embeddings[1:,:].T, dim=0, k=k).indices.to('cpu').tolist()
                        profile_examples = [ ex['profile'][ind] for ind in index ]
                        prompt = dev.prepend_to_prompt(ex, profile_examples)
                        prompts.append(prompt)
                        targets.append(int(label)-1)

                    targets = torch.Tensor(targets).long().to(device)
                    cumul = taskllm.predict(prompts).exp().cumsum(dim=1)
                    guesses = (cumul>=0.5).long().argmax(dim=1)
                    mae = torch.abs(guesses - targets).float().mean().item()
    
                loss = taskllm.learn(prompts, targets) if istrain else None
                printer.addobs(iteration, loss, mae if istrain else None, mae if not istrain else None)

            printer.print()
            printer.autoprint = False
            taskllm.save_pretrained(f'User_keq{k}_t5base_step1_iter{iteration}_loratruncaug')

step_one(k=4, max_iteration=5)

n              iter (since)      4 loss (since)       4 MAE (since) 4 MAE (dev) (since)      dt
1             0.000 (0.000)       0.954 (0.954)       0.438 (0.438)       0.000 (0.000)  6.77 s
2             0.000 (0.000)       0.734 (0.513)       0.344 (0.250)       0.000 (0.000)    12 s
4             0.000 (0.000)       0.975 (1.217)       0.578 (0.812)       0.000 (0.000)  18.4 s
8             0.000 (0.000)       0.733 (0.491)       0.406 (0.234)       0.000 (0.000)  31.5 s
16            0.000 (0.000)       0.697 (0.661)       0.363 (0.320)       0.000 (0.000)  57.7 s
32            0.000 (0.000)       0.631 (0.566)       0.311 (0.258)       0.000 (0.000)  1.77 m
64            0.000 (0.000)       0.664 (0.697)       0.329 (0.348)       0.000 (0.000)   3.5 m
128           0.000 (0.000)       0.641 (0.617)       0.313 (0.297)       0.000 (0.000)  7.08 m
256           0.000 (0.000)       0.636 (0.631)       0.307 (0.301)       0.000 (0.000)    14 m
512           0.000 (0.000)       0.627 

# Step 2: learn ranker using (fixed pre-finetuned) task LLM

In [None]:
def learn_ranker(*, step1_iter, max_iteration, k):
    from more_itertools import chunked
    from RewardPredictor import RewardPredictor
    from TaskLLM import TaskLLM
    from PersonalizedProductRating import train_loader, dev_loader
    from ProgressPrinter import ProgressPrinter
    from SimpleRegret import SimpleRegretHypercubeSampler
    from peft import LoraConfig, TaskType
    from transformers import T5ForConditionalGeneration
    import torch
    from Util import interleave
    
    device = 'cuda'
    torch.set_default_device(device)
    torch.manual_seed(8675309)

    train = train_loader(batch_size=8, augment=True)
    dev = dev_loader(batch_size=16)

    t5 = T5ForConditionalGeneration.from_pretrained('google/flan-t5-base')
    t5.load_adapter(f'User_keq{k}_t5base_step1_iter{step1_iter}', 'taskllm')

    rhat_config = LoraConfig(r=5, task_type=TaskType.SEQ_2_SEQ_LM)
    t5.add_adapter(rhat_config, "rhat")
    t5.enable_adapters()
    
    taskllm = TaskLLM(t5=t5, adapter_name="taskllm")
    rewardpredictor = RewardPredictor(t5=t5, adapter_name="rhat")

    gumbel = torch.distributions.gumbel.Gumbel(0,1)
    def randomized_similarity(ex, nsamples):
        embeddings = dev.embed( [ text[:256]
                                  for text in (' '.join(ex['review'].split()), ) 
                                ] + 
                                [ text[:256]
                                  for v in ex['profile']
                                  for text in (' '.join(v['text'].split()), )
                                ])
        scores = embeddings[0,:] @ embeddings[1:,:].T
        temperature = scores[0].item() - scores[min(scores.shape[0], 4)].item()
        gumbel_shape = torch.Size([nsamples, scores.shape[0]])
        gumbels = temperature * gumbel.sample(gumbel_shape)
        return torch.unique(torch.topk(scores.unsqueeze(0) + gumbels, dim=1, k=k).indices, sorted=False, dim=0).to('cpu')

    with ProgressPrinter('iter', f'{k} loss', f'{k} MAE', f'{k} MAE (dev)', 'samps') as printer:
        for iteration in range(max_iteration):
            for istrain, (examples, labels) in interleave(train, dev, sequential=True):
                greedymaes, allloss, nsamps = [], [], []

                for ex, label in zip(examples, labels):
                    with torch.no_grad():
                        randos = randomized_similarity(ex, 64)
                        
                        prompts, rhatprompts = [], []
                        for rando in randos:
                            profile_examples = [ ex['profile'][ind] for ind in rando ]
                            prompt = dev.prepend_to_prompt(ex, profile_examples)
                            prompts.append(prompt)
                            rhatprompt = '\n'.join([ f"Example: {text[:256]}\nScore: {v['score']}" 
                                                     for ind in rando
                                                     for v in (ex['profile'][ind],)
                                                     for text in (' '.join(v['text'].split()),)
                                                   ] + [ f"Review: {text[:256]}"
                                                         for text in (' '.join(ex['review'].split()),) 
                                                       ])
                            rhatprompts.append(rhatprompt)

                        rhats = rewardpredictor.predict(rhatprompts)
                        exploit, explore = SimpleRegretHypercubeSampler(rhats.view(1, -1), gamma=4)
                        exploit = [ exploit.item() ]
                        explore = explore[0].tolist() if istrain else []
                        actionind = exploit + [ n for n, observed in enumerate(explore) if observed > 0 ]
                        nsamps.append(len(actionind))
                        cumul = taskllm.predict([ prompts[a] for a in actionind ]).exp().cumsum(dim=1)
                        guesses = (cumul>=0.5).long().argmax(dim=1)
                        target = int(label) - 1
                        rewards = (1 - torch.abs((guesses - target)/4).float()).tolist()
                        greedymae = torch.abs(guesses[0] - target).item()
                        greedymaes.append(greedymae)

                    if istrain:
                        inner_batch_size = 4
                        loss = sum(
                            len(inner_batch[0]) * 
                            rewardpredictor.learn([ rhatprompts[a] for a in inner_batch[0] ], 
                                                    torch.Tensor([ [ r ] for r in inner_batch[1] ]).to(device)
                                                 )
                            for inner_batch in zip(chunked(actionind, inner_batch_size), chunked(rewards, inner_batch_size))
                        ) / len(actionind)
                        allloss.append(loss)

                greedymae = torch.Tensor(greedymaes).float().mean().item()
                predloss = torch.Tensor(allloss).mean().item() if istrain else None
                nsamps = torch.Tensor(nsamps).float().mean().item() if istrain else None

                printer.addobs(iteration, predloss, greedymae if istrain else None, greedymae if not istrain else None, nsamps)

            printer.print()
            printer.autoprint = False
            rewardpredictor.save_pretrained(f'User_keq{k}_t5base_step2_iter{iteration}_loratruncaug_hyper')

learn_ranker(k=4, max_iteration=12, step1_iter='3_loratruncaug')