In [1]:
import json
import os

from typing import List

import numpy as np
import pandas as pd
import torch
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

  from .autonotebook import tqdm as notebook_tqdm


### Variables

In [5]:
backbone = "cosmo-xl"
run_name = "prepend_later_drop0.5_immediate-augmented"
inference_mode = "delayed"
assert inference_mode in ["delayed", "immediate"]

### Load model

In [6]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = AutoTokenizer.from_pretrained("allenai/cosmo-xl")
model = AutoModelForSeq2SeqLM.from_pretrained(f"../checkpoints/{backbone}/{run_name}/final").to(device)

In [12]:
tokenizer.push_to_hub(
    "seongbo/cosmo-3b-finetuned",
    token="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",
    safe_serialization=False,
    revision="drop0.5_augmented"
)

CommitInfo(commit_url='https://huggingface.co/seongbo/cosmo-3b-finetuned/commit/f6e6b16b7463a06f894b31b24920c34d88e5c64a', commit_message='Upload tokenizer', commit_description='', oid='f6e6b16b7463a06f894b31b24920c34d88e5c64a', pr_url=None, pr_revision=None, pr_num=None)

In [7]:
model.push_to_hub(
    "seongbo/cosmo-3b-finetuned",
    token="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",
    safe_serialization=False,
    revision="drop0.5_augmented",
)

pytorch_model-00002-of-00003.bin:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

[A[A
pytorch_model-00002-of-00003.bin:   0%|          | 65.5k/4.98G [00:00<2:37:00, 528kB/s]

pytorch_model-00002-of-00003.bin:   0%|          | 475k/4.98G [00:00<34:16, 2.42MB/s]  

pytorch_model-00002-of-00003.bin:   0%|          | 1.25M/4.98G [00:00<18:19, 4.53MB/s]

pytorch_model-00002-of-00003.bin:   0%|          | 1.88M/4.98G [00:00<18:57, 4.37MB/s]

pytorch_model-00002-of-00003.bin:   0%|          | 2.34M/4.98G [00:00<21:12, 3.91MB/s]

[A[A

pytorch_model-00002-of-00003.bin:   0%|          | 5.34M/4.98G [00:01<17:58, 4.61MB/s]

pytorch_model-00002-of-00003.bin:   0%|          | 7.54M/4.98G [00:01<09:58, 8.31MB/s]

pytorch_model-00002-of-00003.bin:   0%|          | 9.08M/4.98G [00:01<08:30, 9.73MB/s]

pytorch_model-00002-of-00003.bin:   0%|          | 10.6M/4.98G [00:01<07:39, 10.8MB/s]

pytorch_model-00002-of-00003.bin:   0%|          | 12.2M/4.98G [00:01<07:09, 11.6MB/s]

pytorch_model-00002-of-0

CommitInfo(commit_url='https://huggingface.co/seongbo/cosmo-3b-finetuned/commit/90cf593ffc29f4ca20331876a16c3d4bb8310a0c', commit_message='Upload T5ForConditionalGeneration', commit_description='', oid='90cf593ffc29f4ca20331876a16c3d4bb8310a0c', pr_url=None, pr_revision=None, pr_num=None)

### Generation functions

In [4]:
def sample_minute() -> int:
    """
    Sample a minute from 0 to 5 from the following distribution:
    0.5 for 0, 0.1 otherwise

    :return: random minute for immediate response
    """
    return np.random.choice([0, 1, 2, 3, 4, 5], p=[0.5, 0.1, 0.1, 0.1, 0.1, 0.1])

In [5]:
def set_input(
    conversation_history: List[str],
    time_elapsed: str,
    turn_separator: str = " <turn> ",
    immediate_dropout: float = 0.0
) -> str:
    context = f"{turn_separator}{conversation_history[0]}"
    for turn in conversation_history[1:]:
        drop = np.random.rand() < immediate_dropout
        if not drop:
            minute = sample_minute()
            suffix = "minute" if minute in [0, 1] else "minutes"
            context += f" <sep> {minute} {suffix} later"
        context += f"{turn_separator}{turn}"
    # append time-conditional sequence to the end of encoder input
    if inference_mode == "delayed":
        context += f" <sep> {time_elapsed} later{turn_separator}"
    else:
        context += f" <sep> 0 minute later{turn_separator}"

    return context

In [6]:
def batch_generate(
    model,
    conversation_histories: List[List[str]],
    time_elapseds: List[str],
    turn_separator: str = " <turn> ",
    immediate_dropout: float = 0.0
) -> List[str]:
    input_texts = [set_input(c, t, turn_separator, immediate_dropout) for c, t in zip(conversation_histories, time_elapseds)]

    inputs = tokenizer(input_texts, padding="max_length", truncation=True, return_tensors="pt").to(device)
    outputs = model.generate(inputs["input_ids"], max_new_tokens=128, temperature=1.0, top_p=.95, do_sample=True)
    responses = tokenizer.batch_decode(outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False)

    return responses, input_texts

### Load test set

In [7]:
data_name = "test_mc_taco"

In [8]:
with open(f"../resources/data/{data_name}.json") as f:
    data = json.load(f)
print(len(data))

324


### Inference

In [9]:
batch_size = 32

In [48]:
dics = []
for i in tqdm(range(0, len(data), batch_size)):
    conversations = [d["context"] for d in data[i:i+batch_size]]
    time_elapseds = [d["time_elapsed"] for d in data[i:i+batch_size]]

    model_responses, input_texts = batch_generate(model, conversations, time_elapseds)
    for d, mr, it in zip(data[i:i+batch_size], model_responses, input_texts):
        dics.append(
            {
                "context": it,
                "reference": d["delayed_response"] if inference_mode== "delayed" else d["immediate_response"],
                "model_response": mr,
            }
        )

100%|██████████| 11/11 [00:45<00:00,  4.17s/it]


### Print results and save

In [49]:
df = pd.DataFrame(dics)

In [50]:
df.head()

Unnamed: 0,context,reference,model_response
0,<turn> What are you doing right now? <sep> 0 ...,Definitely trying the seafood! Excited for the...,Have a good time on your trip!
1,<turn> Are you still out looking for Max? <se...,"Heading back in now, gonna start my bath. Real...",Done looking around the block. He must be arou...
2,<turn> Just finishing dinner. Gonna look for ...,"Lol, will do. Max was playing hide and seek wi...",What's up? Anything new?
3,<turn> What are you doing now? <sep> 1 minute...,"Thanks! I need to hurry, don't want to be late...","Thanks, I hope so too. I'll let you know if I ..."
4,<turn> Guess what? I'm on a mission to find M...,"Going to check the attic first. Hopefully, he'...",I’m off to look for Max! Wish me luck.


In [51]:
df["context"] = df["context"].str.replace(" <turn> ", "\n")
df["context"] = df["context"].str.replace(" <sep> ", "\t| ")
df["context"] = df["context"].apply(lambda x: x.strip())

In [52]:
df.head()

Unnamed: 0,context,reference,model_response
0,What are you doing right now?\t| 0 minute late...,Definitely trying the seafood! Excited for the...,Have a good time on your trip!
1,Are you still out looking for Max?\t| 4 minute...,"Heading back in now, gonna start my bath. Real...",Done looking around the block. He must be arou...
2,Just finishing dinner. Gonna look for Max befo...,"Lol, will do. Max was playing hide and seek wi...",What's up? Anything new?
3,What are you doing now?\t| 1 minute later\nJus...,"Thanks! I need to hurry, don't want to be late...","Thanks, I hope so too. I'll let you know if I ..."
4,Guess what? I'm on a mission to find Max again...,"Going to check the attic first. Hopefully, he'...",I’m off to look for Max! Wish me luck.


In [53]:
result_dir = f"../results/{backbone}/{run_name}"
os.makedirs(result_dir, exist_ok=True)
df.to_csv(os.path.join(result_dir, f"{data_name}_{inference_mode}.csv"), index=False)