# Step 1: fine-tune LLM using top result from (fixed) ranker

In [2]:
def step_one(*, k, max_iteration):
    from TaskLLM import TaskLLM
    from PersonalizedProductRating import train_loader, dev_loader
    from ProgressPrinter import ProgressPrinter
    from peft import LoraConfig, TaskType
    from transformers import T5ForConditionalGeneration
    import torch
    from Util import interleave
    
    device = 'cuda'
    torch.set_default_device(device)
    torch.manual_seed(2112)

    train = train_loader(batch_size=1, augment=True)
    dev = dev_loader(batch_size=2)

    t5 = T5ForConditionalGeneration.from_pretrained('google/flan-t5-base')
    taskllm_config = LoraConfig(r=5, task_type=TaskType.SEQ_2_SEQ_LM)
    t5.add_adapter(taskllm_config, "taskllm")
    t5.enable_adapters()

    taskllm = TaskLLM(t5=t5, adapter_name="taskllm")

    with ProgressPrinter('iter', f'{k} loss', f'{k} MAE', f'{k} MAE (dev)') as printer:
        with torch.no_grad():
            probs_to_mean = torch.arange(dev.num_labels).unsqueeze(1).float().to(device)
            
        for iteration in range(max_iteration):
            for istrain, (examples, labels) in interleave(train, dev, sequential=True):
                with torch.no_grad():
                    prompts, targets = [], []
    
                    for ex, label in zip(examples, labels):
                        embeddings = dev.embed( [ text[:256]
                                                  for text in (' '.join(ex['review'].split()), ) 
                                                ] + 
                                                [ text[:256]
                                                  for v in ex['profile']
                                                  for text in (' '.join(v['text'].split()), )
                                                ])
                        index = torch.topk(embeddings[0,:] @ embeddings[1:,:].T, dim=0, k=k).indices.to('cpu').tolist()
                        profile_examples = [ ex['profile'][ind] for ind in index ]
                        prompt = dev.prepend_to_prompt(ex, profile_examples)
                        prompts.append(prompt)
                        targets.append(int(label)-1)

                    targets = torch.Tensor(targets).long().to(device)
                    cumul = taskllm.predict(prompts).exp().cumsum(dim=1)
                    guesses = (cumul>=0.5).long().argmax(dim=1)
                    mae = torch.abs(guesses - targets).float().mean().item()
    
                loss = taskllm.learn(prompts, targets) if istrain else None
                printer.addobs(iteration, loss, mae if istrain else None, mae if not istrain else None)

            printer.print()
            printer.autoprint = False
            taskllm.save_pretrained(f'User_keq{k}_t5base_step1_iter{iteration}_loratruncaug')

from Fork import SubProcess
from Util import BadPipe
with BadPipe(), SubProcess() as process: process.parent or step_one(k=1, max_iteration=5)

n              iter (since)      1 loss (since)       1 MAE (since) 1 MAE (dev) (since)      dt
1             0.000 (0.000)       0.418 (0.418)       0.000 (0.000)       0.000 (0.000)  2.15 s
2             0.000 (0.000)       1.283 (2.148)       0.500 (1.000)       0.000 (0.000)  2.68 s
4             0.000 (0.000)       2.030 (2.778)       0.875 (1.250)       0.000 (0.000)  4.13 s
8             0.000 (0.000)       1.573 (1.117)       0.812 (0.750)       0.000 (0.000)  7.97 s
16            0.000 (0.000)       1.165 (0.756)       0.562 (0.312)       0.000 (0.000)  19.5 s
32            0.000 (0.000)       0.827 (0.489)       0.375 (0.188)       0.000 (0.000)    34 s
64            0.000 (0.000)       0.733 (0.638)       0.297 (0.219)       0.000 (0.000)  1.06 m
128           0.000 (0.000)       0.699 (0.666)       0.316 (0.336)       0.000 (0.000)  2.03 m
256           0.000 (0.000)       0.647 (0.594)       0.311 (0.305)       0.000 (0.000)  3.86 m
512           0.000 (0.000)       0.672 

# Step 2: learn ranker using (fixed pre-finetuned) task LLM

In [None]:
def learn_ranker(*, step1_iter, rank, max_iteration):
    from RewardPredictor import RewardPredictor
    from TaskLLM import TaskLLM
    from PersonalizedProductRating import train_loader, dev_loader
    from ProgressPrinter import ProgressPrinter
    from peft import IA3Config, TaskType
    from transformers import T5ForConditionalGeneration
    import torch
    from Util import interleave
    
    device = 'cuda'
    torch.set_default_device(device)
    torch.manual_seed(8675309)

    train = train_loader(batch_size=2)
    dev = dev_loader(batch_size=2)

    t5 = T5ForConditionalGeneration.from_pretrained('google/flan-t5-base')
    t5.load_adapter(f'User_keq1_t5base_step1_iter{step1_iter}', 'taskllm')

    rhat_config = IA3Config(task_type=TaskType.SEQ_2_SEQ_LM)
    t5.add_adapter(rhat_config, "rhat")
    t5.enable_adapters()
    
    taskllm = TaskLLM(t5=t5, adapter_name="taskllm")
    rewardpredictor = RewardPredictor(t5=t5, adapter_name="rhat")

    with ProgressPrinter('iter', f'{rank} loss', f'{rank} MAE', f'{rank} MAE (dev)') as printer:
        for iteration in range(max_iteration):
            for istrain, (examples, labels) in interleave(train, dev, sequential=True):
                greedymaes, allloss = [], []
                for ex, label in zip(examples, labels):
                    with torch.no_grad():
                        embeddings = dev.embed( [ ex['review'] ] + 
                                                [ v['text'] 
                                                  for v in ex['profile']
                                                  if v['text'] != ex['review'] 
                                                ])
                        effk = min(rank, embeddings.shape[0] - 1)
                        index = torch.topk(embeddings[0,:] @ embeddings[1:,:].T, dim=0, k=effk).indices.to('cpu').tolist()
                        prompts, rhatprompts = [], []
                        for n, oneind in enumerate(index):
                            profile_examples = [ ex['profile'][ind] for ind in (oneind,) ]
                            prompt = dev.prepend_to_prompt(ex, profile_examples)
                            prompts.append(prompt)
                            maxlen = 256
                            rhatprompt = '\n'.join([ f"Example: {text:.{maxlen}s}\nScore: {v['score']}" 
                                                     for ind in (oneind,)
                                                     for v in (ex['profile'][ind],)
                                                     for text in (' '.join(v['text'].split()),)
                                                   ] + [ f"Review: {ex['review']}" ])
                            rhatprompts.append(rhatprompt)
                
                        cumul = taskllm.predict(prompts).exp().cumsum(dim=1)
                        guesses = (cumul>=0.5).long().argmax(dim=1)
                        target = int(label) - 1
                        rewards = 1 - torch.abs((guesses - target)/4).float().unsqueeze(1)
                        rhats = rewardpredictor.predict(rhatprompts)
                        greedy = torch.argmax(rhats, dim=0).item()
                        greedymae = torch.abs(guesses[greedy] - target).item()
                        greedymaes.append(greedymae)
                        
                    loss = rewardpredictor.learn(rhatprompts, rewards) if istrain else None
                    allloss.append(loss)

                greedymae = torch.Tensor(greedymaes).float().mean().item()
                predloss = torch.Tensor(allloss).mean().item() if istrain else None

                printer.addobs(iteration, predloss, greedymae if istrain else None, greedymae if not istrain else None)

            printer.print()
            printer.autoprint = False
            rewardpredictor.save_pretrained(f'User_keq1_t5base_step2_iter{iteration}_rankeq{rank}')

from Fork import SubProcess
from Util import BadPipe
with BadPipe(), SubProcess() as process: process.parent or learn_ranker(step1_iter=2, rank=8, max_iteration=8)