# Step 1 Dev Set Labels

In [None]:
def step1_dev_set_labels(*, step1_iter, k):
    import json
    from RewardPredictor import RewardPredictor
    from TaskLLM import TaskLLM
    from peft import prepare_model_for_kbit_training
    from PersonalizedProductRating import dev_loader
    from ProgressPrinter import ProgressPrinter
    from transformers import T5ForConditionalGeneration
    import torch
    import warnings
    
    device = 'cuda'
    torch.set_default_device(device)
    torch.manual_seed(8675309)

    dev = dev_loader(batch_size=1)

    t5 = prepare_model_for_kbit_training(T5ForConditionalGeneration.from_pretrained('google/flan-t5-xxl', load_in_8bit=True))
    taskllm_model_id = f'User_keq{k}_t5xxl_step1_iter{step1_iter}'
    t5.load_adapter(taskllm_model_id, 'taskllm')
    
    taskllm = TaskLLM(t5=t5, adapter_name="taskllm")

    def inner_batch(func, inner_batch_size, inputs):
        from more_itertools import chunked
        return [ func(*ib) for ib in zip(*[ chunked(g, inner_batch_size) for g in inputs ]) ]
    
    print(f'*** step1_iter: {step1_iter} ***')

    devgolds = []
    with ProgressPrinter(f'{k} MAE (dev)', f'{k} MSE (dev)') as printer, warnings.catch_warnings():
        warnings.filterwarnings("ignore", message=".*MatMul8bitLt.*")
        cumsum = lambda z, acc=0: [0] + [ acc := acc + v for v in z ]

        for examples, labels in dev:
            with torch.no_grad():
                texts_to_embed = [ [ text[:256]
                                     for text in (' '.join(ex['review'].split()), )
                                   ] +
                                   [ text[:256]
                                     for v in ex['profile']
                                     for text in (' '.join(v['text'].split()), )
                                   ]
                                   for ex in examples
                                 ]
                embeddings = torch.cat(inner_batch(func = dev.embed,
                                                   inner_batch_size = 128,
                                                   inputs = (sum(texts_to_embed, []),)
                                                  ),
                                       dim=0)
                splits = cumsum(map(len, texts_to_embed))
                indices = [ torch.topk(embeddings[a,:] @ embeddings[a+1:b,:].T, dim=0, k=safek).indices
                            for a, b in zip(splits, splits[1:])
                            for safek in (max(0, min(k, b-a-1)),)
                          ]
                prompts = [ dev.prepend_to_prompt(ex, [ ex['profile'][ind] for ind in index.to('cpu').tolist() ])
                            for ex, index in zip(examples, indices) ]
                cumul = taskllm.predict(prompts).exp().cumsum(dim=1)
                guesses = (cumul>=0.5).long().argmax(dim=1)
                targets = [ int(label) - 1 for label in labels ]
                targets = torch.Tensor(targets).long().to(guesses.device)
                mae = torch.abs(guesses - targets).float().mean().item()
                mse = torch.square(guesses - targets).float().mean().item()

                for ex, guess in zip(examples, guesses):
                    devgolds.append({ 'id': ex['id'], 'output': f'{1+guess}' })

            printer.addobs(mae, mse)

    with open(f'lamp3u_xxl_step1_dev_golds.json', 'w') as jsonfile:
        json.dump({ 'task': 'LaMP_2', 'golds': devgolds }, jsonfile)
            
step1_dev_set_labels(k=4, step1_iter='2_augment4')

*** step1_iter: 2_augment4 ***
n       4 MAE (dev) (since) 4 MSE (dev) (since)      dt
1             0.000 (0.000)       0.000 (0.000)  1.31 s
2             0.500 (1.000)       0.500 (1.000)   1.8 s
4             0.500 (0.500)       0.500 (0.500)  2.95 s
8             0.250 (0.000)       0.250 (0.000)  5.43 s
16            0.188 (0.125)       0.188 (0.125)  10.5 s
32            0.156 (0.125)       0.156 (0.125)  19.4 s
64            0.188 (0.219)       0.188 (0.219)  39.8 s
128           0.188 (0.188)       0.203 (0.219)   1.3 m
256           0.180 (0.172)       0.195 (0.188)  2.62 m
512           0.209 (0.238)       0.244 (0.293)  5.19 m
1024          0.213 (0.217)       0.242 (0.240)  10.4 m
2048          0.205 (0.197)       0.227 (0.211)  20.9 m
2500          0.204 (0.199)       0.228 (0.235)  25.5 m


In [None]:
def step2_dev_set_labels(*, step2_iter, step1_iter, k):
    import json
    from RewardPredictor import RewardPredictor
    from TaskLLM import TaskLLM
    from peft import prepare_model_for_kbit_training
    from PersonalizedProductRating import dev_loader
    from ProgressPrinter import ProgressPrinter
    from transformers import T5ForConditionalGeneration
    import torch
    import warnings
    
    device = 'cuda'
    torch.set_default_device(device)
    torch.manual_seed(8675309)

    dev = dev_loader(batch_size=1)

    t5 = prepare_model_for_kbit_training(T5ForConditionalGeneration.from_pretrained('google/flan-t5-xxl', load_in_8bit=True))
    t5.load_adapter(f'User_keq{k}_t5xxl_step1_iter{step1_iter}', 'taskllm')
    t5.load_adapter(f'User_keq{k}_t5xxl_step2_iter{step2_iter}', 'rhat')
    t5.enable_adapters()
    
    taskllm = TaskLLM(t5=t5, adapter_name="taskllm")
    rewardpredictor = RewardPredictor(t5=t5, adapter_name="rhat", model_id=f'User_keq{k}_t5xxl_step2_iter{step2_iter}')

    gumbel = torch.distributions.gumbel.Gumbel(0,1)
    def randomized_similarity(embeddings, nsamples):
        scores = embeddings[0,:] @ embeddings[1:,:].T
        temperature = scores[0].item() - scores[min(scores.shape[0], 4)].item()
        gumbel_shape = torch.Size([nsamples, scores.shape[0]])
        gumbels = temperature * gumbel.sample(gumbel_shape).to(scores.device)
        return torch.unique(torch.topk(scores.unsqueeze(0) + gumbels, dim=1, k=k).indices, sorted=False, dim=0)

    def inner_batch(func, inner_batch_size, inputs):
        from more_itertools import chunked
        return [ func(*ib) for ib in zip(*[ chunked(g, inner_batch_size) for g in inputs ]) ]

    print(f'*** step1_iter: {step1_iter} step2_iter: {step2_iter} ***')
    nvoters = 1
    
    with ProgressPrinter(f'{k} MAE (dev)', f'{k} MSE (dev)') as printer, warnings.catch_warnings():
        warnings.filterwarnings("ignore", message=".*MatMul8bitLt.*")
        devgolds = []
        cumsum = lambda z, acc=0: [0] + [ acc := acc + v for v in z ]

        for examples, labels in dev:
            with torch.no_grad():
                texts_to_embed = [ [ text[:256]
                                     for text in (' '.join(ex['review'].split()), )
                                   ] +
                                   [ text[:256]
                                     for v in ex['profile']
                                     for text in (' '.join(v['text'].split()), )
                                   ]
                                   for ex in examples
                                 ]
                embeddings = torch.cat(inner_batch(func = dev.embed,
                                                   inner_batch_size = 128,
                                                   inputs = (sum(texts_to_embed, []),)
                                                  ),
                                       dim=0)
                splits = cumsum(map(len, texts_to_embed))
                randos = [ randomized_similarity(embeddings[a:b,:], 64) for a, b in zip(splits, splits[1:]) ]
                prompts = [ [ dev.prepend_to_prompt(ex, [ ex['profile'][ind] for ind in indices ])
                              for indices in rando.to('cpu').tolist()
                            ]
                            for ex, rando in zip(examples, randos)
                          ]
                rhats = torch.cat(inner_batch(func = rewardpredictor.predict,
                                              inner_batch_size = 128,
                                              inputs = (sum(prompts, []),)
                                             ),
                                  dim=0)
                splits = cumsum(map(len, prompts))
                votingprompts = [ [ prompt[v] for v in torch.topk(rhats[a:b].view(-1), k=min(nvoters, b-a)).indices.to('cpu').tolist() ]
                                    for a, b, prompt in zip(splits, splits[1:], prompts)
                                ]
                predicts = torch.cat(inner_batch(func = taskllm.predict,
                                                 inner_batch_size = 128,
                                                 inputs = (sum(votingprompts, []),)
                                                ),
                                     dim=0)
                splits = cumsum(map(len, votingprompts))
                guesses = torch.cat([ (predicts[a:b,:].logsumexp(dim=0, keepdim=True).exp().cumsum(dim=1) >= 0.5 * (b-a)).long().argmax(dim=1)
                                      for a, b in zip(splits, splits[1:])
                                    ],
                                    dim=0)

                targets = [ int(label) - 1 for label in labels ]
                targets = torch.Tensor(targets).long().to(guesses.device)
                mae = torch.abs(guesses - targets).float().mean().item()
                mse = torch.square(guesses - targets).float().mean().item()

                for ex, guess in zip(examples, guesses):
                    devgolds.append({ 'id': ex['id'], 'output': f'{1+guess}' })

            printer.addobs(mae, mse)

        printer.print()
        printer.autoprint = False

    with open(f'lamp3u_xxl_step2_dev_golds.json', 'w') as jsonfile:
        json.dump({ 'task': 'LaMP_2', 'golds': devgolds }, jsonfile)

step2_dev_set_labels(k=4, step1_iter='2_augment4', step2_iter='0_augment0')

*** step1_iter: 2_augment4 step2_iter: 0_augment0 ***
n       4 MAE (dev) (since) 4 MSE (dev) (since)      dt
1             0.000 (0.000)       0.000 (0.000)  1.63 s
2             0.500 (1.000)       0.500 (1.000)  5.55 s
4             0.500 (0.500)       0.500 (0.500)  19.9 s
8             0.250 (0.000)       0.250 (0.000)  40.6 s
16            0.188 (0.125)       0.188 (0.125)   1.4 m
32            0.156 (0.125)       0.156 (0.125)  2.83 m
64            0.188 (0.219)       0.188 (0.219)  5.42 m
128           0.211 (0.234)       0.227 (0.266)  11.3 m
256           0.191 (0.172)       0.199 (0.172)  22.3 m
512           0.221 (0.250)       0.260 (0.320)  44.3 m
1024          0.226 (0.230)       0.259 (0.258)  1.48 h
2048          0.214 (0.203)       0.242 (0.225)  2.96 h
2500          0.210 (0.190)       0.240 (0.230)  3.63 h
