# Step 1: fine-tune LLM using top result from (fixed) ranker

In [None]:
def step_one(*, k, max_iteration):
    import os
    from PersonalizedProductRating import train_loader, dev_loader
    from ProgressPrinter import ProgressPrinter
    from peft import LoraConfig, TaskType
    from TaskLLM import TaskLLM
    from transformers import T5ForConditionalGeneration
    import torch
    from Util import interleave, set_directory

    torch.manual_seed(2112)

    augment = int(os.environ.get('AUGMENT', '4'))
    train = train_loader(batch_size=8, augment=augment)
    dev = dev_loader(batch_size=16)

    t5 = T5ForConditionalGeneration.from_pretrained('google/flan-t5-base')
    taskllm_config = LoraConfig(r=5, task_type=TaskType.SEQ_2_SEQ_LM)
    t5.add_adapter(taskllm_config, "taskllm")
    t5.enable_adapters()

    taskllm = TaskLLM(t5=t5, adapter_name="taskllm")

    def inner_batch(func, inner_batch_size, inputs):
        from more_itertools import chunked
        return [ func(*ib) for ib in zip(*[ chunked(g, inner_batch_size) for g in inputs ]) ]

    print(f'************ augment = {augment} *************')
    with ProgressPrinter('iter', f'{k} loss', f'{k} MAE', f'{k} MAE (dev)') as printer:
        cumsum = lambda z, acc=0: [0] + [ acc := acc + v for v in z ]
        
        for iteration in range(max_iteration):
            for istrain, (examples, labels) in interleave(train, dev, sequential=True):
                with torch.no_grad():
                    texts_to_embed = [ [ text[:256]
                                         for text in (' '.join(ex['review'].split()), ) 
                                       ] + 
                                       [ text[:256]
                                         for v in ex['profile']
                                         for text in (' '.join(v['text'].split()), )
                                       ]
                                       for ex in examples
                                     ]
                    embeddings = torch.cat(inner_batch(func = lambda t: dev.embed(t),
                                                       inner_batch_size = 64 * torch.cuda.device_count(),
                                                       inputs = (sum(texts_to_embed, []),)
                                                      ),
                                           dim=0)
                    splits = cumsum(map(len, texts_to_embed))
                    indices = [ torch.topk(embeddings[a,:] @ embeddings[a+1:b,:].T, dim=0, k=k).indices for a, b in zip(splits, splits[1:]) ]
                    prompts = [ dev.prepend_to_prompt(ex, [ ex['profile'][ind] for ind in index.to('cpu').tolist() ])
                                for ex, index in zip(examples, indices) ]
                    targets = [ int(label) - 1 for label in labels ]
                    cumul = taskllm.predict(prompts).exp().cumsum(dim=1)
                    guesses = (cumul>=0.5).long().argmax(dim=1)
                    targets = torch.Tensor(targets).long().to(guesses.device)
                    mae = torch.abs(guesses - targets).float().mean().item()

                loss = taskllm.learn(prompts, targets) if istrain else None
                printer.addobs(iteration, loss, mae if istrain else None, mae if not istrain else None)

            printer.print()
            printer.autoprint = False
            output_dir = os.environ.get('AMLT_OUTPUT_DIR', '.')
            with set_directory(output_dir):
                taskllm.save_pretrained(f'User_keq{k}_t5base_step1_iter{iteration}_augment{augment}')

step_one(k=4, max_iteration=5)

************ augment = 4 *************
n              iter (since)      4 loss (since)       4 MAE (since) 4 MAE (dev) (since)      dt
1             0.000 (0.000)       0.858 (0.858)       0.550 (0.550)       0.000 (0.000)  32.1 s
2             0.000 (0.000)       0.711 (0.564)       0.400 (0.250)       0.000 (0.000)  44.7 s
4             0.000 (0.000)       0.759 (0.807)       0.381 (0.362)       0.000 (0.000)  1.21 m
8             0.000 (0.000)       0.770 (0.780)       0.378 (0.375)       0.000 (0.000)  2.06 m
16            0.000 (0.000)       0.739 (0.709)       0.355 (0.331)       0.000 (0.000)  3.82 m
32            0.000 (0.000)       0.684 (0.628)       0.325 (0.295)       0.000 (0.000)  7.27 m
64            0.000 (0.000)       0.672 (0.661)       0.320 (0.315)       0.000 (0.000)  14.3 m
128           0.000 (0.000)       0.652 (0.631)       0.309 (0.298)       0.000 (0.000)  28.1 m
256           0.000 (0.000)       0.613 (0.574)       0.292 (0.275)       0.000 (0.000)  55.7 m
5

# Step 2: learn ranker using (fixed pre-finetuned) task LLM

In [1]:
def learn_ranker(*, max_iteration, k):
    import os
    from RewardPredictor import RewardPredictor
    from TaskLLM import TaskLLM
    from PersonalizedProductRating import train_loader, dev_loader
    from ProgressPrinter import ProgressPrinter
    from SimpleRegret import SimpleRegretHypercubeSampler
    from peft import LoraConfig, TaskType
    from transformers import T5ForConditionalGeneration
    import torch
    from Util import interleave, set_directory, GPUMonitor

    os.environ['TOKENIZERS_PARALLELISM'] = 'true'
    step1_iter = os.environ.get('STEP1_ITER', '2_augment4')
    augment = int(os.environ.get('AUGMENT', '1'))
    gamma = float(os.environ.get('GAMMA', '1'))
    output_dir = os.environ.get('AMLT_OUTPUT_DIR', '.')
    
    torch.manual_seed(8675309)

    train = train_loader(batch_size=8, augment=augment)
    dev = dev_loader(batch_size=16)

    t5 = T5ForConditionalGeneration.from_pretrained('google/flan-t5-base')
    t5.load_adapter(f'User_keq{k}_t5base_step1_iter{step1_iter}', 'taskllm')

    rhat_config = LoraConfig(r=1, task_type=TaskType.SEQ_2_SEQ_LM)
    t5.add_adapter(rhat_config, "rhat")
    t5.enable_adapters()
    
    taskllm = TaskLLM(t5=t5, adapter_name="taskllm")
    rewardpredictor = RewardPredictor(t5=t5, adapter_name="rhat")

    gumbel = torch.distributions.gumbel.Gumbel(0,1)
    def randomized_similarity(embeddings, nsamples):
        scores = embeddings[0,:] @ embeddings[1:,:].T
        temperature = scores[0].item() - scores[min(scores.shape[0], 4)].item()
        gumbel_shape = torch.Size([nsamples, scores.shape[0]])
        gumbels = temperature * gumbel.sample(gumbel_shape).to(scores.device)
        return torch.unique(torch.topk(scores.unsqueeze(0) + gumbels, dim=1, k=k).indices, sorted=False, dim=0)

    def inner_batch(func, inner_batch_size, inputs):
        from more_itertools import chunked
        return [ func(*ib) for ib in zip(*[ chunked(g, inner_batch_size) for g in inputs ]) ]

    monitor = GPUMonitor(delay=60, maxcount=5)

    print(f'************ augment = {augment} gamma = {gamma} step1_iter = {step1_iter} *************')
    with ProgressPrinter('iter', f'{k} loss', f'{k} MAE', f'{k} MAE (dev)', 'samps') as printer:
        cumsum = lambda z, acc=0: [0] + [ acc := acc + v for v in z ]
        
        for iteration in range(max_iteration):
            for istrain, (examples, labels) in interleave(train, dev, sequential=True):
                with torch.no_grad():
                    texts_to_embed = [ [ text[:256]
                                         for text in (' '.join(ex['review'].split()), ) 
                                       ] + 
                                       [ text[:256]
                                         for v in ex['profile']
                                         for text in (' '.join(v['text'].split()), )
                                       ]
                                       for ex in examples
                                     ]
                    embeddings = torch.cat(inner_batch(func = lambda t: dev.embed(t),
                                                       inner_batch_size = 64 * torch.cuda.device_count(),
                                                       inputs = (sum(texts_to_embed, []),)
                                                      ),
                                           dim=0)
                    splits = cumsum(map(len, texts_to_embed))
                    randos = [ randomized_similarity(embeddings[a:b,:], 64) for a, b in zip(splits, splits[1:]) ]
                    prompts = [ [ dev.prepend_to_prompt(ex, [ ex['profile'][ind] for ind in indices ]) 
                                  for indices in rando.to('cpu').tolist() 
                                ]
                                for ex, rando in zip(examples, randos) 
                              ]
                    rhats = torch.cat(inner_batch(func = lambda p: rewardpredictor.predict(p),
                                                  inner_batch_size = 64 * torch.cuda.device_count(),
                                                  inputs = (sum(prompts, []),)
                                                 ),
                                      dim=0)
                    splits = cumsum(map(len, prompts))
                    samples = [ SimpleRegretHypercubeSampler(rhats[a:b].view(1, -1), gamma=gamma) for a, b in zip(splits, splits[1:]) ]
                    actionind = [ [ exploit.item() ] + [ n for n, observed in enumerate(explore) if observed > 0 ]
                                  for exploit, exploreraw in samples
                                  for explore in (exploreraw[0].tolist() if istrain else [], )
                                ]
                    nsamps = [ len(aind) for aind in actionind ]
                    guessprompts = [ [ prompt[a] for a in aind ] for prompt, aind in zip(prompts, actionind) ]
                    cumul = torch.cat(inner_batch(func = lambda p: taskllm.predict(p).exp().cumsum(dim=1),
                                                  inner_batch_size = 64 * torch.cuda.device_count(),
                                                  inputs = (sum(guessprompts, []),)
                                                 ),
                                      dim=0)
                    splits = cumsum(map(len, guessprompts))
                    guesses = [ (cumul[a:b,:]>=0.5).long().argmax(dim=1) for a, b in zip(splits, splits[1:]) ]
                    targets = [ int(label) - 1 for label in labels ]
                    rewards = [ (1 - torch.abs((g - target)/4).float()).tolist() for g, target in zip(guesses, targets) ]
                    greedymaes = [ torch.abs(g[0] - target).item() for g, target in zip(guesses, targets) ] 

                if istrain:
                    rhatprompts = sum(guessprompts, [])
                    rhattargets = sum(rewards, [])
                    predlosses = inner_batch(func = lambda a, b: (len(a), rewardpredictor.learn(a, torch.Tensor([ [ r ] for r in b ]))),
                                             inner_batch_size = 16 * torch.cuda.device_count(),
                                             inputs = (rhatprompts, rhattargets)
                                            )
                    predloss = sum(n * v for n, v in predlosses) / sum(n for n, v in predlosses)
                else:
                    predloss = None

                greedymae = torch.Tensor(greedymaes, device='cpu').float().mean().item()
                nsamps = torch.Tensor(nsamps, device='cpu').float().mean().item() if istrain else None

                printer.addobs(iteration, predloss, greedymae if istrain else None, greedymae if not istrain else None, nsamps)

            printer.print()
            printer.autoprint = False
            with set_directory(output_dir):
                taskllm.save_pretrained(f'User_keq{k}_t5base_step2_iter{iteration}_augment{augment}')
                
learn_ranker(k=4, max_iteration=5)

************ augment = 1 gamma = 1.0 step1_iter = 2_augment4 *************
n              iter (since)      4 loss (since)       4 MAE (since) 4 MAE (dev) (since)       samps (since)      dt
| ID | GPU | MEM |
------------------
|  0 | 40% | 10% |
| ID | GPU  | MEM |
-------------------
|  0 | 100% | 85% |
1             0.000 (0.000)       0.269 (0.269)       0.188 (0.188)       0.000 (0.000)      59.562 (59.562)  1.94 m
| ID | GPU | MEM |
------------------
|  0 | 88% | 88% |
2             0.000 (0.000)       0.306 (0.344)       0.219 (0.250)       0.000 (0.000)      30.906 (2.250)  2.69 m
| ID | GPU  | MEM |
-------------------
|  0 | 100% | 88% |
| ID | GPU  | MEM |
-------------------
|  0 | 100% | 88% |
4             0.000 (0.000)       0.230 (0.153)       0.172 (0.125)       0.000 (0.000)      17.094 (3.281)   4.3 m


KeyboardInterrupt: 