# Step 1: fine-tune LLM using top result from (fixed) ranker

In [2]:
# NB: amulet dropped some of the log output when this was run (?)
def launch():
    import os
    import StepOne
    import torch

    os.environ['MODEL_TYPE'] = 'xxl'
    augment = int(os.environ.get('AUGMENT', '4'))
    os.environ['BATCH_SIZE'] = '1'
    
    world_size = torch.cuda.device_count()
    torch.multiprocessing.spawn(StepOne.step_one,
                                args=(world_size,),
                                nprocs=world_size,
                                join=True)
    
launch()

******** augment = 4 max_iteration = 5 model_type = xxl *********
n              iter (since)      4 loss (since)       4 MAE (since) 4 MAE (dev) (since)      dt
8             0.000 (0.000)       0.374 (0.698)       0.500 (1.000)       0.000 (0.000)  22.1 s
16            0.000 (0.000)       0.331 (0.288)       0.250 (0.000)       0.000 (0.000)  43.5 s
32            0.000 (0.000)       0.374 (0.416)       0.219 (0.188)       0.000 (0.000)  1.41 m
64            0.000 (0.000)       0.640 (0.905)       0.297 (0.375)       0.000 (0.000)  2.76 m
128           0.000 (0.000)       0.654 (0.669)       0.328 (0.359)       0.000 (0.000)  5.57 m
256           0.000 (0.000)       0.635 (0.615)       0.297 (0.266)       0.000 (0.000)  11.1 m
512           0.000 (0.000)       0.553 (0.472)       0.258 (0.219)       0.000 (0.000)    22 m
1024          0.000 (0.000)       0.541 (0.529)       0.237 (0.217)       0.000 (0.000)  44.6 m
2048          0.000 (0.000)       0.532 (0.524)       0.238 (0.239)   

# Step 2: learn ranker using (fixed pre-finetuned) task LLM

In [None]:
def launch():
    import os
    import StepTwo
    import torch

    # NB: these are 80Gb gpu ram settings
    os.environ['MODEL_TYPE'] = 'xxl'
    os.environ['BATCH_SIZE'] = '1'
    os.environ['LEARN_BATCH_SIZE'] = '4'
    os.environ['STEP1_ITER'] = '2_augment4'
    
    world_size = torch.cuda.device_count()
    torch.multiprocessing.spawn(StepTwo.step_two,
                                args=(world_size,),
                                nprocs=world_size,
                                join=True)
    
launch()

******** augment = 0 max_iteration = 5 model_type = xxl *********
n              iter (since)      4 loss (since)       4 MAE (since) 4 MAE (dev) (since)      nsamps (since)      dt
1             0.000 (0.000)       0.234 (0.234)       0.000 (0.000)       0.000 (0.000)      64.000 (64.000)  37.5 s
2             0.000 (0.000)       0.118 (0.001)       0.000 (0.000)       0.000 (0.000)      33.000 (2.000)    45 s
4             0.000 (0.000)       1.786 (3.455)       0.250 (0.500)       0.000 (0.000)      17.750 (2.500)  56.4 s
8             0.000 (0.000)       1.210 (0.634)       0.250 (0.250)       0.000 (0.000)      19.125 (20.500)  1.71 m
16            0.000 (0.000)       0.928 (0.647)       0.312 (0.375)       0.000 (0.000)      15.000 (10.875)  3.29 m
32            0.000 (0.000)       0.670 (0.412)       0.219 (0.125)       0.000 (0.000)      11.750 (8.500)  5.37 m
64            0.000 (0.000)       0.592 (0.514)       0.219 (0.219)       0.000 (0.000)       8.859 (5.969)  9.25 m
128

# Step 3: Prepare Submission Files

In [None]:
def prepare_submission_probensemble(*, nvoters, step2_iter, step1_iter, k):
    import json
    from RewardPredictor import RewardPredictor
    from TaskLLM import TaskLLM
    from PersonalizedProductRating import dev_loader, test_loader
    from ProgressPrinter import ProgressPrinter
    from transformers import T5ForConditionalGeneration
    import torch
    from Util import interleave
    
    device = 'cuda'
    torch.set_default_device(device)
    torch.manual_seed(8675309)

    dev = dev_loader(batch_size=1)
    test = test_loader(batch_size=1)

    t5 = T5ForConditionalGeneration.from_pretrained('google/flan-t5-xxl')
    t5.load_adapter(f'User_keq{k}_t5xxl_step1_iter{step1_iter}', 'taskllm')
    t5.load_adapter(f'User_keq{k}_t5xxl_step2_iter{step2_iter}', 'rhat')
    t5.enable_adapters()
    
    taskllm = TaskLLM(t5=t5, adapter_name="taskllm")
    rewardpredictor = RewardPredictor(t5=t5, adapter_name="rhat", model_id=f'User_keq{k}_t5xxl_step2_iter{step2_iter}')

    gumbel = torch.distributions.gumbel.Gumbel(0,1)
    def randomized_similarity(embeddings, nsamples):
        scores = embeddings[0,:] @ embeddings[1:,:].T
        temperature = scores[0].item() - scores[min(scores.shape[0], 4)].item()
        gumbel_shape = torch.Size([nsamples, scores.shape[0]])
        gumbels = temperature * gumbel.sample(gumbel_shape).to(scores.device)
        return torch.unique(torch.topk(scores.unsqueeze(0) + gumbels, dim=1, k=k).indices, sorted=False, dim=0)

    def inner_batch(func, inner_batch_size, inputs):
        from more_itertools import chunked
        return [ func(*ib) for ib in zip(*[ chunked(g, inner_batch_size) for g in inputs ]) ]

    print(f'*** step1_iter: {step1_iter} step2_iter: {step2_iter} nvoters {nvoters} ***')
    
    with ProgressPrinter(f'{k} MAE (dev)', f'{k} MSE (dev)') as printer:
        devgolds, testgolds = [], []
        cumsum = lambda z, acc=0: [0] + [ acc := acc + v for v in z ]

        for isdev, (examples, labels) in interleave(dev, test, sequential=True):
            with torch.no_grad():
                texts_to_embed = [ [ text[:256]
                                     for text in (' '.join(ex['review'].split()), )
                                   ] +
                                   [ text[:256]
                                     for v in ex['profile']
                                     for text in (' '.join(v['text'].split()), )
                                   ]
                                   for ex in examples
                                 ]
                embeddings = torch.cat(inner_batch(func = dev.embed,
                                                   inner_batch_size = 128,
                                                   inputs = (sum(texts_to_embed, []),)
                                                  ),
                                       dim=0)
                splits = cumsum(map(len, texts_to_embed))
                randos = [ randomized_similarity(embeddings[a:b,:], 64) for a, b in zip(splits, splits[1:]) ]
                prompts = [ [ dev.prepend_to_prompt(ex, [ ex['profile'][ind] for ind in indices ])
                              for indices in rando.to('cpu').tolist()
                            ]
                            for ex, rando in zip(examples, randos)
                          ]
                rhats = torch.cat(inner_batch(func = rewardpredictor.predict,
                                              inner_batch_size = 128,
                                              inputs = (sum(prompts, []),)
                                             ),
                                  dim=0)
                splits = cumsum(map(len, prompts))
                votingprompts = [ [ prompt[v] for v in torch.topk(rhats[a:b].view(-1), k=min(nvoters, b-a)).indices.to('cpu').tolist() ]
                                    for a, b, prompt in zip(splits, splits[1:], prompts)
                                ]
                predicts = torch.cat(inner_batch(func = taskllm.predict,
                                                 inner_batch_size = 128,
                                                 inputs = (sum(votingprompts, []),)
                                                ),
                                     dim=0)
                splits = cumsum(map(len, votingprompts))
                guesses = torch.cat([ (predicts[a:b,:].logsumexp(dim=0, keepdim=True).exp().cumsum(dim=1) >= 0.5 * (b-a)).long().argmax(dim=1)
                                      for a, b in zip(splits, splits[1:])
                                    ],
                                    dim=0)

                if isdev:
                    targets = [ int(label) - 1 for label in labels ]
                    targets = torch.Tensor(targets).long().to(guesses.device)
                    mae = torch.abs(guesses - targets).float().mean().item()
                    mse = torch.square(guesses - targets).float().mean().item()
                else:
                    mae, mse = None, None

                for ex, guess in zip(examples, guesses):
                    (devgolds if isdev else testgolds).append({ 'id': ex['id'], 'output': f'{1+guess}' })

            printer.addobs(mae, mse)

        printer.print()
        printer.autoprint = False

        for wut, golds in ( ('dev', devgolds), ('test', testgolds) ):
            with open(f'lamp3u_{wut}golds_t5xxl_keq{k}_step1_iter{step1_iter}_step2_iter{step2_iter}_nvoters{nvoters}.json', 'w') as jsonfile:
                json.dump({ 'task': 'LaMP_3', 'golds': golds }, jsonfile)
            
for nvoters in [1, 3, 5, 7]:
    # ugh ... without complete cleanup, run out of memory
    from multiprocessing import Process
    p = Process(target=prepare_submission_probensemble, 
                kwargs = { 'k': 4,
                           'step1_iter': '2_augment4', 
                           'step2_iter': '0_augment0', 
                           'nvoters': nvoters
                         })
    p.start()
    p.join()

*** step1_iter: 2_augment4 step2_iter: 0_augment0 nvoters 1 ***
n       4 MAE (dev) (since) 4 MSE (dev) (since)      dt
1             0.000 (0.000)       0.000 (0.000)  1.26 s
2             0.500 (1.000)       0.500 (1.000)  10.2 s
4             0.500 (0.500)       0.500 (0.500)  44.2 s
8             0.375 (0.250)       0.375 (0.250)  1.61 m
16            0.312 (0.250)       0.312 (0.250)  3.34 m
32            0.219 (0.125)       0.219 (0.125)  6.81 m
64            0.234 (0.250)       0.234 (0.250)  12.9 m
128           0.242 (0.250)       0.258 (0.281)  27.2 m
256           0.207 (0.172)       0.215 (0.172)  53.5 m
512           0.227 (0.246)       0.258 (0.301)  1.77 h
1024          0.227 (0.227)       0.256 (0.254)  3.57 h
2048          0.213 (0.200)       0.233 (0.210)  7.11 h
4096          0.211 (0.199)       0.231 (0.221)  14.4 h
*** step1_iter: 2_augment4 step2_iter: 0_augment0 nvoters 3 ***
n       4 MAE (dev) (since) 4 MSE (dev) (since)      dt
1             0.000 (0.000)     

```
Hi, this is the results of your submissions to the LaMP benchmark:

LaMP-3:
{"MAE": 0.1932, "RMSE": 0.46604720790924176}
```