# trace

In [20]:
def trace(k):
    from MegaT5 import PeftT5Classifier
    from PersonalizedCitation import train_loader, dev_loader
    from ProgressPrinter import ProgressPrinter
    from peft import IA3Config, TaskType, prepare_model_for_kbit_training
    from transformers import T5ForConditionalGeneration
    import torch

    device = 'cuda'
    torch.set_default_device(device)
    torch.manual_seed(2112)

    train = train_loader(batch_size=2)
    dev = dev_loader(batch_size=2)

    def interleave(a, b):
        from math import inf
        
        atot, btot = a.num_examples, b.num_examples
        aiter, biter = a.__iter__(), b.__iter__()
        aelem, belem = next(aiter), next(biter)
        anum, bnum = 1, 1

        while anum != inf and bnum != inf:
            if anum * btot <= bnum * atot:
                yield (True, aelem)
                try:
                    aelem = next(aiter)
                    anum += 1
                except StopIteration:
                    anum = inf
            else:
                yield (False, belem)
                try:
                    belem = next(biter)
                    bnum += 1
                except StopIteration:
                    bnum = inf

    peft_config = IA3Config(task_type=TaskType.SEQ_2_SEQ_LM)
    t5 = T5ForConditionalGeneration.from_pretrained('google/flan-t5-base')
    fewshot = PeftT5Classifier(train.num_labels, peft_config, t5=t5)

    with ProgressPrinter('iter', f'{k} loss', f'{k} acc', f'{k} acc (dev)') as printer, open(f'trace{k}.csv', 'w', newline='') as csvfile:
        import csv
        writer = csv.writer(csvfile, delimiter=' ', quotechar='|', quoting=csv.QUOTE_MINIMAL)
        for iteration in range(2):
            for istrain, (examples, labels) in interleave(train, dev):
                if iteration == 0 or not istrain:
                    with torch.no_grad():
                        inputs = []
                        trace = []
                        target = torch.Tensor([ int(label == train.choices[1]) for label in labels ]).long().to(device)
        
                        for ex in examples:
                            embeddings = train.embed( [ ex['ref1'], ex['ref2'] ] + 
                                                      [ v['title'] 
                                                       for v in ex['profile']
                                                       if v['title'] != ex['title'] 
                                                     ])
                            scores = torch.max(embeddings[[0,1],:] @ embeddings[2:,:].T, dim=0).values
                            index = torch.topk(scores, dim=0, k=k).indices.to('cpu')
                            titles = [ f'{ex["profile"][ind]["title"]}' for ind in index.tolist() ]
                            concat_titles = ' and '.join([f'"{v}"' for v in titles])
                            input = train.append_to_title(ex, concat_titles)
                            inputs.append(input)
                            trace.append((ex['title'], ex['ref1'], ex['ref2'], titles))
        
                        fewshotacc = (fewshot.predict(inputs).argmax(dim=1) == target).int()
                        avfewshotacc = fewshotacc.float().mean().item()
    
                    fewloss = fewshot.learn(inputs, target) if istrain else None
                    printer.addobs(iteration, fewloss, avfewshotacc if istrain else None, avfewshotacc if not istrain else None)

                    if istrain:
                        for (title, ref1, ref2, titles), acc in zip(trace, fewshotacc.tolist()):
                            writer.writerow([acc, title, ref1, ref2] + titles)

            printer.print()
            printer.autoprint = False

 
from Fork import SubProcess
for k in range(1, 2):
    with SubProcess() as process: process.parent or peft_t5_baselines(k)

n                  iter       since      1 loss       since       1 acc       since 1 acc (dev)       since      dt (s)
1                     0           0        0.68        0.68           1           1           0           0        1.04
2                     0           0        1.43        2.18         0.5           0           0           0        1.42
4                     0           0        1.23       0.821       0.333           0           0           0        2.02
8                     0           0       0.925       0.698         0.5       0.625           0           0        3.45
16                    0           0       0.878       0.823       0.423       0.333         0.5        0.75        5.98
32                    0           0       0.798       0.717       0.462         0.5       0.583       0.667          12
64                    0           0       0.746       0.692        0.52        0.58       0.731       0.857        22.7
128                   0           0     

In [49]:
# GPT2
def model_trace(k):
    from GPT2 import PeftGPT2Classifier
    from more_itertools import chunked
    from ProgressPrinter import ProgressPrinter
    from peft import IA3Config, TaskType, prepare_model_for_kbit_training
    from transformers import AutoModelForCausalLM
    import torch

    device = 'cuda'
    torch.set_default_device(device)
    torch.manual_seed(2112)

    peft_config = IA3Config(task_type=TaskType.CAUSAL_LM, fan_in_fan_out=True)
    model = PeftGPT2Classifier(1, peft_config)
    best_const, best_const_n = 0

    with ProgressPrinter('iter', f'{k} loss', f'{k} acc', f'bc loss') as printer:
        for iteration in range(2):
            with open(f'trace{k}.csv', 'r', newline='') as csvfile:
                import csv
                reader = csv.reader(csvfile, delimiter=' ', quotechar='|', quoting=csv.QUOTE_MINIMAL)
    
                for batch in chunked(reader, 2):
                    bc_pred = []
                    inputs = []
                    target = []
                    for row in batch:
                        acc, title, ref1, ref2 = row[0:4]
                        extra = '\n'.join([ f"Extra: {extra}" for extra in row[4:] ])
                        prompt = f"Title: {title}\nRef1: {ref1}\nRef2: {ref2}\n{extra}"
                        inputs.append(prompt)
                        target.append([float(acc)])
                        if iteration == 0:
                            best_const += float(acc)
                            best_const_n += 1
                        bc_pred.append([best_const / best_const_n])
                            
                    target = torch.Tensor(target).to(device)
                        
                    with torch.no_grad():
                        acc = ((model.predict(inputs) > 1/2).float() == target).float().mean()
                        best_const_loss = torch.nn.functional.binary_cross_entropy(torch.Tensor(bc_pred).to(device), target)
        
                    loss = model.learn(inputs, target) if iteration == 0 else None
                    printer.addobs(iteration, loss, acc, best_const_loss)

            printer.print()
            printer.autoprint = False
 
from Fork import SubProcess
for k in range(1, 2):
    with SubProcess() as process: process.parent or model_trace(k)

n              iter     since    1 loss     since     1 acc     since   bc loss     since    dt (s)
1                 0         0     0.705     0.705         0         0         0         0     0.547
2                 0         0      3.05       5.4         0         0     0.448     0.896     0.626
4                 0         0      1.95      0.85     0.375      0.75     0.503     0.558     0.769
8                 0         0       1.2     0.456     0.562      0.75     0.584     0.665     0.996
16                0         0      1.12      1.04     0.438     0.312      0.63     0.676      1.45
32                0         0     0.988     0.853       0.5     0.562     0.657     0.684      2.35
64                0         0     0.878     0.768     0.547     0.594     0.671     0.686      4.13
128               0         0     0.833     0.788     0.504     0.461     0.681     0.691       7.7
256               0         0     0.795     0.757      0.49     0.477     0.687     0.692      14.9


In [51]:
# flan-t5-base
def model_trace(k):
    from T5 import PeftT5Classifier
    from more_itertools import chunked
    from ProgressPrinter import ProgressPrinter
    from peft import IA3Config, TaskType, prepare_model_for_kbit_training
    from transformers import T5ForConditionalGeneration
    import torch

    device = 'cuda'
    torch.set_default_device(device)
    torch.manual_seed(2112)

    peft_config = IA3Config(task_type=TaskType.SEQ_2_SEQ_LM)
    t5 = T5ForConditionalGeneration.from_pretrained('google/flan-t5-base')
    model = PeftT5Classifier(1, peft_config, t5=t5)
    best_const, best_const_n = 0, 0

    with ProgressPrinter('iter', f'{k} loss', f'{k} acc', f'bc loss') as printer:
        for iteration in range(2):
            with open(f'trace{k}.csv', 'r', newline='') as csvfile:
                import csv
                reader = csv.reader(csvfile, delimiter=' ', quotechar='|', quoting=csv.QUOTE_MINIMAL)
    
                for batch in chunked(reader, 2):
                    bc_pred = []
                    inputs = []
                    target = []
                    for row in batch:
                        acc, title, ref1, ref2 = row[0:4]
                        extra = '\n'.join([ f"Extra: {extra}" for extra in row[4:] ])
                        prompt = f"Title: {title}\nRef1: {ref1}\nRef2: {ref2}\n{extra}"
                        inputs.append(prompt)
                        target.append([float(acc)])
                        if iteration == 0:
                            best_const += float(acc)
                            best_const_n += 1
                        bc_pred.append([best_const / best_const_n])
                            
                    target = torch.Tensor(target).to(device)
                        
                    with torch.no_grad():
                        acc = ((model.predict(inputs) > 1/2).float() == target).float().mean()
                        best_const_loss = torch.nn.functional.binary_cross_entropy(torch.Tensor(bc_pred).to(device), target)
        
                    loss = model.learn(inputs, target) if iteration == 0 else None
                    printer.addobs(iteration, loss, acc, best_const_loss)

            printer.print()
            printer.autoprint = False

from Fork import SubProcess
for k in range(1, 2):
    with SubProcess() as process: process.parent or model_trace(k)

n              iter     since    1 loss     since     1 acc     since   bc loss     since    dt (s)
1                 0         0     0.683     0.683         1         1         0         0     0.685
2                 0         0      0.71     0.737       0.5         0     0.448     0.896     0.812
4                 0         0     0.704     0.698     0.375      0.25     0.503     0.558      1.06
8                 0         0       0.7     0.695     0.438       0.5     0.584     0.665      1.56
16                0         0     0.701     0.703     0.406     0.375      0.63     0.676      2.54
32                0         0       0.7     0.699     0.406     0.406     0.657     0.684      4.47
64                0         0     0.697     0.694     0.461     0.516     0.671     0.686      8.39
128               0         0     0.698     0.698     0.465     0.469     0.681     0.691      16.3
256               0         0     0.698     0.698     0.471     0.477     0.687     0.692      31.9


In [53]:
# flan-t5-xl
def model_trace(k):
    from T5 import PeftT5Classifier
    from more_itertools import chunked
    from ProgressPrinter import ProgressPrinter
    from peft import IA3Config, TaskType, prepare_model_for_kbit_training
    from transformers import T5ForConditionalGeneration
    import torch
    import warnings

    device = 'cuda'
    torch.set_default_device(device)
    torch.manual_seed(2112)

    peft_config = IA3Config(task_type=TaskType.SEQ_2_SEQ_LM)
    t5 = prepare_model_for_kbit_training(T5ForConditionalGeneration.from_pretrained('google/flan-t5-xl', load_in_8bit=True))
    model = PeftT5Classifier(1, peft_config, t5=t5)
    best_const, best_const_n = 0, 0

    with ProgressPrinter('iter', f'{k} loss', f'{k} acc', f'bc loss') as printer, warnings.catch_warnings():
        warnings.filterwarnings("ignore", message=".*MatMul8bitLt.*")
        for iteration in range(2):
            with open(f'trace{k}.csv', 'r', newline='') as csvfile:
                import csv
                reader = csv.reader(csvfile, delimiter=' ', quotechar='|', quoting=csv.QUOTE_MINIMAL)
    
                for batch in chunked(reader, 2):
                    bc_pred = []
                    inputs = []
                    target = []
                    for row in batch:
                        acc, title, ref1, ref2 = row[0:4]
                        extra = '\n'.join([ f"Extra: {extra}" for extra in row[4:] ])
                        prompt = f"Title: {title}\nRef1: {ref1}\nRef2: {ref2}\n{extra}"
                        inputs.append(prompt)
                        target.append([float(acc)])
                        if iteration == 0:
                            best_const += float(acc)
                            best_const_n += 1
                        bc_pred.append([best_const / best_const_n])
                            
                    target = torch.Tensor(target).to(device)
                        
                    with torch.no_grad():
                        acc = ((model.predict(inputs) > 1/2).float() == target).float().mean()
                        best_const_loss = torch.nn.functional.binary_cross_entropy(torch.Tensor(bc_pred).to(device), target)
        
                    loss = model.learn(inputs, target) if iteration == 0 else None
                    printer.addobs(iteration, loss, acc, best_const_loss)

            printer.print()
            printer.autoprint = False

from Fork import SubProcess
for k in range(1, 2):
    with SubProcess() as process: process.parent or model_trace(k)

n              iter     since    1 loss     since     1 acc     since   bc loss     since    dt (s)
1                 0         0     0.682     0.682         1         1         0         0      1.72
2                 0         0     0.725     0.768       0.5         0     0.448     0.896      2.85
4                 0         0     0.711     0.698     0.375      0.25     0.503     0.558      5.03
8                 0         0     0.706       0.7     0.438       0.5     0.584     0.665      9.31
16                0         0     0.707     0.709     0.406     0.375      0.63     0.676      17.9
32                0         0     0.705     0.703     0.406     0.406     0.657     0.684      35.1
64                0         0     0.701     0.696     0.453       0.5     0.671     0.686      69.3
128               0         0       0.7     0.699     0.469     0.484     0.681     0.691       138
256               0         0     0.699     0.698     0.475      0.48     0.687     0.692       277


# mega trace (no finetuning)

In [58]:
# no finetuning ...
def mega_trace(k):
    from MegaT5 import PeftT5Classifier
    from PersonalizedCitation import train_loader, dev_loader
    from ProgressPrinter import ProgressPrinter
    from peft import IA3Config, TaskType, prepare_model_for_kbit_training
    from transformers import T5ForConditionalGeneration
    import torch

    assert k == 1
    
    device = 'cuda'
    torch.set_default_device(device)
    torch.manual_seed(2112)

    train = train_loader(batch_size=2)
    dev = dev_loader(batch_size=2)

    def interleave(a, b):
        from math import inf
        
        atot, btot = a.num_examples, b.num_examples
        aiter, biter = a.__iter__(), b.__iter__()
        aelem, belem = next(aiter), next(biter)
        anum, bnum = 1, 1

        while anum != inf and bnum != inf:
            if anum * btot <= bnum * atot:
                yield (True, aelem)
                try:
                    aelem = next(aiter)
                    anum += 1
                except StopIteration:
                    anum = inf
            else:
                yield (False, belem)
                try:
                    belem = next(biter)
                    bnum += 1
                except StopIteration:
                    bnum = inf

    peft_config = IA3Config(task_type=TaskType.SEQ_2_SEQ_LM)
    t5 = T5ForConditionalGeneration.from_pretrained('google/flan-t5-base')
    fewshot = PeftT5Classifier(train.num_labels, peft_config, t5=t5)

    with ProgressPrinter('iter', f'{k} loss', f'{k} acc', f'{k} acc (dev)') as printer, open(f'megatrace{k}.csv', 'w', newline='') as csvfile:
        import csv
        writer = csv.writer(csvfile, delimiter=' ', quotechar='|', quoting=csv.QUOTE_MINIMAL)
        for iteration in range(2):
            for istrain, (examples, labels) in interleave(train, dev):
                if iteration == 0 or not istrain:
                    with torch.no_grad():
                        inputs = []
                        trace = []
                        target = []
        
                        for ex, label in zip(examples, labels):
                            embeddings = train.embed( [ ex['ref1'], ex['ref2'] ] + 
                                                      [ v['title'] 
                                                       for v in ex['profile']
                                                       if v['title'] != ex['title'] 
                                                     ])
                            scores = torch.max(embeddings[[0,1],:] @ embeddings[2:,:].T, dim=0).values
                            index = torch.topk(scores, dim=0, k=8).indices.to('cpu')
                            for oneind in index.tolist():
                                titles = [ f'{ex["profile"][ind]["title"]}' for ind in (oneind,) ]
                                concat_titles = ' and '.join([f'"{v}"' for v in titles])
                                input = train.append_to_title(ex, concat_titles)
                                inputs.append(input)
                                trace.append((ex['title'], ex['ref1'], ex['ref2'], titles))
                                target.append(int(label == train.choices[1]))

                        target = torch.Tensor(target).long().to(device)
                        fewshotacc = (fewshot.predict(inputs).argmax(dim=1) == target).int()
                        avfewshotacc = fewshotacc.float().mean().item()
    
                    fewloss = None # fewshot.learn(inputs, target) if istrain else None
                    printer.addobs(iteration, fewloss, avfewshotacc if istrain else None, avfewshotacc if not istrain else None)

                    if istrain:
                        for (title, ref1, ref2, titles), acc in zip(trace, fewshotacc.tolist()):
                            writer.writerow([acc, title, ref1, ref2] + titles)

            printer.print()
            printer.autoprint = False

from Fork import SubProcess
for k in range(1, 2):
    with SubProcess() as process: process.parent or mega_trace(k)

n                  iter       since      1 loss       since       1 acc       since 1 acc (dev)       since      dt (s)
1                     0           0           0           0        0.75        0.75           0           0        1.03
2                     0           0           0           0       0.438       0.125           0           0        1.38
4                     0           0           0           0       0.625           1           0           0        2.08
8                     0           0           0           0       0.554         0.5           0           0        3.43
16                    0           0           0           0        0.49       0.417       0.354       0.531        6.05
32                    0           0           0           0       0.529       0.567        0.24       0.125        12.3
64                    0           0           0           0       0.501       0.472       0.346       0.438        23.6
128                   0           0     

In [62]:
# flan-t5-base
def model_trace(k):
    from T5 import PeftT5Classifier
    from more_itertools import chunked
    from ProgressPrinter import ProgressPrinter
    from peft import IA3Config, TaskType, prepare_model_for_kbit_training
    from transformers import T5ForConditionalGeneration
    import random
    import torch

    device = 'cuda'
    torch.set_default_device(device)
    torch.manual_seed(2112)
    random.seed(8675309)
    
    peft_config = IA3Config(task_type=TaskType.SEQ_2_SEQ_LM)
    t5 = T5ForConditionalGeneration.from_pretrained('google/flan-t5-base')
    model = PeftT5Classifier(1, peft_config, t5=t5)
    best_const, best_const_n = 0, 0

    with ProgressPrinter('iter', f'{k} loss', f'{k} acc', f'bc loss') as printer:
        for iteration in range(2):
            with open(f'megatrace{k}.csv', 'r', newline='') as csvfile:
                import csv
                reader = csv.reader(csvfile, delimiter=' ', quotechar='|', quoting=csv.QUOTE_MINIMAL)

                rows = [row for row in reader]
                random.shuffle(rows)
                for batch in chunked(rows, 2):
                    bc_pred = []
                    inputs = []
                    target = []
                    for row in batch:
                        acc, title, ref1, ref2 = row[0:4]
                        extra = '\n'.join([ f"Extra: {extra}" for extra in row[4:] ])
                        prompt = f"Title: {title}\nRef1: {ref1}\nRef2: {ref2}\n{extra}"
                        inputs.append(prompt)
                        target.append([float(acc)])
                        if iteration == 0:
                            best_const += float(acc)
                            best_const_n += 1
                        bc_pred.append([best_const / best_const_n])
                            
                    target = torch.Tensor(target).to(device)
                        
                    with torch.no_grad():
                        acc = ((model.predict(inputs) > 1/2).float() == target).float().mean()
                        best_const_loss = torch.nn.functional.binary_cross_entropy(torch.Tensor(bc_pred).to(device), target)
        
                    loss = model.learn(inputs, target) if iteration == 0 else None
                    printer.addobs(iteration, loss, acc, best_const_loss)

            printer.print()
            printer.autoprint = False

from Fork import SubProcess
for k in range(1, 2):
    with SubProcess() as process: process.parent or model_trace(k)

n              iter     since    1 loss     since     1 acc     since   bc loss     since    dt (s)
1                 0         0     0.703     0.703         0         0         0         0      1.02
2                 0         0     0.687     0.671       0.5         1         0         0      1.15
4                 0         0     0.678     0.669     0.625      0.75      0.26      0.52      1.39
8                 0         0     0.683     0.688     0.625     0.625     0.469     0.678      1.88
16                0         0     0.681      0.68     0.625     0.625     0.562     0.655      2.85
32                0         0     0.696      0.71     0.531     0.438     0.649     0.737      4.76
64                0         0     0.696     0.696     0.461     0.391     0.668     0.687      8.64
128               0         0     0.697     0.699      0.48       0.5      0.68     0.692      16.3
256               0         0     0.696     0.696       0.5      0.52     0.684     0.689      31.7


# ultra trace (with finetuning)

In [64]:
# with finetuning ...
def ultra_trace(k):
    from MegaT5 import PeftT5Classifier
    from PersonalizedCitation import train_loader, dev_loader
    from ProgressPrinter import ProgressPrinter
    from peft import IA3Config, TaskType, prepare_model_for_kbit_training
    from transformers import T5ForConditionalGeneration
    import torch

    assert k == 1
    
    device = 'cuda'
    torch.set_default_device(device)
    torch.manual_seed(2112)

    train = train_loader(batch_size=2)
    dev = dev_loader(batch_size=2)

    def interleave(a, b):
        from math import inf
        
        atot, btot = a.num_examples, b.num_examples
        aiter, biter = a.__iter__(), b.__iter__()
        aelem, belem = next(aiter), next(biter)
        anum, bnum = 1, 1

        while anum != inf and bnum != inf:
            if anum * btot <= bnum * atot:
                yield (True, aelem)
                try:
                    aelem = next(aiter)
                    anum += 1
                except StopIteration:
                    anum = inf
            else:
                yield (False, belem)
                try:
                    belem = next(biter)
                    bnum += 1
                except StopIteration:
                    bnum = inf

    peft_config = IA3Config(task_type=TaskType.SEQ_2_SEQ_LM)
    t5 = T5ForConditionalGeneration.from_pretrained('google/flan-t5-base')
    fewshot = PeftT5Classifier(train.num_labels, peft_config, t5=t5)

    with ProgressPrinter('iter', f'{k} loss', f'{k} acc', f'{k} acc (dev)') as printer, open(f'ultratrace{k}.csv', 'w', newline='') as csvfile:
        import csv
        writer = csv.writer(csvfile, delimiter=' ', quotechar='|', quoting=csv.QUOTE_MINIMAL)
        for iteration in range(2):
            for istrain, (examples, labels) in interleave(train, dev):
                if iteration == 0 or not istrain:
                    with torch.no_grad():
                        inputs = []
                        trace = []
                        target = []
                        traininputs = []
                        traintarget = []
        
                        for ex, label in zip(examples, labels):
                            embeddings = train.embed( [ ex['ref1'], ex['ref2'] ] + 
                                                      [ v['title'] 
                                                       for v in ex['profile']
                                                       if v['title'] != ex['title'] 
                                                     ])
                            scores = torch.max(embeddings[[0,1],:] @ embeddings[2:,:].T, dim=0).values
                            index = torch.topk(scores, dim=0, k=8).indices.to('cpu')
                            for n, oneind in enumerate(index.tolist()):
                                titles = [ f'{ex["profile"][ind]["title"]}' for ind in (oneind,) ]
                                concat_titles = ' and '.join([f'"{v}"' for v in titles])
                                input = train.append_to_title(ex, concat_titles)
                                inputs.append(input)
                                trace.append((ex['title'], ex['ref1'], ex['ref2'], titles))
                                target.append(int(label == train.choices[1]))
                                if n == 0:
                                    traininputs.append(input)
                                    traintarget.append(int(label == train.choices[1]))

                        target = torch.Tensor(target).long().to(device)
                        traintarget = torch.Tensor(traintarget).long().to(device)
                        fewshotacc = (fewshot.predict(inputs).argmax(dim=1) == target).int()
                        avfewshotacc = fewshotacc.float().mean().item()
    
                    fewloss = fewshot.learn(traininputs, traintarget) if istrain else None
                    printer.addobs(iteration, fewloss, avfewshotacc if istrain else None, avfewshotacc if not istrain else None)

                    if istrain:
                        for (title, ref1, ref2, titles), acc in zip(trace, fewshotacc.tolist()):
                            writer.writerow([acc, title, ref1, ref2] + titles)

            printer.print()
            printer.autoprint = False

from Fork import SubProcess
for k in range(1, 2):
    with SubProcess() as process: process.parent or ultra_trace(k)

n                  iter       since      1 loss       since       1 acc       since 1 acc (dev)       since      dt (s)
1                     0           0        0.68        0.68        0.75        0.75           0           0        1.14
2                     0           0        1.43        2.18       0.375           0           0           0        1.61
4                     0           0        1.23       0.821        0.25           0           0           0         2.4
8                     0           0       0.925       0.698       0.455       0.609           0           0         4.2
16                    0           0       0.878       0.823       0.399       0.333         0.5        0.75        7.46
32                    0           0       0.798       0.717       0.466       0.534       0.573       0.646          15
64                    0           0       0.746       0.692       0.522        0.58       0.726       0.857        29.8
128                   0           0     

In [65]:
# flan-t5-base
def model_trace(k):
    from T5 import PeftT5Classifier
    from more_itertools import chunked
    from ProgressPrinter import ProgressPrinter
    from peft import IA3Config, TaskType, prepare_model_for_kbit_training
    from transformers import T5ForConditionalGeneration
    import random
    import torch

    device = 'cuda'
    torch.set_default_device(device)
    torch.manual_seed(2112)
    random.seed(8675309)
    
    peft_config = IA3Config(task_type=TaskType.SEQ_2_SEQ_LM)
    t5 = T5ForConditionalGeneration.from_pretrained('google/flan-t5-base')
    model = PeftT5Classifier(1, peft_config, t5=t5)
    best_const, best_const_n = 0, 0

    with ProgressPrinter('iter', f'{k} loss', f'{k} acc', f'bc loss') as printer:
        for iteration in range(2):
            with open(f'ultratrace{k}.csv', 'r', newline='') as csvfile:
                import csv
                reader = csv.reader(csvfile, delimiter=' ', quotechar='|', quoting=csv.QUOTE_MINIMAL)

                rows = [row for row in reader]
                random.shuffle(rows)
                for batch in chunked(rows, 2):
                    bc_pred = []
                    inputs = []
                    target = []
                    for row in batch:
                        acc, title, ref1, ref2 = row[0:4]
                        extra = '\n'.join([ f"Extra: {extra}" for extra in row[4:] ])
                        prompt = f"Title: {title}\nRef1: {ref1}\nRef2: {ref2}\n{extra}"
                        inputs.append(prompt)
                        target.append([float(acc)])
                        if iteration == 0:
                            best_const += float(acc)
                            best_const_n += 1
                        bc_pred.append([best_const / best_const_n])
                            
                    target = torch.Tensor(target).to(device)
                        
                    with torch.no_grad():
                        acc = ((model.predict(inputs) > 1/2).float() == target).float().mean()
                        best_const_loss = torch.nn.functional.binary_cross_entropy(torch.Tensor(bc_pred).to(device), target)
        
                    loss = model.learn(inputs, target) if iteration == 0 else None
                    printer.addobs(iteration, loss, acc, best_const_loss)

            printer.print()
            printer.autoprint = False

from Fork import SubProcess
for k in range(1, 2):
    with SubProcess() as process: process.parent or model_trace(k)

n              iter     since    1 loss     since     1 acc     since   bc loss     since    dt (s)
1                 0         0     0.693     0.693       0.5       0.5     0.347     0.347     0.966
2                 0         0     0.693     0.693       0.5       0.5     0.448     0.549      1.09
4                 0         0      0.69     0.688     0.625      0.75     0.503     0.558      1.33
8                 0         0     0.696     0.702     0.562       0.5     0.584     0.665      1.82
16                0         0     0.699     0.701     0.469     0.375      0.63     0.676      2.78
32                0         0     0.698     0.697     0.516     0.562     0.653     0.676      4.71
64                0         0     0.698     0.698     0.531     0.547     0.668     0.683      8.59
128               0         0     0.691     0.684     0.559     0.586     0.673     0.677      16.3
256               0         0     0.674     0.658     0.598     0.637     0.666     0.659      31.7


# giga trace (finetuning on k=1, save at end)

In [2]:
# with finetuning and data doubling ...
def giga_trace(k):
    from MegaT5 import PeftT5Classifier
    from PersonalizedCitation import train_loader, dev_loader
    from ProgressPrinter import ProgressPrinter
    from peft import IA3Config, TaskType, prepare_model_for_kbit_training
    from transformers import T5ForConditionalGeneration
    import torch

    assert k == 1
    
    device = 'cuda'
    torch.set_default_device(device)
    torch.manual_seed(2112)

    train = train_loader(batch_size=2, double_data=True)
    dev = dev_loader(batch_size=2)

    def interleave(a, b):
        from math import inf
        
        atot, btot = a.num_examples, b.num_examples
        aiter, biter = a.__iter__(), b.__iter__()
        aelem, belem = next(aiter), next(biter)
        anum, bnum = 1, 1

        while anum != inf and bnum != inf:
            if anum * btot <= bnum * atot:
                yield (True, aelem)
                try:
                    aelem = next(aiter)
                    anum += 1
                except StopIteration:
                    anum = inf
            else:
                yield (False, belem)
                try:
                    belem = next(biter)
                    bnum += 1
                except StopIteration:
                    bnum = inf

    peft_config = IA3Config(task_type=TaskType.SEQ_2_SEQ_LM)
    t5 = T5ForConditionalGeneration.from_pretrained('google/flan-t5-base')
    fewshot = PeftT5Classifier(train.num_labels, peft_config, t5=t5)

    with ProgressPrinter('iter', f'{k} loss', f'{k} acc', f'{k} acc (dev)') as printer, open(f'gigatrace{k}.csv', 'w', newline='') as csvfile:
        import csv
        writer = csv.writer(csvfile, delimiter=' ', quotechar='|', quoting=csv.QUOTE_MINIMAL)
        for iteration in range(2):
            for istrain, (examples, labels) in interleave(train, dev):
                if iteration == 0 or not istrain:
                    with torch.no_grad():
                        inputs = []
                        trace = []
                        target = []
                        traininputs = []
                        traintarget = []
        
                        for ex, label in zip(examples, labels):
                            embeddings = train.embed( [ ex['ref1'], ex['ref2'] ] + 
                                                      [ v['title'] 
                                                       for v in ex['profile']
                                                       if v['title'] != ex['title'] 
                                                     ])
                            scores = torch.max(embeddings[[0,1],:] @ embeddings[2:,:].T, dim=0).values
                            index = torch.topk(scores, dim=0, k=8).indices.to('cpu')
                            for n, oneind in enumerate(index.tolist()):
                                titles = [ f'{ex["profile"][ind]["title"]}' for ind in (oneind,) ]
                                concat_titles = ' and '.join([f'"{v}"' for v in titles])
                                input = train.append_to_title(ex, concat_titles)
                                inputs.append(input)
                                trace.append((ex['title'], ex['ref1'], ex['ref2'], titles))
                                target.append(int(label == train.choices[1]))
                                if n == 0:
                                    traininputs.append(input)
                                    traintarget.append(int(label == train.choices[1]))

                        target = torch.Tensor(target).long().to(device)
                        traintarget = torch.Tensor(traintarget).long().to(device)
                        fewshotacc = (fewshot.predict(inputs).argmax(dim=1) == target).int()
                        trainfewshotacc = (fewshot.predict(traininputs).argmax(dim=1) == traintarget).float().mean()
    
                    fewloss = fewshot.learn(traininputs, traintarget) if istrain else None
                    printer.addobs(iteration, fewloss, trainfewshotacc if istrain else None, trainfewshotacc if not istrain else None)

                    if istrain:
                        for (title, ref1, ref2, titles), acc in zip(trace, fewshotacc.tolist()):
                            writer.writerow([acc, title, ref1, ref2] + titles)

            printer.print()
            printer.autoprint = False
            if iteration == 0:
                fewshot.save_pretrained('gigafewshot')

from Fork import SubProcess
for k in range(1, 2):
    with SubProcess() as process: process.parent or giga_trace(k)

n                  iter       since      1 loss       since       1 acc       since 1 acc (dev)       since      dt (s)
1                     0           0       0.748       0.748           0           0           0           0        1.24
2                     0           0       0.913        1.08        0.25         0.5           0           0        1.77
4                     0           0       0.781        0.65         0.5        0.75           0           0        2.81
8                     0           0       0.908        1.08       0.286           0           1           1        4.59
16                    0           0       0.793       0.693         0.4         0.5           1           0        9.05
32                    0           0       0.765       0.736       0.431       0.464       0.667         0.5        17.2
64                    0           0       0.721       0.675       0.509       0.589         0.5       0.375        33.6
128                   0           0     

In [None]:
def wazzup():
    from MegaT5 import PeftT5Classifier
    from transformers import T5ForConditionalGeneration
    from peft import PeftConfig

    t5 = T5ForConditionalGeneration.from_pretrained('google/flan-t5-base')
    peft_config = PeftConfig.from_pretrained('gigafewshot')
    fewshot = PeftT5Classifier(2, peft_config, t5=t5, model_id='gigafewshot')
    print(fewshot.predict(['yo']))
    fewshot = PeftT5Classifier(2, peft_config, t5=t5)
    print(fewshot.predict(['yo']))

from Fork import SubProcess
for k in range(1, 2):
    with SubProcess() as process: process.parent or wazzup()

In [None]:
# check save/load works ... yes but not if the underlying T5 is reused ... figure this out ... 
# https://stackoverflow.com/questions/76197574/loading-multiple-lora-bins
def check_save_load(k):
    from MegaT5 import PeftT5Classifier as TaskLLM
    from T5 import PeftT5Classifier
    from PersonalizedCitation import train_loader, dev_loader
    from ProgressPrinter import ProgressPrinter
    from peft import PeftConfig, IA3Config, TaskType
    from transformers import T5ForConditionalGeneration
    import torch

    assert k == 1
    
    device = 'cuda'
    torch.set_default_device(device)
    torch.manual_seed(2112)

    train = train_loader(batch_size=2)
    dev = dev_loader(batch_size=2)

    def interleave(a, b):
        from math import inf
        
        atot, btot = a.num_examples, b.num_examples
        aiter, biter = a.__iter__(), b.__iter__()
        aelem, belem = next(aiter), next(biter)
        anum, bnum = 1, 1

        while anum != inf and bnum != inf:
            if anum * btot <= bnum * atot:
                yield (True, aelem)
                try:
                    aelem = next(aiter)
                    anum += 1
                except StopIteration:
                    anum = inf
            else:
                yield (False, belem)
                try:
                    belem = next(biter)
                    bnum += 1
                except StopIteration:
                    bnum = inf

    t5 = T5ForConditionalGeneration.from_pretrained('google/flan-t5-base')
    peft_config = PeftConfig.from_pretrained('gigafewshot')
    taskllm = TaskLLM(train.num_labels, peft_config, t5=t5, model_id='gigafewshot')
    rhat_peft_config = IA3Config(task_type=TaskType.SEQ_2_SEQ_LM)
    rhat_t5 = T5ForConditionalGeneration.from_pretrained('google/flan-t5-base')
    rewardpredictor = PeftT5Classifier(1, rhat_peft_config, t5=rhat_t5)

    with ProgressPrinter('iter', f'{k} loss', f'{k} acc', f'{k} acc (dev)') as printer:
        for iteration in range(2):
            for istrain, (examples, labels) in interleave(train, dev):
                if iteration == 0 or not istrain:
                    for ex, label in zip(examples, labels):
                        greedyrewards = []
                        allloss = []
                        with torch.no_grad():
                            embeddings = train.embed( [ ex['ref1'], ex['ref2'] ] + 
                                                      [ v['title'] 
                                                       for v in ex['profile']
                                                       if v['title'] != ex['title'] 
                                                     ])
                            scores = torch.max(embeddings[[0,1],:] @ embeddings[2:,:].T, dim=0).values
                            index = torch.topk(scores, dim=0, k=8).indices.to('cpu')
                            prompts = []
                            rhatprompts = []
                            for n, oneind in enumerate(index.tolist()):
                                titles = [ f'{ex["profile"][ind]["title"]}' for ind in (oneind,) ]
                                concat_titles = ' and '.join([f'"{v}"' for v in titles])
                                input = train.append_to_title(ex, concat_titles)
                                prompts.append(input)
                                rhatprompt = f"Title: {ex['title']}\nRef1: {ex['ref1']}\nRef2: {ex['ref2']}\nExtra: {titles[0]}"
                                rhatprompts.append(rhatprompt)
                      
                            guesses = taskllm.predict(prompts).argmax(dim=1)
                            target = int(label == train.choices[1])
                            rewards = (guesses == target).float().unsqueeze(1)
                            rhats = rewardpredictor.predict(rhatprompts)
                            #greedy = torch.argmax(rhats, dim=0).item()
                            greedy = 0
                            greedyreward = rewards[greedy, 0].item()
                            greedyrewards.append(greedyreward)
                            
                        loss = rewardpredictor.learn(rhatprompts, rewards) if istrain else None
                        allloss.append(loss)

                    greedyacc = torch.Tensor(greedyrewards).float().mean().item()
                    predloss = torch.Tensor(allloss).mean().item() if istrain else None

                    printer.addobs(iteration, predloss, greedyacc if istrain else None, greedyacc if not istrain else None)

            printer.print()
            printer.autoprint = False

from Fork import SubProcess
for k in range(1, 2):
    with SubProcess() as process: process.parent or check_save_load(k)

In [3]:
# learn over top 8 results retrieved by ranker using frozen fine-tuned task llm
def learn_ranker(k):
    from MegaT5 import PeftT5Classifier as TaskLLM
    from T5 import PeftT5Classifier
    from PersonalizedCitation import train_loader, dev_loader
    from ProgressPrinter import ProgressPrinter
    from peft import PeftConfig, IA3Config, TaskType
    from transformers import T5ForConditionalGeneration
    import torch

    assert k == 1
    
    device = 'cuda'
    torch.set_default_device(device)
    torch.manual_seed(2112)

    train = train_loader(batch_size=2, double_data=True)
    dev = dev_loader(batch_size=2)

    def interleave(a, b):
        from math import inf
        
        atot, btot = a.num_examples, b.num_examples
        aiter, biter = a.__iter__(), b.__iter__()
        aelem, belem = next(aiter), next(biter)
        anum, bnum = 1, 1

        while anum != inf and bnum != inf:
            if anum * btot <= bnum * atot:
                yield (True, aelem)
                try:
                    aelem = next(aiter)
                    anum += 1
                except StopIteration:
                    anum = inf
            else:
                yield (False, belem)
                try:
                    belem = next(biter)
                    bnum += 1
                except StopIteration:
                    bnum = inf

    t5 = T5ForConditionalGeneration.from_pretrained('google/flan-t5-base')
    peft_config = PeftConfig.from_pretrained('gigafewshot')
    taskllm = TaskLLM(train.num_labels, peft_config, t5=t5, model_id='gigafewshot')
    rhat_peft_config = IA3Config(task_type=TaskType.SEQ_2_SEQ_LM)
    rhat_t5 = T5ForConditionalGeneration.from_pretrained('google/flan-t5-base')
    rewardpredictor = PeftT5Classifier(1, rhat_peft_config, t5=rhat_t5)

    with ProgressPrinter('iter', f'{k} loss', f'{k} acc', f'{k} acc (dev)') as printer:
        for iteration in range(2):
            for istrain, (examples, labels) in interleave(train, dev):
                if iteration == 0 or not istrain:
                    for ex, label in zip(examples, labels):
                        greedyrewards = []
                        allloss = []
                        with torch.no_grad():
                            embeddings = train.embed( [ ex['ref1'], ex['ref2'] ] + 
                                                      [ v['title'] 
                                                       for v in ex['profile']
                                                       if v['title'] != ex['title'] 
                                                     ])
                            scores = torch.max(embeddings[[0,1],:] @ embeddings[2:,:].T, dim=0).values
                            index = torch.topk(scores, dim=0, k=8).indices.to('cpu')
                            prompts = []
                            rhatprompts = []
                            for n, oneind in enumerate(index.tolist()):
                                titles = [ f'{ex["profile"][ind]["title"]}' for ind in (oneind,) ]
                                concat_titles = ' and '.join([f'"{v}"' for v in titles])
                                input = train.append_to_title(ex, concat_titles)
                                prompts.append(input)
                                rhatprompt = f"Title: {ex['title']}\nRef1: {ex['ref1']}\nRef2: {ex['ref2']}\nExtra: {titles[0]}"
                                rhatprompts.append(rhatprompt)
                      
                            guesses = taskllm.predict(prompts).argmax(dim=1)
                            target = int(label == train.choices[1])
                            rewards = (guesses == target).float().unsqueeze(1)
                            rhats = rewardpredictor.predict(rhatprompts)
                            greedy = torch.argmax(rhats, dim=0).item()
                            greedyreward = rewards[greedy, 0].item()
                            greedyrewards.append(greedyreward)
                            
                        loss = rewardpredictor.learn(rhatprompts, rewards) if istrain else None
                        allloss.append(loss)

                    greedyacc = torch.Tensor(greedyrewards).float().mean().item()
                    predloss = torch.Tensor(allloss).mean().item() if istrain else None

                    printer.addobs(iteration, predloss, greedyacc if istrain else None, greedyacc if not istrain else None)

            printer.print()
            printer.autoprint = False

from Fork import SubProcess
for k in range(1, 2):
    with SubProcess() as process: process.parent or learn_ranker(k)

n                  iter       since      1 loss       since       1 acc       since 1 acc (dev)       since      dt (s)
1                     0           0       0.697       0.697           0           0           0           0         1.3
2                     0           0        0.69       0.683           0           0           0           0           2
4                     0           0       0.706       0.722        0.25         0.5           0           0        3.38
8                     0           0        0.69       0.668       0.571           1           1           1         5.8
16                    0           0       0.697       0.703       0.467       0.375           1           0          12
32                    0           0       0.672       0.645       0.586       0.714       0.667         0.5          23
64                    0           0       0.652       0.632       0.667        0.75       0.571         0.5        45.8
128                   0           0     

In [None]:
# learn over top 16 results retrieved by ranker using frozen fine-tuned task llm
def learn_ranker(k):
    from MegaT5 import PeftT5Classifier as TaskLLM
    from T5 import PeftT5Classifier
    from PersonalizedCitation import train_loader, dev_loader
    from ProgressPrinter import ProgressPrinter
    from peft import PeftConfig, IA3Config, TaskType
    from transformers import T5ForConditionalGeneration
    import torch

    assert k == 1
    
    device = 'cuda'
    torch.set_default_device(device)
    torch.manual_seed(2112)

    train = train_loader(batch_size=2, double_data=True)
    dev = dev_loader(batch_size=2)

    def interleave(a, b):
        from math import inf
        
        atot, btot = a.num_examples, b.num_examples
        aiter, biter = a.__iter__(), b.__iter__()
        aelem, belem = next(aiter), next(biter)
        anum, bnum = 1, 1

        while anum != inf and bnum != inf:
            if anum * btot <= bnum * atot:
                yield (True, aelem)
                try:
                    aelem = next(aiter)
                    anum += 1
                except StopIteration:
                    anum = inf
            else:
                yield (False, belem)
                try:
                    belem = next(biter)
                    bnum += 1
                except StopIteration:
                    bnum = inf

    t5 = T5ForConditionalGeneration.from_pretrained('google/flan-t5-base')
    peft_config = PeftConfig.from_pretrained('gigafewshot')
    taskllm = TaskLLM(train.num_labels, peft_config, t5=t5, model_id='gigafewshot')
    rhat_peft_config = IA3Config(task_type=TaskType.SEQ_2_SEQ_LM)
    rhat_t5 = T5ForConditionalGeneration.from_pretrained('google/flan-t5-base')
    rewardpredictor = PeftT5Classifier(1, rhat_peft_config, t5=rhat_t5)

    with ProgressPrinter('iter', f'{k} loss', f'{k} acc', f'{k} acc (dev)') as printer:
        for iteration in range(2):
            for istrain, (examples, labels) in interleave(train, dev):
                if iteration == 0 or not istrain:
                    for ex, label in zip(examples, labels):
                        greedyrewards = []
                        allloss = []
                        with torch.no_grad():
                            embeddings = train.embed( [ ex['ref1'], ex['ref2'] ] + 
                                                      [ v['title'] 
                                                       for v in ex['profile']
                                                       if v['title'] != ex['title'] 
                                                     ])
                            scores = torch.max(embeddings[[0,1],:] @ embeddings[2:,:].T, dim=0).values
                            index = torch.topk(scores, dim=0, k=16).indices.to('cpu')
                            prompts = []
                            rhatprompts = []
                            for n, oneind in enumerate(index.tolist()):
                                titles = [ f'{ex["profile"][ind]["title"]}' for ind in (oneind,) ]
                                concat_titles = ' and '.join([f'"{v}"' for v in titles])
                                input = train.append_to_title(ex, concat_titles)
                                prompts.append(input)
                                rhatprompt = f"Title: {ex['title']}\nRef1: {ex['ref1']}\nRef2: {ex['ref2']}\nExtra: {titles[0]}"
                                rhatprompts.append(rhatprompt)
                      
                            guesses = taskllm.predict(prompts).argmax(dim=1)
                            target = int(label == train.choices[1])
                            rewards = (guesses == target).float().unsqueeze(1)
                            rhats = rewardpredictor.predict(rhatprompts)
                            greedy = torch.argmax(rhats, dim=0).item()
                            greedyreward = rewards[greedy, 0].item()
                            greedyrewards.append(greedyreward)
                            
                        loss = rewardpredictor.learn(rhatprompts, rewards) if istrain else None
                        allloss.append(loss)

                    greedyacc = torch.Tensor(greedyrewards).float().mean().item()
                    predloss = torch.Tensor(allloss).mean().item() if istrain else None

                    printer.addobs(iteration, predloss, greedyacc if istrain else None, greedyacc if not istrain else None)

            printer.print()
            printer.autoprint = False

from Fork import SubProcess
for k in range(1, 2):
    with SubProcess() as process: process.parent or learn_ranker(k)

n                  iter       since      1 loss       since       1 acc       since 1 acc (dev)       since      dt (s)
1                     0           0       0.666       0.666           0           0           0           0        1.69
2                     0           0       0.667       0.668           0           0           0           0        2.78
4                     0           0       0.692       0.717        0.25         0.5           0           0        4.99
8                     0           0       0.686       0.679       0.571           1           1           1        8.89
16                    0           0       0.693       0.699         0.6       0.625           1           0        18.9
32                    0           0       0.673       0.651       0.621       0.643       0.667         0.5        36.8
64                    0           0        0.65       0.626       0.684        0.75       0.714        0.75        73.7
128                   0           0     

In [None]:
# learn over top 24 results retrieved by ranker using frozen fine-tuned task llm
def learn_ranker(k):
    from MegaT5 import PeftT5Classifier as TaskLLM
    from T5 import PeftT5Classifier
    from PersonalizedCitation import train_loader, dev_loader
    from ProgressPrinter import ProgressPrinter
    from peft import PeftConfig, IA3Config, TaskType
    from transformers import T5ForConditionalGeneration
    import torch

    assert k == 1
    
    device = 'cuda'
    torch.set_default_device(device)
    torch.manual_seed(2112)

    train = train_loader(batch_size=2, double_data=True)
    dev = dev_loader(batch_size=2)

    def interleave(a, b):
        from math import inf
        
        atot, btot = a.num_examples, b.num_examples
        aiter, biter = a.__iter__(), b.__iter__()
        aelem, belem = next(aiter), next(biter)
        anum, bnum = 1, 1

        while anum != inf and bnum != inf:
            if anum * btot <= bnum * atot:
                yield (True, aelem)
                try:
                    aelem = next(aiter)
                    anum += 1
                except StopIteration:
                    anum = inf
            else:
                yield (False, belem)
                try:
                    belem = next(biter)
                    bnum += 1
                except StopIteration:
                    bnum = inf

    t5 = T5ForConditionalGeneration.from_pretrained('google/flan-t5-base')
    peft_config = PeftConfig.from_pretrained('gigafewshot')
    taskllm = TaskLLM(train.num_labels, peft_config, t5=t5, model_id='gigafewshot')
    rhat_peft_config = IA3Config(task_type=TaskType.SEQ_2_SEQ_LM)
    rhat_t5 = T5ForConditionalGeneration.from_pretrained('google/flan-t5-base')
    rewardpredictor = PeftT5Classifier(1, rhat_peft_config, t5=rhat_t5)

    with ProgressPrinter('iter', f'{k} loss', f'{k} acc', f'{k} acc (dev)') as printer:
        for iteration in range(2):
            for istrain, (examples, labels) in interleave(train, dev):
                if iteration == 0 or not istrain:
                    for ex, label in zip(examples, labels):
                        greedyrewards = []
                        allloss = []
                        with torch.no_grad():
                            embeddings = train.embed( [ ex['ref1'], ex['ref2'] ] + 
                                                      [ v['title'] 
                                                       for v in ex['profile']
                                                       if v['title'] != ex['title'] 
                                                     ])
                            scores = torch.max(embeddings[[0,1],:] @ embeddings[2:,:].T, dim=0).values
                            index = torch.topk(scores, dim=0, k=24).indices.to('cpu')
                            prompts = []
                            rhatprompts = []
                            for n, oneind in enumerate(index.tolist()):
                                titles = [ f'{ex["profile"][ind]["title"]}' for ind in (oneind,) ]
                                concat_titles = ' and '.join([f'"{v}"' for v in titles])
                                input = train.append_to_title(ex, concat_titles)
                                prompts.append(input)
                                rhatprompt = f"Title: {ex['title']}\nRef1: {ex['ref1']}\nRef2: {ex['ref2']}\nExtra: {titles[0]}"
                                rhatprompts.append(rhatprompt)
                      
                            guesses = taskllm.predict(prompts).argmax(dim=1)
                            target = int(label == train.choices[1])
                            rewards = (guesses == target).float().unsqueeze(1)
                            rhats = rewardpredictor.predict(rhatprompts)
                            greedy = torch.argmax(rhats, dim=0).item()
                            greedyreward = rewards[greedy, 0].item()
                            greedyrewards.append(greedyreward)
                            
                        loss = rewardpredictor.learn(rhatprompts, rewards) if istrain else None
                        allloss.append(loss)

                    greedyacc = torch.Tensor(greedyrewards).float().mean().item()
                    predloss = torch.Tensor(allloss).mean().item() if istrain else None

                    printer.addobs(iteration, predloss, greedyacc if istrain else None, greedyacc if not istrain else None)

            printer.print()
            printer.autoprint = False

from Fork import SubProcess
for k in range(1, 2):
    with SubProcess() as process: process.parent or learn_ranker(k)